Showing preview only (4,971K chars total). Download the full file or copy to clipboard to get everything.
Repository: Songluchuan/AdaSR-TalkingHead
Branch: main
Commit: 781afd1043ee
Files: 126
Total size: 4.7 MB
Directory structure:
gitextract_xyiv4zse/
├── .gitignore
├── README.md
├── animate.py
├── augmentation.py
├── config/
│ └── mix-resolution.yml
├── demo.py
├── environment.yaml
├── frames_dataset.py
├── logger.py
├── modules/
│ ├── dense_motion.py
│ ├── discriminator.py
│ ├── generator.py
│ ├── hopenet.py
│ ├── keypoint_detector.py
│ ├── model.py
│ └── util.py
├── run_demo.sh
├── sync_batchnorm/
│ ├── __init__.py
│ ├── batchnorm.py
│ ├── comm.py
│ ├── replicate.py
│ └── unittest.py
└── upsampler/
├── app_gradio.py
├── configs/
│ ├── __init__.py
│ ├── data_configs.py
│ ├── dataset_config.yml
│ ├── paths_config.py
│ └── transforms_config.py
├── criteria/
│ ├── __init__.py
│ ├── id_loss.py
│ ├── lpips/
│ │ ├── __init__.py
│ │ ├── lpips.py
│ │ ├── networks.py
│ │ └── utils.py
│ ├── moco_loss.py
│ └── w_norm.py
├── datasets/
│ ├── __init__.py
│ ├── augmentations.py
│ ├── ffhq_degradation_dataset.py
│ ├── gt_res_dataset.py
│ ├── images_dataset.py
│ └── inference_dataset.py
├── image_translation.py
├── inference_playground.ipynb
├── inversion.py
├── latent_optimization.py
├── models/
│ ├── __init__.py
│ ├── bisenet/
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── model.py
│ │ └── resnet.py
│ ├── encoders/
│ │ ├── __init__.py
│ │ ├── helpers.py
│ │ ├── model_irse.py
│ │ └── psp_encoders.py
│ ├── mtcnn/
│ │ ├── __init__.py
│ │ ├── mtcnn.py
│ │ └── mtcnn_pytorch/
│ │ ├── __init__.py
│ │ └── src/
│ │ ├── __init__.py
│ │ ├── align_trans.py
│ │ ├── box_utils.py
│ │ ├── detector.py
│ │ ├── first_stage.py
│ │ ├── get_nets.py
│ │ ├── matlab_cp2tform.py
│ │ ├── visualization_utils.py
│ │ └── weights/
│ │ ├── onet.npy
│ │ ├── pnet.npy
│ │ └── rnet.npy
│ ├── psp.py
│ └── stylegan2/
│ ├── __init__.py
│ ├── lpips/
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── dist_model.py
│ │ ├── networks_basic.py
│ │ ├── pretrained_networks.py
│ │ └── weights/
│ │ ├── v0.0/
│ │ │ ├── alex.pth
│ │ │ ├── squeeze.pth
│ │ │ └── vgg.pth
│ │ └── v0.1/
│ │ ├── alex.pth
│ │ ├── squeeze.pth
│ │ └── vgg.pth
│ ├── model.py
│ ├── op/
│ │ ├── __init__.py
│ │ ├── conv2d_gradfix.py
│ │ ├── fused_act.py
│ │ ├── readme.md
│ │ └── upfirdn2d.py
│ ├── op2/
│ │ ├── __init__.py
│ │ ├── upfirdn2d.cpp
│ │ ├── upfirdn2d.py
│ │ └── upfirdn2d_kernel.cu
│ ├── op_old/
│ │ ├── __init__.py
│ │ ├── fused_act.py
│ │ ├── fused_bias_act.cpp
│ │ ├── fused_bias_act_kernel.cu
│ │ ├── upfirdn2d.cpp
│ │ ├── upfirdn2d.py
│ │ └── upfirdn2d_kernel.cu
│ └── simple_augment.py
├── options/
│ ├── __init__.py
│ ├── test_options.py
│ └── train_options.py
├── output/
│ └── ILip77SbmOE_inversion.pt
├── pretrained_models/
│ └── readme.md
├── scripts/
│ ├── align_all_parallel.py
│ ├── calc_id_loss_parallel.py
│ ├── calc_losses_on_images.py
│ ├── download_ffhq1280.py
│ ├── generate_sketch_data.py
│ ├── inference.py
│ ├── pretrain.py
│ ├── style_mixing.py
│ └── train.py
├── training/
│ ├── __init__.py
│ ├── coach.py
│ └── ranger.py
├── utils/
│ ├── __init__.py
│ ├── common.py
│ ├── data_utils.py
│ ├── inference_utils.py
│ ├── train_utils.py
│ └── wandb_utils.py
├── video_editing.py
└── webUI/
├── app_task.py
└── styleganex_model.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
checkpoints/mix-train.pth.tar
results_hq.mp4
================================================
FILE: README.md
================================================
# Adaptive Super Resolution For One-Shot Talking-Head Generation
The repository for ICASSP2024 Adaptive Super Resolution For One-Shot Talking-Head Generation (AdaSR TalkingHead)
## Abstract
The one-shot talking-head generation learns to synthesize a talking-head video with one source portrait image under the driving of same or different identity video. Usually these methods require plane-based pixel transformations via Jacobin matrices or facial image warps for novel poses generation. The constraints of using a single image source and pixel displacements often compromise the clarity of the synthesized images. Some methods try to improve the quality of synthesized videos by introducing additional super-resolution modules, but this will undoubtedly increase computational consumption and destroy the original data distribution. In this work, we propose an adaptive high-quality talking-head video generation method, which synthesizes high-resolution video without additional pre-trained modules. Specifically, inspired by existing super-resolution methods, we down-sample the one-shot source image, and then adaptively reconstruct high-frequency details via an encoder-decoder module, resulting in enhanced video clarity. Our method consistently improves the quality of generated videos through a straightforward yet effective strategy, substantiated by quantitative and qualitative evaluations. The code and demo video are available on: https://github.com/Songluchuan/AdaSR-TalkingHead/
## Updates
- [03/2024] Inference code and pretrained model are released.
- [03/2024] Arxiv Link: https://arxiv.org/abs/2403.15944.
- [COMING] Super-resolution model (based on StyleGANEX and ESRGAN).
- [COMING] Train code and processed datasets.
## Installation
**Clone this repo:**
```bash
git clone git@github.com:Songluchuan/AdaSR-TalkingHead.git
cd AdaSR-TalkingHead
```
**Dependencies:**
We have tested on:
- CUDA 11.3-11.6
- PyTorch 1.10.1
- Matplotlib 3.4.3; Matplotlib 3.4.2; opencv-python 4.7.0; scikit-learn 1.0; tqdm 4.62.3
## Inference Code
1. Download the pretrained model on google drive: https://drive.google.com/file/d/1g58uuAyZFdny9_twvbv0AHxB9-03koko/view?usp=sharing (it is trained on the HDTF dataset), and put it under checkpoints/<br>
2. The demo video and reference image are under ```DEMO/```
3. The inference code is in the ```run_demo.sh```, please run it with
```
bash run_demo.sh
```
4. You can set different demo image and driven video in the ```run_demo.sh```
```
--source_image DEMO/demo_img_3.jpg
```
and
```
--driving_video DEMO/demo_video_1.mp4
```
## Video
<div align="center">
<a href="https://www.youtube.com/watch?v=B_-3F51QmKE" target="_blank">
<img src="media/Teaser_video.png" alt="AdaSR Talking-Head" width="1120" style="height: auto;" />
</a>
</div>
## Citation
```bibtex
@inproceedings{song2024adaptive,
title={Adaptive Super Resolution for One-Shot Talking Head Generation},
author={Song, Luchuan and Liu, Pinxin and Yin, Guojun and Xu, Chenliang},
year={2024},
organization={IEEE International Conference on Acoustics, Speech, and Signal Processing}
}
```
## Acknowledgments
The code is mainly developed based on [styleGANEX](https://github.com/williamyang1991/StyleGANEX), [ESRGAN](https://github.com/xinntao/ESRGAN) and [unofficial face2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis). Thanks to the authors contribution.
================================================
FILE: animate.py
================================================
import os
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from logger import Logger, Visualizer
import imageio
from scipy.spatial import ConvexHull
import numpy as np
from sync_batchnorm import DataParallelWithCallback
def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
use_relative_movement=False, use_relative_jacobian=False):
if adapt_movement_scale:
source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
else:
adapt_movement_scale = 1
kp_new = {k: v for k, v in kp_driving.items()}
if use_relative_movement:
kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
kp_value_diff *= adapt_movement_scale
kp_new['value'] = kp_value_diff + kp_source['value']
if use_relative_jacobian:
jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
return kp_new
================================================
FILE: augmentation.py
================================================
"""
Code from https://github.com/hassony2/torch_videovision
"""
import numbers
import random
import numpy as np
import PIL
from skimage.transform import resize, rotate
from skimage.util import pad
import torchvision
import warnings
from skimage import img_as_ubyte, img_as_float
def crop_clip(clip, min_h, min_w, h, w):
if isinstance(clip[0], np.ndarray):
cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
elif isinstance(clip[0], PIL.Image.Image):
cropped = [
img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
]
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
return cropped
def pad_clip(clip, h, w):
im_h, im_w = clip[0].shape[:2]
pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2)
pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2)
return pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge')
def resize_clip(clip, size, interpolation='bilinear'):
if isinstance(clip[0], np.ndarray):
if isinstance(size, numbers.Number):
im_h, im_w, im_c = clip[0].shape
# Min spatial dim already matches minimal size
if (im_w <= im_h and im_w == size) or (im_h <= im_w
and im_h == size):
return clip
new_h, new_w = get_resize_sizes(im_h, im_w, size)
size = (new_w, new_h)
else:
size = size[1], size[0]
scaled = [
resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True,
mode='constant', anti_aliasing=True) for img in clip
]
elif isinstance(clip[0], PIL.Image.Image):
if isinstance(size, numbers.Number):
im_w, im_h = clip[0].size
# Min spatial dim already matches minimal size
if (im_w <= im_h and im_w == size) or (im_h <= im_w
and im_h == size):
return clip
new_h, new_w = get_resize_sizes(im_h, im_w, size)
size = (new_w, new_h)
else:
size = size[1], size[0]
if interpolation == 'bilinear':
pil_inter = PIL.Image.NEAREST
else:
pil_inter = PIL.Image.BILINEAR
scaled = [img.resize(size, pil_inter) for img in clip]
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
return scaled
def get_resize_sizes(im_h, im_w, size):
if im_w < im_h:
ow = size
oh = int(size * im_h / im_w)
else:
oh = size
ow = int(size * im_w / im_h)
return oh, ow
class RandomFlip(object):
def __init__(self, time_flip=False, horizontal_flip=False):
self.time_flip = time_flip
self.horizontal_flip = horizontal_flip
def __call__(self, clip):
if random.random() < 0.5 and self.time_flip:
return clip[::-1]
if random.random() < 0.5 and self.horizontal_flip:
return [np.fliplr(img) for img in clip]
return clip
class RandomResize(object):
"""Resizes a list of (H x W x C) numpy.ndarray to the final size
The larger the original image is, the more times it takes to
interpolate
Args:
interpolation (str): Can be one of 'nearest', 'bilinear'
defaults to nearest
size (tuple): (widht, height)
"""
def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
self.ratio = ratio
self.interpolation = interpolation
def __call__(self, clip):
scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
if isinstance(clip[0], np.ndarray):
im_h, im_w, im_c = clip[0].shape
elif isinstance(clip[0], PIL.Image.Image):
im_w, im_h = clip[0].size
new_w = int(im_w * scaling_factor)
new_h = int(im_h * scaling_factor)
new_size = (new_w, new_h)
resized = resize_clip(
clip, new_size, interpolation=self.interpolation)
return resized
class RandomCrop(object):
"""Extract random crop at the same location for a list of videos
Args:
size (sequence or int): Desired output size for the
crop in format (h, w)
"""
def __init__(self, size):
if isinstance(size, numbers.Number):
size = (size, size)
self.size = size
def __call__(self, clip):
"""
Args:
img (PIL.Image or numpy.ndarray): List of videos to be cropped
in format (h, w, c) in numpy.ndarray
Returns:
PIL.Image or numpy.ndarray: Cropped list of videos
"""
h, w = self.size
if isinstance(clip[0], np.ndarray):
im_h, im_w, im_c = clip[0].shape
elif isinstance(clip[0], PIL.Image.Image):
im_w, im_h = clip[0].size
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
clip = pad_clip(clip, h, w)
im_h, im_w = clip.shape[1:3]
x1 = 0 if h == im_h else random.randint(0, im_w - w)
y1 = 0 if w == im_w else random.randint(0, im_h - h)
cropped = crop_clip(clip, y1, x1, h, w)
return cropped
class RandomRotation(object):
"""Rotate entire clip randomly by a random angle within
given bounds
Args:
degrees (sequence or int): Range of degrees to select from
If degrees is a number instead of sequence like (min, max),
the range of degrees, will be (-degrees, +degrees).
"""
def __init__(self, degrees):
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError('If degrees is a single number,'
'must be positive')
degrees = (-degrees, degrees)
else:
if len(degrees) != 2:
raise ValueError('If degrees is a sequence,'
'it must be of len 2.')
self.degrees = degrees
def __call__(self, clip):
"""
Args:
img (PIL.Image or numpy.ndarray): List of videos to be cropped
in format (h, w, c) in numpy.ndarray
Returns:
PIL.Image or numpy.ndarray: Cropped list of videos
"""
angle = random.uniform(self.degrees[0], self.degrees[1])
if isinstance(clip[0], np.ndarray):
rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip]
elif isinstance(clip[0], PIL.Image.Image):
rotated = [img.rotate(angle) for img in clip]
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
return rotated
class ColorJitter(object):
"""Randomly change the brightness, contrast and saturation and hue of the clip
Args:
brightness (float): How much to jitter brightness. brightness_factor
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
contrast (float): How much to jitter contrast. contrast_factor
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
saturation (float): How much to jitter saturation. saturation_factor
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
[-hue, hue]. Should be >=0 and <= 0.5.
"""
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
def get_params(self, brightness, contrast, saturation, hue):
if brightness > 0:
brightness_factor = random.uniform(
max(0, 1 - brightness), 1 + brightness)
else:
brightness_factor = None
if contrast > 0:
contrast_factor = random.uniform(
max(0, 1 - contrast), 1 + contrast)
else:
contrast_factor = None
if saturation > 0:
saturation_factor = random.uniform(
max(0, 1 - saturation), 1 + saturation)
else:
saturation_factor = None
if hue > 0:
hue_factor = random.uniform(-hue, hue)
else:
hue_factor = None
return brightness_factor, contrast_factor, saturation_factor, hue_factor
def __call__(self, clip):
"""
Args:
clip (list): list of PIL.Image
Returns:
list PIL.Image : list of transformed PIL.Image
"""
if isinstance(clip[0], np.ndarray):
brightness, contrast, saturation, hue = self.get_params(
self.brightness, self.contrast, self.saturation, self.hue)
# Create img transform function sequence
img_transforms = []
if brightness is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
if saturation is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
if hue is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
if contrast is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
random.shuffle(img_transforms)
img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array,
img_as_float]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
jittered_clip = []
for img in clip:
jittered_img = img
for func in img_transforms:
jittered_img = func(jittered_img)
jittered_clip.append(jittered_img.astype('float32'))
elif isinstance(clip[0], PIL.Image.Image):
brightness, contrast, saturation, hue = self.get_params(
self.brightness, self.contrast, self.saturation, self.hue)
# Create img transform function sequence
img_transforms = []
if brightness is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
if saturation is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
if hue is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
if contrast is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
random.shuffle(img_transforms)
# Apply to all videos
jittered_clip = []
for img in clip:
for func in img_transforms:
jittered_img = func(img)
jittered_clip.append(jittered_img)
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
return jittered_clip
class AllAugmentationTransform:
def __init__(self, resize_param=None, rotation_param=None, flip_param=None, crop_param=None, jitter_param=None):
self.transforms = []
if flip_param is not None:
self.transforms.append(RandomFlip(**flip_param))
if rotation_param is not None:
self.transforms.append(RandomRotation(**rotation_param))
if resize_param is not None:
self.transforms.append(RandomResize(**resize_param))
if crop_param is not None:
self.transforms.append(RandomCrop(**crop_param))
if jitter_param is not None:
self.transforms.append(ColorJitter(**jitter_param))
def __call__(self, clip):
for t in self.transforms:
clip = t(clip)
return clip
================================================
FILE: config/mix-resolution.yml
================================================
dataset_params:
root_dir: ../../../train/cropped_clips_512_vid/
frame_shape: [512, 512, 3]
id_sampling: True
pairs_list: None
augmentation_params:
flip_param:
horizontal_flip: True
time_flip: True
jitter_param:
brightness: 0.1
contrast: 0.1
saturation: 0.1
hue: 0.1
model_params:
common_params:
num_kp: 15
image_channel: 3
feature_channel: 32
estimate_jacobian: False
kp_detector_params:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25
num_blocks: 5
reshape_channel: 16384 # 16384 = 1024 * 16
reshape_depth: 16
he_estimator_params:
block_expansion: 64
max_features: 2048
num_bins: 66
generator_params:
block_expansion: 64
max_features: 512
num_down_blocks: 2
reshape_channel: 32
reshape_depth: 16 # 512 = 32 * 16
num_resblocks: 6
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 32
max_features: 1024
num_blocks: 5
# reshape_channel: 32
reshape_depth: 16
compress: 4
discriminator_params:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
sn: True
train_params:
num_epochs: 200
num_repeats: 5
num_worker: 8
epoch_milestones: [16,]
lr_generator: 2.0e-4
lr_discriminator: 2.0e-4
lr_kp_detector: 2.0e-4
lr_he_estimator: 2.0e-4
gan_mode: 'hinge' # hinge or ls
batch_size: 4
scales: [1, 0.5, 0.25, 0.125]
checkpoint_freq: 1
hopenet_snapshot: './checkpoints/hopenet_robust_alpha1.pkl'
transform_params:
sigma_affine: 0.05
sigma_tps: 0.005
points_tps: 5
loss_weights:
generator_gan: 1
discriminator_gan: 1
feature_matching: [10, 10, 10, 10]
perceptual: [10, 10, 10, 10, 10]
equivariance_value: 10
equivariance_jacobian: 0
keypoint: 10
headpose: 20
expression: 5
visualizer_params:
kp_size: 5
draw_border: True
colormap: 'gist_rainbow'
================================================
FILE: demo.py
================================================
# python demo.py --config config/vox-256-spade.yml --checkpoint checkpoints/00000189-checkpoint.pth.tar --source_image /home/cxu-serve/p61/rzhu14/lsong11_workspace/Thin-Plate-Spline-Motion-Model/assets/test.png --driving_video /home/cxu-serve/p61/rzhu14/lsong11_workspace/Thin-Plate-Spline-Motion-Model/assets/driving.mp4 --relative --adapt_scale --find_best_frame --gen spade
import matplotlib
matplotlib.use('Agg')
import os, sys
import yaml
from argparse import ArgumentParser
from tqdm import tqdm
import imageio
import numpy as np
from skimage.transform import resize
from skimage import img_as_ubyte
import torch
import torch.nn.functional as F
from sync_batchnorm import DataParallelWithCallback
from modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
from modules.keypoint_detector import KPDetector, HEEstimator
from animate import normalize_kp
from scipy.spatial import ConvexHull
if sys.version_info[0] < 3:
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
def load_checkpoints(config_path, checkpoint_path, gen, cpu=False):
with open(config_path) as f:
config = yaml.load(f)
if gen == 'original':
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
**config['model_params']['common_params'])
elif gen == 'spade':
generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
**config['model_params']['common_params'])
if not cpu:
generator.cuda()
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
**config['model_params']['common_params'])
if not cpu:
kp_detector.cuda()
he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],
**config['model_params']['common_params'])
if not cpu:
he_estimator.cuda()
if cpu:
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
else:
checkpoint = torch.load(checkpoint_path)
generator.load_state_dict(checkpoint['generator'])
kp_detector.load_state_dict(checkpoint['kp_detector'])
he_estimator.load_state_dict(checkpoint['he_estimator'])
if not cpu:
generator = DataParallelWithCallback(generator)
kp_detector = DataParallelWithCallback(kp_detector)
he_estimator = DataParallelWithCallback(he_estimator)
generator.eval()
kp_detector.eval()
he_estimator.eval()
return generator, kp_detector, he_estimator
def headpose_pred_to_degree(pred):
device = pred.device
idx_tensor = [idx for idx in range(66)]
idx_tensor = torch.FloatTensor(idx_tensor).to(device)
pred = F.softmax(pred)
degree = torch.sum(pred*idx_tensor, axis=1) * 3 - 99
return degree
'''
# beta version
def get_rotation_matrix(yaw, pitch, roll):
yaw = yaw / 180 * 3.14
pitch = pitch / 180 * 3.14
roll = roll / 180 * 3.14
roll = roll.unsqueeze(1)
pitch = pitch.unsqueeze(1)
yaw = yaw.unsqueeze(1)
roll_mat = torch.cat([torch.ones_like(roll), torch.zeros_like(roll), torch.zeros_like(roll),
torch.zeros_like(roll), torch.cos(roll), -torch.sin(roll),
torch.zeros_like(roll), torch.sin(roll), torch.cos(roll)], dim=1)
roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3)
pitch_mat = torch.cat([torch.cos(pitch), torch.zeros_like(pitch), torch.sin(pitch),
torch.zeros_like(pitch), torch.ones_like(pitch), torch.zeros_like(pitch),
-torch.sin(pitch), torch.zeros_like(pitch), torch.cos(pitch)], dim=1)
pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3)
yaw_mat = torch.cat([torch.cos(yaw), -torch.sin(yaw), torch.zeros_like(yaw),
torch.sin(yaw), torch.cos(yaw), torch.zeros_like(yaw),
torch.zeros_like(yaw), torch.zeros_like(yaw), torch.ones_like(yaw)], dim=1)
yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3)
rot_mat = torch.einsum('bij,bjk,bkm->bim', roll_mat, pitch_mat, yaw_mat)
return rot_mat
'''
def get_rotation_matrix(yaw, pitch, roll):
yaw = yaw / 180 * 3.14
pitch = pitch / 180 * 3.14
roll = roll / 180 * 3.14
roll = roll.unsqueeze(1)
pitch = pitch.unsqueeze(1)
yaw = yaw.unsqueeze(1)
pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch),
torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch),
torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1)
pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3)
yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw),
torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw),
-torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1)
yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3)
roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll),
torch.sin(roll), torch.cos(roll), torch.zeros_like(roll),
torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1)
roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3)
rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat)
return rot_mat
def keypoint_transformation(kp_canonical, he, estimate_jacobian=True, free_view=False, yaw=0, pitch=0, roll=0):
kp = kp_canonical['value']
if not free_view:
yaw, pitch, roll = he['yaw'], he['pitch'], he['roll']
yaw = headpose_pred_to_degree(yaw)
pitch = headpose_pred_to_degree(pitch)
roll = headpose_pred_to_degree(roll)
else:
if yaw is not None:
yaw = torch.tensor([yaw]).cuda()
else:
yaw = he['yaw']
yaw = headpose_pred_to_degree(yaw)
if pitch is not None:
pitch = torch.tensor([pitch]).cuda()
else:
pitch = he['pitch']
pitch = headpose_pred_to_degree(pitch)
if roll is not None:
roll = torch.tensor([roll]).cuda()
else:
roll = he['roll']
roll = headpose_pred_to_degree(roll)
t, exp = he['t'], he['exp']
rot_mat = get_rotation_matrix(yaw, pitch, roll)
# keypoint rotation
kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp)
# keypoint translation
t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1)
kp_t = kp_rotated + t
# add expression deviation
exp = exp.view(exp.shape[0], -1, 3)
kp_transformed = kp_t + exp
if estimate_jacobian:
jacobian = kp_canonical['jacobian']
jacobian_transformed = torch.einsum('bmp,bkps->bkms', rot_mat, jacobian)
else:
jacobian_transformed = None
return {'value': kp_transformed, 'jacobian': jacobian_transformed}
def make_animation(source_image, driving_video, generator, kp_detector, he_estimator, relative=True, adapt_movement_scale=True, estimate_jacobian=True, cpu=False, free_view=False, yaw=0, pitch=0, roll=0):
with torch.no_grad():
predictions = []
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
if not cpu:
source = source.cuda()
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
kp_canonical = kp_detector(source)
he_source = he_estimator(source)
he_driving_initial = he_estimator(driving[:, :, 0])
kp_source = keypoint_transformation(kp_canonical, he_source, estimate_jacobian)
kp_driving_initial = keypoint_transformation(kp_canonical, he_driving_initial, estimate_jacobian)
# kp_driving_initial = keypoint_transformation(kp_canonical, he_driving_initial, free_view=free_view, yaw=yaw, pitch=pitch, roll=roll)
for frame_idx in tqdm(range(driving.shape[2])):
driving_frame = driving[:, :, frame_idx]
if not cpu:
driving_frame = driving_frame.cuda()
he_driving = he_estimator(driving_frame)
kp_driving = keypoint_transformation(kp_canonical, he_driving, estimate_jacobian, free_view=free_view, yaw=yaw, pitch=pitch, roll=roll)
# np.save('all_kps/%05d.npy'%frame_idx, kp_driving['value'].cpu().detach().numpy())
# import pdb; pdb.set_trace()
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
use_relative_jacobian=estimate_jacobian, adapt_movement_scale=adapt_movement_scale)
out = generator(source, frame_idx, kp_source=kp_source, kp_driving=kp_norm)
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
return predictions
def find_best_frame(source, driving, cpu=False):
import face_alignment
def normalize_kp(kp):
kp = kp - kp.mean(axis=0, keepdims=True)
area = ConvexHull(kp[:, :2]).volume
area = np.sqrt(area)
kp[:, :2] = kp[:, :2] / area
return kp
# fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
# device='cpu' if cpu else 'cuda')
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=True,
device='cpu' if cpu else 'cuda')
kp_source = fa.get_landmarks(255 * source)[0]
kp_source = normalize_kp(kp_source)
norm = float('inf')
frame_num = 0
for i, image in tqdm(enumerate(driving)):
kp_driving = fa.get_landmarks(255 * image)[0]
kp_driving = normalize_kp(kp_driving)
new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
if new_norm < norm:
norm = new_norm
frame_num = i
return frame_num
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--config", default='config/vox-256.yaml', help="path to config")
parser.add_argument("--checkpoint", default='', help="path to checkpoint to restore")
parser.add_argument("--source_image", default='', help="path to source image")
parser.add_argument("--driving_video", default='', help="path to driving video")
parser.add_argument("--result_video", default='./results_hq.mp4', help="path to output")
parser.add_argument("--gen", default="spade", choices=["original", "spade"])
parser.add_argument("--relative", dest="relative", action="store_true", help="use relative or absolute keypoint coordinates")
parser.add_argument("--adapt_scale", dest="adapt_scale", action="store_true", help="adapt movement scale based on convex hull of keypoints")
parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true",
help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)")
parser.add_argument("--best_frame", dest="best_frame", type=int, default=None,
help="Set frame to start from.")
parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
parser.add_argument("--free_view", dest="free_view", action="store_true", help="control head pose")
parser.add_argument("--yaw", dest="yaw", type=int, default=None, help="yaw")
parser.add_argument("--pitch", dest="pitch", type=int, default=None, help="pitch")
parser.add_argument("--roll", dest="roll", type=int, default=None, help="roll")
parser.set_defaults(relative=False)
parser.set_defaults(adapt_scale=False)
parser.set_defaults(free_view=False)
opt = parser.parse_args()
source_image = imageio.imread(opt.source_image)
reader = imageio.get_reader(opt.driving_video)
fps = reader.get_meta_data()['fps']
driving_video = []
try:
for im in reader:
driving_video.append(im)
except RuntimeError:
pass
reader.close()
source_image = resize(source_image, (512, 512))[..., :3]
driving_video = [resize(frame, (512, 512))[..., :3] for frame in driving_video]
generator, kp_detector, he_estimator = load_checkpoints(config_path=opt.config, checkpoint_path=opt.checkpoint, gen=opt.gen, cpu=opt.cpu)
with open(opt.config) as f:
config = yaml.load(f)
estimate_jacobian = config['model_params']['common_params']['estimate_jacobian']
print(f'estimate jacobian: {estimate_jacobian}')
if opt.find_best_frame or opt.best_frame is not None:
i = opt.best_frame if opt.best_frame is not None else find_best_frame(source_image, driving_video, cpu=opt.cpu)
print ("Best frame: " + str(i))
driving_forward = driving_video[i:]
driving_backward = driving_video[:(i+1)][::-1]
predictions_forward = make_animation(source_image, driving_forward, generator, kp_detector, he_estimator, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, estimate_jacobian=estimate_jacobian, cpu=opt.cpu, free_view=opt.free_view, yaw=opt.yaw, pitch=opt.pitch, roll=opt.roll)
predictions_backward = make_animation(source_image, driving_backward, generator, kp_detector, he_estimator, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, estimate_jacobian=estimate_jacobian, cpu=opt.cpu, free_view=opt.free_view, yaw=opt.yaw, pitch=opt.pitch, roll=opt.roll)
predictions = predictions_backward[::-1] + predictions_forward[1:]
else:
predictions = make_animation(source_image, driving_video, generator, kp_detector, he_estimator, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, estimate_jacobian=estimate_jacobian, cpu=opt.cpu, free_view=opt.free_view, yaw=opt.yaw, pitch=opt.pitch, roll=opt.roll)
imageio.mimsave(opt.result_video, [img_as_ubyte(frame) for frame in predictions], fps=fps)
================================================
FILE: environment.yaml
================================================
name: mesh-video
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- blas=1.0=mkl
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2023.01.10=h06a4308_0
- certifi=2022.12.7=py38h06a4308_0
- cudatoolkit=11.3.1=h9edb442_10
- flit-core=3.8.0=py38h06a4308_0
- freetype=2.12.1=h4a9f257_0
- giflib=5.2.1=h5eee18b_3
- gmp=6.2.1=h295c915_3
- gnutls=3.6.15=he1e5248_0
- intel-openmp=2021.4.0=h06a4308_3561
- jpeg=9e=h5eee18b_1
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- lerc=3.0=h295c915_0
- libdeflate=1.17=h5eee18b_0
- libedit=3.1.20221030=h5eee18b_0
- libffi=3.2.1=hf484d3e_1007
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libidn2=2.3.2=h7f8727e_0
- libopus=1.3.1=h7b6447c_0
- libpng=1.6.39=h5eee18b_0
- libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.19.0=h5eee18b_0
- libtiff=4.5.0=h6a678d5_2
- libunistring=0.9.10=h27cfd23_0
- libuv=1.44.2=h5eee18b_0
- libvpx=1.7.0=h439df22_0
- libwebp=1.2.4=h11a3e52_1
- libwebp-base=1.2.4=h5eee18b_1
- lz4-c=1.9.4=h6a678d5_0
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py38h7f8727e_0
- mkl_fft=1.3.1=py38hd3c417c_0
- mkl_random=1.2.2=py38h51133e4_0
- ncurses=6.4=h6a678d5_0
- nettle=3.7.3=hbbd107a_1
- numpy-base=1.23.5=py38h31eccc5_0
- openh264=2.1.1=h4ff587b_0
- openssl=1.1.1t=h7f8727e_0
- pillow=9.4.0=py38h6a678d5_0
- pip=23.0.1=py38h06a4308_0
- python=3.8.0=h0371630_2
- pytorch=1.10.1=py3.8_cuda11.3_cudnn8.2.0_0
- pytorch-mutex=1.0=cuda
- readline=7.0=h7b6447c_5
- setuptools=65.6.3=py38h06a4308_0
- six=1.16.0=pyhd3eb1b0_1
- sqlite=3.33.0=h62c20be_0
- tk=8.6.12=h1ccaba5_0
- torchaudio=0.10.1=py38_cu113
- torchvision=0.11.2=py38_cu113
- typing_extensions=4.4.0=py38h06a4308_0
- wheel=0.38.4=py38h06a4308_0
- x264=1!157.20191217=h7b6447c_0
- xz=5.2.10=h5eee18b_1
- zlib=1.2.13=h5eee18b_0
- zstd=1.5.4=hc292b87_0
- pip:
- cffi==1.14.6
- cycler==0.10.0
- decorator==5.1.0
- face-alignment==1.3.5
- ffmpeg==1.4
- imageio==2.9.0
- imageio-ffmpeg==0.4.5
- importlib-metadata==6.0.0
- joblib==1.2.0
- kiwisolver==1.3.2
- llvmlite==0.39.1
- matplotlib==3.4.3
- networkx==2.6.3
- numba==0.56.4
- numpy==1.20.3
- nvidia-cublas-cu11==11.10.3.66
- nvidia-cuda-nvrtc-cu11==11.7.99
- nvidia-cuda-runtime-cu11==11.7.99
- nvidia-cudnn-cu11==8.5.0.96
- opencv-python==4.7.0.72
- pandas==1.3.3
- pycparser==2.20
- pyparsing==2.4.7
- python-dateutil==2.8.2
- pytube==12.1.3
- pytz==2021.1
- pywavelets==1.1.1
- pyyaml==5.4.1
- scikit-image==0.18.3
- scikit-learn==1.0
- scipy==1.7.1
- threadpoolctl==3.1.0
- tifffile==2023.2.28
- torch==1.13.1
- tqdm==4.62.3
- typing-extensions==4.5.0
- zipp==3.15.0
prefix: /home/songlc/miniconda3/envs/mesh-video
================================================
FILE: frames_dataset.py
================================================
#CUDA_VISIBLE_DEVICES=1 python run.py --config log_TH1K/finetune-th1k-spade.yml --device_ids 0 --checkpoint log_TH1K/00000001-checkpoint.pth.tar
import os
from skimage import io, img_as_float32
from skimage.color import gray2rgb
from sklearn.model_selection import train_test_split
from imageio import mimread
from functools import partial
from skimage.transform import resize
import torch
import random
import numpy as np
from torch.utils.data import Dataset
import pandas as pd
from augmentation import AllAugmentationTransform
import glob
import math
import pickle
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
def read_video(name, frame_shape):
"""
Read video which can be:
- an image of concatenated frames
- '.mp4' and'.gif'
- folder with videos
"""
if os.path.isdir(name):
frames = sorted(os.listdir(name))
num_frames = len(frames)
video_array = np.array(
[img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)])
elif name.lower().endswith('.png') or name.lower().endswith('.jpg'):
image = io.imread(name)
if len(image.shape) == 2 or image.shape[2] == 1:
image = gray2rgb(image)
if image.shape[2] == 4:
image = image[..., :3]
image = img_as_float32(image)
video_array = np.moveaxis(image, 1, 0)
video_array = video_array.reshape((-1,) + frame_shape)
video_array = np.moveaxis(video_array, 1, 2)
elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'):
video = np.array(mimread(name))
if len(video.shape) == 3:
video = np.array([gray2rgb(frame) for frame in video])
if video.shape[-1] == 4:
video = video[..., :3]
video_array = img_as_float32(video)
else:
raise Exception("Unknown file extensions %s" % name)
return video_array
class FramesDataset(Dataset):
"""
Dataset of videos, each video can be represented as:
- an image of concatenated frames
- '.mp4' or '.gif'
- folder with all frames
"""
def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
random_seed=0, pairs_list=None, augmentation_params=None):
self.root_dir = root_dir
tmp_file = open(root_dir + 'train_file_list.pickle','rb')
self.train_files_list = pickle.load(tmp_file)
self.videos = os.listdir(root_dir)
self.frame_shape = tuple(frame_shape)
self.pairs_list = pairs_list
self.id_sampling = id_sampling
if os.path.exists(os.path.join(root_dir, 'train')):
assert os.path.exists(os.path.join(root_dir, 'test'))
print("Use predefined train-test split.")
if id_sampling:
# train_videos = {os.path.basename(video).split('#')[0] for video in
# os.listdir(os.path.join(root_dir, 'train'))}
# train_videos = list(train_videos)
train_videos = list(self.train_files_list.keys())
else:
train_videos = os.listdir(os.path.join(root_dir, 'train'))
test_videos = os.listdir(os.path.join(root_dir, 'test'))
self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
else:
print("Use random train-test split.")
train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)
if is_train:
self.videos = train_videos
else:
self.videos = test_videos
self.is_train = is_train
if self.is_train:
self.transform = AllAugmentationTransform(**augmentation_params)
#### for degradation ####
self.kernel_range = [2 * v + 1 for v in range(1,3)]
self.pulse_tensor = torch.zeros(11, 11).float()
self.pulse_tensor[5, 5] = 1
self.resize_range = [0.15, 1.5]
# blur settings for the first degradation
self.blur_kernel_size = 7
self.kernel_list = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
self.kernel_prob = [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] # a list for each kernel probability
self.blur_sigma = [0.1, 0.5]
self.betag_range = [0.2, 1] # betag used in generalized Gaussian blur kernels
self.betap_range = [0.5, 1.2] # betap used in plateau blur kernels
self.sinc_prob = 0.1 # the probability for sinc filters
# blur settings for the second degradation
self.blur_kernel_size2 = 7
self.kernel_list2 = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
self.kernel_prob2 = [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
self.blur_sigma2 = [0.1, 0.5]
self.betag_range2 = [0.2, 1]
self.betap_range2 = [1, 1.2]
self.sinc_prob2 = 0.1
else:
self.transform = None
def __len__(self):
return len(self.videos)
def __getitem__(self, idx):
if self.is_train and self.id_sampling:
# name = self.videos[idx]
# path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))
name = self.videos[idx]
choice_list = self.train_files_list[name]
# if len(choice_list) == 0:
# name = self.videos[idx-1]
# choice_list = self.train_files_list[name]
paths = np.random.choice(choice_list)
else:
name = self.videos[idx]
paths = os.path.join(self.root_dir, name)
video_name = os.path.basename(paths)
if self.is_train and os.path.isdir(paths):
frames = os.listdir(paths)
num_frames = len(frames)
frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2))
if self.frame_shape is not None:
resize_fn = partial(resize, output_shape=self.frame_shape)
else:
resize_fn = img_as_float32
video_array = [resize_fn(img_as_float32(io.imread(paths + '/' + '%06d.jpg'%(idx) ))) for idx in frame_idx]
else:
video_array = read_video(paths, frame_shape=self.frame_shape)
num_frames = len(video_array)
frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.is_train else range(
num_frames)
video_array = video_array[frame_idx]
if self.transform is not None:
video_array = self.transform(video_array)
out = {}
if self.is_train:
source = np.array(video_array[0], dtype='float32')
driving = np.array(video_array[1], dtype='float32')
out['driving'] = driving.transpose((2, 0, 1))
out['source'] = source.transpose((2, 0, 1))
# if self.degradation:
############ run degradation ############
# ---- Generate kernels (used in the first degradation) ---- #
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < 0.1:
# this sinc filter setting is for kernels ranging from [7, 21]
if kernel_size < 11:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel = random_mixed_kernels(
self.kernel_list,
self.kernel_prob,
kernel_size,
self.blur_sigma,
self.blur_sigma, [-math.pi, math.pi],
self.betag_range,
self.betap_range,
noise_range=None)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
# ----- Generate kernels (used in the second degradation) ---- #
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < 0.1:
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel2 = random_mixed_kernels(
self.kernel_list2,
self.kernel_prob2,
kernel_size,
self.blur_sigma2,
self.blur_sigma2, [-math.pi, math.pi],
self.betag_range2,
self.betap_range2,
noise_range=None)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
# ---- the final sinc kernel ---- #
if np.random.uniform() < 0.8:
kernel_size = random.choice(self.kernel_range)
omega_c = np.random.uniform(np.pi / 3, np.pi)
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=11)
sinc_kernel = torch.FloatTensor(sinc_kernel)
else:
sinc_kernel = self.pulse_tensor
# BGR to RGB, HWC to CHW, numpy to tensor
# img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
kernel = torch.FloatTensor(kernel)
kernel2 = torch.FloatTensor(kernel2)
#########################################
out['kernel'] = kernel
out['kernel2']= kernel2
out['sinc_kernel'] = sinc_kernel
else:
video = np.array(video_array, dtype='float32')
out['video'] = video.transpose((3, 0, 1, 2))
out['name'] = video_name
return out
class DatasetRepeater(Dataset):
"""
Pass several times over the same dataset for better i/o performance
"""
def __init__(self, dataset, num_repeats=100):
self.dataset = dataset
self.num_repeats = num_repeats
def __len__(self):
return self.num_repeats * self.dataset.__len__()
def __getitem__(self, idx):
return self.dataset[idx % self.dataset.__len__()]
================================================
FILE: logger.py
================================================
import numpy as np
import torch
import torch.nn.functional as F
import imageio
import os
from skimage.draw import circle
import matplotlib.pyplot as plt
import collections
class Logger:
def __init__(self, log_dir, checkpoint_freq=100, visualizer_params=None, zfill_num=8, log_file_name='log.txt'):
self.loss_list = []
self.cpk_dir = log_dir
self.visualizations_dir = os.path.join(log_dir, 'train-vis')
if not os.path.exists(self.visualizations_dir):
os.makedirs(self.visualizations_dir)
self.log_file = open(os.path.join(log_dir, log_file_name), 'a')
self.zfill_num = zfill_num
self.visualizer = Visualizer(**visualizer_params)
self.checkpoint_freq = checkpoint_freq
self.epoch = 0
self.best_loss = float('inf')
self.names = None
def log_scores(self, loss_names):
loss_mean = np.array(self.loss_list).mean(axis=0)
loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)])
loss_string = str(self.epoch).zfill(self.zfill_num) + ") " + loss_string
print(loss_string, file=self.log_file)
self.loss_list = []
self.log_file.flush()
def visualize_rec(self, inp, out):
image = self.visualizer.visualize(inp['driving'], inp['source'], out)
imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)), image)
def save_cpk(self, emergent=False):
cpk = {k: v.state_dict() for k, v in self.models.items()}
cpk['epoch'] = self.epoch
cpk_path = os.path.join(self.cpk_dir, '%s-checkpoint.pth.tar' % str(self.epoch + 1).zfill(self.zfill_num))
if not (os.path.exists(cpk_path) and emergent):
torch.save(cpk, cpk_path)
@staticmethod
def load_cpk(checkpoint_path, generator=None, discriminator=None, kp_detector=None, he_estimator=None,
optimizer_generator=None, optimizer_discriminator=None, optimizer_kp_detector=None, optimizer_he_estimator=None):
checkpoint = torch.load(checkpoint_path)
if generator is not None:
generator.load_state_dict(checkpoint['generator'])
if kp_detector is not None:
kp_detector.load_state_dict(checkpoint['kp_detector'])
if he_estimator is not None:
he_estimator.load_state_dict(checkpoint['he_estimator'])
if discriminator is not None:
try:
discriminator.load_state_dict(checkpoint['discriminator'])
except:
print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
if optimizer_generator is not None:
optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
if optimizer_discriminator is not None:
try:
optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
except RuntimeError as e:
print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
if optimizer_kp_detector is not None:
optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
if optimizer_he_estimator is not None:
optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator'])
return checkpoint['epoch']
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if 'models' in self.__dict__:
self.save_cpk()
self.log_file.close()
def log_iter(self, losses):
losses = collections.OrderedDict(losses.items())
if self.names is None:
self.names = list(losses.keys())
self.loss_list.append(list(losses.values()))
def log_epoch(self, epoch, models, inp, out):
self.epoch = epoch
self.models = models
if (self.epoch + 1) % self.checkpoint_freq == 0:
self.save_cpk()
self.log_scores(self.names)
self.visualize_rec(inp, out)
class Visualizer:
def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow'):
self.kp_size = kp_size
self.draw_border = draw_border
self.colormap = plt.get_cmap(colormap)
def draw_image_with_kp(self, image, kp_array):
image = np.copy(image)
spatial_size = np.array(image.shape[:2][::-1])[np.newaxis]
kp_array = spatial_size * (kp_array + 1) / 2
num_kp = kp_array.shape[0]
for kp_ind, kp in enumerate(kp_array):
rr, cc = circle(kp[1], kp[0], self.kp_size, shape=image.shape[:2])
image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3]
return image
def create_image_column_with_kp(self, images, kp):
image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)])
return self.create_image_column(image_array)
def create_image_column(self, images):
if self.draw_border:
images = np.copy(images)
images[:, :, [0, -1]] = (1, 1, 1)
images[:, :, [0, -1]] = (1, 1, 1)
return np.concatenate(list(images), axis=0)
def create_image_grid(self, *args):
out = []
for arg in args:
if type(arg) == tuple:
out.append(self.create_image_column_with_kp(arg[0], arg[1]))
else:
out.append(self.create_image_column(arg))
return np.concatenate(out, axis=1)
def visualize(self, driving, source, out):
images = []
# Source image with keypoints
source = source.data.cpu()
kp_source = out['kp_source']['value'][:, :, :2].data.cpu().numpy() # 3d -> 2d
source = np.transpose(source, [0, 2, 3, 1])
images.append((source, kp_source))
# Equivariance visualization
if 'transformed_frame' in out:
transformed = out['transformed_frame'].data.cpu().numpy()
transformed = np.transpose(transformed, [0, 2, 3, 1])
transformed_kp = out['transformed_kp']['value'][:, :, :2].data.cpu().numpy() # 3d -> 2d
images.append((transformed, transformed_kp))
# Driving image with keypoints
kp_driving = out['kp_driving']['value'][:, :, :2].data.cpu().numpy() # 3d -> 2d
driving = driving.data.cpu().numpy()
driving = np.transpose(driving, [0, 2, 3, 1])
images.append((driving, kp_driving))
# Result
prediction = out['prediction'].data.cpu().numpy()
prediction = np.transpose(prediction, [0, 2, 3, 1])
images.append(prediction)
## Occlusion map
if 'occlusion_map' in out:
occlusion_map = out['occlusion_map'].data.cpu().repeat(1, 3, 1, 1)
occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy()
occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1])
images.append(occlusion_map)
## Mask
if 'mask' in out:
for i in range(out['mask'].shape[1]):
mask = out['mask'][:, i:(i+1)].data.cpu().sum(2).repeat(1, 3, 1, 1) # (n, 3, h, w)
# mask = F.softmax(mask.view(mask.shape[0], mask.shape[1], -1), dim=2).view(mask.shape)
mask = F.interpolate(mask, size=source.shape[1:3]).numpy()
mask = np.transpose(mask, [0, 2, 3, 1])
if i != 0:
color = np.array(self.colormap((i - 1) / (out['mask'].shape[1] - 1)))[:3]
else:
color = np.array((0, 0, 0))
color = color.reshape((1, 1, 1, 3))
if i != 0:
images.append(mask * color)
else:
images.append(mask)
image = self.create_image_grid(*images)
image = (255 * image).astype(np.uint8)
return image
================================================
FILE: modules/dense_motion.py
================================================
from torch import nn
import torch.nn.functional as F
import torch
from modules.util import Hourglass, make_coordinate_grid, kp2gaussian
from sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d
class DenseMotionNetwork(nn.Module):
"""
Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
"""
def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress,
estimate_occlusion_map=False):
super(DenseMotionNetwork, self).__init__()
# self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(feature_channel+1), max_features=max_features, num_blocks=num_blocks)
self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks)
self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3)
self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1)
self.norm = BatchNorm3d(compress, affine=True)
if estimate_occlusion_map:
# self.occlusion = nn.Conv2d(reshape_channel*reshape_depth, 1, kernel_size=7, padding=3)
self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3)
else:
self.occlusion = None
self.num_kp = num_kp
def create_sparse_motions(self, feature, kp_driving, kp_source):
bs, _, d, h, w = feature.shape
identity_grid = make_coordinate_grid((d, h, w), type=kp_source['value'].type())
identity_grid = identity_grid.view(1, 1, d, h, w, 3)
coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 1, 3)
k = coordinate_grid.shape[1]
# if 'jacobian' in kp_driving:
if 'jacobian' in kp_driving and kp_driving['jacobian'] is not None:
jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))
jacobian = jacobian.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3)
jacobian = jacobian.repeat(1, 1, d, h, w, 1, 1)
coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))
coordinate_grid = coordinate_grid.squeeze(-1)
'''
if 'rot' in kp_driving:
rot_s = kp_source['rot']
rot_d = kp_driving['rot']
rot = torch.einsum('bij, bjk->bki', rot_s, torch.inverse(rot_d))
rot = rot.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3).unsqueeze(-3)
rot = rot.repeat(1, k, d, h, w, 1, 1)
# print(rot.shape)
coordinate_grid = torch.matmul(rot, coordinate_grid.unsqueeze(-1))
coordinate_grid = coordinate_grid.squeeze(-1)
# print(coordinate_grid.shape)
'''
driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3)
#adding background feature
identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1)
sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1)
# sparse_motions = driving_to_source
return sparse_motions
def create_deformed_feature(self, feature, sparse_motions):
bs, _, d, h, w = feature.shape
feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w)
feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w)
sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3)
sparse_deformed = F.grid_sample(feature_repeat, sparse_motions)
sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w)
return sparse_deformed
def create_heatmap_representations(self, feature, kp_driving, kp_source):
spatial_size = feature.shape[3:]
gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01)
gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01)
heatmap = gaussian_driving - gaussian_source
# adding background feature
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type())
heatmap = torch.cat([zeros, heatmap], dim=1)
heatmap = heatmap.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
return heatmap
def forward(self, feature, kp_driving, kp_source):
bs, _, d, h, w = feature.shape
feature = self.compress(feature)
feature = self.norm(feature)
feature = F.relu(feature)
out_dict = dict()
sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source)
deformed_feature = self.create_deformed_feature(feature, sparse_motion)
heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source)
input = torch.cat([heatmap, deformed_feature], dim=2)
input = input.view(bs, -1, d, h, w)
# input = deformed_feature.view(bs, -1, d, h, w) # (bs, num_kp+1 * c, d, h, w)
prediction = self.hourglass(input)
mask = self.mask(prediction)
mask = F.softmax(mask, dim=1)
out_dict['mask'] = mask
mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w)
deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w)
deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3)
out_dict['deformation'] = deformation
if self.occlusion:
bs, c, d, h, w = prediction.shape
prediction = prediction.view(bs, -1, h, w)
occlusion_map = torch.sigmoid(self.occlusion(prediction))
out_dict['occlusion_map'] = occlusion_map
return out_dict
================================================
FILE: modules/discriminator.py
================================================
from torch import nn
import torch.nn.functional as F
from modules.util import kp2gaussian
import torch
class DownBlock2d(nn.Module):
"""
Simple block for processing video (encoder).
"""
def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
super(DownBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
if sn:
self.conv = nn.utils.spectral_norm(self.conv)
if norm:
self.norm = nn.InstanceNorm2d(out_features, affine=True)
else:
self.norm = None
self.pool = pool
def forward(self, x):
out = x
out = self.conv(out)
if self.norm:
out = self.norm(out)
out = F.leaky_relu(out, 0.2)
if self.pool:
out = F.avg_pool2d(out, (2, 2))
return out
class Discriminator(nn.Module):
"""
Discriminator similar to Pix2Pix
"""
def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
sn=False, **kwargs):
super(Discriminator, self).__init__()
down_blocks = []
for i in range(num_blocks):
down_blocks.append(
DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)),
min(max_features, block_expansion * (2 ** (i + 1))),
norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
self.down_blocks = nn.ModuleList(down_blocks)
self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
if sn:
self.conv = nn.utils.spectral_norm(self.conv)
def forward(self, x):
feature_maps = []
out = x
for down_block in self.down_blocks:
feature_maps.append(down_block(out))
out = feature_maps[-1]
prediction_map = self.conv(out)
return feature_maps, prediction_map
class MultiScaleDiscriminator(nn.Module):
"""
Multi-scale (scale) discriminator
"""
def __init__(self, scales=(), **kwargs):
super(MultiScaleDiscriminator, self).__init__()
self.scales = scales
discs = {}
for scale in scales:
discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
self.discs = nn.ModuleDict(discs)
def forward(self, x):
out_dict = {}
for scale, disc in self.discs.items():
scale = str(scale).replace('-', '.')
key = 'prediction_' + scale
feature_maps, prediction_map = disc(x[key])
out_dict['feature_maps_' + scale] = feature_maps
out_dict['prediction_map_' + scale] = prediction_map
return out_dict
================================================
FILE: modules/generator.py
================================================
import torch
from torch import nn
import torch.nn.functional as F
from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d, ResBlock3d, SPADEResnetBlock
from modules.dense_motion import DenseMotionNetwork
import torchvision
class OcclusionAwareGenerator(nn.Module):
"""
Generator follows NVIDIA architecture.
"""
def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth,
num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
super(OcclusionAwareGenerator, self).__init__()
if dense_motion_params is not None:
self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel,
estimate_occlusion_map=estimate_occlusion_map,
**dense_motion_params)
else:
self.dense_motion_network = None
self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(7, 7), padding=(3, 3))
down_blocks = []
for i in range(num_down_blocks):
in_features = min(max_features, block_expansion * (2 ** i))
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
self.down_blocks = nn.ModuleList(down_blocks)
self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)
self.reshape_channel = reshape_channel
self.reshape_depth = reshape_depth
self.resblocks_3d = torch.nn.Sequential()
for i in range(num_resblocks):
self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))
out_features = block_expansion * (2 ** (num_down_blocks))
self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True)
self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1)
self.resblocks_2d = torch.nn.Sequential()
for i in range(num_resblocks):
self.resblocks_2d.add_module('2dr' + str(i), ResBlock2d(out_features, kernel_size=3, padding=1))
up_blocks = []
for i in range(num_down_blocks):
in_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i)))
out_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i - 1)))
up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
self.up_blocks = nn.ModuleList(up_blocks)
self.final = nn.Conv2d(block_expansion, image_channel, kernel_size=(7, 7), padding=(3, 3))
self.estimate_occlusion_map = estimate_occlusion_map
self.image_channel = image_channel
def deform_input(self, inp, deformation):
_, d_old, h_old, w_old, _ = deformation.shape
_, _, d, h, w = inp.shape
if d_old != d or h_old != h or w_old != w:
deformation = deformation.permute(0, 4, 1, 2, 3)
deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear')
deformation = deformation.permute(0, 2, 3, 4, 1)
return F.grid_sample(inp, deformation)
def forward(self, source_image, kp_driving, kp_source):
# Encoding (downsampling) part
out = self.first(source_image)
for i in range(len(self.down_blocks)):
out = self.down_blocks[i](out)
out = self.second(out)
bs, c, h, w = out.shape
# print(out.shape)
feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w)
feature_3d = self.resblocks_3d(feature_3d)
# Transforming feature representation according to deformation and occlusion
output_dict = {}
if self.dense_motion_network is not None:
dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving,
kp_source=kp_source)
output_dict['mask'] = dense_motion['mask']
if 'occlusion_map' in dense_motion:
occlusion_map = dense_motion['occlusion_map']
output_dict['occlusion_map'] = occlusion_map
else:
occlusion_map = None
deformation = dense_motion['deformation']
out = self.deform_input(feature_3d, deformation)
bs, c, d, h, w = out.shape
out = out.view(bs, c*d, h, w)
out = self.third(out)
out = self.fourth(out)
if occlusion_map is not None:
if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
out = out * occlusion_map
# output_dict["deformed"] = self.deform_input(source_image, deformation) # 3d deformation cannot deform 2d image
# Decoding part
out = self.resblocks_2d(out)
for i in range(len(self.up_blocks)):
out = self.up_blocks[i](out)
out = self.final(out)
out = F.sigmoid(out)
output_dict["prediction"] = out
return output_dict
class SPADEDecoder(nn.Module):
def __init__(self):
super().__init__()
ic = 256
oc = 64
norm_G = 'spadespectralinstance'
label_nc = 256
self.fc = nn.Conv2d(ic, 2 * ic, 3, padding=1)
self.G_middle_0 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
self.G_middle_1 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
self.G_middle_2 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
self.G_middle_3 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
self.G_middle_4 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
self.G_middle_5 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
self.up_0 = SPADEResnetBlock(2 * ic, ic, norm_G, label_nc)
self.up_1 = SPADEResnetBlock(ic, oc, norm_G, label_nc)
self.conv_img = nn.Conv2d(oc, 3, 3, padding=1)
self.up = nn.Upsample(scale_factor=2)
def forward(self, feature):
seg = feature
x = self.fc(feature)
x = self.G_middle_0(x, seg)
x = self.G_middle_1(x, seg)
x = self.G_middle_2(x, seg)
x = self.G_middle_3(x, seg)
x = self.G_middle_4(x, seg)
x = self.G_middle_5(x, seg)
x = self.up(x)
x = self.up_0(x, seg) # 256, 128, 128
x = self.up(x)
x = self.up_1(x, seg) # 64, 256, 256
x = self.conv_img(F.leaky_relu(x, 2e-1))
# x = torch.tanh(x)
x = F.sigmoid(x)
return x
class OcclusionAwareSPADEGenerator(nn.Module):
def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth,
num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
super(OcclusionAwareSPADEGenerator, self).__init__()
if dense_motion_params is not None:
self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel,
estimate_occlusion_map=estimate_occlusion_map,
**dense_motion_params)
else:
self.dense_motion_network = None
self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1))
down_blocks = []
for i in range(num_down_blocks):
in_features = min(max_features, block_expansion * (2 ** i))
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
self.down_blocks = nn.ModuleList(down_blocks)
self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)
self.reshape_channel = reshape_channel
self.reshape_depth = reshape_depth
self.resblocks_3d = torch.nn.Sequential()
for i in range(num_resblocks):
self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))
out_features = block_expansion * (2 ** (num_down_blocks))
self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True)
self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1)
self.estimate_occlusion_map = estimate_occlusion_map
self.image_channel = image_channel
self.decoder = SPADEDecoder()
def deform_input(self, inp, deformation):
_, d_old, h_old, w_old, _ = deformation.shape
_, _, d, h, w = inp.shape
if d_old != d or h_old != h or w_old != w:
deformation = deformation.permute(0, 4, 1, 2, 3)
deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear')
deformation = deformation.permute(0, 2, 3, 4, 1)
return F.grid_sample(inp, deformation)
def forward(self, source_image, frame_idx, kp_driving, kp_source):
# Encoding (downsampling) part
# import pdb; pdb.set_trace()
# import torchvision.utils as vutils, torchvision.utils.save_image(feature_3d[0][1:2,:3,], 'feature.png')
out = self.first(source_image)
# torchvision.utils.save_image(out[0][:1,], 'ablation_features/feature_1_%05d.png'%frame_idx)
for i in range(len(self.down_blocks)):
out = self.down_blocks[i](out)
# torchvision.utils.save_image(out[0][:1,], 'ablation_features/feature_2_%05d.png'%frame_idx)
out = self.second(out)
# torchvision.utils.save_image(out[0][:1,], 'ablation_features/feature_3_%05d.png'%frame_idx)
bs, c, h, w = out.shape
# print(out.shape)
feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w)
feature_3d = self.resblocks_3d(feature_3d)
# torchvision.utils.save_image(feature_3d[0][1:2,:1,], 'ablation_features/feature_4_%05d.png'%frame_idx)
# Transforming feature representation according to deformation and occlusion
output_dict = {}
if self.dense_motion_network is not None:
dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving,
kp_source=kp_source)
output_dict['mask'] = dense_motion['mask']
if 'occlusion_map' in dense_motion:
occlusion_map = dense_motion['occlusion_map']
output_dict['occlusion_map'] = occlusion_map
else:
occlusion_map = None
deformation = dense_motion['deformation']
out = self.deform_input(feature_3d, deformation)
# torchvision.utils.save_image(out[0][1:2,:1,], 'ablation_features/feature_5_%05d.png'%frame_idx)
bs, c, d, h, w = out.shape
out = out.view(bs, c*d, h, w)
out = self.third(out)
# torchvision.utils.save_image(out[:1,:1,], 'ablation_features/feature_6_%05d.png'%frame_idx)
out = self.fourth(out)
# torchvision.utils.save_image(out[:1,:1,], 'ablation_features/feature_7_%05d.png'%frame_idx)
if occlusion_map is not None:
if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
out = out * occlusion_map
# Decoding part
# torchvision.utils.save_image(out[:1,:1,], 'ablation_features/feature_8_%05d.png'%frame_idx)
out = self.decoder(out)
# torchvision.utils.save_image(out[:1,:1,], 'ablation_features/feature_9_%05d.png'%frame_idx)
output_dict["prediction"] = out
#
return output_dict
================================================
FILE: modules/hopenet.py
================================================
import torch
import torch.nn as nn
from torch.autograd import Variable
import math
import torch.nn.functional as F
class Hopenet(nn.Module):
# Hopenet with 3 output layers for yaw, pitch and roll
# Predicts Euler angles by binning and regression with the expected value
def __init__(self, block, layers, num_bins):
self.inplanes = 64
super(Hopenet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7)
self.fc_yaw = nn.Linear(512 * block.expansion, num_bins)
self.fc_pitch = nn.Linear(512 * block.expansion, num_bins)
self.fc_roll = nn.Linear(512 * block.expansion, num_bins)
# Vestigial layer from previous experiments
self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
pre_yaw = self.fc_yaw(x)
pre_pitch = self.fc_pitch(x)
pre_roll = self.fc_roll(x)
return pre_yaw, pre_pitch, pre_roll
class ResNet(nn.Module):
# ResNet for regression of 3 Euler angles.
def __init__(self, block, layers, num_classes=1000):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7)
self.fc_angles = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc_angles(x)
return x
class AlexNet(nn.Module):
# AlexNet laid out as a Hopenet - classify Euler angles in bins and
# regress the expected value.
def __init__(self, num_bins):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
)
self.fc_yaw = nn.Linear(4096, num_bins)
self.fc_pitch = nn.Linear(4096, num_bins)
self.fc_roll = nn.Linear(4096, num_bins)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), 256 * 6 * 6)
x = self.classifier(x)
yaw = self.fc_yaw(x)
pitch = self.fc_pitch(x)
roll = self.fc_roll(x)
return yaw, pitch, roll
================================================
FILE: modules/keypoint_detector.py
================================================
from torch import nn
import torch
import torch.nn.functional as F
from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
from modules.util import KPHourglass, make_coordinate_grid, AntiAliasInterpolation2d, ResBottleneck
class KPDetector(nn.Module):
"""
Detecting canonical keypoints. Return keypoint position and jacobian near each keypoint.
"""
def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, reshape_channel, reshape_depth,
num_blocks, temperature, estimate_jacobian=False, scale_factor=1, single_jacobian_map=False):
super(KPDetector, self).__init__()
self.predictor = KPHourglass(block_expansion, in_features=image_channel,
max_features=max_features, reshape_features=reshape_channel, reshape_depth=reshape_depth, num_blocks=num_blocks)
# self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=7, padding=3)
self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=3, padding=1)
if estimate_jacobian:
self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
# self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=7, padding=3)
self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=3, padding=1)
'''
initial as:
[[1 0 0]
[0 1 0]
[0 0 1]]
'''
self.jacobian.weight.data.zero_()
self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))
else:
self.jacobian = None
self.temperature = temperature
self.scale_factor = scale_factor
if self.scale_factor != 1:
self.down = AntiAliasInterpolation2d(image_channel, self.scale_factor)
def gaussian2kp(self, heatmap):
"""
Extract the mean from a heatmap
"""
shape = heatmap.shape
heatmap = heatmap.unsqueeze(-1)
grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0)
value = (heatmap * grid).sum(dim=(2, 3, 4))
kp = {'value': value}
return kp
def forward(self, x):
if self.scale_factor != 1:
x = self.down(x)
feature_map = self.predictor(x)
prediction = self.kp(feature_map)
final_shape = prediction.shape
heatmap = prediction.view(final_shape[0], final_shape[1], -1)
heatmap = F.softmax(heatmap / self.temperature, dim=2)
heatmap = heatmap.view(*final_shape)
out = self.gaussian2kp(heatmap)
if self.jacobian is not None:
jacobian_map = self.jacobian(feature_map)
jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 9, final_shape[2],
final_shape[3], final_shape[4])
heatmap = heatmap.unsqueeze(2)
jacobian = heatmap * jacobian_map
jacobian = jacobian.view(final_shape[0], final_shape[1], 9, -1)
jacobian = jacobian.sum(dim=-1)
jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 3, 3)
out['jacobian'] = jacobian
return out
class HEEstimator(nn.Module):
"""
Estimating head pose and expression.
"""
def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, num_bins=66, estimate_jacobian=True):
super(HEEstimator, self).__init__()
self.conv1 = nn.Conv2d(in_channels=image_channel, out_channels=block_expansion, kernel_size=7, padding=3, stride=2)
self.norm1 = BatchNorm2d(block_expansion, affine=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv2d(in_channels=block_expansion, out_channels=256, kernel_size=1)
self.norm2 = BatchNorm2d(256, affine=True)
self.block1 = nn.Sequential()
for i in range(3):
self.block1.add_module('b1_'+ str(i), ResBottleneck(in_features=256, stride=1))
self.conv3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1)
self.norm3 = BatchNorm2d(512, affine=True)
self.block2 = ResBottleneck(in_features=512, stride=2)
self.block3 = nn.Sequential()
for i in range(3):
self.block3.add_module('b3_'+ str(i), ResBottleneck(in_features=512, stride=1))
self.conv4 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1)
self.norm4 = BatchNorm2d(1024, affine=True)
self.block4 = ResBottleneck(in_features=1024, stride=2)
self.block5 = nn.Sequential()
for i in range(5):
self.block5.add_module('b5_'+ str(i), ResBottleneck(in_features=1024, stride=1))
self.conv5 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=1)
self.norm5 = BatchNorm2d(2048, affine=True)
self.block6 = ResBottleneck(in_features=2048, stride=2)
self.block7 = nn.Sequential()
for i in range(2):
self.block7.add_module('b7_'+ str(i), ResBottleneck(in_features=2048, stride=1))
self.fc_roll = nn.Linear(2048, num_bins)
self.fc_pitch = nn.Linear(2048, num_bins)
self.fc_yaw = nn.Linear(2048, num_bins)
self.fc_t = nn.Linear(2048, 3)
self.fc_exp = nn.Linear(2048, 3*num_kp)
def forward(self, x):
out = self.conv1(x)
out = self.norm1(out)
out = F.relu(out)
out = self.maxpool(out)
out = self.conv2(out)
out = self.norm2(out)
out = F.relu(out)
out = self.block1(out)
out = self.conv3(out)
out = self.norm3(out)
out = F.relu(out)
out = self.block2(out)
out = self.block3(out)
out = self.conv4(out)
out = self.norm4(out)
out = F.relu(out)
out = self.block4(out)
out = self.block5(out)
out = self.conv5(out)
out = self.norm5(out)
out = F.relu(out)
out = self.block6(out)
out = self.block7(out)
out = F.adaptive_avg_pool2d(out, 1)
out = out.view(out.shape[0], -1)
yaw = self.fc_roll(out)
pitch = self.fc_pitch(out)
roll = self.fc_yaw(out)
t = self.fc_t(out)
exp = self.fc_exp(out)
return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}
================================================
FILE: modules/model.py
================================================
from torch import nn
import torch
import torch.nn.functional as F
from modules.util import AntiAliasInterpolation2d, make_coordinate_grid_2d
from torchvision import models
import numpy as np
from torch.autograd import grad
import modules.hopenet as hopenet
from torchvision import transforms
import random
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
from prefetch_generator import BackgroundGenerator
from basicsr.utils.img_process_util import filter2D
from basicsr.utils import DiffJPEG, USMSharp
class Vgg19(torch.nn.Module):
"""
Vgg19 network for perceptual loss.
"""
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
requires_grad=False)
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
requires_grad=False)
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
X = (X - self.mean) / self.std
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class ImagePyramide(torch.nn.Module):
"""
Create image pyramide for computing pyramide perceptual loss.
"""
def __init__(self, scales, num_channels):
super(ImagePyramide, self).__init__()
downs = {}
for scale in scales:
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
self.downs = nn.ModuleDict(downs)
def forward(self, x):
out_dict = {}
for scale, down_module in self.downs.items():
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
return out_dict
class Transform:
"""
Random tps transformation for equivariance constraints.
"""
def __init__(self, bs, **kwargs):
noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
self.bs = bs
if ('sigma_tps' in kwargs) and ('points_tps' in kwargs):
self.tps = True
self.control_points = make_coordinate_grid_2d((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
self.control_points = self.control_points.unsqueeze(0)
self.control_params = torch.normal(mean=0,
std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
else:
self.tps = False
def transform_frame(self, frame):
grid = make_coordinate_grid_2d(frame.shape[2:], type=frame.type()).unsqueeze(0)
grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
return F.grid_sample(frame, grid, padding_mode="reflection")
def warp_coordinates(self, coordinates):
theta = self.theta.type(coordinates.type())
theta = theta.unsqueeze(1)
transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
transformed = transformed.squeeze(-1)
if self.tps:
control_points = self.control_points.type(coordinates.type())
control_params = self.control_params.type(coordinates.type())
distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
distances = torch.abs(distances).sum(-1)
result = distances ** 2
result = result * torch.log(distances + 1e-6)
result = result * control_params
result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
transformed = transformed + result
return transformed
def jacobian(self, coordinates):
new_coordinates = self.warp_coordinates(coordinates)
grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True)
grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True)
jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2)
return jacobian
def detach_kp(kp):
return {key: value.detach() for key, value in kp.items()}
def headpose_pred_to_degree(pred):
device = pred.device
idx_tensor = [idx for idx in range(66)]
idx_tensor = torch.FloatTensor(idx_tensor).to(device)
pred = F.softmax(pred)
degree = torch.sum(pred*idx_tensor, axis=1) * 3 - 99
return degree
'''
# beta version
def get_rotation_matrix(yaw, pitch, roll):
yaw = yaw / 180 * 3.14
pitch = pitch / 180 * 3.14
roll = roll / 180 * 3.14
roll = roll.unsqueeze(1)
pitch = pitch.unsqueeze(1)
yaw = yaw.unsqueeze(1)
roll_mat = torch.cat([torch.ones_like(roll), torch.zeros_like(roll), torch.zeros_like(roll),
torch.zeros_like(roll), torch.cos(roll), -torch.sin(roll),
torch.zeros_like(roll), torch.sin(roll), torch.cos(roll)], dim=1)
roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3)
pitch_mat = torch.cat([torch.cos(pitch), torch.zeros_like(pitch), torch.sin(pitch),
torch.zeros_like(pitch), torch.ones_like(pitch), torch.zeros_like(pitch),
-torch.sin(pitch), torch.zeros_like(pitch), torch.cos(pitch)], dim=1)
pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3)
yaw_mat = torch.cat([torch.cos(yaw), -torch.sin(yaw), torch.zeros_like(yaw),
torch.sin(yaw), torch.cos(yaw), torch.zeros_like(yaw),
torch.zeros_like(yaw), torch.zeros_like(yaw), torch.ones_like(yaw)], dim=1)
yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3)
rot_mat = torch.einsum('bij,bjk,bkm->bim', roll_mat, pitch_mat, yaw_mat)
return rot_mat
'''
def get_rotation_matrix(yaw, pitch, roll):
yaw = yaw / 180 * 3.14
pitch = pitch / 180 * 3.14
roll = roll / 180 * 3.14
roll = roll.unsqueeze(1)
pitch = pitch.unsqueeze(1)
yaw = yaw.unsqueeze(1)
pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch),
torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch),
torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1)
pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3)
yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw),
torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw),
-torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1)
yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3)
roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll),
torch.sin(roll), torch.cos(roll), torch.zeros_like(roll),
torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1)
roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3)
rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat)
return rot_mat
def keypoint_transformation(kp_canonical, he, estimate_jacobian=True):
kp = kp_canonical['value'] # (bs, k, 3)
yaw, pitch, roll = he['yaw'], he['pitch'], he['roll']
t, exp = he['t'], he['exp']
yaw = headpose_pred_to_degree(yaw)
pitch = headpose_pred_to_degree(pitch)
roll = headpose_pred_to_degree(roll)
rot_mat = get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3)
# keypoint rotation
kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp)
# keypoint translation
t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1)
kp_t = kp_rotated + t
# add expression deviation
exp = exp.view(exp.shape[0], -1, 3)
kp_transformed = kp_t + exp
if estimate_jacobian:
jacobian = kp_canonical['jacobian'] # (bs, k ,3, 3)
jacobian_transformed = torch.einsum('bmp,bkps->bkms', rot_mat, jacobian)
else:
jacobian_transformed = None
return {'value': kp_transformed, 'jacobian': jacobian_transformed}
class GeneratorFullModel(torch.nn.Module):
"""
Merge all generator related updates into single model for better multi-gpu usage
"""
def __init__(self, kp_extractor, he_estimator, generator, discriminator, train_params, estimate_jacobian=True):
super(GeneratorFullModel, self).__init__()
self.kp_extractor = kp_extractor
self.he_estimator = he_estimator
self.generator = generator
self.discriminator = discriminator
self.train_params = train_params
self.scales = train_params['scales']
self.disc_scales = self.discriminator.scales
self.pyramid = ImagePyramide(self.scales, generator.image_channel)
if torch.cuda.is_available():
self.pyramid = self.pyramid.cuda()
self.loss_weights = train_params['loss_weights']
self.estimate_jacobian = estimate_jacobian
if sum(self.loss_weights['perceptual']) != 0:
self.vgg = Vgg19()
if torch.cuda.is_available():
self.vgg = self.vgg.cuda()
self.L1 = nn.L1Loss().cuda()
if self.loss_weights['headpose'] != 0:
self.hopenet = hopenet.Hopenet(models.resnet.Bottleneck, [3, 4, 6, 3], 66)
print('Loading hopenet')
hopenet_state_dict = torch.load(train_params['hopenet_snapshot'])
self.hopenet.load_state_dict(hopenet_state_dict)
if torch.cuda.is_available():
self.hopenet = self.hopenet.cuda()
self.hopenet.eval()
self.jpeger = DiffJPEG(differentiable=False).cuda()
self.usm_sharpener = USMSharp().cuda()
self.resize_prob = [0.2, 0.7, 0.1]
self.resize_range= [0.5, 1.2]
self.noise_range = [1, 10]
self.poisson_scale_range =[0.05, 1]
self.jpeg_range = [10, 25]
self.opt_scale = 1
self.resize_prob2 = [0.3, 0.4, 0.3]
self.resize_range2= [0.5, 1.2]
self.noise_range2 = [1, 10]
self.poisson_scale_range2 = [0.05, 1.0]
self.jpeg_range2 = [10, 25]
def forward(self, x, config):
kp_canonical = self.kp_extractor(x['source']) # {'value': value, 'jacobian': jacobian}
he_source = self.he_estimator(x['source']) # {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}
he_driving = self.he_estimator(x['driving']) # {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}
# {'value': value, 'jacobian': jacobian}
kp_source = keypoint_transformation(kp_canonical, he_source, self.estimate_jacobian)
kp_driving = keypoint_transformation(kp_canonical, he_driving, self.estimate_jacobian)
if config['train_params']['low_quality_train']:
# ----------------------- The first degradation process ----------------------- #
# blur
# x_source = self.usm_sharpener(x['source'])
x_source = x['source']
x_source = filter2D(x_source, x['kernel'])
# random resize
updown_type = random.choices(['up', 'down', 'keep'], self.resize_prob)[0]
if updown_type == 'up':
scale = np.random.uniform(1, self.resize_range[1])
elif updown_type == 'down':
scale = np.random.uniform(self.resize_range[0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
x_source = F.interpolate(x_source, scale_factor=scale, mode=mode)
# add noise
gray_noise_prob = 0.4
if np.random.uniform() < 0.5:
x_source = random_add_gaussian_noise_pt(
x_source, sigma_range=self.noise_range, clip=True, rounds=False, gray_prob=gray_noise_prob)
else:
x_source = random_add_poisson_noise_pt(
x_source,
scale_range=self.poisson_scale_range,
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
# JPEG compression
jpeg_p = x_source.new_zeros(x_source.size(0)).uniform_(*self.jpeg_range)
x_source = torch.clamp(x_source, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
x_source = self.jpeger(x_source, quality=jpeg_p)
# ----------------------- The second degradation process ----------------------- #
# blur
if np.random.uniform() < 0.8:
x_source = filter2D(x_source, x['kernel2'].cuda())
# random resize
updown_type = random.choices(['up', 'down', 'keep'], self.resize_prob2)[0]
if updown_type == 'up':
scale = np.random.uniform(1, self.resize_range2[1])
elif updown_type == 'down':
scale = np.random.uniform(self.resize_range2[0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
x_source = F.interpolate(
x_source, size=(int(config['dataset_params']['frame_shape'][0] / self.opt_scale * scale), int(config['dataset_params']['frame_shape'][1] / self.opt_scale * scale)), mode=mode)
# add noise
gray_noise_prob = 0.4
if np.random.uniform() < 0.5:
x_source = random_add_gaussian_noise_pt(
x_source, sigma_range=self.noise_range2, clip=True, rounds=False, gray_prob=gray_noise_prob)
else:
x_source = random_add_poisson_noise_pt(
x_source,
scale_range=self.poisson_scale_range2,
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
if np.random.uniform() < 0.5:
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
x_source = F.interpolate(x_source, size=(config['dataset_params']['frame_shape'][0] // self.opt_scale, config['dataset_params']['frame_shape'][1] // self.opt_scale), mode=mode)
x_source = filter2D(x_source, x['sinc_kernel'].cuda())
# JPEG compression
jpeg_p = x_source.new_zeros(x_source.size(0)).uniform_(*self.jpeg_range2)
x_source = torch.clamp(x_source, 0, 1)
x_source = self.jpeger(x_source, quality=jpeg_p)
else:
# JPEG compression
jpeg_p = x_source.new_zeros(x_source.size(0)).uniform_(*self.jpeg_range2)
x_source = torch.clamp(x_source, 0, 1)
x_source = self.jpeger(x_source, quality=jpeg_p)
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
x_source = F.interpolate(x_source, size=(config['dataset_params']['frame_shape'][0] // self.opt_scale, config['dataset_params']['frame_shape'][1] // self.opt_scale), mode=mode)
x_source = filter2D(x_source, x['sinc_kernel'].cuda())
# clamp and round
lq = torch.clamp((x_source * 255.0).round(), 0, 255) / 255.
lq_img = lq.contiguous()
generated = self.generator(lq_img, kp_source=kp_source, kp_driving=kp_driving)
else:
generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving)
generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
loss_values = {}
pyramide_real = self.pyramid(x['driving'])
pyramide_generated = self.pyramid(generated['prediction'])
if sum(self.loss_weights['perceptual']) != 0:
value_total = 0
for scale in self.scales:
x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
for i, weight in enumerate(self.loss_weights['perceptual']):
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
value_total += self.loss_weights['perceptual'][i] * value
loss_values['perceptual'] = value_total
if self.loss_weights['generator_gan'] != 0:
discriminator_maps_generated = self.discriminator(pyramide_generated)
discriminator_maps_real = self.discriminator(pyramide_real)
value_total = 0
for scale in self.disc_scales:
key = 'prediction_map_%s' % scale
if self.train_params['gan_mode'] == 'hinge':
value = -torch.mean(discriminator_maps_generated[key])
elif self.train_params['gan_mode'] == 'ls':
value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
else:
raise ValueError('Unexpected gan_mode {}'.format(self.train_params['gan_mode']))
value_total += self.loss_weights['generator_gan'] * value
loss_values['gen_gan'] = value_total
if sum(self.loss_weights['feature_matching']) != 0:
value_total = 0
for scale in self.disc_scales:
key = 'feature_maps_%s' % scale
for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):
if self.loss_weights['feature_matching'][i] == 0:
continue
value = torch.abs(a - b).mean()
value_total += self.loss_weights['feature_matching'][i] * value
loss_values['feature_matching'] = value_total
if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0:
transform = Transform(x['driving'].shape[0], **self.train_params['transform_params'])
transformed_frame = transform.transform_frame(x['driving'])
transformed_he_driving = self.he_estimator(transformed_frame)
transformed_kp = keypoint_transformation(kp_canonical, transformed_he_driving, self.estimate_jacobian)
generated['transformed_frame'] = transformed_frame
generated['transformed_kp'] = transformed_kp
## Value loss part
if self.loss_weights['equivariance_value'] != 0:
# project 3d -> 2d
kp_driving_2d = kp_driving['value'][:, :, :2]
transformed_kp_2d = transformed_kp['value'][:, :, :2]
value = torch.abs(kp_driving_2d - transform.warp_coordinates(transformed_kp_2d)).mean()
loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value
## jacobian loss part
if self.loss_weights['equivariance_jacobian'] != 0:
# project 3d -> 2d
transformed_kp_2d = transformed_kp['value'][:, :, :2]
transformed_jacobian_2d = transformed_kp['jacobian'][:, :, :2, :2]
jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp_2d),
transformed_jacobian_2d)
jacobian_2d = kp_driving['jacobian'][:, :, :2, :2]
normed_driving = torch.inverse(jacobian_2d)
normed_transformed = jacobian_transformed
value = torch.matmul(normed_driving, normed_transformed)
eye = torch.eye(2).view(1, 1, 2, 2).type(value.type())
value = torch.abs(eye - value).mean()
loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value
if self.loss_weights['keypoint'] != 0:
# print(kp_driving['value'].shape) # (bs, k, 3)
value_total = 0
for i in range(kp_driving['value'].shape[1]):
for j in range(kp_driving['value'].shape[1]):
dist = F.pairwise_distance(kp_driving['value'][:, i, :], kp_driving['value'][:, j, :], p=2, keepdim=True) ** 2
dist = 0.1 - dist # set Dt = 0.1
dd = torch.gt(dist, 0)
value = (dist * dd).mean()
value_total += value
kp_mean_depth = kp_driving['value'][:, :, -1].mean(-1)
value_depth = torch.abs(kp_mean_depth - 0.33).mean() # set Zt = 0.33
value_total += value_depth
loss_values['keypoint'] = self.loss_weights['keypoint'] * value_total
if self.loss_weights['headpose'] != 0:
transform_hopenet = transforms.Compose([transforms.Resize(size=(224, 224)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
driving_224 = transform_hopenet(x['driving'])
yaw_gt, pitch_gt, roll_gt = self.hopenet(driving_224)
yaw_gt = headpose_pred_to_degree(yaw_gt)
pitch_gt = headpose_pred_to_degree(pitch_gt)
roll_gt = headpose_pred_to_degree(roll_gt)
yaw, pitch, roll = he_driving['yaw'], he_driving['pitch'], he_driving['roll']
yaw = headpose_pred_to_degree(yaw)
pitch = headpose_pred_to_degree(pitch)
roll = headpose_pred_to_degree(roll)
value = torch.abs(yaw - yaw_gt).mean() + torch.abs(pitch - pitch_gt).mean() + torch.abs(roll - roll_gt).mean()
loss_values['headpose'] = self.loss_weights['headpose'] * value
if self.loss_weights['expression'] != 0:
value = torch.norm(he_driving['exp'], p=1, dim=-1).mean()
loss_values['expression'] = self.loss_weights['expression'] * value
loss_values['reconstruction'] = self.loss_weights['reconstruction'] * self.L1(generated['prediction'], x['driving'])
return loss_values, generated
class DiscriminatorFullModel(torch.nn.Module):
"""
Merge all discriminator related updates into single model for better multi-gpu usage
"""
def __init__(self, kp_extractor, generator, discriminator, train_params):
super(DiscriminatorFullModel, self).__init__()
self.kp_extractor = kp_extractor
self.generator = generator
self.discriminator = discriminator
self.train_params = train_params
self.scales = self.discriminator.scales
self.pyramid = ImagePyramide(self.scales, generator.image_channel)
if torch.cuda.is_available():
self.pyramid = self.pyramid.cuda()
self.loss_weights = train_params['loss_weights']
self.zero_tensor = None
def get_zero_tensor(self, input):
if self.zero_tensor is None:
self.zero_tensor = torch.FloatTensor(1).fill_(0).cuda()
self.zero_tensor.requires_grad_(False)
return self.zero_tensor.expand_as(input)
def forward(self, x, generated):
pyramide_real = self.pyramid(x['driving'])
pyramide_generated = self.pyramid(generated['prediction'].detach())
discriminator_maps_generated = self.discriminator(pyramide_generated)
discriminator_maps_real = self.discriminator(pyramide_real)
loss_values = {}
value_total = 0
for scale in self.scales:
key = 'prediction_map_%s' % scale
if self.train_params['gan_mode'] == 'hinge':
value = -torch.mean(torch.min(discriminator_maps_real[key]-1, self.get_zero_tensor(discriminator_maps_real[key]))) - torch.mean(torch.min(-discriminator_maps_generated[key]-1, self.get_zero_tensor(discriminator_maps_generated[key])))
elif self.train_params['gan_mode'] == 'ls':
value = ((1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2).mean()
else:
raise ValueError('Unexpected gan_mode {}'.format(self.train_params['gan_mode']))
value_total += self.loss_weights['discriminator_gan'] * value
loss_values['disc_gan'] = value_total
return loss_values
================================================
FILE: modules/util.py
================================================
from torch import nn
import torch.nn.functional as F
import torch
from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
from sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d
import torch.nn.utils.spectral_norm as spectral_norm
import re
def kp2gaussian(kp, spatial_size, kp_variance):
"""
Transform a keypoint into gaussian like representation
"""
mean = kp['value']
coordinate_grid = make_coordinate_grid(spatial_size, mean.type())
number_of_leading_dimensions = len(mean.shape) - 1
shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
coordinate_grid = coordinate_grid.view(*shape)
repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)
coordinate_grid = coordinate_grid.repeat(*repeats)
# Preprocess kp shape
shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)
mean = mean.view(*shape)
mean_sub = (coordinate_grid - mean)
out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
return out
def make_coordinate_grid_2d(spatial_size, type):
"""
Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
"""
h, w = spatial_size
x = torch.arange(w).type(type)
y = torch.arange(h).type(type)
x = (2 * (x / (w - 1)) - 1)
y = (2 * (y / (h - 1)) - 1)
yy = y.view(-1, 1).repeat(1, w)
xx = x.view(1, -1).repeat(h, 1)
meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
return meshed
def make_coordinate_grid(spatial_size, type):
d, h, w = spatial_size
x = torch.arange(w).type(type)
y = torch.arange(h).type(type)
z = torch.arange(d).type(type)
x = (2 * (x / (w - 1)) - 1)
y = (2 * (y / (h - 1)) - 1)
z = (2 * (z / (d - 1)) - 1)
yy = y.view(1, -1, 1).repeat(d, 1, w)
xx = x.view(1, 1, -1).repeat(d, h, 1)
zz = z.view(-1, 1, 1).repeat(1, h, w)
meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3)
return meshed
class ResBottleneck(nn.Module):
def __init__(self, in_features, stride):
super(ResBottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features//4, kernel_size=1)
self.conv2 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features//4, kernel_size=3, padding=1, stride=stride)
self.conv3 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features, kernel_size=1)
self.norm1 = BatchNorm2d(in_features//4, affine=True)
self.norm2 = BatchNorm2d(in_features//4, affine=True)
self.norm3 = BatchNorm2d(in_features, affine=True)
self.stride = stride
if self.stride != 1:
self.skip = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=1, stride=stride)
self.norm4 = BatchNorm2d(in_features, affine=True)
def forward(self, x):
out = self.conv1(x)
out = self.norm1(out)
out = F.relu(out)
out = self.conv2(out)
out = self.norm2(out)
out = F.relu(out)
out = self.conv3(out)
out = self.norm3(out)
if self.stride != 1:
x = self.skip(x)
x = self.norm4(x)
out += x
out = F.relu(out)
return out
class ResBlock2d(nn.Module):
"""
Res block, preserve spatial resolution.
"""
def __init__(self, in_features, kernel_size, padding):
super(ResBlock2d, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
padding=padding)
self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
padding=padding)
self.norm1 = BatchNorm2d(in_features, affine=True)
self.norm2 = BatchNorm2d(in_features, affine=True)
def forward(self, x):
out = self.norm1(x)
out = F.relu(out)
out = self.conv1(out)
out = self.norm2(out)
out = F.relu(out)
out = self.conv2(out)
out += x
return out
class ResBlock3d(nn.Module):
"""
Res block, preserve spatial resolution.
"""
def __init__(self, in_features, kernel_size, padding):
super(ResBlock3d, self).__init__()
self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
padding=padding)
self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
padding=padding)
self.norm1 = BatchNorm3d(in_features, affine=True)
self.norm2 = BatchNorm3d(in_features, affine=True)
def forward(self, x):
out = self.norm1(x)
out = F.relu(out)
out = self.conv1(out)
out = self.norm2(out)
out = F.relu(out)
out = self.conv2(out)
out += x
return out
class UpBlock2d(nn.Module):
"""
Upsampling block for use in decoder.
"""
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
super(UpBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
padding=padding, groups=groups)
self.norm = BatchNorm2d(out_features, affine=True)
def forward(self, x):
out = F.interpolate(x, scale_factor=2)
out = self.conv(out)
out = self.norm(out)
out = F.relu(out)
return out
class UpBlock3d(nn.Module):
"""
Upsampling block for use in decoder.
"""
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
super(UpBlock3d, self).__init__()
self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
padding=padding, groups=groups)
self.norm = BatchNorm3d(out_features, affine=True)
def forward(self, x):
# out = F.interpolate(x, scale_factor=(1, 2, 2), mode='trilinear')
out = F.interpolate(x, scale_factor=(1, 2, 2))
out = self.conv(out)
out = self.norm(out)
out = F.relu(out)
return out
class DownBlock2d(nn.Module):
"""
Downsampling block for use in encoder.
"""
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
super(DownBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
padding=padding, groups=groups)
self.norm = BatchNorm2d(out_features, affine=True)
self.pool = nn.AvgPool2d(kernel_size=(2, 2))
def forward(self, x):
out = self.conv(x)
out = self.norm(out)
out = F.relu(out)
out = self.pool(out)
return out
class DownBlock3d(nn.Module):
"""
Downsampling block for use in encoder.
"""
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
super(DownBlock3d, self).__init__()
'''
self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
padding=padding, groups=groups, stride=(1, 2, 2))
'''
self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
padding=padding, groups=groups)
self.norm = BatchNorm3d(out_features, affine=True)
self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2))
def forward(self, x):
out = self.conv(x)
out = self.norm(out)
out = F.relu(out)
out = self.pool(out)
return out
class SameBlock2d(nn.Module):
"""
Simple block, preserve spatial resolution.
"""
def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False):
super(SameBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
kernel_size=kernel_size, padding=padding, groups=groups)
self.norm = BatchNorm2d(out_features, affine=True)
if lrelu:
self.ac = nn.LeakyReLU()
else:
self.ac = nn.ReLU()
def forward(self, x):
out = self.conv(x)
out = self.norm(out)
out = self.ac(out)
return out
class Encoder(nn.Module):
"""
Hourglass Encoder
"""
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
super(Encoder, self).__init__()
down_blocks = []
for i in range(num_blocks):
down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
min(max_features, block_expansion * (2 ** (i + 1))),
kernel_size=3, padding=1))
self.down_blocks = nn.ModuleList(down_blocks)
def forward(self, x):
outs = [x]
for down_block in self.down_blocks:
outs.append(down_block(outs[-1]))
return outs
class Decoder(nn.Module):
"""
Hourglass Decoder
"""
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
super(Decoder, self).__init__()
up_blocks = []
for i in range(num_blocks)[::-1]:
in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
out_filters = min(max_features, block_expansion * (2 ** i))
up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))
self.up_blocks = nn.ModuleList(up_blocks)
# self.out_filters = block_expansion
self.out_filters = block_expansion + in_features
self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1)
self.norm = BatchNorm3d(self.out_filters, affine=True)
def forward(self, x):
out = x.pop()
# for up_block in self.up_blocks[:-1]:
for up_block in self.up_blocks:
out = up_block(out)
skip = x.pop()
out = torch.cat([out, skip], dim=1)
# out = self.up_blocks[-1](out)
out = self.conv(out)
out = self.norm(out)
out = F.relu(out)
return out
class Hourglass(nn.Module):
"""
Hourglass architecture.
"""
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
super(Hourglass, self).__init__()
self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
self.out_filters = self.decoder.out_filters
def forward(self, x):
return self.decoder(self.encoder(x))
class KPHourglass(nn.Module):
"""
Hourglass architecture.
"""
def __init__(self, block_expansion, in_features, reshape_features, reshape_depth, num_blocks=3, max_features=256):
super(KPHourglass, self).__init__()
self.down_blocks = nn.Sequential()
for i in range(num_blocks):
self.down_blocks.add_module('down'+ str(i), DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
min(max_features, block_expansion * (2 ** (i + 1))),
kernel_size=3, padding=1))
in_filters = min(max_features, block_expansion * (2 ** num_blocks))
self.conv = nn.Conv2d(in_channels=in_filters, out_channels=reshape_features, kernel_size=1)
self.up_blocks = nn.Sequential()
for i in range(num_blocks):
in_filters = min(max_features, block_expansion * (2 ** (num_blocks - i)))
out_filters = min(max_features, block_expansion * (2 ** (num_blocks - i - 1)))
self.up_blocks.add_module('up'+ str(i), UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))
self.reshape_depth = reshape_depth
self.out_filters = out_filters
def forward(self, x):
out = self.down_blocks(x)
out = self.conv(out)
bs, c, h, w = out.shape
out = out.view(bs, c//self.reshape_depth, self.reshape_depth, h, w)
out = self.up_blocks(out)
return out
class AntiAliasInterpolation2d(nn.Module):
"""
Band-limited downsampling, for better preservation of the input signal.
"""
def __init__(self, channels, scale):
super(AntiAliasInterpolation2d, self).__init__()
sigma = (1 / scale - 1) / 2
kernel_size = 2 * round(sigma * 4) + 1
self.ka = kernel_size // 2
self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
kernel_size = [kernel_size, kernel_size]
sigma = [sigma, sigma]
# The gaussian kernel is the product of the
# gaussian function of each dimension.
kernel = 1
meshgrids = torch.meshgrid(
[
torch.arange(size, dtype=torch.float32)
for size in kernel_size
]
)
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
# Make sure sum of values in gaussian kernel equals 1.
kernel = kernel / torch.sum(kernel)
# Reshape to depthwise convolutional weight
kernel = kernel.view(1, 1, *kernel.size())
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
self.register_buffer('weight', kernel)
self.groups = channels
self.scale = scale
inv_scale = 1 / scale
self.int_inv_scale = int(inv_scale)
def forward(self, input):
if self.scale == 1.0:
return input
out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
out = F.conv2d(out, weight=self.weight, groups=self.groups)
out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
return out
class SPADE(nn.Module):
def __init__(self, norm_nc, label_nc):
super().__init__()
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
nhidden = 128
self.mlp_shared = nn.Sequential(
nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
nn.ReLU())
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
def forward(self, x, segmap):
normalized = self.param_free_norm(x)
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
actv = self.mlp_shared(segmap)
gamma = self.mlp_gamma(actv)
beta = self.mlp_beta(actv)
out = normalized * (1 + gamma) + beta
return out
class SPADEResnetBlock(nn.Module):
def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
super().__init__()
# Attributes
self.learned_shortcut = (fin != fout)
fmiddle = min(fin, fout)
self.use_se = use_se
# create conv layers
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# apply spectral norm if specified
if 'spectral' in norm_G:
self.conv_0 = spectral_norm(self.conv_0)
self.conv_1 = spectral_norm(self.conv_1)
if self.learned_shortcut:
self.conv_s = spectral_norm(self.conv_s)
# define normalization layers
self.norm_0 = SPADE(fin, label_nc)
self.norm_1 = SPADE(fmiddle, label_nc)
if self.learned_shortcut:
self.norm_s = SPADE(fin, label_nc)
def forward(self, x, seg1):
x_s = self.shortcut(x, seg1)
dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
out = x_s + dx
return out
def shortcut(self, x, seg1):
if self.learned_shortcut:
x_s = self.conv_s(self.norm_s(x, seg1))
else:
x_s = x
return x_s
def actvn(self, x):
return F.leaky_relu(x, 2e-1)
================================================
FILE: run_demo.sh
================================================
python demo.py \
--config config/mix-resolution.yml \
--checkpoint checkpoints/mix-train.pth.tar \
--source_image DEMO/demo_img_3.jpg \
--driving_video DEMO/demo_video_1.mp4 \
--relative
================================================
FILE: sync_batchnorm/__init__.py
================================================
# -*- coding: utf-8 -*-
# File : __init__.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
from .replicate import DataParallelWithCallback, patch_replication_callback
================================================
FILE: sync_batchnorm/batchnorm.py
================================================
# -*- coding: utf-8 -*-
# File : batchnorm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import collections
import torch
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
from .comm import SyncMaster
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
def _sum_ft(tensor):
"""sum over the first and last dimention"""
return tensor.sum(dim=0).sum(dim=-1)
def _unsqueeze_ft(tensor):
"""add new dementions at the front and the tail"""
return tensor.unsqueeze(0).unsqueeze(-1)
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
class _SynchronizedBatchNorm(_BatchNorm):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
self._sync_master = SyncMaster(self._data_parallel_master)
self._is_parallel = False
self._parallel_id = None
self._slave_pipe = None
def forward(self, input):
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
if not (self._is_parallel and self.training):
return F.batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training, self.momentum, self.eps)
# Resize the input to (B, C, -1).
input_shape = input.size()
input = input.view(input.size(0), self.num_features, -1)
# Compute the sum and square-sum.
sum_size = input.size(0) * input.size(2)
input_sum = _sum_ft(input)
input_ssum = _sum_ft(input ** 2)
# Reduce-and-broadcast the statistics.
if self._parallel_id == 0:
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
else:
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
# Compute the output.
if self.affine:
# MJY:: Fuse the multiplication for speed.
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
else:
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
# Reshape it.
return output.view(input_shape)
def __data_parallel_replicate__(self, ctx, copy_id):
self._is_parallel = True
self._parallel_id = copy_id
# parallel_id == 0 means master device.
if self._parallel_id == 0:
ctx.sync_master = self._sync_master
else:
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
def _data_parallel_master(self, intermediates):
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
# Always using same "device order" makes the ReduceAdd operation faster.
# Thanks to:: Tete Xiao (http://tetexiao.com/)
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
to_reduce = [i[1][:2] for i in intermediates]
to_reduce = [j for i in to_reduce for j in i] # flatten
target_gpus = [i[1].sum.get_device() for i in intermediates]
sum_size = sum([i[1].sum_size for i in intermediates])
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
outputs = []
for i, rec in enumerate(intermediates):
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
return outputs
def _compute_mean_std(self, sum_, ssum, size):
"""Compute the mean and standard-deviation with sum and square-sum. This method
also maintains the moving average on the master device."""
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
mean = sum_ / size
sumvar = ssum - sum_ * mean
unbias_var = sumvar / (size - 1)
bias_var = sumvar / size
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
return mean, bias_var.clamp(self.eps) ** -0.5
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
mini-batch.
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm1d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
Args:
num_features: num_features from an expected input of size
`batch_size x num_features [x width]`
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape:
- Input: :math:`(N, C)` or :math:`(N, C, L)`
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm1d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm1d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'
.format(input.dim()))
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
of 3d inputs
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm2d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
Args:
num_features: num_features from an expected input of
size batch_size x num_features x height x width
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape:
- Input: :math:`(N, C, H, W)`
- Output: :math:`(N, C, H, W)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm2d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm2d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
of 4d inputs
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm3d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
or Spatio-temporal BatchNorm
Args:
num_features: num_features from an expected input of
size batch_size x num_features x depth x height x width
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape:
- Input: :math:`(N, C, D, H, W)`
- Output: :math:`(N, C, D, H, W)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm3d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm3d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
.format(input.dim()))
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
================================================
FILE: sync_batchnorm/comm.py
================================================
# -*- coding: utf-8 -*-
# File : comm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import queue
import collections
import threading
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
class FutureResult(object):
"""A thread-safe future implementation. Used only as one-to-one pipe."""
def __init__(self):
self._result = None
self._lock = threading.Lock()
self._cond = threading.Condition(self._lock)
def put(self, result):
with self._lock:
assert self._result is None, 'Previous result has\'t been fetched.'
self._result = result
self._cond.notify()
def get(self):
with self._lock:
if self._result is None:
self._cond.wait()
res = self._result
self._result = None
return res
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
class SlavePipe(_SlavePipeBase):
"""Pipe for master-slave communication."""
def run_slave(self, msg):
self.queue.put((self.identifier, msg))
ret = self.result.get()
self.queue.put(True)
return ret
class SyncMaster(object):
"""An abstract `SyncMaster` object.
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
and passed to a registered callback.
- After receiving the messages, the master device should gather the information and determine to message passed
back to each slave devices.
"""
def __init__(self, master_callback):
"""
Args:
master_callback: a callback to be invoked after having collected messages from slave devices.
"""
self._master_callback = master_callback
self._queue = queue.Queue()
self._registry = collections.OrderedDict()
self._activated = False
def __getstate__(self):
return {'master_callback': self._master_callback}
def __setstate__(self, state):
self.__init__(state['master_callback'])
def register_slave(self, identifier):
"""
Register an slave device.
Args:
identifier: an identifier, usually is the device id.
Returns: a `SlavePipe` object which can be used to communicate with the master device.
"""
if self._activated:
assert self._queue.empty(), 'Queue is not clean before next initialization.'
self._activated = False
self._registry.clear()
future = FutureResult()
self._registry[identifier] = _MasterRegistry(future)
return SlavePipe(identifier, self._queue, future)
def run_master(self, master_msg):
"""
Main entry for the master device in each forward pass.
The messages were first collected from each devices (including the master device), and then
an callback will be invoked to compute the message to be sent back to each devices
(including the master device).
Args:
master_msg: the message that the master want to send to itself. This will be placed as the first
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
Returns: the message to be sent back to the master device.
"""
self._activated = True
intermediates = [(0, master_msg)]
for i in range(self.nr_slaves):
intermediates.append(self._queue.get())
results = self._master_callback(intermediates)
assert results[0][0] == 0, 'The first result should belongs to the master.'
for i, res in results:
if i == 0:
continue
self._registry[i].result.put(res)
for i in range(self.nr_slaves):
assert self._queue.get() is True
return results[0][1]
@property
def nr_slaves(self):
return len(self._registry)
================================================
FILE: sync_batchnorm/replicate.py
================================================
# -*- coding: utf-8 -*-
# File : replicate.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import functools
from torch.nn.parallel.data_parallel import DataParallel
__all__ = [
'CallbackContext',
'execute_replication_callbacks',
'DataParallelWithCallback',
'patch_replication_callback'
]
class CallbackContext(object):
pass
def execute_replication_callbacks(modules):
"""
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Note that, as all modules are isomorphism, we assign each sub-module with a context
(shared among multiple copies of this module on different devices).
Through this context, different copies can share some information.
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
of any slave copies.
"""
master_copy = modules[0]
nr_modules = len(list(master_copy.modules()))
ctxs = [CallbackContext() for _ in range(nr_modules)]
for i, module in enumerate(modules):
for j, m in enumerate(module.modules()):
if hasattr(m, '__data_parallel_replicate__'):
m.__data_parallel_replicate__(ctxs[j], i)
class DataParallelWithCallback(DataParallel):
"""
Data Parallel with a replication callback.
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
original `replicate` function.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
# sync_bn.__data_parallel_replicate__ will be invoked.
"""
def replicate(self, module, device_ids):
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
def patch_replication_callback(data_parallel):
"""
Monkey-patch an existing `DataParallel` object. Add the replication callback.
Useful when you have customized `DataParallel` implementation.
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
> patch_replication_callback(sync_bn)
# this is equivalent to
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
"""
assert isinstance(data_parallel, DataParallel)
old_replicate = data_parallel.replicate
@functools.wraps(old_replicate)
def new_replicate(module, device_ids):
modules = old_replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
data_parallel.replicate = new_replicate
================================================
FILE: sync_batchnorm/unittest.py
================================================
# -*- coding: utf-8 -*-
# File : unittest.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import unittest
import numpy as np
from torch.autograd import Variable
def as_numpy(v):
if isinstance(v, Variable):
v = v.data
return v.cpu().numpy()
class TorchTestCase(unittest.TestCase):
def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
npa, npb = as_numpy(a), as_numpy(b)
self.assertTrue(
np.allclose(npa, npb, atol=atol),
'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
)
================================================
FILE: upsampler/app_gradio.py
================================================
from __future__ import annotations
import argparse
import pathlib
import torch
import gradio as gr
from webUI.app_task import *
from webUI.styleganex_model import Model
DESCRIPTION = '''
<div align=center>
<h1 style="font-weight: 900; margin-bottom: 7px;">
Face Manipulation with <a href="https://github.com/williamyang1991/StyleGANEX">StyleGANEX</a>
</h1>
<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
<a href="https://huggingface.co/spaces/PKUWilliamYang/StyleGANEX?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>
<p/>
<img style="margin-top: 0em" src="https://raw.githubusercontent.com/williamyang1991/tmpfile/master/imgs/example.jpg" alt="example">
</div>
'''
ARTICLE = r"""
If StyleGANEX is helpful, please help to ⭐ the <a href='https://github.com/williamyang1991/StyleGANEX' target='_blank'>Github Repo</a>. Thanks!
[](https://github.com/williamyang1991/StyleGANEX)
---
📝 **Citation**
If our work is useful for your research, please consider citing:
```bibtex
@article{yang2023styleganex,
title = {StyleGANEX: StyleGAN-Based Manipulation Beyond Cropped Aligned Faces},
author = {Yang, Shuai and Jiang, Liming and Liu, Ziwei and and Loy, Chen Change},
journal = {arXiv preprint arXiv:2303.06146},
year={2023},
}
```
📋 **License**
This project is licensed under <a rel="license" href="https://github.com/williamyang1991/VToonify/blob/main/LICENSE.md">S-Lab License 1.0</a>.
Redistribution and use for non-commercial purposes should follow this license.
📧 **Contact**
If you have any questions, please feel free to reach me out at <b>williamyang@pku.edu.cn</b>.
"""
FOOTER = '<div align=center><img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.laobi.icu/badge?page_id=williamyang1991/styleganex" /></div>'
def main():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('*** Now using %s.'%(device))
model = Model(device=device)
torch.hub.download_url_to_file('https://raw.githubusercontent.com/williamyang1991/StyleGANEX/main/data/234_sketch.jpg',
'234_sketch.jpg')
torch.hub.download_url_to_file('https://github.com/williamyang1991/StyleGANEX/raw/main/output/ILip77SbmOE_inversion.pt',
'ILip77SbmOE_inversion.pt')
torch.hub.download_url_to_file('https://raw.githubusercontent.com/williamyang1991/StyleGANEX/main/data/ILip77SbmOE.png',
'ILip77SbmOE.png')
torch.hub.download_url_to_file('https://raw.githubusercontent.com/williamyang1991/StyleGANEX/main/data/ILip77SbmOE_mask.png',
'ILip77SbmOE_mask.png')
torch.hub.download_url_to_file('https://raw.githubusercontent.com/williamyang1991/StyleGANEX/main/data/pexels-daniel-xavier-1239291.jpg',
'pexels-daniel-xavier-1239291.jpg')
torch.hub.download_url_to_file('https://github.com/williamyang1991/StyleGANEX/raw/main/data/529_2.mp4',
'529_2.mp4')
torch.hub.download_url_to_file('https://github.com/williamyang1991/StyleGANEX/raw/main/data/684.mp4',
'684.mp4')
torch.hub.download_url_to_file('https://github.com/williamyang1991/StyleGANEX/raw/main/data/pexels-anthony-shkraba-production-8136210.mp4',
'pexels-anthony-shkraba-production-8136210.mp4')
with gr.Blocks(css='style.css') as demo:
gr.Markdown(DESCRIPTION)
with gr.Tabs():
with gr.TabItem('Inversion for Editing'):
create_demo_inversion(model.process_inversion, allow_optimization=True)
with gr.TabItem('Image Face Toonify'):
create_demo_toonify(model.process_toonify)
with gr.TabItem('Video Face Toonify'):
create_demo_vtoonify(model.process_vtoonify, max_frame_num=5000)
with gr.TabItem('Image Face Editing'):
create_demo_editing(model.process_editing)
with gr.TabItem('Video Face Editing'):
create_demo_vediting(model.process_vediting, max_frame_num=5000)
with gr.TabItem('Sketch2Face'):
create_demo_s2f(model.process_s2f)
with gr.TabItem('Mask2Face'):
create_demo_m2f(model.process_m2f)
with gr.TabItem('SR'):
create_demo_sr(model.process_sr)
gr.Markdown(ARTICLE)
gr.Markdown(FOOTER)
demo.queue(concurrency_count=1)
demo.launch(server_port=8088, server_name="0.0.0.0", debug=True)
if __name__ == '__main__':
main()
================================================
FILE: upsampler/configs/__init__.py
================================================
================================================
FILE: upsampler/configs/data_configs.py
================================================
from configs import transforms_config
from configs.paths_config import dataset_paths
DATASETS = {
'ffhq_encode': {
'transforms': transforms_config.EncodeTransforms,
'train_source_root': dataset_paths['ffhq'],
'train_target_root': dataset_paths['ffhq'],
'test_source_root': dataset_paths['ffhq_test'],
'test_target_root': dataset_paths['ffhq_test'],
},
'ffhq_sketch_to_face': {
'transforms': transforms_config.SketchToImageTransforms,
'train_source_root': dataset_paths['ffhq_train_sketch'],
'train_target_root': dataset_paths['ffhq'],
'test_source_root': dataset_paths['ffhq_test_sketch'],
'test_target_root': dataset_paths['ffhq_test'],
},
'ffhq_seg_to_face': {
'transforms': transforms_config.SegToImageTransforms,
'train_source_root': dataset_paths['ffhq_train_segmentation'],
'train_target_root': dataset_paths['ffhq'],
'test_source_root': dataset_paths['ffhq_test_segmentation'],
'test_target_root': dataset_paths['ffhq_test'],
},
'ffhq_super_resolution': {
'transforms': transforms_config.SuperResTransforms,
'train_source_root': dataset_paths['ffhq'],
'train_target_root': dataset_paths['ffhq1280'],
'test_source_root': dataset_paths['ffhq_test'],
'test_target_root': dataset_paths['ffhq1280_test'],
},
'toonify': {
'transforms': transforms_config.ToonifyTransforms,
'train_source_root': dataset_paths['toonify_in'],
'train_target_root': dataset_paths['toonify_out'],
'test_source_root': dataset_paths['toonify_test_in'],
'test_target_root': dataset_paths['toonify_test_out'],
},
'ffhq_edit': {
'transforms': transforms_config.EditingTransforms,
'train_source_root': dataset_paths['ffhq'],
'train_target_root': dataset_paths['ffhq'],
'test_source_root': dataset_paths['ffhq_test'],
'test_target_root': dataset_paths['ffhq_test'],
},
}
================================================
FILE: upsampler/configs/dataset_config.yml
================================================
# dataset and data loader settings
datasets:
train:
name: FFHQ
type: FFHQDegradationDataset
# dataroot_gt: datasets/ffhq/ffhq_512.lmdb
dataroot_gt: ../../../../share/shuaiyang/ffhq/realign1280x1280test/
io_backend:
# type: lmdb
type: disk
use_hflip: true
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
out_size: 1280
scale: 4
blur_kernel_size: 41
kernel_list: ['iso', 'aniso']
kernel_prob: [0.5, 0.5]
blur_sigma: [0.1, 10]
downsample_range: [4, 40]
noise_range: [0, 20]
jpeg_range: [60, 100]
# color jitter and gray
#color_jitter_prob: 0.3
#color_jitter_shift: 20
#color_jitter_pt_prob: 0.3
#gray_prob: 0.01
# If you do not want colorization, please set
color_jitter_prob: ~
color_jitter_pt_prob: ~
gray_prob: 0.01
gt_gray: True
crop_components: true
component_path: ./pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
eye_enlarge_ratio: 1.4
# data loader
use_shuffle: true
num_worker_per_gpu: 6
batch_size_per_gpu: 4
dataset_enlarge_ratio: 1
prefetch_mode: ~
val:
# Please modify accordingly to use your own validation
# Or comment the val block if do not need validation during training
name: validation
type: PairedImageDataset
dataroot_lq: datasets/faces/validation/input
dataroot_gt: datasets/faces/validation/reference
io_backend:
type: disk
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
scale: 1
================================================
FILE: upsampler/configs/paths_config.py
================================================
dataset_paths = {
'ffhq': 'data/train/ffhq/realign320x320/',
'ffhq_test': 'data/train/ffhq/realign320x320test/',
'ffhq1280': 'data/train/ffhq/realign1280x1280/',
'ffhq1280_test': 'data/train/ffhq/realign1280x1280test/',
'ffhq_train_sketch': 'data/train/ffhq/realign640x640sketch/',
'ffhq_test_sketch': 'data/train/ffhq/realign640x640sketchtest/',
'ffhq_train_segmentation': 'data/train/ffhq/realign320x320mask/',
'ffhq_test_segmentation': 'data/train/ffhq/realign320x320masktest/',
'toonify_in': 'data/train/pixar/trainA/',
'toonify_out': 'data/train/pixar/trainB/',
'toonify_test_in': 'data/train/pixar/testA/',
'toonify_test_out': 'data/train/testB/',
}
model_paths = {
'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt',
'ir_se50': 'pretrained_models/model_ir_se50.pth',
'circular_face': 'pretrained_models/CurricularFace_Backbone.pth',
'mtcnn_pnet': 'pretrained_models/mtcnn/pnet.npy',
'mtcnn_rnet': 'pretrained_models/mtcnn/rnet.npy',
'mtcnn_onet': 'pretrained_models/mtcnn/onet.npy',
'shape_predictor': 'shape_predictor_68_face_landmarks.dat',
'moco': 'pretrained_models/moco_v2_800ep_pretrain.pth.tar'
}
================================================
FILE: upsampler/configs/transforms_config.py
================================================
from abc import abstractmethod
import torchvision.transforms as transforms
from datasets import augmentations
class TransformsConfig(object):
def __init__(self, opts):
self.opts = opts
@abstractmethod
def get_transforms(self):
pass
class EncodeTransforms(TransformsConfig):
def __init__(self, opts):
super(EncodeTransforms, self).__init__(opts)
def get_transforms(self):
transforms_dict = {
'transform_gt_train': transforms.Compose([
transforms.Resize((320, 320)),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_source': None,
'transform_test': transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_inference': transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
}
return transforms_dict
class FrontalizationTransforms(TransformsConfig):
def __init__(self, opts):
super(FrontalizationTransforms, self).__init__(opts)
def get_transforms(self):
transforms_dict = {
'transform_gt_train': transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_source': transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_test': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_inference': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
}
return transforms_dict
class SketchToImageTransforms(TransformsConfig):
def __init__(self, opts):
super(SketchToImageTransforms, self).__init__(opts)
def get_transforms(self):
transforms_dict = {
'transform_gt_train': transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_source': transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor()]),
'transform_test': transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_inference': transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor()]),
}
return transforms_dict
class SegToImageTransforms(TransformsConfig):
def __init__(self, opts):
super(SegToImageTransforms, self).__init__(opts)
def get_transforms(self):
transforms_dict = {
'transform_gt_train': transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_source': transforms.Compose([
transforms.Resize((320, 320)),
augmentations.ToOneHot(self.opts.label_nc),
transforms.ToTensor()]),
'transform_test': transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_inference': transforms.Compose([
transforms.Resize((320, 320)),
augmentations.ToOneHot(self.opts.label_nc),
transforms.ToTensor()])
}
return transforms_dict
class SuperResTransforms(TransformsConfig):
def __init__(self, opts):
super(SuperResTransforms, self).__init__(opts)
def get_transforms(self):
if self.opts.resize_factors is None:
self.opts.resize_factors = '1,2,4,8,16,32'
factors = [int(f) for f in self.opts.resize_factors.split(",")]
print("Performing down-sampling with factors: {}".format(factors))
transforms_dict = {
'transform_gt_train': transforms.Compose([
transforms.Resize((1280, 1280)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_source': transforms.Compose([
transforms.Resize((320, 320)),
augmentations.BilinearResize(factors=factors),
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_test': transforms.Compose([
transforms.Resize((1280, 1280)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_inference': transforms.Compose([
transforms.Resize((320, 320)),
augmentations.BilinearResize(factors=factors),
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
}
return transforms_dict
class SuperResTransforms_320(TransformsConfig):
def __init__(self, opts):
super(SuperResTransforms_320, self).__init__(opts)
def get_transforms(self):
if self.opts.resize_factors is None:
self.opts.resize_factors = '1,2,4,8,16,32'
factors = [int(f) for f in self.opts.resize_factors.split(",")]
print("Performing down-sampling with factors: {}".format(factors))
transforms_dict = {
'transform_gt_train': transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_source': transforms.Compose([
transforms.Resize((320, 320)),
augmentations.BilinearResize(factors=factors),
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_test': transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_inference': transforms.Compose([
transforms.Resize((320, 320)),
augmentations.BilinearResize(factors=factors),
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
}
return transforms_dict
class ToonifyTransforms(TransformsConfig):
def __init__(self, opts):
super(ToonifyTransforms, self).__init__(opts)
def get_transforms(self):
transforms_dict = {
'transform_gt_train': transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_source': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_test': transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_inference': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
}
return transforms_dict
class EditingTransforms(TransformsConfig):
def __init__(self, opts):
super(EditingTransforms, self).__init__(opts)
def get_transforms(self):
transforms_dict = {
'transform_gt_train': transforms.Compose([
transforms.Resize((1280, 1280)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_source': transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_test': transforms.Compose([
transforms.Resize((1280, 1280)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_inference': transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
}
return transforms_dict
================================================
FILE: upsampler/criteria/__init__.py
================================================
================================================
FILE: upsampler/criteria/id_loss.py
================================================
import torch
from torch import nn
from configs.paths_config import model_paths
from models.encoders.model_irse import Backbone
class IDLoss(nn.Module):
def __init__(self):
super(IDLoss, self).__init__()
print('Loading ResNet ArcFace')
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
self.facenet.load_state_dict(torch.load(model_paths['ir_se50']))
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
self.facenet.eval()
def extract_feats(self, x):
x = x[:, :, 35:223, 32:220] # Crop interesting region
x = self.face_pool(x)
x_feats = self.facenet(x)
return x_feats
def forward(self, y_hat, y, x):
n_samples = x.shape[0]
x_feats = self.extract_feats(x)
y_feats = self.extract_feats(y) # Otherwise use the feature from there
y_hat_feats = self.extract_feats(y_hat)
y_feats = y_feats.detach()
loss = 0
sim_improvement = 0
id_logs = []
count = 0
for i in range(n_samples):
diff_target = y_hat_feats[i].dot(y_feats[i])
diff_input = y_hat_feats[i].dot(x_feats[i])
diff_views = y_feats[i].dot(x_feats[i])
id_logs.append({'diff_target': float(diff_target),
'diff_input': float(diff_input),
'diff_views': float(diff_views)})
loss += 1 - diff_target
id_diff = float(diff_target) - float(diff_views)
sim_improvement += id_diff
count += 1
return loss / count, sim_improvement / count, id_logs
================================================
FILE: upsampler/criteria/lpips/__init__.py
================================================
================================================
FILE: upsampler/criteria/lpips/lpips.py
================================================
import torch
import torch.nn as nn
from criteria.lpips.networks import get_network, LinLayers
from criteria.lpips.utils import get_state_dict
class LPIPS(nn.Module):
r"""Creates a criterion that measures
Learned Perceptual Image Patch Similarity (LPIPS).
Arguments:
net_type (str): the network type to compare the features:
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
version (str): the version of LPIPS. Default: 0.1.
"""
def __init__(self, net_type: str = 'alex', version: str = '0.1'):
assert version in ['0.1'], 'v0.1 is only supported now'
super(LPIPS, self).__init__()
# pretrained network
self.net = get_network(net_type).to("cuda")
# linear layers
self.lin = LinLayers(self.net.n_channels_list).to("cuda")
self.lin.load_state_dict(get_state_dict(net_type, version))
def forward(self, x: torch.Tensor, y: torch.Tensor):
feat_x, feat_y = self.net(x), self.net(y)
diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
return torch.sum(torch.cat(res, 0)) / x.shape[0]
================================================
FILE: upsampler/criteria/lpips/networks.py
================================================
from typing import Sequence
from itertools import chain
import torch
import torch.nn as nn
from torchvision import models
from criteria.lpips.utils import normalize_activation
def get_network(net_type: str):
if net_type == 'alex':
return AlexNet()
elif net_type == 'squeeze':
return SqueezeNet()
elif net_type == 'vgg':
return VGG16()
else:
raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
class LinLayers(nn.ModuleList):
def __init__(self, n_channels_list: Sequence[int]):
super(LinLayers, self).__init__([
nn.Sequential(
nn.Identity(),
nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
) for nc in n_channels_list
])
for param in self.parameters():
param.requires_grad = False
class BaseNet(nn.Module):
def __init__(self):
super(BaseNet, self).__init__()
# register buffer
self.register_buffer(
'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
self.register_buffer(
'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
def set_requires_grad(self, state: bool):
for param in chain(self.parameters(), self.buffers()):
param.requires_grad = state
def z_score(self, x: torch.Tensor):
return (x - self.mean) / self.std
def forward(self, x: torch.Tensor):
x = self.z_score(x)
output = []
for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
x = layer(x)
if i in self.target_layers:
output.append(normalize_activation(x))
if len(output) == len(self.target_layers):
break
return output
class SqueezeNet(BaseNet):
def __init__(self):
super(SqueezeNet, self).__init__()
self.layers = models.squeezenet1_1(True).features
self.target_layers = [2, 5, 8, 10, 11, 12, 13]
self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
self.set_requires_grad(False)
class AlexNet(BaseNet):
def __init__(self):
super(AlexNet, self).__init__()
self.layers = models.alexnet(True).features
self.target_layers = [2, 5, 8, 10, 12]
self.n_channels_list = [64, 192, 384, 256, 256]
self.set_requires_grad(False)
class VGG16(BaseNet):
def __init__(self):
super(VGG16, self).__init__()
self.layers = models.vgg16(True).features
self.target_layers = [4, 9, 16, 23, 30]
self.n_channels_list = [64, 128, 256, 512, 512]
self.set_requires_grad(False)
================================================
FILE: upsampler/criteria/lpips/utils.py
================================================
from collections import OrderedDict
import torch
def normalize_activation(x, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
return x / (norm_factor + eps)
def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
# build url
url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
+ f'master/lpips/weights/v{version}/{net_type}.pth'
# download
old_state_dict = torch.hub.load_state_dict_from_url(
url, progress=True,
map_location=None if torch.cuda.is_available() else torch.device('cpu')
)
# rename keys
new_state_dict = OrderedDict()
for key, val in old_state_dict.items():
new_key = key
new_key = new_key.replace('lin', '')
new_key = new_key.replace('model.', '')
new_state_dict[new_key] = val
return new_state_dict
================================================
FILE: upsampler/criteria/moco_loss.py
================================================
import torch
from torch import nn
import torch.nn.functional as F
from configs.paths_config import model_paths
class MocoLoss(nn.Module):
def __init__(self):
super(MocoLoss, self).__init__()
print("Loading MOCO model from path: {}".format(model_paths["moco"]))
self.model = self.__load_model()
self.model.cuda()
self.model.eval()
@staticmethod
def __load_model():
import torchvision.models as models
model = models.__dict__["resnet50"]()
# freeze all layers but the last fc
for name, param in model.named_parameters():
if name not in ['fc.weight', 'fc.bias']:
param.requires_grad = False
checkpoint = torch.load(model_paths['moco'], map_location="cpu")
state_dict = checkpoint['state_dict']
# rename moco pre-trained keys
for k in list(state_dict.keys()):
# retain only encoder_q up to before the embedding layer
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
# remove prefix
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
# delete renamed or unused k
del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
# remove output layer
model = nn.Sequential(*list(model.children())[:-1]).cuda()
return model
def extract_feats(self, x):
x = F.interpolate(x, size=224)
x_feats = self.model(x)
x_feats = nn.functional.normalize(x_feats, dim=1)
x_feats = x_feats.squeeze()
return x_feats
def forward(self, y_hat, y, x):
n_samples = x.shape[0]
x_feats = self.extract_feats(x)
y_feats = self.extract_feats(y)
y_hat_feats = self.extract_feats(y_hat)
y_feats = y_feats.detach()
loss = 0
sim_improvement = 0
sim_logs = []
count = 0
for i in range(n_samples):
diff_target = y_hat_feats[i].dot(y_feats[i])
diff_input = y_hat_feats[i].dot(x_feats[i])
diff_views = y_feats[i].dot(x_feats[i])
sim_logs.append({'diff_target': float(diff_target),
'diff_input': float(diff_input),
'diff_views': float(diff_views)})
loss += 1 - diff_target
sim_diff = float(diff_target) - float(diff_views)
sim_improvement += sim_diff
count += 1
return loss / count, sim_improvement / count, sim_logs
================================================
FILE: upsampler/criteria/w_norm.py
================================================
import torch
from torch import nn
class WNormLoss(nn.Module):
def __init__(self, start_from_latent_avg=True):
super(WNormLoss, self).__init__()
self.start_from_latent_avg = start_from_latent_avg
def forward(self, latent, latent_avg=None):
if self.start_from_latent_avg:
latent = latent - latent_avg
return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0]
================================================
FILE: upsampler/datasets/__init__.py
================================================
================================================
FILE: upsampler/datasets/augmentations.py
================================================
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
class ToOneHot(object):
""" Convert the input PIL image to a one-hot torch tensor """
def __init__(self, n_classes=None):
self.n_classes = n_classes
def onehot_initialization(self, a):
if self.n_classes is None:
self.n_classes = len(np.unique(a))
out = np.zeros(a.shape + (self.n_classes, ), dtype=int)
out[self.__all_idx(a, axis=2)] = 1
return out
def __all_idx(self, idx, axis):
grid = np.ogrid[tuple(map(slice, idx.shape))]
grid.insert(axis, idx)
return tuple(grid)
def __call__(self, img):
img = np.array(img)
one_hot = self.onehot_initialization(img)
return one_hot
class BilinearResize(object):
def __init__(self, factors=[1, 2, 4, 8, 16, 32]):
self.factors = factors
def __call__(self, image):
factor = np.random.choice(self.factors, size=1)[0]
D = BicubicDownSample(factor=factor, cuda=False)
img_tensor = transforms.ToTensor()(image).unsqueeze(0)
img_tensor_lr = D(img_tensor)[0].clamp(0, 1)
img_low_res = transforms.ToPILImage()(img_tensor_lr)
return img_low_res
class BicubicDownSample(nn.Module):
def bicubic_kernel(self, x, a=-0.50):
"""
This equation is exactly copied from the website below:
https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic
"""
abs_x = torch.abs(x)
if abs_x <= 1.:
return (a + 2.) * torch.pow(abs_x, 3.) - (a + 3.) * torch.pow(abs_x, 2.) + 1
elif 1. < abs_x < 2.:
return a * torch.pow(abs_x, 3) - 5. * a * torch.pow(abs_x, 2.) + 8. * a * abs_x - 4. * a
else:
return 0.0
def __init__(self, factor=4, cuda=True, padding='reflect'):
super().__init__()
self.factor = factor
size = factor * 4
k = torch.tensor([self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor)
for i in range(size)], dtype=torch.float32)
k = k / torch.sum(k)
k1 = torch.reshape(k, shape=(1, 1, size, 1))
self.k1 = torch.cat([k1, k1, k1], dim=0)
k2 = torch.reshape(k, shape=(1, 1, 1, size))
self.k2 = torch.cat([k2, k2, k2], dim=0)
self.cuda = '.cuda' if cuda else ''
self.padding = padding
for param in self.parameters():
param.requires_grad = False
def forward(self, x, nhwc=False, clip_round=False, byte_output=False):
filter_height = self.factor * 4
filter_width = self.factor * 4
stride = self.factor
pad_along_height = max(filter_height - stride, 0)
pad_along_width = max(filter_width - stride, 0)
filters1 = self.k1.type('torch{}.FloatTensor'.format(self.cuda))
filters2 = self.k2.type('torch{}.FloatTensor'.format(self.cuda))
# compute actual padding values for each side
pad_top = pad_along_height // 2
pad_bottom = pad_along_height - pad_top
pad_left = pad_along_width // 2
pad_right = pad_along_width - pad_left
# apply mirror padding
if nhwc:
x = torch.transpose(torch.transpose(x, 2, 3), 1, 2) # NHWC to NCHW
# downscaling performed by 1-d convolution
x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding)
x = F.conv2d(input=x, weight=filters1, stride=(stride, 1), groups=3)
if clip_round:
x = torch.clamp(torch.round(x), 0.0, 255.)
x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding)
x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3)
if clip_round:
x = torch.clamp(torch.round(x), 0.0, 255.)
if nhwc:
x = torch.transpose(torch.transpose(x, 1, 3), 1, 2)
if byte_output:
return x.type('torch.ByteTensor'.format(self.cuda))
else:
return x
================================================
FILE: upsampler/datasets/ffhq_degradation_dataset.py
================================================
import cv2
import math
import numpy as np
import os.path as osp
import torch
import torch.utils.data as data
from basicsr.data import degradations as degradations
from basicsr.data.data_util import paths_from_folder
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
normalize)
@DATASET_REGISTRY.register()
class FFHQDegradationDataset(data.Dataset):
"""FFHQ dataset for GFPGAN.
It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
io_backend (dict): IO backend type and other kwarg.
mean (list | tuple): Image mean.
std (list | tuple): Image std.
use_hflip (bool): Whether to horizontally flip.
Please see more options in the codes.
"""
def __init__(self, opt):
super(FFHQDegradationDataset, self).__init__()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.gt_folder = opt['dataroot_gt']
self.mean = opt['mean']
self.std = opt['std']
self.out_size = opt['out_size']
self.crop_components = opt.get('crop_components', False) # facial components
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions
if self.crop_components:
# load component list from a pre-process pth files
self.components_list = torch.load(opt.get('component_path'))
# file client (lmdb io backend)
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = self.gt_folder
if not self.gt_folder.endswith('.lmdb'):
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
self.paths = [line.split('.')[0] for line in fin]
else:
# disk backend: scan file list from a folder
self.paths = paths_from_folder(self.gt_folder)
# degradation configurations
self.blur_kernel_size = opt['blur_kernel_size']
self.kernel_list = opt['kernel_list']
self.kernel_prob = opt['kernel_prob']
self.blur_sigma = opt['blur_sigma']
self.downsample_range = opt['downsample_range']
self.noise_range = opt['noise_range']
self.jpeg_range = opt['jpeg_range']
# color jitter
self.color_jitter_prob = opt.get('color_jitter_prob')
self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
self.color_jitter_shift = opt.get('color_jitter_shift', 20)
# to gray
self.gray_prob = opt.get('gray_prob')
logger = get_root_logger()
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
if self.color_jitter_prob is not None:
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
if self.gray_prob is not None:
logger.info(f'Use random gray. Prob: {self.gray_prob}')
self.color_jitter_shift /= 255.
@staticmethod
def color_jitter(img, shift):
"""jitter color: randomly jitter the RGB values, in numpy formats"""
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
img = img + jitter_val
img = np.clip(img, 0, 1)
return img
@staticmethod
def color_jitter_pt(img, brightness, contrast, saturation, hue):
"""jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
fn_idx = torch.randperm(4)
for fn_id in fn_idx:
if fn_id == 0 and brightness is not None:
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
img = adjust_brightness(img, brightness_factor)
if fn_id == 1 and contrast is not None:
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
img = adjust_contrast(img, contrast_factor)
if fn_id == 2 and saturation is not None:
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
img = adjust_saturation(img, saturation_factor)
if fn_id == 3 and hue is not None:
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
img = adjust_hue(img, hue_factor)
return img
def get_component_coordinates(self, index, status):
"""Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
components_bbox = self.components_list[f'{index:08d}']
if status[0]: # hflip
# exchange right and left eye
tmp = components_bbox['left_eye']
components_bbox['left_eye'] = components_bbox['right_eye']
components_bbox['right_eye'] = tmp
# modify the width coordinate
components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]
# get coordinates
locations = []
for part in ['left_eye', 'right_eye', 'mouth']:
mean = components_bbox[part][0:2]
mean[0] = mean[0] * 2 + 128 ########
mean[1] = mean[1] * 2 + 128 ########
half_len = components_bbox[part][2] * 2 ########
if 'eye' in part:
half_len *= self.eye_enlarge_ratio
loc = np.hstack((mean - half_len + 1, mean + half_len))
loc = torch.from_numpy(loc).float()
locations.append(loc)
return locations
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# load gt image
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
gt_path = self.paths[index]
img_bytes = self.file_client.get(gt_path)
img_gt = imfrombytes(img_bytes, float3
gitextract_xyiv4zse/
├── .gitignore
├── README.md
├── animate.py
├── augmentation.py
├── config/
│ └── mix-resolution.yml
├── demo.py
├── environment.yaml
├── frames_dataset.py
├── logger.py
├── modules/
│ ├── dense_motion.py
│ ├── discriminator.py
│ ├── generator.py
│ ├── hopenet.py
│ ├── keypoint_detector.py
│ ├── model.py
│ └── util.py
├── run_demo.sh
├── sync_batchnorm/
│ ├── __init__.py
│ ├── batchnorm.py
│ ├── comm.py
│ ├── replicate.py
│ └── unittest.py
└── upsampler/
├── app_gradio.py
├── configs/
│ ├── __init__.py
│ ├── data_configs.py
│ ├── dataset_config.yml
│ ├── paths_config.py
│ └── transforms_config.py
├── criteria/
│ ├── __init__.py
│ ├── id_loss.py
│ ├── lpips/
│ │ ├── __init__.py
│ │ ├── lpips.py
│ │ ├── networks.py
│ │ └── utils.py
│ ├── moco_loss.py
│ └── w_norm.py
├── datasets/
│ ├── __init__.py
│ ├── augmentations.py
│ ├── ffhq_degradation_dataset.py
│ ├── gt_res_dataset.py
│ ├── images_dataset.py
│ └── inference_dataset.py
├── image_translation.py
├── inference_playground.ipynb
├── inversion.py
├── latent_optimization.py
├── models/
│ ├── __init__.py
│ ├── bisenet/
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── model.py
│ │ └── resnet.py
│ ├── encoders/
│ │ ├── __init__.py
│ │ ├── helpers.py
│ │ ├── model_irse.py
│ │ └── psp_encoders.py
│ ├── mtcnn/
│ │ ├── __init__.py
│ │ ├── mtcnn.py
│ │ └── mtcnn_pytorch/
│ │ ├── __init__.py
│ │ └── src/
│ │ ├── __init__.py
│ │ ├── align_trans.py
│ │ ├── box_utils.py
│ │ ├── detector.py
│ │ ├── first_stage.py
│ │ ├── get_nets.py
│ │ ├── matlab_cp2tform.py
│ │ ├── visualization_utils.py
│ │ └── weights/
│ │ ├── onet.npy
│ │ ├── pnet.npy
│ │ └── rnet.npy
│ ├── psp.py
│ └── stylegan2/
│ ├── __init__.py
│ ├── lpips/
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── dist_model.py
│ │ ├── networks_basic.py
│ │ ├── pretrained_networks.py
│ │ └── weights/
│ │ ├── v0.0/
│ │ │ ├── alex.pth
│ │ │ ├── squeeze.pth
│ │ │ └── vgg.pth
│ │ └── v0.1/
│ │ ├── alex.pth
│ │ ├── squeeze.pth
│ │ └── vgg.pth
│ ├── model.py
│ ├── op/
│ │ ├── __init__.py
│ │ ├── conv2d_gradfix.py
│ │ ├── fused_act.py
│ │ ├── readme.md
│ │ └── upfirdn2d.py
│ ├── op2/
│ │ ├── __init__.py
│ │ ├── upfirdn2d.cpp
│ │ ├── upfirdn2d.py
│ │ └── upfirdn2d_kernel.cu
│ ├── op_old/
│ │ ├── __init__.py
│ │ ├── fused_act.py
│ │ ├── fused_bias_act.cpp
│ │ ├── fused_bias_act_kernel.cu
│ │ ├── upfirdn2d.cpp
│ │ ├── upfirdn2d.py
│ │ └── upfirdn2d_kernel.cu
│ └── simple_augment.py
├── options/
│ ├── __init__.py
│ ├── test_options.py
│ └── train_options.py
├── output/
│ └── ILip77SbmOE_inversion.pt
├── pretrained_models/
│ └── readme.md
├── scripts/
│ ├── align_all_parallel.py
│ ├── calc_id_loss_parallel.py
│ ├── calc_losses_on_images.py
│ ├── download_ffhq1280.py
│ ├── generate_sketch_data.py
│ ├── inference.py
│ ├── pretrain.py
│ ├── style_mixing.py
│ └── train.py
├── training/
│ ├── __init__.py
│ ├── coach.py
│ └── ranger.py
├── utils/
│ ├── __init__.py
│ ├── common.py
│ ├── data_utils.py
│ ├── inference_utils.py
│ ├── train_utils.py
│ └── wandb_utils.py
├── video_editing.py
└── webUI/
├── app_task.py
└── styleganex_model.py
SYMBOL INDEX (785 symbols across 83 files)
FILE: animate.py
function normalize_kp (line 14) | def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_moveme...
FILE: augmentation.py
function crop_clip (line 20) | def crop_clip(clip, min_h, min_w, h, w):
function pad_clip (line 34) | def pad_clip(clip, h, w):
function resize_clip (line 42) | def resize_clip(clip, size, interpolation='bilinear'):
function get_resize_sizes (line 81) | def get_resize_sizes(im_h, im_w, size):
class RandomFlip (line 91) | class RandomFlip(object):
method __init__ (line 92) | def __init__(self, time_flip=False, horizontal_flip=False):
method __call__ (line 96) | def __call__(self, clip):
class RandomResize (line 105) | class RandomResize(object):
method __init__ (line 115) | def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
method __call__ (line 119) | def __call__(self, clip):
class RandomCrop (line 136) | class RandomCrop(object):
method __init__ (line 143) | def __init__(self, size):
method __call__ (line 149) | def __call__(self, clip):
class RandomRotation (line 175) | class RandomRotation(object):
method __init__ (line 184) | def __init__(self, degrees):
method __call__ (line 197) | def __call__(self, clip):
class ColorJitter (line 217) | class ColorJitter(object):
method __init__ (line 230) | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
method get_params (line 236) | def get_params(self, brightness, contrast, saturation, hue):
method __call__ (line 261) | def __call__(self, clip):
class AllAugmentationTransform (line 323) | class AllAugmentationTransform:
method __init__ (line 324) | def __init__(self, resize_param=None, rotation_param=None, flip_param=...
method __call__ (line 342) | def __call__(self, clip):
FILE: demo.py
function load_checkpoints (line 26) | def load_checkpoints(config_path, checkpoint_path, gen, cpu=False):
function headpose_pred_to_degree (line 72) | def headpose_pred_to_degree(pred):
function get_rotation_matrix (line 112) | def get_rotation_matrix(yaw, pitch, roll):
function keypoint_transformation (line 140) | def keypoint_transformation(kp_canonical, he, estimate_jacobian=True, fr...
function make_animation (line 187) | def make_animation(source_image, driving_video, generator, kp_detector, ...
function find_best_frame (line 219) | def find_best_frame(source, driving, cpu=False):
FILE: frames_dataset.py
function read_video (line 26) | def read_video(name, frame_shape):
class FramesDataset (line 69) | class FramesDataset(Dataset):
method __init__ (line 77) | def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=Fa...
method __len__ (line 143) | def __len__(self):
method __getitem__ (line 146) | def __getitem__(self, idx):
class DatasetRepeater (line 267) | class DatasetRepeater(Dataset):
method __init__ (line 272) | def __init__(self, dataset, num_repeats=100):
method __len__ (line 276) | def __len__(self):
method __getitem__ (line 279) | def __getitem__(self, idx):
FILE: logger.py
class Logger (line 13) | class Logger:
method __init__ (line 14) | def __init__(self, log_dir, checkpoint_freq=100, visualizer_params=Non...
method log_scores (line 29) | def log_scores(self, loss_names):
method visualize_rec (line 39) | def visualize_rec(self, inp, out):
method save_cpk (line 43) | def save_cpk(self, emergent=False):
method load_cpk (line 51) | def load_cpk(checkpoint_path, generator=None, discriminator=None, kp_d...
method __enter__ (line 79) | def __enter__(self):
method __exit__ (line 82) | def __exit__(self, exc_type, exc_val, exc_tb):
method log_iter (line 87) | def log_iter(self, losses):
method log_epoch (line 93) | def log_epoch(self, epoch, models, inp, out):
class Visualizer (line 102) | class Visualizer:
method __init__ (line 103) | def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbo...
method draw_image_with_kp (line 108) | def draw_image_with_kp(self, image, kp_array):
method create_image_column_with_kp (line 118) | def create_image_column_with_kp(self, images, kp):
method create_image_column (line 122) | def create_image_column(self, images):
method create_image_grid (line 129) | def create_image_grid(self, *args):
method visualize (line 138) | def visualize(self, driving, source, out):
FILE: modules/dense_motion.py
class DenseMotionNetwork (line 9) | class DenseMotionNetwork(nn.Module):
method __init__ (line 14) | def __init__(self, block_expansion, num_blocks, max_features, num_kp, ...
method create_sparse_motions (line 34) | def create_sparse_motions(self, feature, kp_driving, kp_source):
method create_deformed_feature (line 71) | def create_deformed_feature(self, feature, sparse_motions):
method create_heatmap_representations (line 80) | def create_heatmap_representations(self, feature, kp_driving, kp_source):
method forward (line 92) | def forward(self, feature, kp_driving, kp_source):
FILE: modules/discriminator.py
class DownBlock2d (line 7) | class DownBlock2d(nn.Module):
method __init__ (line 12) | def __init__(self, in_features, out_features, norm=False, kernel_size=...
method forward (line 25) | def forward(self, x):
class Discriminator (line 36) | class Discriminator(nn.Module):
method __init__ (line 41) | def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, m...
method forward (line 57) | def forward(self, x):
class MultiScaleDiscriminator (line 69) | class MultiScaleDiscriminator(nn.Module):
method __init__ (line 74) | def __init__(self, scales=(), **kwargs):
method forward (line 82) | def forward(self, x):
FILE: modules/generator.py
class OcclusionAwareGenerator (line 10) | class OcclusionAwareGenerator(nn.Module):
method __init__ (line 15) | def __init__(self, image_channel, feature_channel, num_kp, block_expan...
method deform_input (line 63) | def deform_input(self, inp, deformation):
method forward (line 72) | def forward(self, source_image, kp_driving, kp_source):
class SPADEDecoder (line 124) | class SPADEDecoder(nn.Module):
method __init__ (line 125) | def __init__(self):
method forward (line 144) | def forward(self, feature):
class OcclusionAwareSPADEGenerator (line 165) | class OcclusionAwareSPADEGenerator(nn.Module):
method __init__ (line 167) | def __init__(self, image_channel, feature_channel, num_kp, block_expan...
method deform_input (line 205) | def deform_input(self, inp, deformation):
method forward (line 214) | def forward(self, source_image, frame_idx, kp_driving, kp_source):
FILE: modules/hopenet.py
class Hopenet (line 7) | class Hopenet(nn.Module):
method __init__ (line 10) | def __init__(self, block, layers, num_bins):
method _make_layer (line 38) | def _make_layer(self, block, planes, blocks, stride=1):
method forward (line 55) | def forward(self, x):
class ResNet (line 74) | class ResNet(nn.Module):
method __init__ (line 76) | def __init__(self, block, layers, num_classes=1000):
method _make_layer (line 99) | def _make_layer(self, block, planes, blocks, stride=1):
method forward (line 116) | def forward(self, x):
class AlexNet (line 132) | class AlexNet(nn.Module):
method __init__ (line 135) | def __init__(self, num_bins):
method forward (line 164) | def forward(self, x):
FILE: modules/keypoint_detector.py
class KPDetector (line 9) | class KPDetector(nn.Module):
method __init__ (line 14) | def __init__(self, block_expansion, feature_channel, num_kp, image_cha...
method gaussian2kp (line 44) | def gaussian2kp(self, heatmap):
method forward (line 56) | def forward(self, x):
class HEEstimator (line 85) | class HEEstimator(nn.Module):
method __init__ (line 90) | def __init__(self, block_expansion, feature_channel, num_kp, image_cha...
method forward (line 136) | def forward(self, x):
FILE: modules/model.py
class Vgg19 (line 18) | class Vgg19(torch.nn.Module):
method __init__ (line 22) | def __init__(self, requires_grad=False):
method forward (line 50) | def forward(self, X):
class ImagePyramide (line 61) | class ImagePyramide(torch.nn.Module):
method __init__ (line 65) | def __init__(self, scales, num_channels):
method forward (line 72) | def forward(self, x):
class Transform (line 79) | class Transform:
method __init__ (line 83) | def __init__(self, bs, **kwargs):
method transform_frame (line 97) | def transform_frame(self, frame):
method warp_coordinates (line 103) | def warp_coordinates(self, coordinates):
method jacobian (line 123) | def jacobian(self, coordinates):
function detach_kp (line 131) | def detach_kp(kp):
function headpose_pred_to_degree (line 135) | def headpose_pred_to_degree(pred):
function get_rotation_matrix (line 175) | def get_rotation_matrix(yaw, pitch, roll):
function keypoint_transformation (line 203) | def keypoint_transformation(kp_canonical, he, estimate_jacobian=True):
class GeneratorFullModel (line 233) | class GeneratorFullModel(torch.nn.Module):
method __init__ (line 238) | def __init__(self, kp_extractor, he_estimator, generator, discriminato...
method forward (line 289) | def forward(self, x, config):
class DiscriminatorFullModel (line 521) | class DiscriminatorFullModel(torch.nn.Module):
method __init__ (line 526) | def __init__(self, kp_extractor, generator, discriminator, train_params):
method get_zero_tensor (line 541) | def get_zero_tensor(self, input):
method forward (line 547) | def forward(self, x, generated):
FILE: modules/util.py
function kp2gaussian (line 13) | def kp2gaussian(kp, spatial_size, kp_variance):
function make_coordinate_grid_2d (line 36) | def make_coordinate_grid_2d(spatial_size, type):
function make_coordinate_grid (line 55) | def make_coordinate_grid(spatial_size, type):
class ResBottleneck (line 74) | class ResBottleneck(nn.Module):
method __init__ (line 75) | def __init__(self, in_features, stride):
method forward (line 89) | def forward(self, x):
class ResBlock2d (line 106) | class ResBlock2d(nn.Module):
method __init__ (line 111) | def __init__(self, in_features, kernel_size, padding):
method forward (line 120) | def forward(self, x):
class ResBlock3d (line 131) | class ResBlock3d(nn.Module):
method __init__ (line 136) | def __init__(self, in_features, kernel_size, padding):
method forward (line 145) | def forward(self, x):
class UpBlock2d (line 156) | class UpBlock2d(nn.Module):
method __init__ (line 161) | def __init__(self, in_features, out_features, kernel_size=3, padding=1...
method forward (line 168) | def forward(self, x):
class UpBlock3d (line 175) | class UpBlock3d(nn.Module):
method __init__ (line 180) | def __init__(self, in_features, out_features, kernel_size=3, padding=1...
method forward (line 187) | def forward(self, x):
class DownBlock2d (line 196) | class DownBlock2d(nn.Module):
method __init__ (line 201) | def __init__(self, in_features, out_features, kernel_size=3, padding=1...
method forward (line 208) | def forward(self, x):
class DownBlock3d (line 216) | class DownBlock3d(nn.Module):
method __init__ (line 221) | def __init__(self, in_features, out_features, kernel_size=3, padding=1...
method forward (line 232) | def forward(self, x):
class SameBlock2d (line 240) | class SameBlock2d(nn.Module):
method __init__ (line 245) | def __init__(self, in_features, out_features, groups=1, kernel_size=3,...
method forward (line 255) | def forward(self, x):
class Encoder (line 262) | class Encoder(nn.Module):
method __init__ (line 267) | def __init__(self, block_expansion, in_features, num_blocks=3, max_fea...
method forward (line 277) | def forward(self, x):
class Decoder (line 284) | class Decoder(nn.Module):
method __init__ (line 289) | def __init__(self, block_expansion, in_features, num_blocks=3, max_fea...
method forward (line 306) | def forward(self, x):
class Hourglass (line 320) | class Hourglass(nn.Module):
method __init__ (line 325) | def __init__(self, block_expansion, in_features, num_blocks=3, max_fea...
method forward (line 331) | def forward(self, x):
class KPHourglass (line 335) | class KPHourglass(nn.Module):
method __init__ (line 340) | def __init__(self, block_expansion, in_features, reshape_features, res...
method forward (line 361) | def forward(self, x):
class AntiAliasInterpolation2d (line 372) | class AntiAliasInterpolation2d(nn.Module):
method __init__ (line 376) | def __init__(self, channels, scale):
method forward (line 410) | def forward(self, input):
class SPADE (line 421) | class SPADE(nn.Module):
method __init__ (line 422) | def __init__(self, norm_nc, label_nc):
method forward (line 434) | def forward(self, x, segmap):
class SPADEResnetBlock (line 444) | class SPADEResnetBlock(nn.Module):
method __init__ (line 445) | def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation...
method forward (line 468) | def forward(self, x, seg1):
method shortcut (line 475) | def shortcut(self, x, seg1):
method actvn (line 482) | def actvn(self, x):
FILE: sync_batchnorm/batchnorm.py
function _sum_ft (line 24) | def _sum_ft(tensor):
function _unsqueeze_ft (line 29) | def _unsqueeze_ft(tensor):
class _SynchronizedBatchNorm (line 38) | class _SynchronizedBatchNorm(_BatchNorm):
method __init__ (line 39) | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
method forward (line 48) | def forward(self, input):
method __data_parallel_replicate__ (line 80) | def __data_parallel_replicate__(self, ctx, copy_id):
method _data_parallel_master (line 90) | def _data_parallel_master(self, intermediates):
method _compute_mean_std (line 113) | def _compute_mean_std(self, sum_, ssum, size):
class SynchronizedBatchNorm1d (line 128) | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
method _check_input_dim (line 184) | def _check_input_dim(self, input):
class SynchronizedBatchNorm2d (line 191) | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
method _check_input_dim (line 247) | def _check_input_dim(self, input):
class SynchronizedBatchNorm3d (line 254) | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
method _check_input_dim (line 311) | def _check_input_dim(self, input):
FILE: sync_batchnorm/comm.py
class FutureResult (line 18) | class FutureResult(object):
method __init__ (line 21) | def __init__(self):
method put (line 26) | def put(self, result):
method get (line 32) | def get(self):
class SlavePipe (line 46) | class SlavePipe(_SlavePipeBase):
method run_slave (line 49) | def run_slave(self, msg):
class SyncMaster (line 56) | class SyncMaster(object):
method __init__ (line 67) | def __init__(self, master_callback):
method __getstate__ (line 78) | def __getstate__(self):
method __setstate__ (line 81) | def __setstate__(self, state):
method register_slave (line 84) | def register_slave(self, identifier):
method run_master (line 102) | def run_master(self, master_msg):
method nr_slaves (line 136) | def nr_slaves(self):
FILE: sync_batchnorm/replicate.py
class CallbackContext (line 23) | class CallbackContext(object):
function execute_replication_callbacks (line 27) | def execute_replication_callbacks(modules):
class DataParallelWithCallback (line 50) | class DataParallelWithCallback(DataParallel):
method replicate (line 64) | def replicate(self, module, device_ids):
function patch_replication_callback (line 70) | def patch_replication_callback(data_parallel):
FILE: sync_batchnorm/unittest.py
function as_numpy (line 17) | def as_numpy(v):
class TorchTestCase (line 23) | class TorchTestCase(unittest.TestCase):
method assertTensorClose (line 24) | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
FILE: upsampler/app_gradio.py
function main (line 48) | def main():
FILE: upsampler/configs/transforms_config.py
class TransformsConfig (line 6) | class TransformsConfig(object):
method __init__ (line 8) | def __init__(self, opts):
method get_transforms (line 12) | def get_transforms(self):
class EncodeTransforms (line 16) | class EncodeTransforms(TransformsConfig):
method __init__ (line 18) | def __init__(self, opts):
method get_transforms (line 21) | def get_transforms(self):
class FrontalizationTransforms (line 41) | class FrontalizationTransforms(TransformsConfig):
method __init__ (line 43) | def __init__(self, opts):
method get_transforms (line 46) | def get_transforms(self):
class SketchToImageTransforms (line 70) | class SketchToImageTransforms(TransformsConfig):
method __init__ (line 72) | def __init__(self, opts):
method get_transforms (line 75) | def get_transforms(self):
class SegToImageTransforms (line 95) | class SegToImageTransforms(TransformsConfig):
method __init__ (line 97) | def __init__(self, opts):
method get_transforms (line 100) | def get_transforms(self):
class SuperResTransforms (line 122) | class SuperResTransforms(TransformsConfig):
method __init__ (line 124) | def __init__(self, opts):
method get_transforms (line 127) | def get_transforms(self):
class SuperResTransforms_320 (line 157) | class SuperResTransforms_320(TransformsConfig):
method __init__ (line 159) | def __init__(self, opts):
method get_transforms (line 162) | def get_transforms(self):
class ToonifyTransforms (line 192) | class ToonifyTransforms(TransformsConfig):
method __init__ (line 194) | def __init__(self, opts):
method get_transforms (line 197) | def get_transforms(self):
class EditingTransforms (line 218) | class EditingTransforms(TransformsConfig):
method __init__ (line 220) | def __init__(self, opts):
method get_transforms (line 223) | def get_transforms(self):
FILE: upsampler/criteria/id_loss.py
class IDLoss (line 7) | class IDLoss(nn.Module):
method __init__ (line 8) | def __init__(self):
method extract_feats (line 16) | def extract_feats(self, x):
method forward (line 22) | def forward(self, y_hat, y, x):
FILE: upsampler/criteria/lpips/lpips.py
class LPIPS (line 8) | class LPIPS(nn.Module):
method __init__ (line 16) | def __init__(self, net_type: str = 'alex', version: str = '0.1'):
method forward (line 29) | def forward(self, x: torch.Tensor, y: torch.Tensor):
FILE: upsampler/criteria/lpips/networks.py
function get_network (line 12) | def get_network(net_type: str):
class LinLayers (line 23) | class LinLayers(nn.ModuleList):
method __init__ (line 24) | def __init__(self, n_channels_list: Sequence[int]):
class BaseNet (line 36) | class BaseNet(nn.Module):
method __init__ (line 37) | def __init__(self):
method set_requires_grad (line 46) | def set_requires_grad(self, state: bool):
method z_score (line 50) | def z_score(self, x: torch.Tensor):
method forward (line 53) | def forward(self, x: torch.Tensor):
class SqueezeNet (line 66) | class SqueezeNet(BaseNet):
method __init__ (line 67) | def __init__(self):
class AlexNet (line 77) | class AlexNet(BaseNet):
method __init__ (line 78) | def __init__(self):
class VGG16 (line 88) | class VGG16(BaseNet):
method __init__ (line 89) | def __init__(self):
FILE: upsampler/criteria/lpips/utils.py
function normalize_activation (line 6) | def normalize_activation(x, eps=1e-10):
function get_state_dict (line 11) | def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
FILE: upsampler/criteria/moco_loss.py
class MocoLoss (line 7) | class MocoLoss(nn.Module):
method __init__ (line 9) | def __init__(self):
method __load_model (line 17) | def __load_model():
method extract_feats (line 40) | def extract_feats(self, x):
method forward (line 47) | def forward(self, y_hat, y, x):
FILE: upsampler/criteria/w_norm.py
class WNormLoss (line 5) | class WNormLoss(nn.Module):
method __init__ (line 7) | def __init__(self, start_from_latent_avg=True):
method forward (line 11) | def forward(self, latent, latent_avg=None):
FILE: upsampler/datasets/augmentations.py
class ToOneHot (line 8) | class ToOneHot(object):
method __init__ (line 10) | def __init__(self, n_classes=None):
method onehot_initialization (line 13) | def onehot_initialization(self, a):
method __all_idx (line 20) | def __all_idx(self, idx, axis):
method __call__ (line 25) | def __call__(self, img):
class BilinearResize (line 31) | class BilinearResize(object):
method __init__ (line 32) | def __init__(self, factors=[1, 2, 4, 8, 16, 32]):
method __call__ (line 35) | def __call__(self, image):
class BicubicDownSample (line 44) | class BicubicDownSample(nn.Module):
method bicubic_kernel (line 45) | def bicubic_kernel(self, x, a=-0.50):
method __init__ (line 58) | def __init__(self, factor=4, cuda=True, padding='reflect'):
method forward (line 74) | def forward(self, x, nhwc=False, clip_round=False, byte_output=False):
FILE: upsampler/datasets/ffhq_degradation_dataset.py
class FFHQDegradationDataset (line 17) | class FFHQDegradationDataset(data.Dataset):
method __init__ (line 30) | def __init__(self, opt):
method color_jitter (line 89) | def color_jitter(img, shift):
method color_jitter_pt (line 97) | def color_jitter_pt(img, brightness, contrast, saturation, hue):
method get_component_coordinates (line 118) | def get_component_coordinates(self, index, status):
method __getitem__ (line 145) | def __getitem__(self, index):
method __len__ (line 234) | def __len__(self):
FILE: upsampler/datasets/gt_res_dataset.py
class GTResDataset (line 8) | class GTResDataset(Dataset):
method __init__ (line 10) | def __init__(self, root_path, gt_dir=None, transform=None, transform_t...
method __len__ (line 20) | def __len__(self):
method __getitem__ (line 23) | def __getitem__(self, index):
FILE: upsampler/datasets/images_dataset.py
class ImagesDataset (line 6) | class ImagesDataset(Dataset):
method __init__ (line 8) | def __init__(self, source_root, target_root, opts, target_transform=No...
method __len__ (line 15) | def __len__(self):
method __getitem__ (line 18) | def __getitem__(self, index):
FILE: upsampler/datasets/inference_dataset.py
class InferenceDataset (line 6) | class InferenceDataset(Dataset):
method __init__ (line 8) | def __init__(self, root, opts, transform=None):
method __len__ (line 13) | def __len__(self):
method __getitem__ (line 16) | def __getitem__(self, index):
FILE: upsampler/image_translation.py
class TestOptions (line 21) | class TestOptions():
method __init__ (line 22) | def __init__(self):
method parse (line 34) | def parse(self):
FILE: upsampler/inversion.py
class TestOptions (line 21) | class TestOptions():
method __init__ (line 22) | def __init__(self):
method parse (line 30) | def parse(self):
FILE: upsampler/latent_optimization.py
function latent_optimization (line 9) | def latent_optimization(frame, pspex, landmarkpredictor, step=500, devic...
FILE: upsampler/models/bisenet/model.py
class ConvBNReLU (line 14) | class ConvBNReLU(nn.Module):
method __init__ (line 15) | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args...
method forward (line 26) | def forward(self, x):
method init_weight (line 31) | def init_weight(self):
class BiSeNetOutput (line 37) | class BiSeNetOutput(nn.Module):
method __init__ (line 38) | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
method forward (line 44) | def forward(self, x):
method init_weight (line 49) | def init_weight(self):
method get_params (line 55) | def get_params(self):
class AttentionRefinementModule (line 67) | class AttentionRefinementModule(nn.Module):
method __init__ (line 68) | def __init__(self, in_chan, out_chan, *args, **kwargs):
method forward (line 76) | def forward(self, x):
method init_weight (line 85) | def init_weight(self):
class ContextPath (line 92) | class ContextPath(nn.Module):
method __init__ (line 93) | def __init__(self, *args, **kwargs):
method forward (line 104) | def forward(self, x):
method init_weight (line 127) | def init_weight(self):
method get_params (line 133) | def get_params(self):
class SpatialPath (line 146) | class SpatialPath(nn.Module):
method __init__ (line 147) | def __init__(self, *args, **kwargs):
method forward (line 155) | def forward(self, x):
method init_weight (line 162) | def init_weight(self):
method get_params (line 168) | def get_params(self):
class FeatureFusionModule (line 180) | class FeatureFusionModule(nn.Module):
method __init__ (line 181) | def __init__(self, in_chan, out_chan, *args, **kwargs):
method forward (line 200) | def forward(self, fsp, fcp):
method init_weight (line 212) | def init_weight(self):
method get_params (line 218) | def get_params(self):
class BiSeNet (line 230) | class BiSeNet(nn.Module):
method __init__ (line 231) | def __init__(self, n_classes, *args, **kwargs):
method forward (line 241) | def forward(self, x):
method init_weight (line 256) | def init_weight(self):
method get_params (line 262) | def get_params(self):
FILE: upsampler/models/bisenet/resnet.py
function conv3x3 (line 14) | def conv3x3(in_planes, out_planes, stride=1):
class BasicBlock (line 20) | class BasicBlock(nn.Module):
method __init__ (line 21) | def __init__(self, in_chan, out_chan, stride=1):
method forward (line 36) | def forward(self, x):
function create_layer_basic (line 51) | def create_layer_basic(in_chan, out_chan, bnum, stride=1):
class Resnet18 (line 58) | class Resnet18(nn.Module):
method __init__ (line 59) | def __init__(self):
method forward (line 71) | def forward(self, x):
method init_weight (line 82) | def init_weight(self):
method get_params (line 90) | def get_params(self):
FILE: upsampler/models/encoders/helpers.py
class Flatten (line 10) | class Flatten(Module):
method forward (line 11) | def forward(self, input):
function l2_norm (line 15) | def l2_norm(input, axis=1):
class Bottleneck (line 21) | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
function get_block (line 25) | def get_block(in_channel, depth, num_units, stride=2):
function get_blocks (line 29) | def get_blocks(num_layers):
class SEModule (line 56) | class SEModule(Module):
method __init__ (line 57) | def __init__(self, channels, reduction):
method forward (line 65) | def forward(self, x):
class bottleneck_IR (line 75) | class bottleneck_IR(Module):
method __init__ (line 76) | def __init__(self, in_channel, depth, stride):
method forward (line 91) | def forward(self, x):
class bottleneck_IR_SE (line 97) | class bottleneck_IR_SE(Module):
method __init__ (line 98) | def __init__(self, in_channel, depth, stride):
method forward (line 116) | def forward(self, x):
FILE: upsampler/models/encoders/model_irse.py
class Backbone (line 9) | class Backbone(Module):
method __init__ (line 10) | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, ...
method forward (line 44) | def forward(self, x):
function IR_50 (line 51) | def IR_50(input_size):
function IR_101 (line 57) | def IR_101(input_size):
function IR_152 (line 63) | def IR_152(input_size):
function IR_SE_50 (line 69) | def IR_SE_50(input_size):
function IR_SE_101 (line 75) | def IR_SE_101(input_size):
function IR_SE_152 (line 81) | def IR_SE_152(input_size):
FILE: upsampler/models/encoders/psp_encoders.py
class GradualStyleBlock (line 11) | class GradualStyleBlock(Module):
method __init__ (line 12) | def __init__(self, in_c, out_c, spatial, max_pooling=False):
method forward (line 29) | def forward(self, x):
class AdaptiveInstanceNorm (line 41) | class AdaptiveInstanceNorm(nn.Module):
method __init__ (line 42) | def __init__(self, fin, style_dim=512):
method forward (line 51) | def forward(self, input, style):
class FusionLayer (line 59) | class FusionLayer(Module): ##### modified
method __init__ (line 60) | def __init__(self, inchannel, outchannel, use_skip_torgb=True, use_att...
method forward (line 87) | def forward(self, feat, out, skip, editing_w=None):
class ResnetBlock (line 102) | class ResnetBlock(nn.Module):
method __init__ (line 103) | def __init__(self, dim):
method forward (line 111) | def forward(self, x):
class ResnetGenerator (line 118) | class ResnetGenerator(nn.Module):
method __init__ (line 119) | def __init__(self, in_channel=19, res_num=2):
method forward (line 137) | def forward(self, input):
class GradualStyleEncoder (line 140) | class GradualStyleEncoder(Module):
method __init__ (line 141) | def __init__(self, num_layers, mode='ir', opts=None):
method _upsample_add (line 194) | def _upsample_add(self, x, y):
method forward (line 219) | def forward(self, x, return_feat=False, return_full=False): ##### modi...
method get_feat (line 269) | def get_feat(self, x): ##### modified
class BackboneEncoderUsingLastLayerIntoW (line 290) | class BackboneEncoderUsingLastLayerIntoW(Module):
method __init__ (line 291) | def __init__(self, num_layers, mode='ir', opts=None):
method forward (line 314) | def forward(self, x):
class BackboneEncoderUsingLastLayerIntoWPlus (line 323) | class BackboneEncoderUsingLastLayerIntoWPlus(Module):
method __init__ (line 324) | def __init__(self, num_layers, mode='ir', opts=None):
method forward (line 351) | def forward(self, x):
FILE: upsampler/models/mtcnn/mtcnn.py
class MTCNN (line 12) | class MTCNN():
method __init__ (line 13) | def __init__(self):
method align (line 23) | def align(self, img):
method align_multi (line 31) | def align_multi(self, img, limit=None, min_face_size=30.0):
method detect_faces (line 45) | def detect_faces(self, image, min_face_size=20.0,
FILE: upsampler/models/mtcnn/mtcnn_pytorch/src/align_trans.py
class FaceWarpException (line 26) | class FaceWarpException(Exception):
method __str__ (line 27) | def __str__(self):
function get_reference_facial_points (line 32) | def get_reference_facial_points(output_size=None,
function get_affine_transform_matrix (line 163) | def get_affine_transform_matrix(src_pts, dst_pts):
function warp_and_crop_face (line 210) | def warp_and_crop_face(src_img,
FILE: upsampler/models/mtcnn/mtcnn_pytorch/src/box_utils.py
function nms (line 5) | def nms(boxes, overlap_threshold=0.5, mode='union'):
function convert_to_square (line 71) | def convert_to_square(bboxes):
function calibrate_box (line 94) | def calibrate_box(bboxes, offsets):
function get_image_boxes (line 127) | def get_image_boxes(bounding_boxes, img, size=24):
function correct_bboxes (line 162) | def correct_bboxes(bboxes, width, height):
function _preprocess (line 226) | def _preprocess(img):
FILE: upsampler/models/mtcnn/mtcnn_pytorch/src/detector.py
function detect_faces (line 9) | def detect_faces(image, min_face_size=20.0,
FILE: upsampler/models/mtcnn/mtcnn_pytorch/src/first_stage.py
function run_first_stage (line 12) | def run_first_stage(image, net, scale, threshold):
function _generate_bboxes (line 51) | def _generate_bboxes(probs, offsets, scale, threshold):
FILE: upsampler/models/mtcnn/mtcnn_pytorch/src/get_nets.py
class Flatten (line 13) | class Flatten(nn.Module):
method __init__ (line 15) | def __init__(self):
method forward (line 18) | def forward(self, x):
class PNet (line 32) | class PNet(nn.Module):
method __init__ (line 34) | def __init__(self):
method forward (line 63) | def forward(self, x):
class RNet (line 78) | class RNet(nn.Module):
method __init__ (line 80) | def __init__(self):
method forward (line 107) | def forward(self, x):
class ONet (line 122) | class ONet(nn.Module):
method __init__ (line 124) | def __init__(self):
method forward (line 157) | def forward(self, x):
FILE: upsampler/models/mtcnn/mtcnn_pytorch/src/matlab_cp2tform.py
class MatlabCp2tormException (line 13) | class MatlabCp2tormException(Exception):
method __str__ (line 14) | def __str__(self):
function tformfwd (line 19) | def tformfwd(trans, uv):
function tforminv (line 45) | def tforminv(trans, uv):
function findNonreflectiveSimilarity (line 68) | def findNonreflectiveSimilarity(uv, xy, options=None):
function findSimilarity (line 119) | def findSimilarity(uv, xy, options=None):
function get_similarity_transform (line 159) | def get_similarity_transform(src_pts, dst_pts, reflective=True):
function cvt_tform_mat_for_cv2 (line 199) | def cvt_tform_mat_for_cv2(trans):
function get_similarity_transform_for_cv2 (line 227) | def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
FILE: upsampler/models/mtcnn/mtcnn_pytorch/src/visualization_utils.py
function show_bboxes (line 4) | def show_bboxes(img, bounding_boxes, facial_landmarks=[]):
FILE: upsampler/models/psp.py
function get_keys (line 15) | def get_keys(d, name):
class pSp (line 22) | class pSp(nn.Module):
method __init__ (line 24) | def __init__(self, opts, ckpt=None):
method set_encoder (line 36) | def set_encoder(self):
method load_weights (line 47) | def load_weights(self, ckpt=None):
method forward (line 84) | def forward(self, x1, x2=None, resize=True, latent_mask=None, randomiz...
method set_opts (line 139) | def set_opts(self, opts):
method __load_latent_avg (line 142) | def __load_latent_avg(self, ckpt, repeat=None):
FILE: upsampler/models/stylegan2/lpips/__init__.py
class PerceptualLoss (line 14) | class PerceptualLoss(torch.nn.Module):
method __init__ (line 15) | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spat...
method forward (line 27) | def forward(self, pred, target, normalize=False):
function normalize_tensor (line 43) | def normalize_tensor(in_feat,eps=1e-10):
function l2 (line 47) | def l2(p0, p1, range=255.):
function psnr (line 50) | def psnr(p0, p1, peak=255.):
function dssim (line 53) | def dssim(p0, p1, range=255.):
function rgb2lab (line 56) | def rgb2lab(in_img,mean_cent=False):
function tensor2np (line 63) | def tensor2np(tensor_obj):
function np2tensor (line 67) | def np2tensor(np_obj):
function tensor2tensorlab (line 71) | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
function tensorlab2tensor (line 85) | def tensorlab2tensor(lab_tensor,return_inbnd=False):
function rgb2lab (line 103) | def rgb2lab(input):
function tensor2im (line 107) | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
function im2tensor (line 112) | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
function tensor2vec (line 116) | def tensor2vec(vector_tensor):
function voc_ap (line 119) | def voc_ap(rec, prec, use_07_metric=False):
function tensor2im (line 152) | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
function im2tensor (line 158) | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
FILE: upsampler/models/stylegan2/lpips/base_model.py
class BaseModel (line 8) | class BaseModel():
method __init__ (line 9) | def __init__(self):
method name (line 12) | def name(self):
method initialize (line 15) | def initialize(self, use_gpu=True, gpu_ids=[0]):
method forward (line 19) | def forward(self):
method get_image_paths (line 22) | def get_image_paths(self):
method optimize_parameters (line 25) | def optimize_parameters(self):
method get_current_visuals (line 28) | def get_current_visuals(self):
method get_current_errors (line 31) | def get_current_errors(self):
method save (line 34) | def save(self, label):
method save_network (line 38) | def save_network(self, network, path, network_label, epoch_label):
method load_network (line 44) | def load_network(self, network, network_label, epoch_label):
method update_learning_rate (line 50) | def update_learning_rate():
method get_image_paths (line 53) | def get_image_paths(self):
method save_done (line 56) | def save_done(self, flag=False):
FILE: upsampler/models/stylegan2/lpips/dist_model.py
class DistModel (line 24) | class DistModel(BaseModel):
method name (line 25) | def name(self):
method initialize (line 28) | def initialize(self, model='net-lin', net='alex', colorspace='Lab', pn...
method forward (line 109) | def forward(self, in0, in1, retPerLayer=False):
method optimize_parameters (line 120) | def optimize_parameters(self):
method clamp_weights (line 127) | def clamp_weights(self):
method set_input (line 132) | def set_input(self, data):
method forward_train (line 148) | def forward_train(self): # run forward pass
method backward_train (line 162) | def backward_train(self):
method compute_accuracy (line 165) | def compute_accuracy(self,d0,d1,judge):
method get_current_errors (line 171) | def get_current_errors(self):
method get_current_visuals (line 180) | def get_current_visuals(self):
method save (line 195) | def save(self, path, label):
method update_learning_rate (line 202) | def update_learning_rate(self,nepoch_decay):
function score_2afc_dataset (line 212) | def score_2afc_dataset(data_loader, func, name=''):
function score_jnd_dataset (line 247) | def score_jnd_dataset(data_loader, func, name=''):
FILE: upsampler/models/stylegan2/lpips/networks_basic.py
function spatial_average (line 17) | def spatial_average(in_tens, keepdim=True):
function upsample (line 20) | def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
class PNetLin (line 27) | class PNetLin(nn.Module):
method __init__ (line 28) | def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, ...
method forward (line 64) | def forward(self, in0, in1, retPerLayer=False):
class ScalingLayer (line 94) | class ScalingLayer(nn.Module):
method __init__ (line 95) | def __init__(self):
method forward (line 100) | def forward(self, inp):
class NetLinLayer (line 104) | class NetLinLayer(nn.Module):
method __init__ (line 106) | def __init__(self, chn_in, chn_out=1, use_dropout=False):
class Dist2LogitLayer (line 114) | class Dist2LogitLayer(nn.Module):
method __init__ (line 116) | def __init__(self, chn_mid=32, use_sigmoid=True):
method forward (line 128) | def forward(self,d0,d1,eps=0.1):
class BCERankingLoss (line 131) | class BCERankingLoss(nn.Module):
method __init__ (line 132) | def __init__(self, chn_mid=32):
method forward (line 138) | def forward(self, d0, d1, judge):
class FakeNet (line 144) | class FakeNet(nn.Module):
method __init__ (line 145) | def __init__(self, use_gpu=True, colorspace='Lab'):
class L2 (line 150) | class L2(FakeNet):
method forward (line 152) | def forward(self, in0, in1, retPerLayer=None):
class DSSIM (line 167) | class DSSIM(FakeNet):
method forward (line 169) | def forward(self, in0, in1, retPerLayer=None):
function print_network (line 182) | def print_network(net):
FILE: upsampler/models/stylegan2/lpips/pretrained_networks.py
class squeezenet (line 6) | class squeezenet(torch.nn.Module):
method __init__ (line 7) | def __init__(self, requires_grad=False, pretrained=True):
method forward (line 36) | def forward(self, X):
class alexnet (line 57) | class alexnet(torch.nn.Module):
method __init__ (line 58) | def __init__(self, requires_grad=False, pretrained=True):
method forward (line 81) | def forward(self, X):
class vgg16 (line 97) | class vgg16(torch.nn.Module):
method __init__ (line 98) | def __init__(self, requires_grad=False, pretrained=True):
method forward (line 121) | def forward(self, X):
class resnet (line 139) | class resnet(torch.nn.Module):
method __init__ (line 140) | def __init__(self, requires_grad=False, pretrained=True, num=18):
method forward (line 163) | def forward(self, X):
FILE: upsampler/models/stylegan2/model.py
class PixelNorm (line 11) | class PixelNorm(nn.Module):
method __init__ (line 12) | def __init__(self):
method forward (line 15) | def forward(self, input):
function make_kernel (line 19) | def make_kernel(k):
class Upsample (line 30) | class Upsample(nn.Module):
method __init__ (line 31) | def __init__(self, kernel, factor=2):
method forward (line 45) | def forward(self, input):
class Downsample (line 51) | class Downsample(nn.Module):
method __init__ (line 52) | def __init__(self, kernel, factor=2):
method forward (line 66) | def forward(self, input):
class Blur (line 72) | class Blur(nn.Module):
method __init__ (line 73) | def __init__(self, kernel, pad, upsample_factor=1):
method forward (line 85) | def forward(self, input):
class EqualConv2d (line 91) | class EqualConv2d(nn.Module):
method __init__ (line 92) | def __init__(
method forward (line 112) | def forward(self, input):
method __repr__ (line 124) | def __repr__(self):
class EqualLinear (line 131) | class EqualLinear(nn.Module):
method __init__ (line 132) | def __init__(
method forward (line 150) | def forward(self, input):
method __repr__ (line 162) | def __repr__(self):
class ScaledLeakyReLU (line 168) | class ScaledLeakyReLU(nn.Module):
method __init__ (line 169) | def __init__(self, negative_slope=0.2):
method forward (line 174) | def forward(self, input):
class ModulatedConv2d (line 180) | class ModulatedConv2d(nn.Module):
method __init__ (line 181) | def __init__(
method __repr__ (line 243) | def __repr__(self):
method forward (line 249) | def forward(self, input, style):
class NoiseInjection (line 305) | class NoiseInjection(nn.Module):
method __init__ (line 306) | def __init__(self):
method forward (line 311) | def forward(self, image, noise=None):
class ConstantInput (line 324) | class ConstantInput(nn.Module):
method __init__ (line 325) | def __init__(self, channel, size=4):
method forward (line 330) | def forward(self, input):
class StyledConv (line 337) | class StyledConv(nn.Module):
method __init__ (line 338) | def __init__(
method forward (line 365) | def forward(self, input, style, noise=None):
class ToRGB (line 373) | class ToRGB(nn.Module):
method __init__ (line 374) | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[...
method forward (line 394) | def forward(self, input, style, skip=None):
class Generator (line 411) | class Generator(nn.Module):
method __init__ (line 412) | def __init__(
method make_noise (line 498) | def make_noise(self):
method mean_latent (line 509) | def mean_latent(self, n_latent):
method get_latent (line 517) | def get_latent(self, input):
method forward (line 527) | def forward(
class ConvLayer (line 633) | class ConvLayer(nn.Sequential):
method __init__ (line 634) | def __init__(
class ResBlock (line 684) | class ResBlock(nn.Module):
method __init__ (line 685) | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
method forward (line 695) | def forward(self, input):
class Discriminator (line 705) | class Discriminator(nn.Module):
method __init__ (line 706) | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1...
method forward (line 747) | def forward(self, input):
FILE: upsampler/models/stylegan2/op/conv2d_gradfix.py
function no_weight_gradients (line 13) | def no_weight_gradients():
function conv2d (line 22) | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, gr...
function conv_transpose2d (line 45) | def conv_transpose2d(
function could_use_op (line 78) | def could_use_op(input):
function ensure_tuple (line 95) | def ensure_tuple(xs, ndim):
function conv2d_gradfix (line 104) | def conv2d_gradfix(
FILE: upsampler/models/stylegan2/op/fused_act.py
class FusedLeakyReLU (line 6) | class FusedLeakyReLU(nn.Module):
method __init__ (line 7) | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** ...
method forward (line 19) | def forward(self, inputs):
function fused_leaky_relu (line 23) | def fused_leaky_relu(inputs, bias=None, negative_slope=0.2, scale=2 ** 0...
FILE: upsampler/models/stylegan2/op/upfirdn2d.py
function upfirdn2d (line 7) | def upfirdn2d(inputs, kernel, up=1, down=1, pad=(0, 0)):
function upfirdn2d_native (line 20) | def upfirdn2d_native(
FILE: upsampler/models/stylegan2/op2/upfirdn2d.cpp
function upfirdn2d (line 17) | torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor ...
function PYBIND11_MODULE (line 29) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: upsampler/models/stylegan2/op2/upfirdn2d.py
class UpFirDn2dBackward (line 20) | class UpFirDn2dBackward(Function):
method forward (line 22) | def forward(
method backward (line 64) | def backward(ctx, gradgrad_input):
class UpFirDn2d (line 89) | class UpFirDn2d(Function):
method forward (line 91) | def forward(ctx, input, kernel, up, down, pad):
method backward (line 128) | def backward(ctx, grad_output):
function upfirdn2d (line 149) | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
function upfirdn2d_native (line 168) | def upfirdn2d_native(
FILE: upsampler/models/stylegan2/op_old/fused_act.py
class FusedLeakyReLUFunctionBackward (line 18) | class FusedLeakyReLUFunctionBackward(Function):
method forward (line 20) | def forward(ctx, grad_output, out, negative_slope, scale):
method backward (line 41) | def backward(ctx, gradgrad_input, gradgrad_bias):
class FusedLeakyReLUFunction (line 50) | class FusedLeakyReLUFunction(Function):
method forward (line 52) | def forward(ctx, input, bias, negative_slope, scale):
method backward (line 62) | def backward(ctx, grad_output):
class FusedLeakyReLU (line 72) | class FusedLeakyReLU(nn.Module):
method __init__ (line 73) | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
method forward (line 80) | def forward(self, input):
function fused_leaky_relu (line 84) | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
FILE: upsampler/models/stylegan2/op_old/fused_bias_act.cpp
function fused_bias_act (line 11) | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Te...
function PYBIND11_MODULE (line 19) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: upsampler/models/stylegan2/op_old/upfirdn2d.cpp
function upfirdn2d (line 12) | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor&...
function PYBIND11_MODULE (line 21) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: upsampler/models/stylegan2/op_old/upfirdn2d.py
class UpFirDn2dBackward (line 17) | class UpFirDn2dBackward(Function):
method forward (line 19) | def forward(
method backward (line 60) | def backward(ctx, gradgrad_input):
class UpFirDn2d (line 85) | class UpFirDn2d(Function):
method forward (line 87) | def forward(ctx, input, kernel, up, down, pad):
method backward (line 124) | def backward(ctx, grad_output):
function upfirdn2d (line 142) | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
function upfirdn2d_native (line 150) | def upfirdn2d_native(
FILE: upsampler/models/stylegan2/simple_augment.py
function reduce_sum (line 12) | def reduce_sum(tensor):
class AdaptiveAugment (line 25) | class AdaptiveAugment:
method __init__ (line 26) | def __init__(self, ada_aug_target, ada_aug_len, update_every, device):
method tune (line 37) | def tune(self, real_pred):
function translate_mat (line 80) | def translate_mat(t_x, t_y, device="cpu"):
function rotate_mat (line 90) | def rotate_mat(theta, device="cpu"):
function scale_mat (line 102) | def scale_mat(s_x, s_y, device="cpu"):
function translate3d_mat (line 112) | def translate3d_mat(t_x, t_y, t_z):
function rotate3d_mat (line 122) | def rotate3d_mat(axis, theta):
function scale3d_mat (line 143) | def scale3d_mat(s_x, s_y, s_z):
function luma_flip_mat (line 154) | def luma_flip_mat(axis, i):
function saturation_mat (line 164) | def saturation_mat(axis, i):
function lognormal_sample (line 175) | def lognormal_sample(size, mean=0, std=1, device="cpu"):
function category_sample (line 179) | def category_sample(size, categories, device="cpu"):
function uniform_sample (line 186) | def uniform_sample(size, low, high, device="cpu"):
function normal_sample (line 190) | def normal_sample(size, mean=0, std=1, device="cpu"):
function bernoulli_sample (line 194) | def bernoulli_sample(size, p, device="cpu"):
function random_mat_apply (line 198) | def random_mat_apply(p, transform, prev, eye, device="cpu"):
function sample_affine (line 206) | def sample_affine(p, size, height, width, device="cpu"):
function sample_color (line 265) | def sample_color(p, size):
function make_grid (line 299) | def make_grid(shape, x0, x1, y0, y1, device):
function affine_grid (line 309) | def affine_grid(grid, mat):
function get_padding (line 314) | def get_padding(G, height, width, kernel_size):
function try_sample_affine_and_pad (line 337) | def try_sample_affine_and_pad(img, p, kernel_size, G=None):
class GridSampleForward (line 352) | class GridSampleForward(autograd.Function):
method forward (line 354) | def forward(ctx, input, grid):
method backward (line 363) | def backward(ctx, grad_output):
class GridSampleBackward (line 370) | class GridSampleBackward(autograd.Function):
method forward (line 372) | def forward(ctx, grad_output, input, grid):
method backward (line 380) | def backward(ctx, grad_grad_input, grad_grad_grid):
function scale_mat_single (line 393) | def scale_mat_single(s_x, s_y):
function translate_mat_single (line 397) | def translate_mat_single(t_x, t_y):
function random_apply_affine (line 401) | def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6):
function apply_color (line 454) | def apply_color(img, mat):
function random_apply_color (line 465) | def random_apply_color(img, p, C=None):
function augment (line 474) | def augment(img, p, transform_matrix=(None, None)):
FILE: upsampler/options/test_options.py
class TestOptions (line 4) | class TestOptions:
method __init__ (line 6) | def __init__(self):
method initialize (line 10) | def initialize(self):
method parse (line 31) | def parse(self):
FILE: upsampler/options/train_options.py
class TrainOptions (line 5) | class TrainOptions:
method __init__ (line 7) | def __init__(self):
method initialize (line 11) | def initialize(self):
method parse (line 79) | def parse(self):
FILE: upsampler/scripts/align_all_parallel.py
function get_landmark (line 32) | def get_landmark(filepath, predictor):
function align_face (line 59) | def align_face(filepath, predictor):
function chunks (line 153) | def chunks(lst, n):
function extract_on_paths (line 159) | def extract_on_paths(file_paths):
function parse_args (line 179) | def parse_args():
function run (line 187) | def run(args):
FILE: upsampler/scripts/calc_id_loss_parallel.py
function chunks (line 22) | def chunks(lst, n):
function extract_on_paths (line 28) | def extract_on_paths(file_paths):
function parse_args (line 71) | def parse_args():
function run (line 80) | def run(args):
FILE: upsampler/scripts/calc_losses_on_images.py
function parse_args (line 18) | def parse_args():
function run (line 29) | def run(args):
FILE: upsampler/scripts/download_ffhq1280.py
function download_file (line 59) | def download_file(session, file_spec, stats, chunk_size=128, num_attempt...
function choose_bytes_unit (line 135) | def choose_bytes_unit(num_bytes):
function format_time (line 145) | def format_time(seconds):
function download_files (line 155) | def download_files(file_specs, num_threads=32, status_delay=0.2, timing_...
function _download_thread (line 209) | def _download_thread(spec_queue, exception_queue, stats, download_kwargs):
function print_statistics (line 220) | def print_statistics(json_data):
function find_coeffs (line 259) | def find_coeffs(pa, pb):
function recreate_aligned_images (line 272) | def recreate_aligned_images(json_data, source_dir, dst_dir='realign1280x...
function run (line 405) | def run(tasks, **download_kwargs):
function run_cmdline (line 437) | def run_cmdline(argv):
FILE: upsampler/scripts/generate_sketch_data.py
function sobel (line 15) | def sobel(img):
function sketch (line 21) | def sketch(frame):
function get_sketch_image (line 31) | def get_sketch_image(image_path):
FILE: upsampler/scripts/inference.py
function run (line 22) | def run():
function run_on_batch (line 112) | def run_on_batch(inputs, net, opts):
FILE: upsampler/scripts/pretrain.py
function requires_grad (line 20) | def requires_grad(model, flag=True):
class TrainOptions (line 25) | class TrainOptions():
method __init__ (line 26) | def __init__(self):
method parse (line 36) | def parse(self):
FILE: upsampler/scripts/style_mixing.py
function run (line 21) | def run():
FILE: upsampler/scripts/train.py
function main (line 16) | def main():
FILE: upsampler/training/coach.py
class Coach (line 25) | class Coach:
method __init__ (line 26) | def __init__(self, opts):
method train (line 111) | def train(self):
method validate (line 260) | def validate(self):
method checkpoint_me (line 374) | def checkpoint_me(self, loss_dict, is_best):
method configure_optimizers (line 387) | def configure_optimizers(self):
method configure_datasets (line 400) | def configure_datasets(self):
method calc_loss (line 435) | def calc_loss(self, x, y, y_hat, latent, y0_hat=None):
method log_metrics (line 480) | def log_metrics(self, metrics_dict, prefix):
method print_metrics (line 486) | def print_metrics(self, metrics_dict, prefix):
method parse_and_log_images (line 491) | def parse_and_log_images(self, id_logs, x, y, y_hat, title, subscript=...
method log_images (line 505) | def log_images(self, name, im_data, subscript=None, log_latest=False):
method __get_save_dict (line 518) | def __get_save_dict(self):
method discriminator_loss (line 534) | def discriminator_loss(real_pred, fake_pred, loss_dict):
method discriminator_r1_loss (line 544) | def discriminator_r1_loss(real_pred, real_w):
method requires_grad (line 553) | def requires_grad(model, flag=True):
method train_discriminator (line 557) | def train_discriminator(self, real_img, fake_img):
method validate_discriminator (line 590) | def validate_discriminator(self, real_img, fake_img):
FILE: upsampler/training/ranger.py
class Ranger (line 29) | class Ranger(Optimizer):
method __init__ (line 31) | def __init__(self, params, lr=1e-3, # lr
method __setstate__ (line 75) | def __setstate__(self, state):
method step (line 78) | def step(self, closure=None):
FILE: upsampler/utils/common.py
function log_input_image (line 8) | def log_input_image(x, opts):
function tensor2im (line 17) | def tensor2im(var):
function tensor2map (line 26) | def tensor2map(var):
function tensor2sketch (line 36) | def tensor2sketch(var):
function get_colors (line 44) | def get_colors():
function vis_faces (line 52) | def vis_faces(log_hooks):
function vis_faces_with_id (line 67) | def vis_faces_with_id(hooks_dict, fig, gs, i):
function vis_faces_no_id (line 79) | def vis_faces_no_id(hooks_dict, fig, gs, i):
FILE: upsampler/utils/data_utils.py
function is_image_file (line 13) | def is_image_file(filename):
function make_dataset (line 17) | def make_dataset(dir):
FILE: upsampler/utils/inference_utils.py
function visualize (line 16) | def visualize(img_arr, dpi):
function save_image (line 22) | def save_image(img, filename):
function load_image (line 26) | def load_image(filename):
function get_video_crop_parameter (line 36) | def get_video_crop_parameter(filepath, predictor, padding=[256,256,256,2...
function tensor2cv2 (line 63) | def tensor2cv2(img):
function noise_regularize (line 67) | def noise_regularize(noises):
function noise_normalize_ (line 91) | def noise_normalize_(noises):
function get_lr (line 99) | def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
function latent_noise (line 107) | def latent_noise(latent, strength):
function make_image (line 113) | def make_image(tensor):
function tensor2label (line 129) | def tensor2label(label_tensor, n_label, imtype=np.uint8):
function uint82bin (line 139) | def uint82bin(n, count=8):
function labelcolormap (line 143) | def labelcolormap(N):
class Colorize (line 167) | class Colorize(object):
method __init__ (line 168) | def __init__(self, n=35):
method __call__ (line 172) | def __call__(self, gray_image):
FILE: upsampler/utils/train_utils.py
function aggregate_loss_dict (line 2) | def aggregate_loss_dict(agg_loss_dict):
FILE: upsampler/utils/wandb_utils.py
class WBLogger (line 9) | class WBLogger:
method __init__ (line 11) | def __init__(self, opts):
method log_best_model (line 16) | def log_best_model():
method log (line 20) | def log(prefix, metrics_dict, global_step):
method log_dataset_wandb (line 26) | def log_dataset_wandb(dataset, dataset_name, n_images=16):
method log_images_to_wandb (line 32) | def log_images_to_wandb(x, y, y_hat, id_logs, prefix, step, opts):
FILE: upsampler/video_editing.py
class TestOptions (line 21) | class TestOptions():
method __init__ (line 22) | def __init__(self):
method parse (line 31) | def parse(self):
FILE: upsampler/webUI/app_task.py
function create_demo_sr (line 7) | def create_demo_sr(process):
function create_demo_s2f (line 46) | def create_demo_s2f(process):
function create_demo_m2f (line 78) | def create_demo_m2f(process):
function create_demo_editing (line 113) | def create_demo_editing(process):
function create_demo_toonify (line 145) | def create_demo_toonify(process):
function create_demo_vediting (line 172) | def create_demo_vediting(process, max_frame_num = 4):
function create_demo_vtoonify (line 214) | def create_demo_vtoonify(process, max_frame_num = 4):
function create_demo_inversion (line 252) | def create_demo_inversion(process, allow_optimization=False):
FILE: upsampler/webUI/styleganex_model.py
class Model (line 27) | class Model():
method __init__ (line 28) | def __init__(self, device):
method load_model (line 67) | def load_model(self, task_name: str) -> None:
method load_G_model (line 89) | def load_G_model(self, model_type: str) -> None:
method tensor2np (line 101) | def tensor2np(self, img):
method process_sr (line 105) | def process_sr(self, input_image: str, resize_scale: int, model: str) ...
method process_s2f (line 149) | def process_s2f(self, input_image: str, seed: int) -> np.ndarray:
method process_m2f (line 169) | def process_m2f(self, input_image: str, input_type: str, seed: int) ->...
method process_editing (line 221) | def process_editing(self, input_image: str, scale_factor: float, model...
method process_vediting (line 262) | def process_vediting(self, input_video: str, scale_factor: float, mode...
method process_toonify (line 321) | def process_toonify(self, input_image: str, style_type: str) -> np.nda...
method process_vtoonify (line 364) | def process_vtoonify(self, input_video: str, style_type: str, frame_nu...
method process_inversion (line 424) | def process_inversion(self, input_image: str, optimize: str, input_lat...
Condensed preview — 126 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (5,021K chars).
[
{
"path": ".gitignore",
"chars": 44,
"preview": "checkpoints/mix-train.pth.tar\nresults_hq.mp4"
},
{
"path": "README.md",
"chars": 3456,
"preview": "# Adaptive Super Resolution For One-Shot Talking-Head Generation\nThe repository for ICASSP2024 Adaptive Super Resolution"
},
{
"path": "animate.py",
"chars": 1252,
"preview": "import os\nfrom tqdm import tqdm\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom logger import Logger, Visual"
},
{
"path": "augmentation.py",
"chars": 12547,
"preview": "\"\"\"\nCode from https://github.com/hassony2/torch_videovision\n\"\"\"\n\nimport numbers\n\nimport random\nimport numpy as np\nimport"
},
{
"path": "config/mix-resolution.yml",
"chars": 2085,
"preview": "dataset_params:\n root_dir: ../../../train/cropped_clips_512_vid/\n frame_shape: [512, 512, 3]\n id_sampling: True\n pai"
},
{
"path": "demo.py",
"chars": 14212,
"preview": "# python demo.py --config config/vox-256-spade.yml --checkpoint checkpoints/00000189-checkpoint.pth.tar --source_image "
},
{
"path": "environment.yaml",
"chars": 3002,
"preview": "name: mesh-video\nchannels:\n - pytorch\n - conda-forge\n - defaults\ndependencies:\n - _libgcc_mutex=0.1=main\n - _openmp"
},
{
"path": "frames_dataset.py",
"chars": 10851,
"preview": "#CUDA_VISIBLE_DEVICES=1 python run.py --config log_TH1K/finetune-th1k-spade.yml --device_ids 0 --checkpoint log_TH1K/000"
},
{
"path": "logger.py",
"chars": 7984,
"preview": "import numpy as np\nimport torch\nimport torch.nn.functional as F\nimport imageio\n\nimport os\nfrom skimage.draw import circl"
},
{
"path": "modules/dense_motion.py",
"chars": 6247,
"preview": "from torch import nn\nimport torch.nn.functional as F\nimport torch\nfrom modules.util import Hourglass, make_coordinate_gr"
},
{
"path": "modules/discriminator.py",
"chars": 2861,
"preview": "from torch import nn\nimport torch.nn.functional as F\nfrom modules.util import kp2gaussian\nimport torch\n\n\nclass DownBlock"
},
{
"path": "modules/generator.py",
"chars": 12576,
"preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom modules.util import ResBlock2d, SameBlock2d, UpBl"
},
{
"path": "modules/hopenet.py",
"chars": 6503,
"preview": "import torch\nimport torch.nn as nn\nfrom torch.autograd import Variable\nimport math\nimport torch.nn.functional as F\n\nclas"
},
{
"path": "modules/keypoint_detector.py",
"chars": 6656,
"preview": "from torch import nn\nimport torch\nimport torch.nn.functional as F\n\nfrom sync_batchnorm import SynchronizedBatchNorm2d as"
},
{
"path": "modules/model.py",
"chars": 26035,
"preview": "from torch import nn\nimport torch\nimport torch.nn.functional as F\nfrom modules.util import AntiAliasInterpolation2d, mak"
},
{
"path": "modules/util.py",
"chars": 16726,
"preview": "from torch import nn\n\nimport torch.nn.functional as F\nimport torch\n\nfrom sync_batchnorm import SynchronizedBatchNorm2d a"
},
{
"path": "run_demo.sh",
"chars": 207,
"preview": "python demo.py \\\n --config config/mix-resolution.yml \\\n --checkpoint checkpoints/mix-train.pth.tar \\\n --source_"
},
{
"path": "sync_batchnorm/__init__.py",
"chars": 449,
"preview": "# -*- coding: utf-8 -*-\n# File : __init__.py\n# Author : Jiayuan Mao\n# Email : maojiayuan@gmail.com\n# Date : 27/01/2"
},
{
"path": "sync_batchnorm/batchnorm.py",
"chars": 12973,
"preview": "# -*- coding: utf-8 -*-\n# File : batchnorm.py\n# Author : Jiayuan Mao\n# Email : maojiayuan@gmail.com\n# Date : 27/01/"
},
{
"path": "sync_batchnorm/comm.py",
"chars": 4449,
"preview": "# -*- coding: utf-8 -*-\n# File : comm.py\n# Author : Jiayuan Mao\n# Email : maojiayuan@gmail.com\n# Date : 27/01/2018\n"
},
{
"path": "sync_batchnorm/replicate.py",
"chars": 3226,
"preview": "# -*- coding: utf-8 -*-\n# File : replicate.py\n# Author : Jiayuan Mao\n# Email : maojiayuan@gmail.com\n# Date : 27/01/"
},
{
"path": "sync_batchnorm/unittest.py",
"chars": 835,
"preview": "# -*- coding: utf-8 -*-\n# File : unittest.py\n# Author : Jiayuan Mao\n# Email : maojiayuan@gmail.com\n# Date : 27/01/2"
},
{
"path": "upsampler/app_gradio.py",
"chars": 4808,
"preview": "from __future__ import annotations\r\n\r\nimport argparse\r\nimport pathlib\r\nimport torch\r\nimport gradio as gr\r\n\r\nfrom webUI.a"
},
{
"path": "upsampler/configs/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "upsampler/configs/data_configs.py",
"chars": 2083,
"preview": "from configs import transforms_config\r\nfrom configs.paths_config import dataset_paths\r\n\r\n\r\nDATASETS = {\r\n 'ffhq_encod"
},
{
"path": "upsampler/configs/dataset_config.yml",
"chars": 1510,
"preview": "# dataset and data loader settings\ndatasets:\n train:\n name: FFHQ\n type: FFHQDegradationDataset\n # dataroot_gt:"
},
{
"path": "upsampler/configs/paths_config.py",
"chars": 1234,
"preview": "dataset_paths = {\r\n 'ffhq': 'data/train/ffhq/realign320x320/',\r\n 'ffhq_test': 'data/train/ffhq/realign320x320test/"
},
{
"path": "upsampler/configs/transforms_config.py",
"chars": 9754,
"preview": "from abc import abstractmethod\nimport torchvision.transforms as transforms\nfrom datasets import augmentations\n\n\nclass Tr"
},
{
"path": "upsampler/criteria/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "upsampler/criteria/id_loss.py",
"chars": 1705,
"preview": "import torch\r\nfrom torch import nn\r\nfrom configs.paths_config import model_paths\r\nfrom models.encoders.model_irse import"
},
{
"path": "upsampler/criteria/lpips/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "upsampler/criteria/lpips/lpips.py",
"chars": 1238,
"preview": "import torch\r\nimport torch.nn as nn\r\n\r\nfrom criteria.lpips.networks import get_network, LinLayers\r\nfrom criteria.lpips.u"
},
{
"path": "upsampler/criteria/lpips/networks.py",
"chars": 2762,
"preview": "from typing import Sequence\r\n\r\nfrom itertools import chain\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nfrom torchvision impo"
},
{
"path": "upsampler/criteria/lpips/utils.py",
"chars": 915,
"preview": "from collections import OrderedDict\r\n\r\nimport torch\r\n\r\n\r\ndef normalize_activation(x, eps=1e-10):\r\n norm_factor = torc"
},
{
"path": "upsampler/criteria/moco_loss.py",
"chars": 2638,
"preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom configs.paths_config import model_paths\n\n\nclass M"
},
{
"path": "upsampler/criteria/w_norm.py",
"chars": 379,
"preview": "import torch\nfrom torch import nn\n\n\nclass WNormLoss(nn.Module):\n\n\tdef __init__(self, start_from_latent_avg=True):\n\t\tsupe"
},
{
"path": "upsampler/datasets/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "upsampler/datasets/augmentations.py",
"chars": 3661,
"preview": "import numpy as np\r\nimport torch\r\nfrom torch import nn\r\nfrom torch.nn import functional as F\r\nfrom torchvision import tr"
},
{
"path": "upsampler/datasets/ffhq_degradation_dataset.py",
"chars": 10675,
"preview": "import cv2\nimport math\nimport numpy as np\nimport os.path as osp\nimport torch\nimport torch.utils.data as data\nfrom basics"
},
{
"path": "upsampler/datasets/gt_res_dataset.py",
"chars": 910,
"preview": "#!/usr/bin/python\r\n# encoding: utf-8\r\nimport os\r\nfrom torch.utils.data import Dataset\r\nfrom PIL import Image\r\n\r\n\r\nclass "
},
{
"path": "upsampler/datasets/images_dataset.py",
"chars": 1014,
"preview": "from torch.utils.data import Dataset\r\nfrom PIL import Image\r\nfrom utils import data_utils\r\n\r\n\r\nclass ImagesDataset(Datas"
},
{
"path": "upsampler/datasets/inference_dataset.py",
"chars": 581,
"preview": "from torch.utils.data import Dataset\nfrom PIL import Image\nfrom utils import data_utils\n\n\nclass InferenceDataset(Dataset"
},
{
"path": "upsampler/image_translation.py",
"chars": 8033,
"preview": "import os\r\n#os.environ['CUDA_VISIBLE_DEVICES'] = \"0\"\r\n\r\nfrom models.psp import pSp\r\nimport torch\r\nimport dlib\r\nimport cv"
},
{
"path": "upsampler/inference_playground.ipynb",
"chars": 4349424,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"Jpeb3w3R1Bxx\"\n },\n \"sou"
},
{
"path": "upsampler/inversion.py",
"chars": 4500,
"preview": "import os\r\n#os.environ['CUDA_VISIBLE_DEVICES'] = \"0\"\r\n\r\nfrom models.psp import pSp\r\nimport torch\r\nimport dlib\r\nimport cv"
},
{
"path": "upsampler/latent_optimization.py",
"chars": 3788,
"preview": "import models.stylegan2.lpips as lpips\r\nfrom torch import autograd, optim\r\nfrom torchvision import transforms, utils\r\nfr"
},
{
"path": "upsampler/models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "upsampler/models/bisenet/LICENSE",
"chars": 1060,
"preview": "MIT License\n\nCopyright (c) 2019 zll\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof thi"
},
{
"path": "upsampler/models/bisenet/README.md",
"chars": 1686,
"preview": "# face-parsing.PyTorch\n\n<p align=\"center\">\n\t<a href=\"https://github.com/zllrunning/face-parsing.PyTorch\">\n <img class"
},
{
"path": "upsampler/models/bisenet/model.py",
"chars": 10608,
"preview": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\n\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport "
},
{
"path": "upsampler/models/bisenet/resnet.py",
"chars": 3648,
"preview": "#!/usr/bin/python\n# -*- encoding: utf-8 -*-\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport t"
},
{
"path": "upsampler/models/encoders/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "upsampler/models/encoders/helpers.py",
"chars": 3556,
"preview": "from collections import namedtuple\nimport torch\nfrom torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2"
},
{
"path": "upsampler/models/encoders/model_irse.py",
"chars": 2920,
"preview": "from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module\r\nfrom models.encoders."
},
{
"path": "upsampler/models/encoders/psp_encoders.py",
"chars": 15242,
"preview": "import numpy as np\r\nimport torch\r\nimport torch.nn.functional as F\r\nfrom torch import nn\r\nfrom torch.nn import Linear, Co"
},
{
"path": "upsampler/models/mtcnn/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "upsampler/models/mtcnn/mtcnn.py",
"chars": 6376,
"preview": "import numpy as np\r\nimport torch\r\nfrom PIL import Image\r\nfrom models.mtcnn.mtcnn_pytorch.src.get_nets import PNet, RNet,"
},
{
"path": "upsampler/models/mtcnn/mtcnn_pytorch/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "upsampler/models/mtcnn/mtcnn_pytorch/src/__init__.py",
"chars": 82,
"preview": "from .visualization_utils import show_bboxes\r\nfrom .detector import detect_faces\r\n"
},
{
"path": "upsampler/models/mtcnn/mtcnn_pytorch/src/align_trans.py",
"chars": 11340,
"preview": "# -*- coding: utf-8 -*-\r\n\"\"\"\r\nCreated on Mon Apr 24 15:43:29 2017\r\n@author: zhaoy\r\n\"\"\"\r\nimport numpy as np\r\nimport cv2\r\n"
},
{
"path": "upsampler/models/mtcnn/mtcnn_pytorch/src/box_utils.py",
"chars": 7174,
"preview": "import numpy as np\r\nfrom PIL import Image\r\n\r\n\r\ndef nms(boxes, overlap_threshold=0.5, mode='union'):\r\n \"\"\"Non-maximum "
},
{
"path": "upsampler/models/mtcnn/mtcnn_pytorch/src/detector.py",
"chars": 4495,
"preview": "import numpy as np\r\nimport torch\r\nfrom torch.autograd import Variable\r\nfrom .get_nets import PNet, RNet, ONet\r\nfrom .box"
},
{
"path": "upsampler/models/mtcnn/mtcnn_pytorch/src/first_stage.py",
"chars": 3284,
"preview": "import torch\r\nfrom torch.autograd import Variable\r\nimport math\r\nfrom PIL import Image\r\nimport numpy as np\r\nfrom .box_uti"
},
{
"path": "upsampler/models/mtcnn/mtcnn_pytorch/src/get_nets.py",
"chars": 5166,
"preview": "import torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nfrom collections import OrderedDict\r\nimport numpy "
},
{
"path": "upsampler/models/mtcnn/mtcnn_pytorch/src/matlab_cp2tform.py",
"chars": 8912,
"preview": "# -*- coding: utf-8 -*-\r\n\"\"\"\r\nCreated on Tue Jul 11 06:54:28 2017\r\n\r\n@author: zhaoyafei\r\n\"\"\"\r\n\r\nimport numpy as np\r\nfrom"
},
{
"path": "upsampler/models/mtcnn/mtcnn_pytorch/src/visualization_utils.py",
"chars": 817,
"preview": "from PIL import ImageDraw\r\n\r\n\r\ndef show_bboxes(img, bounding_boxes, facial_landmarks=[]):\r\n \"\"\"Draw bounding boxes an"
},
{
"path": "upsampler/models/psp.py",
"chars": 7254,
"preview": "\"\"\"\r\nThis file defines the core research contribution\r\n\"\"\"\r\nimport matplotlib\r\nmatplotlib.use('Agg')\r\nimport math\r\n\r\nimp"
},
{
"path": "upsampler/models/stylegan2/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "upsampler/models/stylegan2/lpips/__init__.py",
"chars": 5804,
"preview": "\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport nu"
},
{
"path": "upsampler/models/stylegan2/lpips/base_model.py",
"chars": 1618,
"preview": "import os\nimport numpy as np\nimport torch\nfrom torch.autograd import Variable\nfrom pdb import set_trace as st\nfrom IPyth"
},
{
"path": "upsampler/models/stylegan2/lpips/dist_model.py",
"chars": 11833,
"preview": "\nfrom __future__ import absolute_import\n\nimport sys\nimport numpy as np\nimport torch\nfrom torch import nn\nimport os\nfrom "
},
{
"path": "upsampler/models/stylegan2/lpips/networks_basic.py",
"chars": 7521,
"preview": "\nfrom __future__ import absolute_import\n\nimport sys\nimport torch\nimport torch.nn as nn\nimport torch.nn.init as init\nfrom"
},
{
"path": "upsampler/models/stylegan2/lpips/pretrained_networks.py",
"chars": 6533,
"preview": "from collections import namedtuple\nimport torch\nfrom torchvision import models as tv\nfrom IPython import embed\n\nclass sq"
},
{
"path": "upsampler/models/stylegan2/model.py",
"chars": 24941,
"preview": "import math\r\nimport random\r\nimport torch\r\nfrom torch import nn\r\nfrom torch.nn import functional as F\r\nimport numpy as np"
},
{
"path": "upsampler/models/stylegan2/op/__init__.py",
"chars": 89,
"preview": "from .fused_act import FusedLeakyReLU, fused_leaky_relu\nfrom .upfirdn2d import upfirdn2d\n"
},
{
"path": "upsampler/models/stylegan2/op/conv2d_gradfix.py",
"chars": 6646,
"preview": "import contextlib\r\nimport warnings\r\n\r\nimport torch\r\nfrom torch import autograd\r\nfrom torch.nn import functional as F\r\n\r\n"
},
{
"path": "upsampler/models/stylegan2/op/fused_act.py",
"chars": 982,
"preview": "import torch\r\nfrom torch import nn\r\nfrom torch.nn import functional as F\r\n\r\n\r\nclass FusedLeakyReLU(nn.Module):\r\n def "
},
{
"path": "upsampler/models/stylegan2/op/readme.md",
"chars": 516,
"preview": "Code from [rosinality-stylegan2-pytorch-cp](https://github.com/senior-sigan/rosinality-stylegan2-pytorch-cpu)\n\nScripts t"
},
{
"path": "upsampler/models/stylegan2/op/upfirdn2d.py",
"chars": 1857,
"preview": "from collections import abc\r\n\r\nimport torch\r\nfrom torch.nn import functional as F\r\n\r\n\r\ndef upfirdn2d(inputs, kernel, up="
},
{
"path": "upsampler/models/stylegan2/op2/__init__.py",
"chars": 33,
"preview": "from .upfirdn2d import upfirdn2d\n"
},
{
"path": "upsampler/models/stylegan2/op2/upfirdn2d.cpp",
"chars": 1343,
"preview": "#include <ATen/ATen.h>\r\n#include <torch/extension.h>\r\n\r\ntorch::Tensor upfirdn2d_op(const torch::Tensor &input,\r\n "
},
{
"path": "upsampler/models/stylegan2/op2/upfirdn2d.py",
"chars": 6134,
"preview": "from collections import abc\r\nimport os\r\n\r\nimport torch\r\nfrom torch.nn import functional as F\r\nfrom torch.autograd import"
},
{
"path": "upsampler/models/stylegan2/op2/upfirdn2d_kernel.cu",
"chars": 12070,
"preview": "// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.\r\n//\r\n// This work is made available under the Nvidia Sou"
},
{
"path": "upsampler/models/stylegan2/op_old/__init__.py",
"chars": 89,
"preview": "from .fused_act import FusedLeakyReLU, fused_leaky_relu\nfrom .upfirdn2d import upfirdn2d\n"
},
{
"path": "upsampler/models/stylegan2/op_old/fused_act.py",
"chars": 2463,
"preview": "import os\r\n\r\nimport torch\r\nfrom torch import nn\r\nfrom torch.autograd import Function\r\nfrom torch.utils.cpp_extension imp"
},
{
"path": "upsampler/models/stylegan2/op_old/fused_bias_act.cpp",
"chars": 826,
"preview": "#include <torch/extension.h>\n\n\ntorch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, co"
},
{
"path": "upsampler/models/stylegan2/op_old/fused_bias_act_kernel.cu",
"chars": 2777,
"preview": "// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.\n//\n// This work is made available under the Nvidia Sourc"
},
{
"path": "upsampler/models/stylegan2/op_old/upfirdn2d.cpp",
"chars": 966,
"preview": "#include <torch/extension.h>\n\n\ntorch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,\n "
},
{
"path": "upsampler/models/stylegan2/op_old/upfirdn2d.py",
"chars": 5202,
"preview": "import os\n\nimport torch\nfrom torch.autograd import Function\nfrom torch.utils.cpp_extension import load\n\nmodule_path = os"
},
{
"path": "upsampler/models/stylegan2/op_old/upfirdn2d_kernel.cu",
"chars": 8953,
"preview": "// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.\n//\n// This work is made available under the Nvidia Sourc"
},
{
"path": "upsampler/models/stylegan2/simple_augment.py",
"chars": 14346,
"preview": "import math\r\n\r\nimport torch\r\nfrom torch import autograd\r\nfrom torch.nn import functional as F\r\nimport numpy as np\r\n\r\nfro"
},
{
"path": "upsampler/options/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "upsampler/options/test_options.py",
"chars": 1830,
"preview": "from argparse import ArgumentParser\n\n\nclass TestOptions:\n\n\tdef __init__(self):\n\t\tself.parser = ArgumentParser()\n\t\tself.i"
},
{
"path": "upsampler/options/train_options.py",
"chars": 7222,
"preview": "from argparse import ArgumentParser\nfrom configs.paths_config import model_paths\n\n\nclass TrainOptions:\n\n def __init__"
},
{
"path": "upsampler/pretrained_models/readme.md",
"chars": 469,
"preview": "## pretained model for testing\n\nstyleganex_toonify_arcane.pt\n\nstyleganex_toonify_cartoon.pt \n\nstyleganex_toonify_pixar.p"
},
{
"path": "upsampler/scripts/align_all_parallel.py",
"chars": 7603,
"preview": "\"\"\"\r\nbrief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)\r\nauthor: lzhbrian (https://lzhbrian"
},
{
"path": "upsampler/scripts/calc_id_loss_parallel.py",
"chars": 3392,
"preview": "from argparse import ArgumentParser\r\nimport time\r\nimport numpy as np\r\nimport os\r\nimport json\r\nimport sys\r\nfrom PIL impor"
},
{
"path": "upsampler/scripts/calc_losses_on_images.py",
"chars": 2573,
"preview": "from argparse import ArgumentParser\r\nimport os\r\nimport json\r\nimport sys\r\nfrom tqdm import tqdm\r\nimport numpy as np\r\nimpo"
},
{
"path": "upsampler/scripts/download_ffhq1280.py",
"chars": 24759,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons\n# At"
},
{
"path": "upsampler/scripts/generate_sketch_data.py",
"chars": 1839,
"preview": "from torchvision import transforms\nfrom torchvision.utils import save_image\nfrom torch.utils.serialization import load_l"
},
{
"path": "upsampler/scripts/inference.py",
"chars": 5640,
"preview": "import os\r\nfrom argparse import Namespace\r\n\r\nfrom tqdm import tqdm\r\nimport time\r\nimport numpy as np\r\nimport torch\r\nfrom "
},
{
"path": "upsampler/scripts/pretrain.py",
"chars": 5565,
"preview": "import os\r\nimport sys\r\nimport torch\r\nimport dlib\r\nimport cv2\r\nimport PIL\r\nimport argparse\r\nfrom tqdm import tqdm\r\nimport"
},
{
"path": "upsampler/scripts/style_mixing.py",
"chars": 3639,
"preview": "import os\nfrom argparse import Namespace\n\nfrom tqdm import tqdm\nimport numpy as np\nfrom PIL import Image\nimport torch\nfr"
},
{
"path": "upsampler/scripts/train.py",
"chars": 667,
"preview": "\"\"\"\r\nThis file runs the main training/val loop\r\n\"\"\"\r\nimport os\r\nimport json\r\nimport sys\r\nimport pprint\r\n\r\nsys.path.appen"
},
{
"path": "upsampler/training/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "upsampler/training/coach.py",
"chars": 32846,
"preview": "import os\r\nimport matplotlib\r\nimport matplotlib.pyplot as plt\r\n\r\nmatplotlib.use('Agg')\r\n\r\nimport torch\r\nfrom torch impor"
},
{
"path": "upsampler/training/ranger.py",
"chars": 5899,
"preview": "# Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer.\n\n# https://"
},
{
"path": "upsampler/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "upsampler/utils/common.py",
"chars": 2699,
"preview": "import cv2\r\nimport numpy as np\r\nfrom PIL import Image\r\nimport matplotlib.pyplot as plt\r\n\r\n\r\n# Log images\r\ndef log_input_"
},
{
"path": "upsampler/utils/data_utils.py",
"chars": 670,
"preview": "\"\"\"\nCode adopted from pix2pixHD:\nhttps://github.com/NVIDIA/pix2pixHD/blob/master/data/image_folder.py\n\"\"\"\nimport os\n\nIMG"
},
{
"path": "upsampler/utils/inference_utils.py",
"chars": 6209,
"preview": "import numpy as np\nimport matplotlib.pyplot as plt\nfrom PIL import Image\nimport cv2\nimport random\nimport math\nimport arg"
},
{
"path": "upsampler/utils/train_utils.py",
"chars": 377,
"preview": "\ndef aggregate_loss_dict(agg_loss_dict):\n\tmean_vals = {}\n\tfor output in agg_loss_dict:\n\t\tfor key in output:\n\t\t\tmean_vals"
},
{
"path": "upsampler/utils/wandb_utils.py",
"chars": 1720,
"preview": "import datetime\nimport os\nimport numpy as np\nimport wandb\n\nfrom utils import common\n\n\nclass WBLogger:\n\n def __init__("
},
{
"path": "upsampler/video_editing.py",
"chars": 4987,
"preview": "import os\r\n#os.environ['CUDA_VISIBLE_DEVICES'] = \"0\"\r\n\r\nfrom models.psp import pSp\r\nimport torch\r\nimport dlib\r\nimport cv"
},
{
"path": "upsampler/webUI/app_task.py",
"chars": 16230,
"preview": "from __future__ import annotations\r\nfrom huggingface_hub import hf_hub_download\r\nimport numpy as np\r\nimport gradio as gr"
},
{
"path": "upsampler/webUI/styleganex_model.py",
"chars": 26310,
"preview": "from __future__ import annotations\r\nimport numpy as np\r\nimport gradio as gr\r\n\r\nimport os\r\nimport pathlib\r\nimport gc\r\nimp"
}
]
// ... and 10 more files (download for full content)
About this extraction
This page contains the full source code of the Songluchuan/AdaSR-TalkingHead GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 126 files (4.7 MB), approximately 1.2M tokens, and a symbol index with 785 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.