Full Code of semchan/HyperLips for AI

main 3276904865bb cached
70 files
87.3 MB
630.3k tokens
529 symbols
1 requests
Download .txt
Showing preview only (1,675K chars total). Download the full file or copy to clipboard to get everything.
Repository: semchan/HyperLips
Branch: main
Commit: 3276904865bb
Files: 70
Total size: 87.3 MB

Directory structure:
gitextract_hwq25lj6/

├── GFPGAN.py
├── Gen_hyperlipsbase_videos.py
├── HYPERLIPS.py
├── Inference_hyperlips.py
├── README.md
├── Train_data/
│   └── video_clips/
│       └── MEAD/
│           └── readme.txt
├── Train_hyperlipsBase.py
├── Train_hyperlipsHR.py
├── audio.py
├── checkpoint
├── checkpoints/
│   └── readme.txt
├── color_syncnet_trainv3.py
├── conv.py
├── datasets/
│   └── MEAD/
│       └── readme.txt
├── environment.yml
├── face_detection/
│   ├── README.md
│   ├── __init__.py
│   ├── api.py
│   ├── detection/
│   │   ├── __init__.py
│   │   ├── core.py
│   │   └── sfd/
│   │       ├── __init__.py
│   │       ├── bbox.py
│   │       ├── detect.py
│   │       ├── net_s3fd.py
│   │       ├── s3fd.pth
│   │       └── sfd_detector.py
│   ├── models.py
│   └── utils.py
├── face_parsing/
│   ├── README.md
│   ├── __init__.py
│   ├── model.py
│   ├── resnet.py
│   └── swap.py
├── filelists/
│   ├── train.txt
│   └── val.txt
├── filelists_lrs2/
│   ├── README.md
│   ├── test.txt
│   ├── train.txt
│   └── val.txt
├── filelists_mead/
│   ├── README.md
│   ├── test.txt
│   ├── train.txt
│   └── val.txt
├── gfpgan/
│   ├── gfpganv1_clean_arch.py
│   └── stylegan2_clean_arch.py
├── hparams.py
├── hparams_Base.py
├── hparams_HR.py
├── inference.py
├── models/
│   ├── __init__.py
│   ├── audio_v.py
│   ├── conv.py
│   ├── decoder.py
│   ├── deep_guided_filter.py
│   ├── gfpganv1_clean_arch.py
│   ├── guided_filter_pytorch/
│   │   ├── __init__.py
│   │   ├── box_filter.py
│   │   └── guided_filter.py
│   ├── hyperlayers.py
│   ├── hypernetwork.py
│   ├── layers.py
│   ├── lraspp.py
│   ├── memory.py
│   ├── mobilenetv3.py
│   ├── model.py
│   ├── model_hyperlips.py
│   ├── resnet.py
│   └── syncnet.py
├── preprocess.py
└── requirements.txt

================================================
FILE CONTENTS
================================================

================================================
FILE: GFPGAN.py
================================================
import cv2
import os
import torch
from basicsr.utils import img2tensor
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from torchvision.transforms.functional import normalize
import time
from gfpgan.gfpganv1_clean_arch import GFPGANv1Clean
import time
import numpy as np
import torch.nn.functional as F

def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
    """Convert torch Tensors into image numpy arrays.

    After clamping to [min, max], values will be normalized to [0, 1].

    Args:
        tensor (Tensor or list[Tensor]): Accept shapes:
            1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
            2) 3D Tensor of shape (3/1 x H x W);
            3) 2D Tensor of shape (H x W).
            Tensor channel should be in RGB order.
        rgb2bgr (bool): Whether to change rgb to bgr.
        out_type (numpy type): output types. If ``np.uint8``, transform outputs
            to uint8 type with range [0, 255]; otherwise, float type with
            range [0, 1]. Default: ``np.uint8``.
        min_max (tuple[int]): min and max values for clamp.

    Returns:
        (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
        shape (H x W). The channel order is BGR.
    """
    if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
        raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')

    result = []
    _tensor = tensor
    import time
    start = time.time()
    _tensor = _tensor.squeeze(0).float().detach().clamp_(*min_max)
    end = time.time()

    _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
    _tensor = (_tensor.permute(1, 2, 0))

    img_np = (_tensor * 255.0).round().cpu().numpy()[:, :, ::-1]


    img_np = img_np.astype(out_type)
    result.append(img_np)
    if len(result) == 1:
        result = result[0]
    end = time.time()

    return result

class GFPGANer():
    """Helper for restoration with GFPGAN.

    It will detect and crop faces, and then resize the faces to 512x512.
    GFPGAN is used to restored the resized faces.
    The background is upsampled with the bg_upsampler.
    Finally, the faces will be pasted back to the upsample background image.

    Args:
        model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
        upscale (float): The upscale of the final output. Default: 2.
        arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
        channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
        bg_upsampler (nn.Module): The upsampler for the background. Default: None.
    """

    def __init__(self, device,model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None):
        self.upscale = upscale
        self.bg_upsampler = bg_upsampler

        # initialize model
        self.device = device#torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
        # initialize the GFP-GAN
        if arch == 'clean':
            self.gfpgan = GFPGANv1Clean(
                out_size=512,
                num_style_feat=512,
                channel_multiplier=channel_multiplier,
                decoder_load_path=None,
                fix_decoder=False,
                num_mlp=8,
                input_is_latent=True,
                different_w=True,
                narrow=1,
                sft_half=True)

        self.face_helper = FaceRestoreHelper(
            upscale,
            face_size=512,
            crop_ratio=(1, 1),
            det_model='retinaface_resnet50',
            save_ext='png',
            use_parse=True,
            device=self.device)
        loadnet = torch.load(model_path, map_location=device)
        #loadnet = torch.load(model_path)
        if 'params_ema' in loadnet:
            keyname = 'params_ema'
        else:
            keyname = 'params'
        self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
        self.gfpgan.eval()
        self.gfpgan = self.gfpgan.to(self.device)
        print('GFPGAN model loaded')

    @torch.no_grad()
    def enhance_allimg(self, img, has_aligned=False, only_center_face=False, paste_back=True):
        self.face_helper.clean_all()
        import time
        start = time.time()
        if has_aligned:  # the inputs are already aligned
            img = cv2.resize(img, (512, 512))
            self.face_helper.cropped_faces = [img]
        else:
            self.face_helper.read_image(img)
            # get face landmarks for each face
            self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
            self.face_helper.align_warp_face()
        end = time.time()
        # face restoration
        start = time.time()
        for cropped_face in self.face_helper.cropped_faces:
            # prepare data
            start = time.time()
            cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
            normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
            cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
            end = time.time()

            try:
                output = self.gfpgan(cropped_face_t, return_rgb=False)[0]  # 15ms #NCHW
                restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))  # 18msms

            except RuntimeError as error:
                print(f'\tFailed inference for GFPGAN: {error}.')
                restored_face = cropped_face
            start = time.time()
            restored_face = restored_face.astype('uint8')
            self.face_helper.add_restored_face(restored_face)
            end = time.time()

        end = time.time()


        start = time.time()

        if not has_aligned and paste_back:
            # upsample the background
            if self.bg_upsampler is not None:
                # Now only support RealESRGAN for upsampling background
                bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
            else:
                bg_img = None

            self.face_helper.get_inverse_affine(None)
            # paste each restored face to the input image
            restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)

            return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img

        else:
            return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
        end = time.time()


    @torch.no_grad()
    def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
        self.face_helper.clean_all()
        if has_aligned:  # the inputs are already aligned

            # img = torch_resize(img)
            img = cv2.resize(img, (512, 512))
            self.face_helper.cropped_faces = [img]
        else:
            self.face_helper.read_image(img)
            # get face landmarks for each face
            self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
            self.face_helper.align_warp_face()

        start = time.time()
        for cropped_face in self.face_helper.cropped_faces:
            # prepare data
            start = time.time()
            cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
            normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
            cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)#([1, 3, 512, 512])
            end  = time.time()
            try:
                output = self.gfpgan(cropped_face_t, return_rgb=False)[0] #15ms #NCHW
                restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)) #18msms

            except RuntimeError as error:
                print(f'\tFailed inference for GFPGAN: {error}.')
                restored_face = cropped_face
            start = time.time()
            restored_face = restored_face.astype('uint8')
            self.face_helper.add_restored_face(restored_face)
            end  = time.time()
        if not has_aligned and paste_back:
            # upsample the background
            if self.bg_upsampler is not None:
                # Now only support RealESRGAN for upsampling background
                bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
            else:
                bg_img = None

            self.face_helper.get_inverse_affine(None)
            # paste each restored face to the input image
            restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
            return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
        else:
            return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
        end = time.time()
        print('paste_faces_to_input_image face: ', (end - start) * 1000)



def GFPGANInit(device,face_enhancement_path):
    """Inference demo for GFPGAN (for users).
    """
    upscale = 1

    # ------------------------ input & output ------------------------
    import numpy as np
    bg_upsampler = None
    # ------------------------ set up GFPGAN restorer ------------------------
    arch = 'clean'
    channel_multiplier = 2
    model_name = 'GFPGANv1.3'
    model_path = face_enhancement_path
    restorer = GFPGANer(
        device = device,
        model_path=model_path,
        upscale=upscale,
        arch=arch,
        channel_multiplier=channel_multiplier,
        bg_upsampler=bg_upsampler)
    return restorer

def GFPGANInfer(img, restorer, aligned):
    only_center_face = True
    start = time.time()
    if aligned:
        cropped_faces, restored_faces, restored_img = restorer.enhance(
                img, has_aligned=aligned, only_center_face=only_center_face, paste_back=True)
    else:
        cropped_faces, restored_faces, restored_img = restorer.enhance_allimg(
                img, has_aligned=aligned, only_center_face=only_center_face, paste_back=True)

    end = time.time()
    if aligned==False:
        return restored_img
    else:
        return restored_faces[0]




================================================
FILE: Gen_hyperlipsbase_videos.py
================================================
from HYPERLIPS import Hyperlips
import argparse
import os



parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using HyperLipsBase or HyperLipsHR models')
parser.add_argument('--checkpoint_path_BASE', type=str,help='Name of saved HyperLipsBase checkpoint to load weights from', default="checkpoints/require_grad_checkpoint_step000169000.pth")
parser.add_argument('--checkpoint_path_HR', type=str,help='Name of saved HyperLipsHR checkpoint to load weights from', default=None)
parser.add_argument('--videos', type=str,
                    help='Filepath of video/image that contains faces to use', default="datasets/MEAD")
parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.',
                    default='hyperlips_base_results')
parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
                    help='Padding (top, bottom, left, right). Please adjust to include chin at least')
parser.add_argument('--filter_window', default=None, type=int,
                    help='real window is 2*T+1')
parser.add_argument('--hyper_batch_size', type=int, help='Batch size for hyperlips model(s)', default=128)
parser.add_argument('--resize_factor', default=1, type=int,
                    help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')
parser.add_argument('--img_size', default=128, type=int)
parser.add_argument('--segmentation_path', type=str,
					help='Name of saved checkpoint of segmentation network', default="checkpoints/face_segmentation.pth")
parser.add_argument('--face_enhancement_path', type=str,
					help='Name of saved checkpoint of segmentation network', default=None)#"checkpoints/GFPGANv1.3.pth"
parser.add_argument('--no_faceenhance', default=False, action='store_true',
					help='Prevent using face enhancement')
parser.add_argument('--gpu_id', type=float, help='gpu id (default: 0)',
                    default=0, required=False)
args = parser.parse_args()


def inference_list():
    Hyperlips_executor = Hyperlips(checkpoint_path_BASE=args.checkpoint_path_BASE,
                                    checkpoint_path_HR=args.checkpoint_path_HR,
                                    segmentation_path=args.segmentation_path,
                                    face_enhancement_path = args.face_enhancement_path,
                                    gpu_id = args.gpu_id,
                                    window =args.filter_window,
                                    hyper_batch_size=args.hyper_batch_size,
                                    img_size = args.img_size,
                                    resize_factor = args.resize_factor,
                                    pad = args.pads)
    Hyperlips_executor._HyperlipsLoadModels()
    filelist = os.listdir(args.videos)
    for i in filelist:
        face = args.videos+"/"+i
        audio = args.videos+"/"+i
        outputfile = args.outfile+"/"+i
        Hyperlips_executor._HyperlipsInference(face,audio,outputfile)


if __name__ == '__main__':

    inference_list()

================================================
FILE: HYPERLIPS.py
================================================
import cv2, os, sys,audio
import subprocess, random, string
from tqdm import tqdm
import torch, face_detection
from models.model_hyperlips import HyperLips_inference
from GFPGAN import *
from face_parsing import init_parser,swap_regions
import shutil


def get_smoothened_boxes(boxes, T):
    for i in range(len(boxes)):
        if i + T > len(boxes):
            window = boxes[len(boxes) - T:]
        else:
            window = boxes[i : i + T]
        boxes[i] = np.mean(window, axis=0)
    return boxes

def face_detect(images, detector,pad):
    batch_size = 8

    while 1:
        predictions = []
        try:
            for i in range(0, len(images), batch_size):
                predictions.extend(
                    detector.get_detections_for_batch(np.array(images[i:i + batch_size])))  
        except RuntimeError as e:
            print(e)
            if batch_size == 1:
                raise RuntimeError(
                    'Image too big to run face detection on GPU. Please use the --resize_factor argument')
            batch_size //= 2
            print('Recovering from OOM error; New batch size: {}'.format(batch_size))
            continue
        break

    results = []
    pady1, pady2, padx1, padx2 = pad  # [0, 10, 0, 0]
    for rect, image in zip(predictions, images):
        if rect is None:
            raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')

        y1 = max(0, rect[1] - pady1)
        y2 = min(image.shape[0], rect[3] + pady2)
        x1 = max(0, rect[0] - padx1)
        x2 = min(image.shape[1], rect[2] + padx2)

        results.append([x1, y1, x2, y2])

    boxes = np.array(results)
    boxes = get_smoothened_boxes(boxes, T=5)
    results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]

    del detector
    return results

def datagen(mels, detector,frames,img_size,hyper_batch_size,pads):
    # img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
    img_batch, mel_batch, frame_batch, coords_batch,ref_batch = [], [], [], [],[]
    face_det_results = face_detect(frames,detector,pads)
    ref, _ = face_det_results[0].copy()
    ref =  cv2.resize(ref, (img_size, img_size))
    for i, m in enumerate(mels):
        frame_to_save = frames[i].copy()
        face, coords = face_det_results[i].copy()
        face = cv2.resize(face, (img_size, img_size))
        img_batch.append(face)
        mel_batch.append(m)
        frame_batch.append(frame_to_save)
        ref_batch.append(ref)
        coords_batch.append(coords)

        if len(img_batch) >= hyper_batch_size:
            # img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
            img_batch, mel_batch,ref_batch = np.asarray(img_batch), np.asarray(mel_batch), np.asarray(ref_batch)
            img_masked = img_batch.copy()
            img_masked[:, img_size // 2:] = 0

            img_batch = np.concatenate((img_masked, ref_batch), axis=3) / 255.
            mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
            
            yield img_batch, mel_batch, frame_batch, coords_batch
            # img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
            img_batch, mel_batch, frame_batch, coords_batch,ref_batch = [], [], [], [],[]

    if len(img_batch) > 0:
        # img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
        img_batch, mel_batch,ref_batch = np.asarray(img_batch), np.asarray(mel_batch), np.asarray(ref_batch)

        img_masked = img_batch.copy()
        img_masked[:, img_size // 2:] = 0

        img_batch = np.concatenate((img_masked, ref_batch), axis=3) / 255.
        mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])

        yield img_batch, mel_batch, frame_batch, coords_batch

    
def load_HyperLips(window,rescaling,path,path_hr,device):
    model = HyperLips_inference(window_T =window ,rescaling=rescaling,base_model_checkpoint=path,HRDecoder_model_checkpoint =path_hr)
    model = model.to(device)
    print("HyperLipsHR model loaded")
    return model.eval()
    
def main():
    Hyperlips_executor = Hyperlips()
    Hyperlips_executor._HyperlipsLoadModels()
    Hyperlips_executor._HyperlipsInference()

class Hyperlips():
    def __init__(self,checkpoint_path_BASE=None,
                 checkpoint_path_HR=None,
                 segmentation_path=None,
                 face_enhancement_path = None,
                 gpu_id = None,
                 window =None,
                 hyper_batch_size=128,
                 img_size = 128,
                 resize_factor = 1,
                 pad = [0, 10, 0, 0]
                 ):
        self.checkpoint_path_BASE = checkpoint_path_BASE
        self.checkpoint_path_HR = checkpoint_path_HR
        self.parser_path = segmentation_path
        self.face_enhancement_path = face_enhancement_path
        self.batch_size = hyper_batch_size #128
        self.mel_step_size = 16
        self.gpu_id = gpu_id
        self.img_size = img_size
        self.resize_factor = resize_factor
        self.pad =pad
        if (128==self.img_size):
            self.rescaling = 1
        elif(256==self.img_size):
             self.rescaling = 2
        elif(512==self.img_size):
            self.rescaling = 4
        else:
            raise ValueError(
                f'Init error! img_size should be 128 256 or 512!')
        self.window = window

    def _HyperlipsLoadModels(self):
        gpu_id = self.gpu_id
        if not torch.cuda.is_available() or (gpu_id > (torch.cuda.device_count() - 1)):
            raise ValueError(
                f'Existing gpu configuration problem.(gpu.is_available={torch.cuda.is_available()}| gpu.device_count={torch.cuda.device_count()})')
        self.device = torch.device(f'cuda:{gpu_id}')
        print('Using {} for inference.'.format(self.device))
        if self.face_enhancement_path is not None:
            self.restorer = GFPGANInit(self.device, self.face_enhancement_path)
        self.model = load_HyperLips(self.window,self.rescaling,self.checkpoint_path_BASE, self.checkpoint_path_HR,self.device)
        self.seg_net = init_parser(self.parser_path, self.device)
        print(' models init successed...')

    def _HyperlipsInference(self,face_path,audio_path,outfile_path):
        face = face_path
        audiopath =audio_path
        print("The input video path is {}, The intput audio path is {}".format(face_path, audio_path))

        outfile =outfile_path
        outfile = os.path.abspath(outfile)
        rest_root_path = os.path.dirname(os.path.realpath(outfile))
        temp_save_path = outfile.rsplit('.', 1)[0]
        if not os.path.exists(rest_root_path):
            os.mkdir(rest_root_path)
        if not os.path.exists(temp_save_path):
            os.mkdir(temp_save_path)
        detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,flip_input=False, device='cuda:{}'.format(self.gpu_id))

        if not os.path.isfile(face):
            raise ValueError('--face argument must be a valid path to video/image file')
        else:
            video_stream = cv2.VideoCapture(face)
            fps = video_stream.get(cv2.CAP_PROP_FPS)
            frame_width = int(video_stream.get(cv2.CAP_PROP_FRAME_WIDTH))
            frame_height = int(video_stream.get(cv2.CAP_PROP_FRAME_HEIGHT))
            full_frames = []
            while 1:
                still_reading, frame = video_stream.read()
                if not still_reading:
                    video_stream.release()
                    break
                if self.resize_factor > 1:
                    frame = cv2.resize(frame, (frame.shape[1]//self.resize_factor, frame.shape[0]//self.resize_factor))
                full_frames.append(frame)
            video_stream.release()
        print ("Number of frames available for inference: "+str(len(full_frames)))
        out = cv2.VideoWriter(os.path.join(temp_save_path, 'result.avi'), cv2.VideoWriter_fourcc(*'DIVX'),
                                      fps, (frame_width, frame_height))

        if not audiopath.endswith('.wav'):
            print('Extracting raw audio...')

            command = 'ffmpeg -y -i {} -strict -2 {}'.format(
                audiopath, os.path.join(temp_save_path, 'temp.wav'))
            subprocess.call(command, shell=True)
            audiopath = os.path.join(temp_save_path, 'temp.wav')
        wav = audio.load_wav(audiopath, 16000)
        mel = audio.melspectrogram(wav)
        if np.isnan(mel.reshape(-1)).sum() > 0:
            raise ValueError(
                'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
        mel_chunks = []
        mel_idx_multiplier = 80. / fps
        i = 0
        while 1:
            start_idx = int(i * mel_idx_multiplier)
            if start_idx + self.mel_step_size > len(mel[0]):
                mel_chunks.append(mel[:, len(mel[0]) - self.mel_step_size:])
                break
            mel_chunks.append(mel[:, start_idx: start_idx + self.mel_step_size])
            i += 1
        print("Length of mel chunks: {}".format(len(mel_chunks)))
        full_frames = full_frames[:len(mel_chunks)]
        gen = datagen(mel_chunks, detector, full_frames, self.img_size,self.batch_size,self.pad)
        for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
                                                                        total=int(
                                                                            np.ceil(
                                                                                float(len(mel_chunks))/ self.batch_size)))):

            img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(self.device)
            mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(self.device)
            with torch.no_grad():
                pred = self.model(mel_batch, img_batch)
            for p, f, c in zip(pred, frames, coords):

                y1, y2, x1, x2 = c
                mask_temp = np.zeros_like(f)
                p = p.cpu().numpy().transpose(1,2,0) * 255.
                f_background = f.copy()
                p,mask_out = swap_regions(f[y1:y2, x1:x2], p, self.seg_net) #
                p = cv2.resize(p, (x2 - x1, y2 - y1)).astype(np.uint8)
                mask_out=mask_out*255
                mask_out[:mask_out.shape[0]//2, :, :] = 0.
                mask_out[:,:int(mask_out.shape[1]*0.15),:] = 0.
                mask_out[:,int(mask_out.shape[1]*0.85):,:] = 0.
                mask_temp[y1:y2, x1:x2] = mask_out.astype(np.float)
                kernel = np.ones((5,5),np.uint8)  
                mask_temp = cv2.erode(mask_temp,kernel,iterations = 1)
                mask_temp = cv2.GaussianBlur(mask_temp, (75, 75), 0,0,cv2.BORDER_DEFAULT) 
                mask_temp = mask_temp.astype(np.float)
                # cv2.imwrite("mask_temp.jpg", mask_temp)
                f[y1:y2, x1:x2] = p
                # cv2.imwrite("f00.jpg", f)
                f = f_background*(1-mask_temp/255.0)+f*(mask_temp/255.0)
                # cv2.imwrite("f0.jpg", f)
                if self.face_enhancement_path is not None:
                    Code_img = GFPGANInfer(f, self.restorer,aligned=False) 
                    f=Code_img                
                # cv2.imwrite("f1.jpg", f)
                # f = f_background*(1-mask_temp/255.0)+f*(mask_temp/255.0)
                f = f.astype(np.uint8)
                out.write(f)

        out.release()
        command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(
            audiopath, os.path.join(temp_save_path, 'result.avi'), outfile)
        subprocess.call(command, shell=True)
        if os.path.exists(temp_save_path):
            shutil.rmtree(temp_save_path)

        torch.cuda.empty_cache()

if __name__ == '__main__':
    main()


================================================
FILE: Inference_hyperlips.py
================================================
import cv2, os, sys, argparse, audio
import subprocess, random, string
from tqdm import tqdm
import torch, face_detection
from models.model_hyperlips import HyperLipsBase,HyperLipsHR
from GFPGAN import *
from face_parsing import init_parser, swap_regions_img
import shutil


parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using HyperLipsBase or HyperLipsHR models')

parser.add_argument('--checkpoint_path_BASE', type=str,help='Name of saved HyperLipsBase checkpoint to load weights from', default="checkpoints/hyperlipsbase_mead.pth")
parser.add_argument('--checkpoint_path_HR', type=str,help='Name of saved HyperLipsHR checkpoint to load weights from', default="checkpoints/hyperlipshr_mead_128.pth")
parser.add_argument('--modelname', type=str,
                    help='Choosing HyperLipsBase or HyperLipsHR', default="HyperLipsHR")
parser.add_argument('--face', type=str,
                    help='Filepath of video/image that contains faces to use', default="test/video/M003-002.mp4")
parser.add_argument('--audio', type=str,
                    help='Filepath of video/audio file to use as raw audio source', default="test/audio/M003-002.mp4")
parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.',
                    default='results/result_voice.mp4')
parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
                    help='Padding (top, bottom, left, right). Please adjust to include chin at least')
parser.add_argument('--filter_window', default=2, type=int,
                    help='real window is 2*T+1')
parser.add_argument('--face_det_batch_size', type=int,
                    help='Batch size for face detection', default=8)
parser.add_argument('--hyper_batch_size', type=int, help='Batch size for hyperlips model(s)', default=128)

parser.add_argument('--resize_factor', default=1, type=int,
                    help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')
parser.add_argument('--segmentation_path', type=str,
					help='Name of saved checkpoint of segmentation network', default="checkpoints/face_segmentation.pth")
parser.add_argument('--face_enhancement_path', type=str,
					help='Name of saved checkpoint of segmentation network', default="checkpoints/GFPGANv1.3.pth")
parser.add_argument('--no_faceenhance', default=True, action='store_true',
					help='Prevent using face enhancement')
parser.add_argument('--gpu_id', type=float, help='gpu id (default: 0)',
                    default=0, required=False)
args = parser.parse_args()
args.img_size = 128


def get_smoothened_mels(mel_chunks, T):
    for i in range(len(mel_chunks)):
        if i > T-1 and i<len(mel_chunks)-T:
            window = mel_chunks[i-T: i + T]
            mel_chunks[i] = np.mean(window, axis=0)
        else:
            mel_chunks[i] = mel_chunks[i]
        
    return mel_chunks

def face_detect(images, detector,pad):
    batch_size = 16
    if len(images) > 1:
        print('error')
        raise RuntimeError('leng(imgaes')
    while 1:
        predictions = []
        try:
            for i in range(0, len(images), batch_size):
                predictions.extend(
                    detector.get_detections_for_batch(np.array(images[i:i + batch_size])))  
        except RuntimeError as e:
            print(e)
            if batch_size == 1:
                raise RuntimeError(
                    'Image too big to run face detection on GPU. Please use the --resize_factor argument')
            batch_size //= 2
            print('Recovering from OOM error; New batch size: {}'.format(batch_size))
            continue
        break

    results = []
    pady1, pady2, padx1, padx2 = pad  # [0, 10, 0, 0]
    for rect, image in zip(predictions, images):
        if rect is None:
            raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')

        y1 = max(0, rect[1] - pady1)
        y2 = min(image.shape[0], rect[3] + pady2)
        x1 = max(0, rect[0] - padx1)
        x2 = min(image.shape[1], rect[2] + padx2)

        results.append([x1, y1, x2, y2])

    boxes = np.array(results)
    results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]

    del detector
    return results

def datagen(mels, detector,face_path, resize_factor):
    img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
    bbox_face, frame_to_det_list, rects, frame_to_det_batch = [], [], [], []
    img_size = 128
    hyper_batch_size = args.hyper_batch_size
    reader = read_frames(face_path, resize_factor)
    for i, m in enumerate(mels):
        try:
            frame_to_save = next(reader)
        except StopIteration:
            reader = read_frames(face_path, resize_factor)
            frame_to_save = next(reader)
        h, w, _ = frame_to_save.shape
        face, coords = face_detect([frame_to_save], detector,args.pads)[0] 
        face = cv2.resize(face, (img_size, img_size))
        img_batch.append(face)
        mel_batch.append(m)
        frame_batch.append(frame_to_save)
        coords_batch.append(coords)

        if len(img_batch) >= hyper_batch_size:
            img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)

            img_masked = img_batch.copy()
            img_masked[:, img_size // 2:] = 0

            img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
            mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
            
            yield img_batch, mel_batch, frame_batch, coords_batch
            img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []

    if len(img_batch) > 0:
        img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)

        img_masked = img_batch.copy()
        img_masked[:, img_size // 2:] = 0

        img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
        mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])

        yield img_batch, mel_batch, frame_batch, coords_batch

def _load(checkpoint_path, device):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    return checkpoint


def load_HyperLipsHR(path,path_hr,device):
    model = HyperLipsHR(window_T =args.filter_window ,rescaling=1,base_model_checkpoint=path,HRDecoder_model_checkpoint =path_hr)
    model = model.to(device)
    print("HyperLipsHR model loaded")
    return model.eval()

def load_HyperLipsBase(path, device):
    model = HyperLipsBase()
    checkpoint = _load(path, device)
    s = checkpoint["state_dict"]
    model.load_state_dict(s)
    model = model.to(device)
    print("HyperLipsBase model loaded")
    return model.eval()

def read_frames(face_path, resize_factor):
    video_stream = cv2.VideoCapture(face_path)

    print('Reading video frames from start...')
    read_frames_index = 0
    while 1:
        still_reading, frame = video_stream.read()
        if not still_reading:
            video_stream.release()
            break
        if resize_factor > 1:
            frame = cv2.resize(frame, (frame.shape[1] // resize_factor, frame.shape[0] // resize_factor))
        yield frame

def main():
    Hyperlips_executor = Hyperlips()
    Hyperlips_executor._HyperlipsLoadModels()
    Hyperlips_executor._HyperlipsInference()

class Hyperlips():
    def __init__(self):
        self.checkpoint_path_BASE = args.checkpoint_path_BASE
        self.checkpoint_path_HR = args.checkpoint_path_HR
        self.parser_path = args.segmentation_path
        self.batch_size = args.hyper_batch_size #128
        self.mel_step_size = 16

    def _HyperlipsLoadModels(self):
        gpu_id = args.gpu_id
        if not torch.cuda.is_available() or (gpu_id > (torch.cuda.device_count() - 1)):
            raise ValueError(
                f'Existing gpu configuration problem.(gpu.is_available={torch.cuda.is_available()}| gpu.device_count={torch.cuda.device_count()})')
        self.device = torch.device(f'cuda:{gpu_id}')
        print('Using {} for inference.'.format(self.device))
        self.restorer = GFPGANInit(self.device, args.face_enhancement_path)
        if args.modelname == "HyperLipsBase":
            self.model = load_HyperLipsBase(self.checkpoint_path_BASE, self.device)
        elif args.modelname == "HyperLipsHR":
            self.model = load_HyperLipsHR(self.checkpoint_path_BASE, self.checkpoint_path_HR,self.device)
        self.seg_net = init_parser(self.parser_path, self.device)
        print(' models init successed...')

    def _HyperlipsInference(self):
        face = args.face
        audiopath = args.audio
        print("The input video path is {}, The output audio path is {}".format(face, audiopath))

        outfile = args.outfile
        outfile = os.path.abspath(outfile)
        rest_root_path = os.path.dirname(os.path.realpath(outfile))
        temp_save_path = outfile.rsplit('.', 1)[0]
        # rest_root_path = '/'.join(outfile.split('/')[:-1])
        # temp_save_path = os.path.join(rest_root_path, outfile.split('/')[-1][:-4])
        if not os.path.exists(rest_root_path):
            os.mkdir(rest_root_path)
        if not os.path.exists(temp_save_path):
            os.mkdir(temp_save_path)
        detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,flip_input=False, device='cuda:{}'.format(args.gpu_id))

        if not os.path.isfile(face):
            raise ValueError('--face argument must be a valid path to video/image file')
        else:
            video_stream = cv2.VideoCapture(face)
            fps = video_stream.get(cv2.CAP_PROP_FPS)
            frame_width = int(video_stream.get(cv2.CAP_PROP_FRAME_WIDTH))
            frame_height = int(video_stream.get(cv2.CAP_PROP_FRAME_HEIGHT))
            video_stream.release()

        out = cv2.VideoWriter(os.path.join(temp_save_path, 'result.avi'), cv2.VideoWriter_fourcc(*'DIVX'),
                                      fps, (frame_width, frame_height))

        if not audiopath.endswith('.wav'):
            print('Extracting raw audio...')

            command = 'ffmpeg -y -i {} -strict -2 {}'.format(
                audiopath, os.path.join(temp_save_path, 'temp.wav'))
            subprocess.call(command, shell=True)
            audiopath = os.path.join(temp_save_path, 'temp.wav')
        wav = audio.load_wav(audiopath, 16000)
        mel = audio.melspectrogram(wav)
        if np.isnan(mel.reshape(-1)).sum() > 0:
            raise ValueError(
                'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
        mel_chunks = []
        mel_idx_multiplier = 80. / fps
        i = 0
        while 1:
            start_idx = int(i * mel_idx_multiplier)
            if start_idx + self.mel_step_size > len(mel[0]):
                mel_chunks.append(mel[:, len(mel[0]) - self.mel_step_size:])
                break
            mel_chunks.append(mel[:, start_idx: start_idx + self.mel_step_size])
            i += 1
        if not (args.filter_window == None):
            mel_chunks = get_smoothened_mels(mel_chunks,T=args.filter_window)
        print("Length of mel chunks: {}".format(len(mel_chunks)))

        gen = datagen(mel_chunks, detector, face, args.resize_factor)
        for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
                                                                        total=int(
                                                                            np.ceil(
                                                                                float(len(mel_chunks))/ self.batch_size)))):

            img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(self.device)#([122, 6, 96, 96])
            mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(self.device)
            with torch.no_grad():
                pred = self.model(mel_batch, img_batch)  # mel_batch([122, 1, 80, 16]) img_batch([128, 6, 128, 128])
            for p, f, c in zip(pred, frames, coords):
                y1, y2, x1, x2 = c
                mask_temp = np.zeros_like(f)
                p = p.cpu().numpy().transpose(1,2,0) * 255.

                if not args.no_faceenhance: 
                    ori_f = f.copy()
                    p = cv2.resize(p, (x2 - x1, y2 - y1)).astype(np.uint8)
                    f[y1:y2, x1:x2] = p
                    Code_img = GFPGANInfer(f, self.restorer, aligned=False)  # 33ms
                    p,mask_out = swap_regions_img(ori_f[y1:y2, x1:x2], Code_img[y1:y2, x1:x2], self.seg_net)
                    p = cv2.resize(p, (x2 - x1, y2 - y1)).astype(np.uint8)
                    mask_out = cv2.resize(mask_out.astype(np.float)*255.0, (x2 - x1, y2 - y1)).astype(np.uint8)
                    f[y1:y2, x1:x2] = p
                        
                else:
                    p,mask_out = swap_regions_img(f[y1:y2, x1:x2], p, self.seg_net)
                    p = cv2.resize(p, (x2 - x1, y2 - y1)).astype(np.uint8)
                    mask_out = cv2.resize(mask_out.astype(np.float)*255.0, (x2 - x1, y2 - y1)).astype(np.uint8)
                        

                mask_temp[y1:y2, x1:x2] = mask_out
                kernel = np.ones((5,5),np.uint8)  
                mask_temp = cv2.erode(mask_temp,kernel,iterations = 1)
                mask_temp = cv2.GaussianBlur(mask_temp, (75, 75), 0,0,cv2.BORDER_DEFAULT) 
                f_background = f.copy()
                f[y1:y2, x1:x2] = p
                f = f_background*(1-mask_temp/255.0)+f*(mask_temp/255.0)
                f = f.astype(np.uint8)
                out.write(f)

        out.release()
        outfile_dfl = os.path.join(rest_root_path, args.outfile.split('/')[-1]) 
        command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(
            audiopath, os.path.join(temp_save_path, 'result.avi'), outfile_dfl)
        subprocess.call(command, shell=True)
        if os.path.exists(temp_save_path):
            shutil.rmtree(temp_save_path)

        torch.cuda.empty_cache()

if __name__ == '__main__':
    main()


================================================
FILE: README.md
================================================
# HyperLips: Hyper Control Lips with High Resolution Decoder for Talking Face Generation
Pytorch official implementation for our  paper "HyperLips: Hyper Control Lips with High Resolution Decoder for Talking Face Generation".

<img src='./hyperlips_net.png' width=900>

[[Paper]](https://arxiv.org/abs/2310.05720) [[Demo Video]](https://www.youtube.com/watch?v=j4GdJoTF0wY)

## Requirements
- Python 3.8.16
- torch 1.10.1+cu113
- torchvision 0.11.2+cu113
- ffmpeg

We recommend to install [pytorch](https://pytorch.org/) firstly,and then install related toolkit by running
```
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple
```
You also can use environment.yml to install relevant environment by running
```
conda env create -f environment.yml
```
## Getting the weights
Download the pre-trained models from [BaiduYun](https://pan.baidu.com/s/1wy986BiROq5bkXweHxSvVA?pwd=6666 ),and place them to the folder `checkpoints`

## Inference
We trained a pretrained model on LRS2 dataset.You can quickly try it by running:
```
python inference.py --checkpoint_path_BASE=checkpoints/hyperlipsbase_lrs2.pth 
```
The result is saved (by default) in `results/result_video.mp4`. To inference on other videos, please specify the `--face` and `--audio` option and see more details in code.

## Train
### 1.Download MEAD dateset
Our models are trained on MEAD. Please go to the [MEAD](https://www.robots.ox.ac.uk/~vgg/data/lip_reading/lrs2.html) website to download the dataset. We select videos with neutral emotion and frontal view as MEAD-Neutral dataset and resample all split videos into 25fps by using [software](http://www.pcfreetime.com/formatfactory/cn/index.html). All the videos after resampling are put in to `datasets/MEAD/`.
The folder structure of MEAD-Neutral dataset is as follow.
```
data_root (datasets)
├── name of dataset(MEAD)
|	├── videos ending with(.mp4)
```

### 2.Preprocess for hyperlips_base
extract the face images and raw audio from video files and generate filelists obtaining `train.txt` and `val.txt` by running:
```
python preprocess.py --origin_data_root=datasets/MEAD --clip_flag=0 --Function=base --hyperlips_train_dataset=Train_data
```
### 3.Train lipsync expert
train the lipsync expert by running:
```
python color_syncnet_trainv3.py --data_root=Train_data/imgs  --checkpoint_dir=checkpoints_lipsync_expert
```
You can use the pre-trained weights saved at `checkpoints/pretrain_sync_expert.pth`  if you want to skip this step.

### 4.Train hyperlips base
train the hyperlips base by running:
```
python Train_hyperlipsBase.py --data_root=Train_data/imgs  --checkpoint_dir=checkpoints_hyperlips_base --syncnet_checkpoint_path=checkpoints/pretrain_sync_expert.pth
```
### 5.Generate hyperlips base videos
generate hyperlips base videos by running:
```
python Gen_hyperlipsbase_videos.py --checkpoint_path_BASE=checkpoints_hyperlips_base/xxxxxxxxx.pth --video=datasets --outfile=hyperlips_base_results
```
### 6.preprocess for hyperlips_HR
extract image, sketch and lip mask from origin videos and extract image and sketch from videos generated from hyperlips base videos by running:
```
python preprocess.py --origin_data_root=datasets/MEAD --Function=HR --hyperlips_train_dataset=Train_data --hyperlipsbase_video_root=hyperlips_base_results 
```
### 7.Train hyperlips HR
train hyperlips HR by running:
```
python Train_hyperlipsHR.py -hyperlips_trian_dataset=Train_data/HR_Train_Dateset --checkpoint_dir=checkpoints_hyperlips_HR --batch_size=28 --img_size=128
```
You can also train HR_256 and HR_512 by changing `--img_size`.More details can be seen in code.


## Acknowledgement
This project is built upon the publicly available code [Wav2Lip](https://github.com/Rudrabha/Wav2Lip/tree/master) and [IP_LAP](https://github.com/Weizhi-Zhong/IP_LAP). Thank the authors of these works for making their excellent work and codes publicly available.


## Citation and Star
Please cite the following paper and star this project if you use this repository in your research. Thank you!
```
@InProceedings{
    author    = {Yaosen Chen, Yu Yao, Zhiqiang Li, Wei Wang, Yanru Zhang, Han Yang, Xuming Wen},
    title     = {HyperLips: Hyper Control Lips with High Resolution Decoder for Talking Face Generation},
    year      = {2023},

}
```


================================================
FILE: Train_data/video_clips/MEAD/readme.txt
================================================
Put video here.

================================================
FILE: Train_hyperlipsBase.py
================================================
from os.path import dirname, join, basename, isfile
from tqdm import tqdm
from models import SyncNet_color as SyncNet
from models.model_hyperlips import HyperLipsBase, HyperCtrolDiscriminator
import audio

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torch.backends.cudnn as cudnn
from torch.utils import data as data_utils
import numpy as np
from glob import glob
import os, random, cv2, argparse
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
from hparams_Base import hparams, get_image_list

parser = argparse.ArgumentParser(description='Code to train the Hyperbase model WITH the visual quality discriminator')
parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", default='Train_data/imgs')
parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', default="checkpoints_hyperlips_base", type=str)
parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained audio-visual sync module', default="checkpoints/pretrain_sync_expert.pth", type=str)
parser.add_argument('--checkpoint_path', help='Resume generator from this checkpoint', default=None, type=str)
parser.add_argument('--disc_checkpoint_path', help='Resume quality disc from this checkpoint', default=None, type=str)
args = parser.parse_args()


global_step = 0
global_epoch = 0
use_cuda = torch.cuda.is_available()
print('use_cuda: {}'.format(use_cuda))

syncnet_T = 5
syncnet_mel_step_size = 16

class Dataset(object):
    def __init__(self, split):
        self.all_videos = get_image_list(args.data_root, split)

    def get_frame_id(self, frame):
        return int(basename(frame).split('.')[0])

    def get_window(self, start_frame):
        start_id = self.get_frame_id(start_frame)
        vidname = dirname(start_frame)

        window_fnames = []
        for frame_id in range(start_id, start_id + syncnet_T):
            frame = join(vidname, '{}.jpg'.format(frame_id))
            if not isfile(frame):
                return None
            window_fnames.append(frame)
        return window_fnames

    def read_window(self, window_fnames):
        if window_fnames is None: return None
        window = []
        for fname in window_fnames:
            img = cv2.imread(fname)
            if img is None:
                return None
            try:
                img = cv2.resize(img, (hparams.img_size, hparams.img_size))
            except Exception as e:
                return None

            window.append(img)

        return window

    def crop_audio_window(self, spec, start_frame):
        if type(start_frame) == int:
            start_frame_num = start_frame
        else:
            start_frame_num = self.get_frame_id(start_frame)
        start_idx = int(80. * (start_frame_num / float(hparams.fps)))
        
        end_idx = start_idx + syncnet_mel_step_size

        return spec[start_idx : end_idx, :]

    def get_segmented_mels(self, spec, start_frame):
        mels = []
        assert syncnet_T == 5
        start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing
        if start_frame_num - 2 < 0: return None
        for i in range(start_frame_num, start_frame_num + syncnet_T):
            m = self.crop_audio_window(spec, i - 2)
            if m.shape[0] != syncnet_mel_step_size:
                return None
            mels.append(m.T)

        mels = np.asarray(mels)

        return mels

    def prepare_window(self, window):
        # 3 x T x H x W
        x = np.asarray(window) / 255.
        x = np.transpose(x, (3, 0, 1, 2))

        return x

    def __len__(self):
        return len(self.all_videos)

    def __getitem__(self, idx):
        while 1:
            idx = random.randint(0, len(self.all_videos) - 1)
            vidname = self.all_videos[idx]
            img_names = list(glob(join(vidname, '*.jpg')))
            if len(img_names) <= 3 * syncnet_T:
                continue
            
            img_name = random.choice(img_names)
            wrong_img_name = random.choice(img_names)
            while wrong_img_name == img_name:
                wrong_img_name = random.choice(img_names)

            window_fnames = self.get_window(img_name)
            wrong_window_fnames = self.get_window(wrong_img_name)
            if window_fnames is None or wrong_window_fnames is None:
                continue

            window = self.read_window(window_fnames)
            if window is None:
                continue

            wrong_window = self.read_window(wrong_window_fnames)
            if wrong_window is None:
                continue

            try:
                wavpath = join(vidname, "audio.wav")
                wav = audio.load_wav(wavpath, hparams.sample_rate)

                orig_mel = audio.melspectrogram(wav).T
            except Exception as e:
                continue

            mel = self.crop_audio_window(orig_mel.copy(), img_name)
            
            if (mel.shape[0] != syncnet_mel_step_size):
                continue

            indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name)
            if indiv_mels is None: continue

            window = self.prepare_window(window)
            y = window.copy()
            window[:, :, window.shape[2]//2:] = 0.

            wrong_window = self.prepare_window(wrong_window)
            x = np.concatenate([window, wrong_window], axis=0)

            x = torch.FloatTensor(x)
            mel = torch.FloatTensor(mel.T).unsqueeze(0)
            indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1)
            y = torch.FloatTensor(y)
            return x, indiv_mels, mel, y

def save_sample_images(x, g, gt, global_step, checkpoint_dir):
    x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
    g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
    gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)

    refs, inps = x[..., 3:], x[..., :3]
    folder = join(checkpoint_dir, "samples_step{:09d}".format(global_step))
    if not os.path.exists(folder): os.mkdir(folder)
    collage = np.concatenate((refs, inps, g, gt), axis=-2)
    for batch_idx, c in enumerate(collage):
        for t in range(len(c)):
            cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t])

logloss = nn.BCELoss()
def cosine_loss(a, v, y):
    d = nn.functional.cosine_similarity(a, v)
    loss = logloss(d.unsqueeze(1), y)

    return loss

device = torch.device("cuda" if use_cuda else "cpu")
syncnet = SyncNet().to(device)
for p in syncnet.parameters():
    p.requires_grad = False

recon_loss = nn.L1Loss()
def get_sync_loss(mel, g):
    g = g[:, :, :, g.size(3)//2:]
    g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1)
    # B, 3 * T, H//2, W
    g = torch.nn.functional.interpolate(g,(64, 128), mode='bilinear', align_corners=False)#[1, 3, 64, 128]
    
    
    a, v = syncnet(mel, g)
    y = torch.ones(g.size(0), 1).float().to(device)
    return cosine_loss(a, v, y)

def train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer,
          checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
    global global_step, global_epoch
    resumed_step = global_step

    while global_epoch < nepochs:
        print('Starting Epoch: {}'.format(global_epoch))
        running_sync_loss, running_l1_loss, disc_loss, running_perceptual_loss = 0., 0., 0., 0.
        running_disc_real_loss, running_disc_fake_loss = 0., 0.
        prog_bar = tqdm(enumerate(train_data_loader))
        for step, (x, indiv_mels, mel, gt) in prog_bar:
            disc.train()
            model.train()

            x = x.to(device)#([2, 6, 5, 512, 512])
            mel = mel.to(device)#([2, 1, 80, 16])
            indiv_mels = indiv_mels.to(device)#([2, 5, 1, 80, 16])
            gt = gt.to(device)#([2, 3, 5, 512, 512])

            ### Train generator now. Remove ALL grads. 
            optimizer.zero_grad()
            disc_optimizer.zero_grad()

            g = model(indiv_mels, x)#([2, 3, 5, 512, 512])[B,C,T,H,W]

            if hparams.syncnet_wt > 0.:
                sync_loss = get_sync_loss(mel, g)
            else:
                sync_loss = 0.

            if hparams.disc_wt > 0.:
                perceptual_loss = disc.perceptual_forward(g)
            else:
                perceptual_loss = 0.

            l1loss = recon_loss(g, gt)

            loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \
                                    (1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss

            loss.backward()
            optimizer.step()

            ### Remove all gradients before Training disc
            disc_optimizer.zero_grad()

            pred = disc(gt)#([2, 3, 5, 512, 512])->([10, 1])
            disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device))
            disc_real_loss.backward()

            pred = disc(g.detach())#([2, 3, 5, 512, 512])->([10, 1])
            disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device))
            disc_fake_loss.backward()

            disc_optimizer.step()

            running_disc_real_loss += disc_real_loss.item()
            running_disc_fake_loss += disc_fake_loss.item()

            if global_step % checkpoint_interval == 0:
                save_sample_images(x, g, gt, global_step, checkpoint_dir)

            # Logs
            global_step += 1
            cur_session_steps = global_step - resumed_step

            running_l1_loss += l1loss.item()
            if hparams.syncnet_wt > 0.:
                running_sync_loss += sync_loss.item()
            else:
                running_sync_loss += 0.

            if hparams.disc_wt > 0.:
                running_perceptual_loss += perceptual_loss.item()
            else:
                running_perceptual_loss += 0.

            if global_step == 1 or global_step % checkpoint_interval == 0:
                save_checkpoint(
                    model, optimizer, global_step, checkpoint_dir, global_epoch)
                save_checkpoint(disc, disc_optimizer, global_step, checkpoint_dir, global_epoch, prefix='disc_')

            # eval_model(test_data_loader, global_step, device, model, disc)
            if global_step % hparams.eval_interval == 0:
                with torch.no_grad():
                    average_sync_loss = eval_model(test_data_loader, global_step, device, model, disc)

                    if average_sync_loss < .75:
                        hparams.set_hparam('syncnet_wt', 0.03)

            prog_bar.set_description('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(running_l1_loss / (step + 1),
                                                                                        running_sync_loss / (step + 1),
                                                                                        running_perceptual_loss / (step + 1),
                                                                                        running_disc_fake_loss / (step + 1),
                                                                                        running_disc_real_loss / (step + 1)))

        global_epoch += 1

def eval_model(test_data_loader, global_step, device, model, disc):
    eval_steps = 300
    print('Evaluating for {} steps'.format(eval_steps))
    running_sync_loss, running_l1_loss, running_disc_real_loss, running_disc_fake_loss, running_perceptual_loss = [], [], [], [], []
    while 1:
        for step, (x, indiv_mels, mel, gt) in enumerate((test_data_loader)):
            model.eval()
            disc.eval()

            x = x.to(device)
            mel = mel.to(device)
            indiv_mels = indiv_mels.to(device)
            gt = gt.to(device)

            pred = disc(gt)
            disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device))

            g = model(indiv_mels, x)
            pred = disc(g)
            disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device))

            running_disc_real_loss.append(disc_real_loss.item())
            running_disc_fake_loss.append(disc_fake_loss.item())

            sync_loss = get_sync_loss(mel, g)
            
            if hparams.disc_wt > 0.:
                # perceptual_loss = disc.module.perceptual_forward(g)
                perceptual_loss = disc.perceptual_forward(g)
            else:
                perceptual_loss = 0.

            l1loss = recon_loss(g, gt)

            loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \
                                    (1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss

            running_l1_loss.append(l1loss.item())
            running_sync_loss.append(sync_loss.item())
            
            if hparams.disc_wt > 0.:
                running_perceptual_loss.append(perceptual_loss.item())
            else:
                running_perceptual_loss.append(0.)

            if step > eval_steps: break

        print('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(sum(running_l1_loss) / len(running_l1_loss),
                                                            sum(running_sync_loss) / len(running_sync_loss),
                                                            sum(running_perceptual_loss) / len(running_perceptual_loss),
                                                            sum(running_disc_fake_loss) / len(running_disc_fake_loss),
                                                             sum(running_disc_real_loss) / len(running_disc_real_loss)))
        return sum(running_sync_loss) / len(running_sync_loss)


def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch, prefix=''):
    checkpoint_path = join(
        checkpoint_dir, "{}checkpoint_step{:09d}.pth".format(prefix, global_step))
    optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
    torch.save({
        "state_dict": model.state_dict(),
        "optimizer": optimizer_state,
        "global_step": step,
        "global_epoch": epoch,
    }, checkpoint_path)
    print("Saved checkpoint:", checkpoint_path)

def _load(checkpoint_path):
    if use_cuda:
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
    return checkpoint


def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
    global global_step
    global global_epoch

    print("Load checkpoint from: {}".format(path))
    checkpoint = _load(path)
    s = checkpoint["state_dict"]

    model.load_state_dict(s)

    if overwrite_global_states:
        global_step = checkpoint["global_step"]
        global_epoch = checkpoint["global_epoch"]

    return model




if __name__ == "__main__":
    checkpoint_dir = args.checkpoint_dir

    # Dataset and Dataloader setup
    train_dataset = Dataset('train')
    test_dataset = Dataset('val')

    train_data_loader = data_utils.DataLoader(
        train_dataset, batch_size=hparams.batch_size, shuffle=True,
        num_workers=hparams.num_workers)

    test_data_loader = data_utils.DataLoader(
        test_dataset, batch_size=hparams.batch_size,
        num_workers=4)

    device = torch.device("cuda" if use_cuda else "cpu")

    # Model
    model = HyperLipsBase()
    if torch.cuda.device_count() > 1:
                print("Let's use", torch.cuda.device_count(), "GPUs!")
                model = nn.DataParallel(model)
    model = model.to(device)

    disc = HyperCtrolDiscriminator()
    if torch.cuda.device_count() > 1:
                print("Let's use", torch.cuda.device_count(), "GPUs!")
                disc = nn.DataParallel(disc)
    disc = disc.to(device)


    print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
    print('total DISC trainable params {}'.format(sum(p.numel() for p in disc.parameters() if p.requires_grad)))

    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=hparams.initial_learning_rate, betas=(0.5, 0.999))
    disc_optimizer = optim.Adam([p for p in disc.parameters() if p.requires_grad],
                           lr=hparams.disc_initial_learning_rate, betas=(0.5, 0.999))

    if args.checkpoint_path is not None:
        load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)

    if args.disc_checkpoint_path is not None:
        load_checkpoint(args.disc_checkpoint_path, disc, disc_optimizer, 
                                reset_optimizer=False, overwrite_global_states=False)
        
    load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True, 
                                overwrite_global_states=False)

    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # Train!
    train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer,
              checkpoint_dir=checkpoint_dir,
              checkpoint_interval=hparams.checkpoint_interval,
              nepochs=hparams.nepochs)

================================================
FILE: Train_hyperlipsHR.py
================================================
from os.path import dirname, join, basename, isfile
from tqdm import tqdm

from models import SyncNet_color as SyncNet
from models.model_hyperlips import HRDecoder,HRDecoder_disc_qual
import audio
import lpips
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torch.backends.cudnn as cudnn
from torch.utils import data as data_utils
import numpy as np
from torchvision.models.vgg import vgg19
from glob import glob
mseloss = nn.MSELoss()
import os, random, cv2, argparse
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
from hparams_HR import hparams, get_image_list

parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model WITH the visual quality discriminator')

parser.add_argument("-hyperlips_trian_dataset", help="Root folder of the preprocessed LRS2 dataset", default='Train_data/HR_Train_Dateset')
parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', default="checkpoints_hyperlips_HR", type=str)
parser.add_argument('--batch_size', type=int, help='Batch size for hyperlips model(s)', default=28)
parser.add_argument('--img_size', type=int, help='imgsize for hyperlips model(s)', default=128)
parser.add_argument('--checkpoint_path', help='Resume generator from this checkpoint', default=None, type=str)
parser.add_argument('--disc_checkpoint_path', help='Resume quality disc from this checkpoint', default=None, type=str)

args = parser.parse_args()


global_step = 0
global_epoch = 0
use_cuda = torch.cuda.is_available()
print('use_cuda: {}'.format(use_cuda))

syncnet_T = 5
syncnet_mel_step_size = 16




class Dataset(object):
    def __init__(self, split):
        gt_img_root = os.path.join(args.hyperlips_trian_dataset,'GT_IMG')
        self.gt_img      =  get_image_list(gt_img_root,split) 




    def get_frame_id(self, frame):
        return int(basename(frame).split('.')[0])

    def get_window(self, start_frame):
        start_id = self.get_frame_id(start_frame)
        vidname = dirname(start_frame)

        window_fnames = []
        for frame_id in range(start_id, start_id + syncnet_T):
            frame = join(vidname, '{}.jpg'.format(frame_id))
            if not isfile(frame):
                return None
            window_fnames.append(frame)
        return window_fnames

    def read_window(self, window_fnames):
        if window_fnames is None: return None
        window = []
        for fname in window_fnames:
            img = cv2.imread(fname)
            if img is None:
                return None
            try:
                img = cv2.resize(img, (args.img_size, args.img_size))
            except Exception as e:
                return None

            window.append(img)

        return window


    def read_window_base(self, window_fnames):
        if window_fnames is None: return None
        window = []
        for fname in window_fnames:
            img = cv2.imread(fname)
            if img is None:
                return None
            try:
                img = cv2.resize(img, (128, 128))
            except Exception as e:
                return None

            window.append(img)

        return window


    def read_window_sketch(self, window_fnames):
        if window_fnames is None: return None
        window = []
        for fname in window_fnames:
            img = cv2.imread(fname)
            if img is None:
                return None
            try:
                if args.img_size == 128:
                    kenerl_size = 5
                elif args.img_size == 256:
                    kenerl_size = 7
                elif args.img_size == 512:
                    kenerl_size = 11
                else:
                    print("Please input rigtht img_size!")
                img = cv2.resize(img, (args.img_size, args.img_size))
                img = cv2.GaussianBlur(img, (kenerl_size, kenerl_size), 0,0,cv2.BORDER_DEFAULT)
                ret, img= cv2.threshold(img, 0, 255, cv2.THRESH_BINARY)
                cv2.imwrite("test_skech.png",img)
            except Exception as e:
                return None

            window.append(img)

        return window

    def read_window_sketch_base(self, window_fnames):
        if window_fnames is None: return None
        window = []
        img_size = 128
        for fname in window_fnames:
            img = cv2.imread(fname)
            if img is None:
                return None
            try:
                if img_size == 128:
                    kenerl_size = 5
                elif img_size == 256:
                    kenerl_size = 7
                elif img_size == 512:
                    kenerl_size = 11
                else:
                    print("Please input rigtht img_size!")
                img = cv2.resize(img, (img_size, img_size))
                img = cv2.GaussianBlur(img, (kenerl_size, kenerl_size), 0,0,cv2.BORDER_DEFAULT)
                ret, img= cv2.threshold(img, 0, 255, cv2.THRESH_BINARY)
            except Exception as e:
                return None

            window.append(img)

        return window
    def read_coord(self,window_fnames):
        if window_fnames is None: return None

        coords =  []
        for fname in window_fnames:
            img = cv2.imread(fname)
            if img is None:
                return None
            try:
                img = cv2.resize(img, (args.img_size, args.img_size))
            except Exception as e:
                return None
            index = np.argwhere(img[:,:,0] == 255)
            x_max =max(index[:,0])
            x_min =min(index[:,0])
            y_max =max(index[:,1])
            y_min =min(index[:,1])
            coords.append([x_min,x_max,y_min,y_max])
        return coords
    def prepare_window(self, window):
        # 3 x T x H x W
        x = np.asarray(window) / 255.
        x = np.transpose(x, (3, 0, 1, 2))

        return x

    def __len__(self):
        return len(self.gt_img)

    def __getitem__(self, idx):
        while 1:
 
            idx = random.randint(0, len(self.gt_img) - 1)
            vidname = os.path.join(self.gt_img[idx].split('/')[-2],self.gt_img[idx].split('/')[-1])
            gt_img_root = os.path.join(args.hyperlips_trian_dataset,'GT_IMG')
            gt_sketch_data_root = os.path.join(args.hyperlips_trian_dataset,'GT_SKETCH')
            gt_mask_root = os.path.join(args.hyperlips_trian_dataset,'GT_MASK')
            hyper_img_root = os.path.join(args.hyperlips_trian_dataset,'HYPER_IMG')
            hyper_sketch_data_root = os.path.join(args.hyperlips_trian_dataset,'HYPER_SKETCH')

            gt_img_names       =   list(glob(join(gt_img_root,vidname, '*.jpg')))
            gt_sketch_names    =   list(glob(join(gt_sketch_data_root,vidname, '*.jpg')))
            gt_mask_names      =   list(glob(join(gt_mask_root,vidname, '*.jpg')))
            hyper_img_names    =   list(glob(join(hyper_img_root,vidname, '*.jpg')))
            hyper_sketch_names =   list(glob(join(hyper_sketch_data_root,vidname, '*.jpg')))
            if not(len(gt_img_names)==len(gt_sketch_names)==len(gt_mask_names)==len(hyper_img_names)==len(hyper_sketch_names)):
                continue
            if len(gt_img_names) <= 3 * syncnet_T:
                continue
            
            img_name = random.choice(gt_img_names).split('/')[-1]
            gt_img_name        = join(gt_img_root,vidname,img_name)
            gt_sketch_name     = join(gt_sketch_data_root,vidname,img_name)
            gt_mask_name       = join(gt_mask_root,vidname,img_name)
            hyper_img_name     = join(hyper_img_root,vidname,img_name)
            hyper_sketch_name  = join(hyper_sketch_data_root,vidname,img_name)


            gt_img_name_window_frames         = self.get_window(gt_img_name)
            gt_sketch_name_window_frames      = self.get_window(gt_sketch_name)
            gt_mask_name_window_frames        = self.get_window(gt_mask_name)
            hyper_img_name_window_frames      = self.get_window(hyper_img_name)
            hyper_sketch_name_window_frames   = self.get_window(hyper_sketch_name)

            coords = self.read_coord(gt_mask_name_window_frames)
            
            if gt_img_name_window_frames is None :
                continue

            gt_img_window           =   self.read_window(gt_img_name_window_frames)
            gt_sketch_window        =   self.read_window_sketch(gt_sketch_name_window_frames)
            gt_mask_window          =   self.read_window(gt_mask_name_window_frames)
            hyper_img_window        =   self.read_window_base(hyper_img_name_window_frames)
            hyper_sketch_window     =   self.read_window_sketch_base(hyper_sketch_name_window_frames)


            gt_img_window          =   self.prepare_window(gt_img_window)
            gt_sketch_window       =   self.prepare_window(gt_sketch_window)
            gt_mask_window         =   self.prepare_window(gt_mask_window)
            hyper_img_window       =   self.prepare_window(hyper_img_window)
            hyper_sketch_window    =   self.prepare_window(hyper_sketch_window)

            gt_img          =   torch.FloatTensor(gt_img_window)
            gt_sketch       =   torch.FloatTensor(gt_sketch_window) 
            gt_mask         =   torch.FloatTensor(gt_mask_window)
            hyper_img       =   torch.FloatTensor(hyper_img_window)
            hyper_sketch    =   torch.FloatTensor(hyper_sketch_window)
            coords          =   torch.FloatTensor(coords)
            return gt_img, gt_sketch, gt_mask,hyper_img,hyper_sketch,coords



def save_sample_images(x, g, gt,m, global_step, checkpoint_dir):
    x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
    g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
    gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
    m = (m.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
    folder = join(checkpoint_dir, "samples_step{:09d}".format(global_step))
    if not os.path.exists(folder): os.mkdir(folder)
    collage = np.concatenate((x, g, gt,m), axis=-2)
    for batch_idx, c in enumerate(collage):
        for t in range(len(c)):
            cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t])

class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()

        vgg = vgg19(pretrained=True)
        loss_network = nn.Sequential(*list(vgg.features)[:35]).eval()
        for param in loss_network.parameters():
            param.requires_grad = False
        self.loss_network = loss_network
        self.l1_loss = nn.L1Loss()

    def forward(self, high_resolution, fake_high_resolution):
        perception_loss = self.l1_loss(self.loss_network(high_resolution), self.loss_network(fake_high_resolution))
        return perception_loss   

logloss = nn.BCELoss()
def cosine_loss(a, v, y):
    d = nn.functional.cosine_similarity(a, v)
    loss = logloss(d.unsqueeze(1), y)

    return loss

device = torch.device("cuda" if use_cuda else "cpu")

loss_fn_vgg = lpips.LPIPS(net='vgg').cuda()
recon_loss = nn.L1Loss()


def train(device, model, disc,train_data_loader, test_data_loader, optimizer,disc_optimizer, checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
    global global_step, global_epoch
    resumed_step = global_step

    adversarial_criterion = nn.BCEWithLogitsLoss().to(device)
    content_criterion = nn.L1Loss().to(device)
    perception_criterion = PerceptualLoss().to(device)

    
    while global_epoch < nepochs:
        print('Starting Epoch: {}'.format(global_epoch))
        running_lip_c_loss, running_l1_loss, disc_loss, running_lip_l_loss = 0., 0., 0., 0.
        running_con_loss, running_mse_loss = 0., 0.
        prog_bar = tqdm(enumerate(train_data_loader))
        for step, (gt_img, gt_sketch, gt_mask,hyper_img,hyper_sketch,coords) in prog_bar:
            disc.train()
            model.train()
            hyper_img = hyper_img.to(device)
            hyper_sketch = hyper_sketch.to(device)
            gt_mask = gt_mask.to(device)
            gt_sketch = gt_sketch.to(device)
            gt_img = gt_img.to(device)
            B = hyper_img.size(0)

            
            input_dim_size = len(hyper_img.size())
            if input_dim_size > 4:
                hyper_img       = torch.cat([hyper_img[:, :, i] for i in range(hyper_img.size(2))], dim=0)#([2, 6, 5, 512, 512])->([10, 6, 512, 512])
                hyper_sketch    = torch.cat([hyper_sketch[:, :, i] for i in range(hyper_sketch.size(2))], dim=0)
                gt_mask         = torch.cat([gt_mask[:, :, i] for i in range(gt_mask.size(2))], dim=0)
                gt_sketch       = torch.cat([gt_sketch[:, :, i] for i in range(gt_sketch.size(2))], dim=0)
                gt_img          = torch.cat([gt_img[:, :, i] for i in range(gt_img.size(2))], dim=0)
                coords_t = torch.cat([( coords)[ :, i] for i in range(coords.size(1))], dim=0)
            real_labels = torch.ones((gt_img.size()[0], 1)).to(device)#[4,1]
            fake_labels = torch.zeros((gt_img.size()[0], 1)).to(device)#[4,1]   
                
            input_temp = torch.cat((hyper_img,hyper_sketch), dim=1)#([2, 5, 1, 80, 16])->([10, 1, 80, 16])
            optimizer.zero_grad()
            g = model(input_temp)

            lip_lpips_loss = 0
            lip_recons_loss_temp = 0
            for i in range(gt_img.shape[0]):
                x_min,x_max,y_min,y_max = int(coords_t[i,0]),int(coords_t[i,1]),int(coords_t[i,2]),int(coords_t[i,3])
                gt_t_i = gt_img[i,:,x_min:x_max,y_min:y_max]
                g_t_i = g[i,:,x_min:x_max,y_min:y_max]
                recons_loss_temp_i = recon_loss(g_t_i, gt_t_i)
                lip_recons_loss_temp = lip_recons_loss_temp+recons_loss_temp_i
                
                lpips_loss_i = loss_fn_vgg(g_t_i, gt_t_i)
                lip_lpips_loss = lip_lpips_loss+lpips_loss_i
            lip_lpips_loss = lip_lpips_loss/gt_img.shape[0]
            lip_recons_loss_temp = lip_recons_loss_temp/gt_img.shape[0]
            
            score_real = disc(gt_img)#[4,1]
            score_fake = disc(g)#[4,1]
            discriminator_rf = score_real - score_fake.mean()
            discriminator_fr = score_fake - score_real.mean()

            adversarial_loss_rf = adversarial_criterion(discriminator_rf, fake_labels)
            adversarial_loss_fr = adversarial_criterion(discriminator_fr, real_labels)
            adversarial_loss = (adversarial_loss_fr + adversarial_loss_rf) / 2

            perceptual_loss = perception_criterion(gt_img, g)
            content_loss = content_criterion(g, gt_img)

            loss = adversarial_loss  + perceptual_loss  + content_loss +lip_lpips_loss+lip_recons_loss_temp

            loss.backward()
            optimizer.step()

            ##########################
             # training discriminator #
            ##########################            


            disc_optimizer.zero_grad()

            score_real = disc(gt_img)
            score_fake = disc(g.detach())
            discriminator_rf = score_real - score_fake.mean()
            discriminator_fr = score_fake - score_real.mean()

            adversarial_loss_rf = adversarial_criterion(discriminator_rf, real_labels)
            adversarial_loss_fr = adversarial_criterion(discriminator_fr, fake_labels)
            discriminator_loss = (adversarial_loss_fr + adversarial_loss_rf) / 2

            discriminator_loss.backward()
            disc_optimizer.step()

            if global_step % checkpoint_interval == 0:
                hyper_img_temp = torch.nn.functional.interpolate(hyper_img,(gt_img.size()[2], gt_img.size()[3]), mode='bilinear', align_corners=False)
                hyper_sketch_temp = torch.nn.functional.interpolate(hyper_sketch,(gt_img.size()[2], gt_img.size()[3]), mode='bilinear', align_corners=False)
                if input_dim_size > 4:#训练时输入为5维,测试时输入为4维(把T与B进行了合并)
                    output = torch.split(g, B, dim=0) 
                    outputs1 = torch.stack(output, dim=2) 
                    
                    hyper_img_temp = torch.split(hyper_img_temp, B, dim=0) 
                    hyper_img_temp = torch.stack(hyper_img_temp, dim=2) 
                    
                    hyper_sketch_temp = torch.split(hyper_sketch_temp, B, dim=0) 
                    hyper_sketch_temp = torch.stack(hyper_sketch_temp, dim=2) 
                    
                    gt_img = torch.split(gt_img, B, dim=0) 
                    gt_img = torch.stack(gt_img, dim=2) 
                else:
                    outputs1 = output

                save_sample_images(hyper_img_temp, hyper_sketch_temp, outputs1,gt_img, global_step, checkpoint_dir)

            # Logs
            global_step += 1
            cur_session_steps = global_step - resumed_step                

            if global_step == 1 or global_step % checkpoint_interval == 0:
                save_checkpoint(model, optimizer, global_step, checkpoint_dir, global_epoch)  
                save_checkpoint(disc, disc_optimizer, global_step, checkpoint_dir, global_epoch, prefix='disc_')             
            
            current_lr = optimizer.state_dict()['param_groups'][0]['lr']
            current_lr_disc = disc_optimizer.state_dict()['param_groups'][0]['lr']
            running_l1_loss+= adversarial_loss.item()
            running_mse_loss+= perceptual_loss.item()
            running_con_loss+= content_loss.item()
            running_lip_c_loss+= lip_recons_loss_temp.item()
            running_lip_l_loss +=lip_lpips_loss.item()#+lip_recons_loss_temp
            
            
            
            disc_loss+= discriminator_loss.item()

            prog_bar.set_description('ad_loss: {}, perc_loss: {},cont_loss: {},lipc_loss: {},lipl_loss: {},disc_loss: {}'.format(running_l1_loss / (step + 1),
                                                                                        running_mse_loss / (step + 1),
                                                                                        running_con_loss / (step + 1),
                                                                                        running_lip_c_loss / (step + 1),
                                                                                        running_lip_l_loss / (step + 1),
                                                                                        
                                                                                        disc_loss / (step + 1),
                                                                                        # running_disc_fake_loss / (step + 1),
                                                                                        # running_disc_real_loss / (step + 1)
                                                                                        ))
        global_epoch += 1

def eval_model(test_data_loader, global_step, device, model):
    eval_steps = 300
    print('Evaluating for {} steps'.format(eval_steps))
    running_sync_loss, running_l1_loss, running_disc_real_loss, running_disc_fake_loss, running_perceptual_loss = [], [], [], [], []
    while 1:
        for step, (x, indiv_mels, mel, gt,m,coords) in enumerate((test_data_loader)):
        # for step, (x, indiv_mels, mel, gt,m,coords) in prog_bar:
            model.eval()
            # disc.eval()

            x = x.to(device)
            mel = mel.to(device)
            indiv_mels = indiv_mels.to(device)
            gt = gt.to(device)



            g = model(indiv_mels, x)



            l1loss = recon_loss(g, gt)



            running_l1_loss.append(l1loss.item())
            running_sync_loss.append(sync_loss.item())
            


            if step > eval_steps: break

        print('L1: {}, Sync: {}'.format(sum(running_l1_loss) / len(running_l1_loss),
                                                            sum(running_sync_loss) / len(running_sync_loss),
                                                            # sum(running_perceptual_loss) / len(running_perceptual_loss),
                                                            # sum(running_disc_fake_loss) / len(running_disc_fake_loss),
                                                            #  sum(running_disc_real_loss) / len(running_disc_real_loss)
                                                             ))
        return sum(running_sync_loss) / len(running_sync_loss)


def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch, prefix=''):
    checkpoint_path = join(
        checkpoint_dir, "{}checkpoint_step{:09d}.pth".format(prefix, global_step))
    optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
    torch.save({
        "state_dict": model.state_dict(),
        "optimizer": optimizer_state,
        "global_step": step,
        "global_epoch": epoch,
    }, checkpoint_path)
    print("Saved checkpoint:", checkpoint_path)

def _load(checkpoint_path):
    if use_cuda:
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
    return checkpoint


def load_checkpoint(path, model, reset_optimizer=False, overwrite_global_states=True):
    global global_step
    global global_epoch

    print("Load checkpoint from: {}".format(path))
    checkpoint = _load(path)
    s = checkpoint["state_dict"]
    new_s = {}
    for k, v in s.items():
        new_s[k.replace('module.', '')] = v

    model.load_state_dict(new_s,strict=False)
    if overwrite_global_states:
        global_step = checkpoint["global_step"]
        global_epoch = checkpoint["global_epoch"]

    return model


if __name__ == "__main__":
    checkpoint_dir = args.checkpoint_dir

    # Dataset and Dataloader setup
    train_dataset = Dataset('train')
    test_dataset = Dataset('val')

    train_data_loader = data_utils.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=hparams.num_workers)

    test_data_loader = data_utils.DataLoader(
        test_dataset, batch_size=args.batch_size,
        num_workers=4)

    device = torch.device("cuda" if use_cuda else "cpu")

    
    if args.img_size==512:
        rescaling = 4
    elif args.img_size==256:
        
        rescaling = 2
    else:
        rescaling = 1
    model = HRDecoder(rescaling)
    if torch.cuda.device_count() > 1:
                print("Let's use", torch.cuda.device_count(), "GPUs!")
                model = nn.DataParallel(model)
    model = model.to(device)

    disc = HRDecoder_disc_qual()
    if torch.cuda.device_count() > 1:
                print("Let's use", torch.cuda.device_count(), "GPUs!")
                disc = nn.DataParallel(disc)
    disc = disc.to(device)




    print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
    print('total DISC trainable params {}'.format(sum(p.numel() for p in disc.parameters() if p.requires_grad)))

    disc_optimizer = optim.Adam([p for p in disc.parameters() if p.requires_grad],
                           lr=hparams.disc_initial_learning_rate, betas=(0.5, 0.999))

    if args.checkpoint_path is not None:
        load_checkpoint(args.checkpoint_path, model, reset_optimizer=False)

    if args.disc_checkpoint_path is not None:
        load_checkpoint(args.disc_checkpoint_path, disc, reset_optimizer=False, overwrite_global_states=False)
        


    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=hparams.initial_learning_rate, betas=(0.5, 0.999))

    disc_optimizer = optim.Adam([p for p in disc.parameters() if p.requires_grad],
                           lr=hparams.disc_initial_learning_rate, betas=(0.5, 0.999))


    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # Train!
    train(device, model,disc, train_data_loader, test_data_loader, optimizer,disc_optimizer, checkpoint_dir=checkpoint_dir,
              checkpoint_interval=hparams.checkpoint_interval,
              nepochs=hparams.nepochs)

================================================
FILE: audio.py
================================================
import librosa
import librosa.filters
import numpy as np
# import tensorflow as tf
from scipy import signal
from scipy.io import wavfile
from hparams import hparams as hp

def load_wav(path, sr):
    return librosa.core.load(path, sr=sr)[0]

def save_wav(wav, path, sr):
    wav *= 32767 / max(0.01, np.max(np.abs(wav)))
    #proposed by @dsmiller
    wavfile.write(path, sr, wav.astype(np.int16))

def save_wavenet_wav(wav, path, sr):
    librosa.output.write_wav(path, wav, sr=sr)

def preemphasis(wav, k, preemphasize=True):
    if preemphasize:
        return signal.lfilter([1, -k], [1], wav)
    return wav

def inv_preemphasis(wav, k, inv_preemphasize=True):
    if inv_preemphasize:
        return signal.lfilter([1], [1, -k], wav)
    return wav

def get_hop_size():
    hop_size = hp.hop_size
    if hop_size is None:
        assert hp.frame_shift_ms is not None
        hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
    return hop_size

def linearspectrogram(wav):
    D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
    S = _amp_to_db(np.abs(D)) - hp.ref_level_db
    
    if hp.signal_normalization:
        return _normalize(S)
    return S

def melspectrogram(wav):
    D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
    S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
    
    if hp.signal_normalization:
        return _normalize(S)
    return S

def _lws_processor():
    import lws
    return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")

def _stft(y):
    if hp.use_lws:
        return _lws_processor(hp).stft(y).T
    else:
        # return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
        return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)

##########################################################
#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
def num_frames(length, fsize, fshift):
    """Compute number of time frames of spectrogram
    """
    pad = (fsize - fshift)
    if length % fshift == 0:
        M = (length + pad * 2 - fsize) // fshift + 1
    else:
        M = (length + pad * 2 - fsize) // fshift + 2
    return M


def pad_lr(x, fsize, fshift):
    """Compute left and right padding
    """
    M = num_frames(len(x), fsize, fshift)
    pad = (fsize - fshift)
    T = len(x) + 2 * pad
    r = (M - 1) * fshift + fsize - T
    return pad, pad + r
##########################################################
#Librosa correct padding
def librosa_pad_lr(x, fsize, fshift):
    return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]

# Conversions
_mel_basis = None

def _linear_to_mel(spectogram):
    global _mel_basis
    if _mel_basis is None:
        _mel_basis = _build_mel_basis()
    return np.dot(_mel_basis, spectogram)

def _build_mel_basis():
    assert hp.fmax <= hp.sample_rate // 2
    # return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels,
    #                            fmin=hp.fmin, fmax=hp.fmax)
    return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
                               fmin=hp.fmin, fmax=hp.fmax)
def _amp_to_db(x):
    min_level = np.exp(hp.min_level_db / 20 * np.log(10))
    return 20 * np.log10(np.maximum(min_level, x))

def _db_to_amp(x):
    return np.power(10.0, (x) * 0.05)

def _normalize(S):
    if hp.allow_clipping_in_normalization:
        if hp.symmetric_mels:
            return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
                           -hp.max_abs_value, hp.max_abs_value)
        else:
            return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
    
    assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
    if hp.symmetric_mels:
        return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
    else:
        return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))

def _denormalize(D):
    if hp.allow_clipping_in_normalization:
        if hp.symmetric_mels:
            return (((np.clip(D, -hp.max_abs_value,
                              hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
                    + hp.min_level_db)
        else:
            return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
    
    if hp.symmetric_mels:
        return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
    else:
        return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)


================================================
FILE: checkpoint
================================================



================================================
FILE: checkpoints/readme.txt
================================================
Put checkpoint here.


================================================
FILE: color_syncnet_trainv3.py
================================================
from os.path import dirname, join, basename, isfile
from tqdm import tqdm
from models import SyncNet_color as SyncNet
import audio
import torch
from torch import nn
from torch import optim
import torch.backends.cudnn as cudnn
from torch.utils import data as data_utils
import numpy as np
from glob import glob
import os, random, cv2, argparse
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
from hparams import hparams, get_image_list

parser = argparse.ArgumentParser(description='Code to train the expert lip-sync discriminator')


parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", default='Train_data/imgs')
parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', default="./checkpoints_lipsync_expert", type=str)
parser.add_argument('--checkpoint_path', help='Resumed from this checkpoint', default=None, type=str)


args = parser.parse_args()


global_step = 0
global_epoch = 0
use_cuda = torch.cuda.is_available()

ema_decay = 0.5 ** (32 / (10 * 1000))

syncnet_T = 5   
syncnet_mel_step_size = 16    

class Dataset(object):
    def __init__(self, split):
        self.all_videos = get_image_list(args.data_root, split)
        self.av_offset_shift = 0
    def get_frame_id(self, frame):
        return int(basename(frame).split('.')[0])

    def get_window(self, start_frame):
        start_id = self.get_frame_id(start_frame)
        vidname = dirname(start_frame)

        window_fnames = []
        for frame_id in range(start_id, start_id + syncnet_T):
            frame = join(vidname, '{}.jpg'.format(frame_id))
            if not isfile(frame):
                return None
            window_fnames.append(frame)
        return window_fnames



    def crop_audio_window(self, spec, start_frame):

        start_frame_num = self.get_frame_id(start_frame)

        start_frame_num = start_frame_num + self.av_offset_shift

        start_idx = int(80. * (start_frame_num / float(hparams.fps)))

        end_idx = start_idx + syncnet_mel_step_size

        return spec[start_idx: end_idx, :]



    def read_window(self, window_fnames, flip_flag=False):
        if window_fnames is None: return None
        window = []
        for fname in window_fnames:
            img = cv2.imread(fname)
            if img is None:
                return None
            try:
                img = cv2.resize(img, (hparams.img_size, hparams.img_size))
            except Exception as e:
                return None

            if flip_flag:
                img = np.flip(img, axis=1).copy()
            window.append(img)

        return window


    def __len__(self):
        return len(self.all_videos)

    def __getitem__(self, idx):

        while 1:
            idx = random.randint(0, len(self.all_videos) - 1)
            vidname = self.all_videos[idx]

            img_names = list(glob(join(vidname, '*.jpg')))
            if len(img_names) <= 3 * syncnet_T:
                continue
            img_name = random.choice(img_names)
            wrong_img_name = random.choice(img_names)
            while wrong_img_name == img_name:
                wrong_img_name = random.choice(img_names)
            if random.choice([True, False]):
                y = torch.ones(1).float()
                chosen = img_name
            else:
                y = torch.zeros(1).float()
                chosen = wrong_img_name

            window_fnames = self.get_window(chosen)
            if window_fnames is None:
                continue
            
            window = self.read_window(window_fnames, flip_flag=True)

            try:
                wavpath = join(vidname, "audio.wav")
                wav = audio.load_wav(wavpath, hparams.sample_rate)

                orig_mel = audio.melspectrogram(wav).T
            except Exception as e:
                continue

            mel = self.crop_audio_window(orig_mel.copy(), img_name)

            if (mel.shape[0] != syncnet_mel_step_size):
                continue

            # H x W x 3 * T
            x = np.concatenate(window, axis=2) / 255.
            x = x.transpose(2, 0, 1)
            x = x[:, x.shape[1]//2:]
            x = torch.FloatTensor(x)
            mel = torch.FloatTensor(mel.T).unsqueeze(0)

            return x, mel, y

logloss = nn.BCELoss()

def cosine_loss(a, v, y):
    d = nn.functional.cosine_similarity(a, v)
    loss = logloss(d.unsqueeze(1), y)

    return loss

def train(device, model, train_data_loader, test_data_loader, optimizer,
          checkpoint_dir=None, checkpoint_interval=None, nepochs=None):

    global global_step, global_epoch
    resumed_step = global_step
    
    while global_epoch < nepochs:
        running_loss = 0.
        prog_bar = tqdm(enumerate(train_data_loader))
        for step, (x, mel, y) in prog_bar:
            model.train()
            optimizer.zero_grad()

            # Transform data to CUDA device
            x = x.to(device)

            mel = mel.to(device)

            a, v = model(mel, x)
            y = y.to(device)

            loss = cosine_loss(a, v, y)
            loss.backward()
            optimizer.step()
            
            # model_ema.update_parameters(model)
            
            global_step += 1
            cur_session_steps = global_step - resumed_step
            running_loss += loss.item()

            if global_step == 1 or global_step % checkpoint_interval == 0:
                save_checkpoint(
                    model, optimizer, global_step, checkpoint_dir, global_epoch)


            with torch.no_grad():
                eval_model(test_data_loader, global_step, device, model, checkpoint_dir)




            if global_step % hparams.syncnet_eval_interval == 0:
                with torch.no_grad():
                    pass
                    eval_model(test_data_loader, global_step, device, model, checkpoint_dir)
            lr_temp = optimizer.state_dict()['param_groups'][0]['lr']
            # print(lr_temp)
            prog_bar.set_description('Loss: {}'.format(running_loss / (step + 1))+'  '+'lr: {}'.format(lr_temp))

        global_epoch += 1
        print(global_epoch)

def eval_model(test_data_loader, global_step, device, model, checkpoint_dir):
    eval_steps = 1400
    print('Evaluating for {} steps'.format(eval_steps))
    losses = []
    while 1:
        for step, (x, mel, y) in enumerate(test_data_loader):

            

            # Transform data to CUDA device
            x = x.to(device)

            mel = mel.to(device)

            model.eval()
            a, v = model(mel, x)
            
            # model_ema.eval()
            # a, v = model_ema(mel, x)            
            y = y.to(device)

            loss = cosine_loss(a, v, y)
            losses.append(loss.item())

            if step > eval_steps: break

        averaged_loss = sum(losses) / len(losses)
        print(averaged_loss)

        return

def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch):

    checkpoint_path = join(
        checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step))
    optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
    torch.save({
        "state_dict": model.state_dict(),
        "optimizer": optimizer_state,
        "global_step": step,
        "global_epoch": epoch,
    }, checkpoint_path)
    print("Saved checkpoint:", checkpoint_path)

def _load(checkpoint_path):
    if use_cuda:
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
    return checkpoint

def load_checkpoint(path, model, optimizer, reset_optimizer=False):
    global global_step
    global global_epoch

    print("Load checkpoint from: {}".format(path))
    checkpoint = _load(path)
    model.load_state_dict(checkpoint["state_dict"])
    if not reset_optimizer:
        optimizer_state = checkpoint["optimizer"]
        if optimizer_state is not None:
            print("Load optimizer state from {}".format(path))
            optimizer.load_state_dict(checkpoint["optimizer"])
    global_step = checkpoint["global_step"]
    global_epoch = checkpoint["global_epoch"]


if __name__ == "__main__":
    checkpoint_dir = args.checkpoint_dir
    checkpoint_path = args.checkpoint_path

    if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir)

    # Dataset and Dataloader setup
    train_dataset = Dataset('train')
    test_dataset = Dataset('val')

    train_data_loader = data_utils.DataLoader(
        train_dataset, batch_size=hparams.syncnet_batch_size, shuffle=True,
        num_workers=hparams.num_workers)

    test_data_loader = data_utils.DataLoader(
        # test_dataset, batch_size=hparams.syncnet_batch_size,
        test_dataset, batch_size=1,
        num_workers=1)

    device = torch.device("cuda" if use_cuda else "cpu")

    # Model
    model = SyncNet().to(device)
    print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))

    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=hparams.syncnet_lr)

    if checkpoint_path is not None:
        load_checkpoint(checkpoint_path, model, optimizer, reset_optimizer=False)

    train(device, model, train_data_loader, test_data_loader, optimizer,
          checkpoint_dir=checkpoint_dir,
          checkpoint_interval=hparams.syncnet_checkpoint_interval,
          nepochs=hparams.nepochs)


================================================
FILE: conv.py
================================================
import torch
from torch import nn
from torch.nn import functional as F

class Conv2d(nn.Module):
    def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv_block = nn.Sequential(
                            nn.Conv2d(cin, cout, kernel_size, stride, padding),
                            nn.BatchNorm2d(cout)
                            )
        self.act = nn.ReLU()
        self.residual = residual

    def forward(self, x):
        out = self.conv_block(x)
        if self.residual:
            out += x
        return self.act(out)

class nonorm_Conv2d(nn.Module):
    def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv_block = nn.Sequential(
                            nn.Conv2d(cin, cout, kernel_size, stride, padding),
                            )
        self.act = nn.LeakyReLU(0.01, inplace=True)

    def forward(self, x):
        out = self.conv_block(x)
        return self.act(out)

class Conv2dTranspose(nn.Module):
    def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv_block = nn.Sequential(
                            nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
                            nn.BatchNorm2d(cout)
                            )
        self.act = nn.ReLU()

    def forward(self, x):
        out = self.conv_block(x)
        return self.act(out)


================================================
FILE: datasets/MEAD/readme.txt
================================================
Put all the traning Mead .mp4 file here.

================================================
FILE: environment.yml
================================================
name: hyperlips
channels:
  - http://mirrors.ustc.edu.cn/anaconda/pkgs/free/
  - http://mirrors.ustc.edu.cn/anaconda/cloud/msys2/
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - ca-certificates=2023.05.30=h06a4308_0
  - ld_impl_linux-64=2.38=h1181459_1
  - libffi=3.4.4=h6a678d5_0
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libstdcxx-ng=11.2.0=h1234567_1
  - ncurses=6.4=h6a678d5_0
  - openssl=3.0.10=h7f8727e_1
  - pip=23.2.1=py38h06a4308_0
  - python=3.8.16=h955ad1f_4
  - readline=8.2=h5eee18b_0
  - setuptools=68.0.0=py38h06a4308_0
  - sqlite=3.41.2=h5eee18b_0
  - tk=8.6.12=h1ccaba5_0
  - wheel=0.38.4=py38h06a4308_0
  - xz=5.4.2=h5eee18b_0
  - zlib=1.2.13=h5eee18b_0
  - pip:
    - absl-py==1.4.0
    - addict==2.4.0
    - attrs==23.1.0
    - audioread==3.0.0
    - basicsr==1.4.2
    - cachetools==5.3.1
    - certifi==2023.7.22
    - cffi==1.15.1
    - charset-normalizer==3.2.0
    - contourpy==1.1.0
    - cycler==0.11.0
    - decorator==5.1.1
    - facexlib==0.2.5
    - filterpy==1.4.5
    - flatbuffers==23.5.26
    - fonttools==4.42.1
    - future==0.18.3
    - google-auth==2.22.0
    - google-auth-oauthlib==1.0.0
    - grpcio==1.57.0
    - idna==3.4
    - imageio==2.31.1
    - importlib-metadata==6.8.0
    - importlib-resources==6.0.1
    - joblib==1.3.2
    - kiwisolver==1.4.4
    - lazy-loader==0.3
    - librosa==0.9.2
    - llvmlite==0.39.1
    - lmdb==1.4.1
    - markdown==3.4.4
    - markupsafe==2.1.3
    - matplotlib==3.7.2
    - mediapipe==0.10.1
    - networkx==3.1
    - numba==0.56.4
    - numpy==1.21.5
    - oauthlib==3.2.2
    - opencv-contrib-python==4.7.0.72
    - opencv-python==4.7.0.72
    - packaging==23.1
    - pillow==10.0.0
    - platformdirs==3.10.0
    - pooch==1.7.0
    - protobuf==3.20.3
    - pyasn1==0.5.0
    - pyasn1-modules==0.3.0
    - pycparser==2.21
    - pyparsing==3.0.9
    - python-dateutil==2.8.2
    - pywavelets==1.4.1
    - pyyaml==6.0.1
    - requests==2.31.0
    - requests-oauthlib==1.3.1
    - resampy==0.4.2
    - rsa==4.9
    - scikit-image==0.21.0
    - scikit-learn==1.3.0
    - scipy==1.10.1
    - six==1.16.0
    - sounddevice==0.4.6
    - soundfile==0.12.1
    - tb-nightly==2.14.0a20230808
    - tensorboard-data-server==0.7.1
    - threadpoolctl==3.2.0
    - tifffile==2023.7.10
    - tomli==2.0.1
    - torch==1.10.1+cu113
    - torchvision==0.11.2+cu113
    - tqdm==4.65.0
    - typing-extensions==4.7.1
    - urllib3==1.26.16
    - werkzeug==2.3.7
    - yapf==0.40.1
    - zipp==3.16.2
prefix: /root/anaconda3/envs/hyperlips


================================================
FILE: face_detection/README.md
================================================
The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time. 

================================================
FILE: face_detection/__init__.py
================================================
# -*- coding: utf-8 -*-

__author__ = """Adrian Bulat"""
__email__ = 'adrian.bulat@nottingham.ac.uk'
__version__ = '1.0.1'

from .api import FaceAlignment, LandmarksType, NetworkSize


================================================
FILE: face_detection/api.py
================================================
from __future__ import print_function
import os
import torch
from torch.utils.model_zoo import load_url
from enum import Enum
import numpy as np
import cv2
try:
    import urllib.request as request_file
except BaseException:
    import urllib as request_file

from .models import FAN, ResNetDepth
from .utils import *


class LandmarksType(Enum):
    """Enum class defining the type of landmarks to detect.

    ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
    ``_2halfD`` - this points represent the projection of the 3D points into 3D
    ``_3D`` - detect the points ``(x,y,z)``` in a 3D space

    """
    _2D = 1
    _2halfD = 2
    _3D = 3


class NetworkSize(Enum):
    # TINY = 1
    # SMALL = 2
    # MEDIUM = 3
    LARGE = 4

    def __new__(cls, value):
        member = object.__new__(cls)
        member._value_ = value
        return member

    def __int__(self):
        return self.value

ROOT = os.path.dirname(os.path.abspath(__file__))

class FaceAlignment:
    def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
                 device='cuda', flip_input=False, face_detector='sfd', verbose=False):
        self.device = device
        self.flip_input = flip_input
        self.landmarks_type = landmarks_type
        self.verbose = verbose

        network_size = int(network_size)

        if 'cuda' in device:
            torch.backends.cudnn.benchmark = True

        # Get the face detector
        face_detector_module = __import__('face_detection.detection.' + face_detector,
                                          globals(), locals(), [face_detector], 0)
        self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)

    def get_detections_for_batch(self, images):
        images = images[..., ::-1]
        detected_faces = self.face_detector.detect_from_batch(images.copy())
        results = []

        for i, d in enumerate(detected_faces):
            if len(d) == 0:
                results.append(None)
                continue
            d = d[0]
            d = np.clip(d, 0, None)
            
            x1, y1, x2, y2 = map(int, d[:-1])
            results.append((x1, y1, x2, y2))

        return results

================================================
FILE: face_detection/detection/__init__.py
================================================
from .core import FaceDetector

================================================
FILE: face_detection/detection/core.py
================================================
import logging
import glob
from tqdm import tqdm
import numpy as np
import torch
import cv2


class FaceDetector(object):
    """An abstract class representing a face detector.

    Any other face detection implementation must subclass it. All subclasses
    must implement ``detect_from_image``, that return a list of detected
    bounding boxes. Optionally, for speed considerations detect from path is
    recommended.
    """

    def __init__(self, device, verbose):
        self.device = device
        self.verbose = verbose

        if verbose:
            if 'cpu' in device:
                logger = logging.getLogger(__name__)
                logger.warning("Detection running on CPU, this may be potentially slow.")

        if 'cpu' not in device and 'cuda' not in device:
            if verbose:
                logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
            raise ValueError

    def detect_from_image(self, tensor_or_path):
        """Detects faces in a given image.

        This function detects the faces present in a provided BGR(usually)
        image. The input can be either the image itself or the path to it.

        Arguments:
            tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
            to an image or the image itself.

        Example::

            >>> path_to_image = 'data/image_01.jpg'
            ...   detected_faces = detect_from_image(path_to_image)
            [A list of bounding boxes (x1, y1, x2, y2)]
            >>> image = cv2.imread(path_to_image)
            ...   detected_faces = detect_from_image(image)
            [A list of bounding boxes (x1, y1, x2, y2)]

        """
        raise NotImplementedError

    def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
        """Detects faces from all the images present in a given directory.

        Arguments:
            path {string} -- a string containing a path that points to the folder containing the images

        Keyword Arguments:
            extensions {list} -- list of string containing the extensions to be
            consider in the following format: ``.extension_name`` (default:
            {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
            folder recursively (default: {False}) show_progress_bar {bool} --
            display a progressbar (default: {True})

        Example:
        >>> directory = 'data'
        ...   detected_faces = detect_from_directory(directory)
        {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}

        """
        if self.verbose:
            logger = logging.getLogger(__name__)

        if len(extensions) == 0:
            if self.verbose:
                logger.error("Expected at list one extension, but none was received.")
            raise ValueError

        if self.verbose:
            logger.info("Constructing the list of images.")
        additional_pattern = '/**/*' if recursive else '/*'
        files = []
        for extension in extensions:
            files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))

        if self.verbose:
            logger.info("Finished searching for images. %s images found", len(files))
            logger.info("Preparing to run the detection.")

        predictions = {}
        for image_path in tqdm(files, disable=not show_progress_bar):
            if self.verbose:
                logger.info("Running the face detector on image: %s", image_path)
            predictions[image_path] = self.detect_from_image(image_path)

        if self.verbose:
            logger.info("The detector was successfully run on all %s images", len(files))

        return predictions

    @property
    def reference_scale(self):
        raise NotImplementedError

    @property
    def reference_x_shift(self):
        raise NotImplementedError

    @property
    def reference_y_shift(self):
        raise NotImplementedError

    @staticmethod
    def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
        """Convert path (represented as a string) or torch.tensor to a numpy.ndarray

        Arguments:
            tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
        """
        if isinstance(tensor_or_path, str):
            return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
        elif torch.is_tensor(tensor_or_path):
            # Call cpu in case its coming from cuda
            return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
        elif isinstance(tensor_or_path, np.ndarray):
            return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
        else:
            raise TypeError


================================================
FILE: face_detection/detection/sfd/__init__.py
================================================
from .sfd_detector import SFDDetector as FaceDetector

================================================
FILE: face_detection/detection/sfd/bbox.py
================================================
from __future__ import print_function
import os
import sys
import cv2
import random
import datetime
import time
import math
import argparse
import numpy as np
import torch

try:
    from iou import IOU
except BaseException:
    # IOU cython speedup 10x
    def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
        sa = abs((ax2 - ax1) * (ay2 - ay1))
        sb = abs((bx2 - bx1) * (by2 - by1))
        x1, y1 = max(ax1, bx1), max(ay1, by1)
        x2, y2 = min(ax2, bx2), min(ay2, by2)
        w = x2 - x1
        h = y2 - y1
        if w < 0 or h < 0:
            return 0.0
        else:
            return 1.0 * w * h / (sa + sb - w * h)


def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
    xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
    dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
    dw, dh = math.log(ww / aww), math.log(hh / ahh)
    return dx, dy, dw, dh


def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
    xc, yc = dx * aww + axc, dy * ahh + ayc
    ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
    x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
    return x1, y1, x2, y2


def nms(dets, thresh):
    if 0 == len(dets):
        return []
    x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    order = scores.argsort()[::-1]

    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)
        xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
        xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])

        w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
        ovr = w * h / (areas[i] + areas[order[1:]] - w * h)

        inds = np.where(ovr <= thresh)[0]
        order = order[inds + 1]

    return keep


def encode(matched, priors, variances):
    """Encode the variances from the priorbox layers into the ground truth boxes
    we have matched (based on jaccard overlap) with the prior boxes.
    Args:
        matched: (tensor) Coords of ground truth for each prior in point-form
            Shape: [num_priors, 4].
        priors: (tensor) Prior boxes in center-offset form
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        encoded boxes (tensor), Shape: [num_priors, 4]
    """

    # dist b/t match center and prior's center
    g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
    # encode variance
    g_cxcy /= (variances[0] * priors[:, 2:])
    # match wh / prior wh
    g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
    g_wh = torch.log(g_wh) / variances[1]
    # return target for smooth_l1_loss
    return torch.cat([g_cxcy, g_wh], 1)  # [num_priors,4]


def decode(loc, priors, variances):
    """Decode locations from predictions using priors to undo
    the encoding we did for offset regression at train time.
    Args:
        loc (tensor): location predictions for loc layers,
            Shape: [num_priors,4]
        priors (tensor): Prior boxes in center-offset form.
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        decoded bounding box predictions
    """

    boxes = torch.cat((
        priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
        priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
    boxes[:, :2] -= boxes[:, 2:] / 2
    boxes[:, 2:] += boxes[:, :2]
    return boxes

def batch_decode(loc, priors, variances):
    """Decode locations from predictions using priors to undo
    the encoding we did for offset regression at train time.
    Args:
        loc (tensor): location predictions for loc layers,
            Shape: [num_priors,4]
        priors (tensor): Prior boxes in center-offset form.
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        decoded bounding box predictions
    """

    boxes = torch.cat((
        priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
        priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
    boxes[:, :, :2] -= boxes[:, :, 2:] / 2
    boxes[:, :, 2:] += boxes[:, :, :2]
    return boxes


================================================
FILE: face_detection/detection/sfd/detect.py
================================================
import torch
import torch.nn.functional as F

import os
import sys
import cv2
import random
import datetime
import math
import argparse
import numpy as np

import scipy.io as sio
import zipfile
from .net_s3fd import s3fd
from .bbox import *


def detect(net, img, device):
    img = img - np.array([104, 117, 123])
    img = img.transpose(2, 0, 1)
    img = img.reshape((1,) + img.shape)

    if 'cuda' in device:
        torch.backends.cudnn.benchmark = True

    img = torch.from_numpy(img).float().to(device)
    BB, CC, HH, WW = img.size()
    with torch.no_grad():
        olist = net(img)

    bboxlist = []
    for i in range(len(olist) // 2):
        olist[i * 2] = F.softmax(olist[i * 2], dim=1)
    olist = [oelem.data.cpu() for oelem in olist]
    for i in range(len(olist) // 2):
        ocls, oreg = olist[i * 2], olist[i * 2 + 1]
        FB, FC, FH, FW = ocls.size()  # feature map size
        stride = 2**(i + 2)    # 4,8,16,32,64,128
        anchor = stride * 4
        poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
        for Iindex, hindex, windex in poss:
            axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
            score = ocls[0, 1, hindex, windex]
            loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
            priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
            variances = [0.1, 0.2]
            box = decode(loc, priors, variances)
            x1, y1, x2, y2 = box[0] * 1.0
            # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
            bboxlist.append([x1, y1, x2, y2, score])
    bboxlist = np.array(bboxlist)
    if 0 == len(bboxlist):
        bboxlist = np.zeros((1, 5))

    return bboxlist

def batch_detect(net, imgs, device):
    imgs = imgs - np.array([104, 117, 123])
    imgs = imgs.transpose(0, 3, 1, 2)

    if 'cuda' in device:
        torch.backends.cudnn.benchmark = True

    imgs = torch.from_numpy(imgs).float().to(device)
    BB, CC, HH, WW = imgs.size()
    with torch.no_grad():
        olist = net(imgs)

    bboxlist = []
    for i in range(len(olist) // 2):
        olist[i * 2] = F.softmax(olist[i * 2], dim=1)
    olist = [oelem.data.cpu() for oelem in olist]
    for i in range(len(olist) // 2):
        ocls, oreg = olist[i * 2], olist[i * 2 + 1]
        FB, FC, FH, FW = ocls.size()  # feature map size
        stride = 2**(i + 2)    # 4,8,16,32,64,128
        anchor = stride * 4
        poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
        for Iindex, hindex, windex in poss:
            axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
            score = ocls[:, 1, hindex, windex]
            loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
            priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
            variances = [0.1, 0.2]
            box = batch_decode(loc, priors, variances)
            box = box[:, 0] * 1.0
            # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
            bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
    bboxlist = np.array(bboxlist)
    if 0 == len(bboxlist):
        bboxlist = np.zeros((1, BB, 5))

    return bboxlist

def flip_detect(net, img, device):
    img = cv2.flip(img, 1)
    b = detect(net, img, device)

    bboxlist = np.zeros(b.shape)
    bboxlist[:, 0] = img.shape[1] - b[:, 2]
    bboxlist[:, 1] = b[:, 1]
    bboxlist[:, 2] = img.shape[1] - b[:, 0]
    bboxlist[:, 3] = b[:, 3]
    bboxlist[:, 4] = b[:, 4]
    return bboxlist


def pts_to_bb(pts):
    min_x, min_y = np.min(pts, axis=0)
    max_x, max_y = np.max(pts, axis=0)
    return np.array([min_x, min_y, max_x, max_y])


================================================
FILE: face_detection/detection/sfd/net_s3fd.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F


class L2Norm(nn.Module):
    def __init__(self, n_channels, scale=1.0):
        super(L2Norm, self).__init__()
        self.n_channels = n_channels
        self.scale = scale
        self.eps = 1e-10
        self.weight = nn.Parameter(torch.Tensor(self.n_channels))
        self.weight.data *= 0.0
        self.weight.data += self.scale

    def forward(self, x):
        norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
        x = x / norm * self.weight.view(1, -1, 1, 1)
        return x


class s3fd(nn.Module):
    def __init__(self):
        super(s3fd, self).__init__()
        self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)

        self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)

        self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)

        self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)

        self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)

        self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
        self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)

        self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
        self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)

        self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
        self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)

        self.conv3_3_norm = L2Norm(256, scale=10)
        self.conv4_3_norm = L2Norm(512, scale=8)
        self.conv5_3_norm = L2Norm(512, scale=5)

        self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
        self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
        self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
        self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
        self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
        self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)

        self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
        self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
        self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
        self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
        self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
        self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        h = F.relu(self.conv1_1(x))
        h = F.relu(self.conv1_2(h))
        h = F.max_pool2d(h, 2, 2)

        h = F.relu(self.conv2_1(h))
        h = F.relu(self.conv2_2(h))
        h = F.max_pool2d(h, 2, 2)

        h = F.relu(self.conv3_1(h))
        h = F.relu(self.conv3_2(h))
        h = F.relu(self.conv3_3(h))
        f3_3 = h
        h = F.max_pool2d(h, 2, 2)

        h = F.relu(self.conv4_1(h))
        h = F.relu(self.conv4_2(h))
        h = F.relu(self.conv4_3(h))
        f4_3 = h
        h = F.max_pool2d(h, 2, 2)

        h = F.relu(self.conv5_1(h))
        h = F.relu(self.conv5_2(h))
        h = F.relu(self.conv5_3(h))
        f5_3 = h
        h = F.max_pool2d(h, 2, 2)

        h = F.relu(self.fc6(h))
        h = F.relu(self.fc7(h))
        ffc7 = h
        h = F.relu(self.conv6_1(h))
        h = F.relu(self.conv6_2(h))
        f6_2 = h
        h = F.relu(self.conv7_1(h))
        h = F.relu(self.conv7_2(h))
        f7_2 = h

        f3_3 = self.conv3_3_norm(f3_3)
        f4_3 = self.conv4_3_norm(f4_3)
        f5_3 = self.conv5_3_norm(f5_3)

        cls1 = self.conv3_3_norm_mbox_conf(f3_3)
        reg1 = self.conv3_3_norm_mbox_loc(f3_3)
        cls2 = self.conv4_3_norm_mbox_conf(f4_3)
        reg2 = self.conv4_3_norm_mbox_loc(f4_3)
        cls3 = self.conv5_3_norm_mbox_conf(f5_3)
        reg3 = self.conv5_3_norm_mbox_loc(f5_3)
        cls4 = self.fc7_mbox_conf(ffc7)
        reg4 = self.fc7_mbox_loc(ffc7)
        cls5 = self.conv6_2_mbox_conf(f6_2)
        reg5 = self.conv6_2_mbox_loc(f6_2)
        cls6 = self.conv7_2_mbox_conf(f7_2)
        reg6 = self.conv7_2_mbox_loc(f7_2)

        # max-out background label
        chunk = torch.chunk(cls1, 4, 1)
        bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
        cls1 = torch.cat([bmax, chunk[3]], dim=1)

        return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]


================================================
FILE: face_detection/detection/sfd/s3fd.pth
================================================
[File too large to display: 85.7 MB]

================================================
FILE: face_detection/detection/sfd/sfd_detector.py
================================================
import os
import cv2
from torch.utils.model_zoo import load_url

from ..core import FaceDetector

from .net_s3fd import s3fd
from .bbox import *
from .detect import *

models_urls = {
    's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
}


class SFDDetector(FaceDetector):
    def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
        super(SFDDetector, self).__init__(device, verbose)

        # Initialise the face detector
        if not os.path.isfile(path_to_detector):
            model_weights = load_url(models_urls['s3fd'])
        else:
            model_weights = torch.load(path_to_detector)

        self.face_detector = s3fd()
        self.face_detector.load_state_dict(model_weights)
        self.face_detector.to(device)
        self.face_detector.eval()

    def detect_from_image(self, tensor_or_path):
        image = self.tensor_or_path_to_ndarray(tensor_or_path)

        bboxlist = detect(self.face_detector, image, device=self.device)
        keep = nms(bboxlist, 0.3)
        bboxlist = bboxlist[keep, :]
        bboxlist = [x for x in bboxlist if x[-1] > 0.5]

        return bboxlist

    def detect_from_batch(self, images):
        bboxlists = batch_detect(self.face_detector, images, device=self.device)
        keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
        bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
        bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]

        return bboxlists

    @property
    def reference_scale(self):
        return 195

    @property
    def reference_x_shift(self):
        return 0

    @property
    def reference_y_shift(self):
        return 0


================================================
FILE: face_detection/models.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3,
                     stride=strd, padding=padding, bias=bias)


class ConvBlock(nn.Module):
    def __init__(self, in_planes, out_planes):
        super(ConvBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = conv3x3(in_planes, int(out_planes / 2))
        self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
        self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
        self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
        self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))

        if in_planes != out_planes:
            self.downsample = nn.Sequential(
                nn.BatchNorm2d(in_planes),
                nn.ReLU(True),
                nn.Conv2d(in_planes, out_planes,
                          kernel_size=1, stride=1, bias=False),
            )
        else:
            self.downsample = None

    def forward(self, x):
        residual = x

        out1 = self.bn1(x)
        out1 = F.relu(out1, True)
        out1 = self.conv1(out1)

        out2 = self.bn2(out1)
        out2 = F.relu(out2, True)
        out2 = self.conv2(out2)

        out3 = self.bn3(out2)
        out3 = F.relu(out3, True)
        out3 = self.conv3(out3)

        out3 = torch.cat((out1, out2, out3), 1)

        if self.downsample is not None:
            residual = self.downsample(residual)

        out3 += residual

        return out3


class Bottleneck(nn.Module):

    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class HourGlass(nn.Module):
    def __init__(self, num_modules, depth, num_features):
        super(HourGlass, self).__init__()
        self.num_modules = num_modules
        self.depth = depth
        self.features = num_features

        self._generate_network(self.depth)

    def _generate_network(self, level):
        self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))

        self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))

        if level > 1:
            self._generate_network(level - 1)
        else:
            self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))

        self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))

    def _forward(self, level, inp):
        # Upper branch
        up1 = inp
        up1 = self._modules['b1_' + str(level)](up1)

        # Lower branch
        low1 = F.avg_pool2d(inp, 2, stride=2)
        low1 = self._modules['b2_' + str(level)](low1)

        if level > 1:
            low2 = self._forward(level - 1, low1)
        else:
            low2 = low1
            low2 = self._modules['b2_plus_' + str(level)](low2)

        low3 = low2
        low3 = self._modules['b3_' + str(level)](low3)

        up2 = F.interpolate(low3, scale_factor=2, mode='nearest')

        return up1 + up2

    def forward(self, x):
        return self._forward(self.depth, x)


class FAN(nn.Module):

    def __init__(self, num_modules=1):
        super(FAN, self).__init__()
        self.num_modules = num_modules

        # Base part
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = ConvBlock(64, 128)
        self.conv3 = ConvBlock(128, 128)
        self.conv4 = ConvBlock(128, 256)

        # Stacking part
        for hg_module in range(self.num_modules):
            self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
            self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
            self.add_module('conv_last' + str(hg_module),
                            nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
            self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
            self.add_module('l' + str(hg_module), nn.Conv2d(256,
                                                            68, kernel_size=1, stride=1, padding=0))

            if hg_module < self.num_modules - 1:
                self.add_module(
                    'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
                self.add_module('al' + str(hg_module), nn.Conv2d(68,
                                                                 256, kernel_size=1, stride=1, padding=0))

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)), True)
        x = F.avg_pool2d(self.conv2(x), 2, stride=2)
        x = self.conv3(x)
        x = self.conv4(x)

        previous = x

        outputs = []
        for i in range(self.num_modules):
            hg = self._modules['m' + str(i)](previous)

            ll = hg
            ll = self._modules['top_m_' + str(i)](ll)

            ll = F.relu(self._modules['bn_end' + str(i)]
                        (self._modules['conv_last' + str(i)](ll)), True)

            # Predict heatmaps
            tmp_out = self._modules['l' + str(i)](ll)
            outputs.append(tmp_out)

            if i < self.num_modules - 1:
                ll = self._modules['bl' + str(i)](ll)
                tmp_out_ = self._modules['al' + str(i)](tmp_out)
                previous = previous + ll + tmp_out_

        return outputs


class ResNetDepth(nn.Module):

    def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
        self.inplanes = 64
        super(ResNetDepth, self).__init__()
        self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


================================================
FILE: face_detection/utils.py
================================================
from __future__ import print_function
import os
import sys
import time
import torch
import math
import numpy as np
import cv2


def _gaussian(
        size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
        height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
        mean_vert=0.5):
    # handle some defaults
    if width is None:
        width = size
    if height is None:
        height = size
    if sigma_horz is None:
        sigma_horz = sigma
    if sigma_vert is None:
        sigma_vert = sigma
    center_x = mean_horz * width + 0.5
    center_y = mean_vert * height + 0.5
    gauss = np.empty((height, width), dtype=np.float32)
    # generate kernel
    for i in range(height):
        for j in range(width):
            gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
                sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
    if normalize:
        gauss = gauss / np.sum(gauss)
    return gauss


def draw_gaussian(image, point, sigma):
    # Check if the gaussian is inside
    ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
    br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
    if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
        return image
    size = 6 * sigma + 1
    g = _gaussian(size)
    g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
    g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
    img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
    img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
    assert (g_x[0] > 0 and g_y[1] > 0)
    image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
          ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
    image[image > 1] = 1
    return image


def transform(point, center, scale, resolution, invert=False):
    """Generate and affine transformation matrix.

    Given a set of points, a center, a scale and a targer resolution, the
    function generates and affine transformation matrix. If invert is ``True``
    it will produce the inverse transformation.

    Arguments:
        point {torch.tensor} -- the input 2D point
        center {torch.tensor or numpy.array} -- the center around which to perform the transformations
        scale {float} -- the scale of the face/object
        resolution {float} -- the output resolution

    Keyword Arguments:
        invert {bool} -- define wherever the function should produce the direct or the
        inverse transformation matrix (default: {False})
    """
    _pt = torch.ones(3)
    _pt[0] = point[0]
    _pt[1] = point[1]

    h = 200.0 * scale
    t = torch.eye(3)
    t[0, 0] = resolution / h
    t[1, 1] = resolution / h
    t[0, 2] = resolution * (-center[0] / h + 0.5)
    t[1, 2] = resolution * (-center[1] / h + 0.5)

    if invert:
        t = torch.inverse(t)

    new_point = (torch.matmul(t, _pt))[0:2]

    return new_point.int()


def crop(image, center, scale, resolution=256.0):
    """Center crops an image or set of heatmaps

    Arguments:
        image {numpy.array} -- an rgb image
        center {numpy.array} -- the center of the object, usually the same as of the bounding box
        scale {float} -- scale of the face

    Keyword Arguments:
        resolution {float} -- the size of the output cropped image (default: {256.0})

    Returns:
        [type] -- [description]
    """  # Crop around the center point
    """ Crops the image around the center. Input is expected to be an np.ndarray """
    ul = transform([1, 1], center, scale, resolution, True)
    br = transform([resolution, resolution], center, scale, resolution, True)
    # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
    if image.ndim > 2:
        newDim = np.array([br[1] - ul[1], br[0] - ul[0],
                           image.shape[2]], dtype=np.int32)
        newImg = np.zeros(newDim, dtype=np.uint8)
    else:
        newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
        newImg = np.zeros(newDim, dtype=np.uint8)
    ht = image.shape[0]
    wd = image.shape[1]
    newX = np.array(
        [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
    newY = np.array(
        [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
    oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
    oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
    newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
           ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
    newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
                        interpolation=cv2.INTER_LINEAR)
    return newImg


def get_preds_fromhm(hm, center=None, scale=None):
    """Obtain (x,y) coordinates given a set of N heatmaps. If the center
    and the scale is provided the function will return the points also in
    the original coordinate frame.

    Arguments:
        hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]

    Keyword Arguments:
        center {torch.tensor} -- the center of the bounding box (default: {None})
        scale {float} -- face scale (default: {None})
    """
    max, idx = torch.max(
        hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
    idx += 1
    preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
    preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
    preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)

    for i in range(preds.size(0)):
        for j in range(preds.size(1)):
            hm_ = hm[i, j, :]
            pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
            if pX > 0 and pX < 63 and pY > 0 and pY < 63:
                diff = torch.FloatTensor(
                    [hm_[pY, pX + 1] - hm_[pY, pX - 1],
                     hm_[pY + 1, pX] - hm_[pY - 1, pX]])
                preds[i, j].add_(diff.sign_().mul_(.25))

    preds.add_(-.5)

    preds_orig = torch.zeros(preds.size())
    if center is not None and scale is not None:
        for i in range(hm.size(0)):
            for j in range(hm.size(1)):
                preds_orig[i, j] = transform(
                    preds[i, j], center, scale, hm.size(2), True)

    return preds, preds_orig

def get_preds_fromhm_batch(hm, centers=None, scales=None):
    """Obtain (x,y) coordinates given a set of N heatmaps. If the centers
    and the scales is provided the function will return the points also in
    the original coordinate frame.

    Arguments:
        hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]

    Keyword Arguments:
        centers {torch.tensor} -- the centers of the bounding box (default: {None})
        scales {float} -- face scales (default: {None})
    """
    max, idx = torch.max(
        hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
    idx += 1
    preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
    preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
    preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)

    for i in range(preds.size(0)):
        for j in range(preds.size(1)):
            hm_ = hm[i, j, :]
            pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
            if pX > 0 and pX < 63 and pY > 0 and pY < 63:
                diff = torch.FloatTensor(
                    [hm_[pY, pX + 1] - hm_[pY, pX - 1],
                     hm_[pY + 1, pX] - hm_[pY - 1, pX]])
                preds[i, j].add_(diff.sign_().mul_(.25))

    preds.add_(-.5)

    preds_orig = torch.zeros(preds.size())
    if centers is not None and scales is not None:
        for i in range(hm.size(0)):
            for j in range(hm.size(1)):
                preds_orig[i, j] = transform(
                    preds[i, j], centers[i], scales[i], hm.size(2), True)

    return preds, preds_orig

def shuffle_lr(parts, pairs=None):
    """Shuffle the points left-right according to the axis of symmetry
    of the object.

    Arguments:
        parts {torch.tensor} -- a 3D or 4D object containing the
        heatmaps.

    Keyword Arguments:
        pairs {list of integers} -- [order of the flipped points] (default: {None})
    """
    if pairs is None:
        pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
                 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
                 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
                 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
                 62, 61, 60, 67, 66, 65]
    if parts.ndimension() == 3:
        parts = parts[pairs, ...]
    else:
        parts = parts[:, pairs, ...]

    return parts


def flip(tensor, is_label=False):
    """Flip an image or a set of heatmaps left-right

    Arguments:
        tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]

    Keyword Arguments:
        is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
    """
    if not torch.is_tensor(tensor):
        tensor = torch.from_numpy(tensor)

    if is_label:
        tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
    else:
        tensor = tensor.flip(tensor.ndimension() - 1)

    return tensor

# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)


def appdata_dir(appname=None, roaming=False):
    """ appdata_dir(appname=None, roaming=False)

    Get the path to the application directory, where applications are allowed
    to write user specific files (e.g. configurations). For non-user specific
    data, consider using common_appdata_dir().
    If appname is given, a subdir is appended (and created if necessary).
    If roaming is True, will prefer a roaming directory (Windows Vista/7).
    """

    # Define default user directory
    userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
    if userDir is None:
        userDir = os.path.expanduser('~')
        if not os.path.isdir(userDir):  # pragma: no cover
            userDir = '/var/tmp'  # issue #54

    # Get system app data dir
    path = None
    if sys.platform.startswith('win'):
        path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
        path = (path2 or path1) if roaming else (path1 or path2)
    elif sys.platform.startswith('darwin'):
        path = os.path.join(userDir, 'Library', 'Application Support')
    # On Linux and as fallback
    if not (path and os.path.isdir(path)):
        path = userDir

    # Maybe we should store things local to the executable (in case of a
    # portable distro or a frozen application that wants to be portable)
    prefix = sys.prefix
    if getattr(sys, 'frozen', None):
        prefix = os.path.abspath(os.path.dirname(sys.executable))
    for reldir in ('settings', '../settings'):
        localpath = os.path.abspath(os.path.join(prefix, reldir))
        if os.path.isdir(localpath):  # pragma: no cover
            try:
                open(os.path.join(localpath, 'test.write'), 'wb').close()
                os.remove(os.path.join(localpath, 'test.write'))
            except IOError:
                pass  # We cannot write in this directory
            else:
                path = localpath
                break

    # Get path specific for this app
    if appname:
        if path == userDir:
            appname = '.' + appname.lstrip('.')  # Make it a hidden directory
        path = os.path.join(path, appname)
        if not os.path.isdir(path):  # pragma: no cover
            os.mkdir(path)

    # Done
    return path


================================================
FILE: face_parsing/README.md
================================================
Most of the code in this folder was taken from the awesome [face parsing](https://github.com/zllrunning/face-parsing.PyTorch.git) repository.

================================================
FILE: face_parsing/__init__.py
================================================
from .swap import init_parser,swap_regions

================================================
FILE: face_parsing/model.py
================================================
#!/usr/bin/python
# -*- encoding: utf-8 -*-


import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from .resnet import Resnet18
# from modules.bn import InPlaceABNSync as BatchNorm2d


class ConvBNReLU(nn.Module):
    def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(in_chan,
                out_chan,
                kernel_size = ks,
                stride = stride,
                padding = padding,
                bias = False)
        self.bn = nn.BatchNorm2d(out_chan)
        self.init_weight()

    def forward(self, x):
        x = self.conv(x)
        x = F.relu(self.bn(x))
        return x

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

class BiSeNetOutput(nn.Module):
    def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
        super(BiSeNetOutput, self).__init__()
        self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
        self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
        self.init_weight()

    def forward(self, x):
        x = self.conv(x)
        x = self.conv_out(x)
        return x

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, nn.BatchNorm2d):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


class AttentionRefinementModule(nn.Module):
    def __init__(self, in_chan, out_chan, *args, **kwargs):
        super(AttentionRefinementModule, self).__init__()
        self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
        self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
        self.bn_atten = nn.BatchNorm2d(out_chan)
        self.sigmoid_atten = nn.Sigmoid()
        self.init_weight()

    def forward(self, x):
        feat = self.conv(x)
        atten = F.avg_pool2d(feat, feat.size()[2:])
        atten = self.conv_atten(atten)
        atten = self.bn_atten(atten)
        atten = self.sigmoid_atten(atten)
        out = torch.mul(feat, atten)
        return out

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)


class ContextPath(nn.Module):
    def __init__(self, device,*args, **kwargs):
        super(ContextPath, self).__init__()
        self.resnet = Resnet18(device)
        self.arm16 = AttentionRefinementModule(256, 128)
        self.arm32 = AttentionRefinementModule(512, 128)
        self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
        self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
        self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)

        self.init_weight()

    def forward(self, x):
        H0, W0 = x.size()[2:]
        feat8, feat16, feat32 = self.resnet(x)
        H8, W8 = feat8.size()[2:]
        H16, W16 = feat16.size()[2:]
        H32, W32 = feat32.size()[2:]

        avg = F.avg_pool2d(feat32, feat32.size()[2:])
        avg = self.conv_avg(avg)
        avg_up = F.interpolate(avg, (H32, W32), mode='nearest')

        feat32_arm = self.arm32(feat32)
        feat32_sum = feat32_arm + avg_up
        feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
        feat32_up = self.conv_head32(feat32_up)

        feat16_arm = self.arm16(feat16)
        feat16_sum = feat16_arm + feat32_up
        feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
        feat16_up = self.conv_head16(feat16_up)

        return feat8, feat16_up, feat32_up  # x8, x8, x16

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, nn.BatchNorm2d):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


### This is not used, since I replace this with the resnet feature with the same size
class SpatialPath(nn.Module):
    def __init__(self, *args, **kwargs):
        super(SpatialPath, self).__init__()
        self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
        self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
        self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
        self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
        self.init_weight()

    def forward(self, x):
        feat = self.conv1(x)
        feat = self.conv2(feat)
        feat = self.conv3(feat)
        feat = self.conv_out(feat)
        return feat

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, nn.BatchNorm2d):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


class FeatureFusionModule(nn.Module):
    def __init__(self, in_chan, out_chan, *args, **kwargs):
        super(FeatureFusionModule, self).__init__()
        self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
        self.conv1 = nn.Conv2d(out_chan,
                out_chan//4,
                kernel_size = 1,
                stride = 1,
                padding = 0,
                bias = False)
        self.conv2 = nn.Conv2d(out_chan//4,
                out_chan,
                kernel_size = 1,
                stride = 1,
                padding = 0,
                bias = False)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()
        self.init_weight()

    def forward(self, fsp, fcp):
        fcat = torch.cat([fsp, fcp], dim=1)
        feat = self.convblk(fcat)
        atten = F.avg_pool2d(feat, feat.size()[2:])
        atten = self.conv1(atten)
        atten = self.relu(atten)
        atten = self.conv2(atten)
        atten = self.sigmoid(atten)
        feat_atten = torch.mul(feat, atten)
        feat_out = feat_atten + feat
        return feat_out

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, nn.BatchNorm2d):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


class BiSeNet(nn.Module):
    def __init__(self,device, n_classes, *args, **kwargs):
        super(BiSeNet, self).__init__()
        self.cp = ContextPath(device)
        ## here self.sp is deleted
        self.ffm = FeatureFusionModule(256, 256)
        self.conv_out = BiSeNetOutput(256, 256, n_classes)
        self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
        self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
        self.init_weight()

    def forward(self, x):
        H, W = x.size()[2:]
        feat_res8, feat_cp8, feat_cp16 = self.cp(x)  # here return res3b1 feature
        feat_sp = feat_res8  # use res3b1 feature to replace spatial path feature
        feat_fuse = self.ffm(feat_sp, feat_cp8)

        feat_out = self.conv_out(feat_fuse)
        feat_out16 = self.conv_out16(feat_cp8)
        feat_out32 = self.conv_out32(feat_cp16)

        feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
        feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
        feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
        return feat_out, feat_out16, feat_out32

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
        for name, child in self.named_children():
            child_wd_params, child_nowd_params = child.get_params()
            if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
                lr_mul_wd_params += child_wd_params
                lr_mul_nowd_params += child_nowd_params
            else:
                wd_params += child_wd_params
                nowd_params += child_nowd_params
        return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params





================================================
FILE: face_parsing/resnet.py
================================================
#!/usr/bin/python
# -*- encoding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as modelzoo

# from modules.bn import InPlaceABNSync as BatchNorm2d

resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    def __init__(self, in_chan, out_chan, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_chan, out_chan, stride)
        self.bn1 = nn.BatchNorm2d(out_chan)
        self.conv2 = conv3x3(out_chan, out_chan)
        self.bn2 = nn.BatchNorm2d(out_chan)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = None
        if in_chan != out_chan or stride != 1:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_chan, out_chan,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_chan),
                )

    def forward(self, x):
        residual = self.conv1(x)
        residual = F.relu(self.bn1(residual))
        residual = self.conv2(residual)
        residual = self.bn2(residual)

        shortcut = x
        if self.downsample is not None:
            shortcut = self.downsample(x)

        out = shortcut + residual
        out = self.relu(out)
        return out


def create_layer_basic(in_chan, out_chan, bnum, stride=1):
    layers = [BasicBlock(in_chan, out_chan, stride=stride)]
    for i in range(bnum-1):
        layers.append(BasicBlock(out_chan, out_chan, stride=1))
    return nn.Sequential(*layers)


class Resnet18(nn.Module):
    def __init__(self,device):
        super(Resnet18, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
        self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
        self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
        self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
        self.init_weight(device)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(self.bn1(x))
        x = self.maxpool(x)

        x = self.layer1(x)
        feat8 = self.layer2(x) # 1/8
        feat16 = self.layer3(feat8) # 1/16
        feat32 = self.layer4(feat16) # 1/32
        return feat8, feat16, feat32

    def init_weight(self,device):
        print('load resnet18 model from dir')
        checkpoint_path = "./checkpoints/resnet18-5c106cde.pth"
        state_dict = torch.load(checkpoint_path)
        #state_dict = modelzoo.load_url(resnet18_url)
        self_state_dict = self.state_dict()
        for k, v in state_dict.items():
            if 'fc' in k: continue
            self_state_dict.update({k: v})
        self.load_state_dict(self_state_dict)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module,  nn.BatchNorm2d):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


if __name__ == "__main__":
    net = Resnet18()
    x = torch.randn(16, 3, 224, 224)
    out = net(x)
    print(out[0].size())
    print(out[1].size())
    print(out[2].size())
    net.get_params()


================================================
FILE: face_parsing/swap.py
================================================
import torch
import torchvision.transforms as transforms
import cv2
import numpy as np
import torch.nn.functional as F
from .model import BiSeNet
from torchvision.transforms.functional import normalize
from torchvision.transforms import Resize
def init_parser(pth_path, device):

    n_classes = 19
    net = BiSeNet(device,n_classes=n_classes)
    net.to(device)

    net.load_state_dict(torch.load(pth_path,map_location=device))
    net.eval().to(device)
    print('Parser model loaded')
    return net

def image_to_parsing_img(img, net):
    import time
    start = time.time()
    img = cv2.resize(img, (512, 512))
    img_copy = img.copy()
    img = img[:,:,::-1]
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    img = transform(img.copy())
    img = torch.unsqueeze(img, 0)
    end = time.time()


    start1 = time.time()
    with torch.no_grad():
        img = img.cuda()
        out = net(img)[0]
        parsing = out.squeeze(0).argmax(0)
        parsing = parsing.cpu().numpy()
        end1 = time.time()

        return parsing


def image_to_parsing(img, net):
    img = cv2.resize(img, (512, 512))
    img_copy = img.copy()
    img = img[:,:,::-1]
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    img = transform(img.copy())
    img = torch.unsqueeze(img, 0)
    device = next(net.parameters()).device
    with torch.no_grad():
        img = img.to(device)
        out = net(img)[0]
        parsing = out.squeeze(0).argmax(0)
        parsing = parsing.cpu().numpy()

        return parsing


def image_to_parsing2(img, net):
    img = img.to(dtype=torch.float32).div(255)
    normalize(img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True)
    img = torch.unsqueeze(img, 0)
    img = F.interpolate(img, (512, 512), mode='bilinear', align_corners=True)
    with torch.no_grad():
        out = net(img)[0]   #15ms
        parsing = out.squeeze(0).argmax(0)
        return parsing


def get_mask(parsing, classes):
    res = parsing == classes[0]
    for val in classes[1:]:
        res += parsing == val
    return res

def swap_regions_img(source, target, net):
    import time
    parsing = image_to_parsing_img(source, net)  #13ms
    source = cv2.resize(source,(512,512))
    face_classes = [1, 11, 12, 13]
    mask = get_mask(parsing, face_classes)
    mask = np.repeat(np.expand_dims(mask, axis=2), 3, 2)
    mask = mask.astype(np.float)
    result = (1 - mask) * cv2.resize(source, (512, 512)) + mask * cv2.resize(target, (512, 512))
    result = cv2.resize(result.astype(np.uint8), (source.shape[1], source.shape[0]))
    return result,mask



def swap_regions(source, target, net):
    parsing = image_to_parsing(source, net)  #13ms
    face_classes = [1, 11, 12, 13]
    mask = get_mask(parsing, face_classes)
    mask = np.repeat(np.expand_dims(mask, axis=2), 3, 2)
    result = (1 - mask) * cv2.resize(source, (512, 512)) + mask * cv2.resize(target, (512, 512))
    result = cv2.resize(result.astype(np.uint8), (source.shape[1], source.shape[0]))
    mask = cv2.resize(mask.astype(np.uint8), (source.shape[1], source.shape[0]))
    return result,mask


================================================
FILE: filelists/train.txt
================================================
MEAD/M003-002
MEAD/M003-004
MEAD/M003-003
MEAD/M005-001
MEAD/M003-005
MEAD/M003-001


================================================
FILE: filelists/val.txt
================================================


================================================
FILE: filelists_lrs2/README.md
================================================
Place LRS2 (and any other) filelists here for training.

================================================
FILE: filelists_lrs2/test.txt
================================================
00001\00002

================================================
FILE: filelists_lrs2/train.txt
================================================
6251513448847229551/00029
6251513448847229551/00015
6251513448847229551/00003
6251513448847229551/00032
6251513448847229551/00009
6251513448847229551/00010
6251513448847229551/00033
6251513448847229551/00031
6251513448847229551/00013
6251513448847229551/00022
6251513448847229551/00026
6251513448847229551/00001
6251513448847229551/00018
6251513448847229551/00012
6251513448847229551/00014
6251513448847229551/00016
6251513448847229551/00011
6251513448847229551/00034
6251513448847229551/00002
6251513448847229551/00025
6251513448847229551/00004
5939542915844728475/00015
5939542915844728475/00006
5939542915844728475/00046
5939542915844728475/00010
5939542915844728475/00033
5939542915844728475/00013
5939542915844728475/00022
5939542915844728475/00026
5939542915844728475/00044
5939542915844728475/00050
5939542915844728475/00001
5939542915844728475/00023
5939542915844728475/00048
5939542915844728475/00042
5939542915844728475/00041
5939542915844728475/00021
5939542915844728475/00014
5939542915844728475/00039
5939542915844728475/00020
5939542915844728475/00035
5939542915844728475/00027
5939542915844728475/00034
5939542915844728475/00008
5939542915844728475/00040
5939542915844728475/00007
5939542915844728475/00043
6287945508935485444/00003
6287945508935485444/00006
6287945508935485444/00010
6287945508935485444/00012
6287945508935485444/00014
6287945508935485444/00016
6287945508935485444/00002
6287945508935485444/00005
6287945508935485444/00004
6101595038399976173/00015
6101595038399976173/00003
6101595038399976173/00009
6101595038399976173/00010
6101595038399976173/00013
6101595038399976173/00001
6101595038399976173/00012
6101595038399976173/00014
6101595038399976173/00016
6101595038399976173/00011
6101595038399976173/00002
6101595038399976173/00007
6101595038399976173/00005
6101595038399976173/00004
5944231731641774858/00029
5944231731641774858/00015
5944231731641774858/00003
5944231731641774858/00010
5944231731641774858/00013
5944231731641774858/00026
5944231731641774858/00023
5944231731641774858/00018
5944231731641774858/00021
5944231731641774858/00016
5944231731641774858/00019
5944231731641774858/00011
5944231731641774858/00002
5944231731641774858/00008
5944231731641774858/00004
5944231731641774858/00028
6077845587240023489/00003
6077845587240023489/00010
6077845587240023489/00013
6077845587240023489/00001
6077845587240023489/00008
6077845587240023489/00007
6077845587240023489/00005
6251200345731287347/00006
6251200345731287347/00010
6251200345731287347/00018
6251200345731287347/00012
6251200345731287347/00014
6251200345731287347/00011
6251200345731287347/00002
6251200345731287347/00008
6251200345731287347/00007
6251200345731287347/00005
6092441604098670412/00060
6092441604098670412/00029
6092441604098670412/00015
6092441604098670412/00003
6092441604098670412/00006
6092441604098670412/00140
6092441604098670412/00046
6092441604098670412/00107
6092441604098670412/00032
6092441604098670412/00113
6092441604098670412/00109
6092441604098670412/00122
6092441604098670412/00070
6092441604098670412/00090
6092441604098670412/00010
6092441604098670412/00033
6092441604098670412/00123
6092441604098670412/00024
6092441604098670412/00141
6092441604098670412/00053
6092441604098670412/00118
6092441604098670412/00116
6092441604098670412/00030
6092441604098670412/00071
6092441604098670412/00058
6092441604098670412/00121
6092441604098670412/00062
6092441604098670412/00142
6092441604098670412/00013
6092441604098670412/00137
6092441604098670412/00022
6092441604098670412/00096
6092441604098670412/00094
6092441604098670412/00026
6092441604098670412/00044
6092441604098670412/00130
6092441604098670412/00099
6092441604098670412/00120
6092441604098670412/00103
6092441604098670412/00050
6092441604098670412/00105
6092441604098670412/00063
6092441604098670412/00110
6092441604098670412/00001
6092441604098670412/00045
6092441604098670412/00064
6092441604098670412/00136
6092441604098670412/00049
6092441604098670412/00072
6092441604098670412/00114
6092441604098670412/00038
6092441604098670412/00048
6092441604098670412/00134
6092441604098670412/00108
6092441604098670412/00145
6092441604098670412/00018
6092441604098670412/00085
6092441604098670412/00135
6092441604098670412/00041
6092441604098670412/00126
6092441604098670412/00065
6092441604098670412/00129
6092441604098670412/00144
6092441604098670412/00083
6092441604098670412/00021
6092441604098670412/00075
6092441604098670412/00079
6092441604098670412/00014
6092441604098670412/00127
6092441604098670412/00101
6092441604098670412/00047
6092441604098670412/00039
6092441604098670412/00124
6092441604098670412/00020
6092441604098670412/00035
6092441604098670412/00051
6092441604098670412/00017
6092441604098670412/00076
6092441604098670412/00080
6092441604098670412/00091
6092441604098670412/00055
6092441604098670412/00016
6092441604098670412/00095
6092441604098670412/00138
6092441604098670412/00111
6092441604098670412/00059
6092441604098670412/00119
6092441604098670412/00027
6092441604098670412/00115
6092441604098670412/00034
6092441604098670412/00100
6092441604098670412/00002
6092441604098670412/00133
6092441604098670412/00106
6092441604098670412/00089
6092441604098670412/00092
6092441604098670412/00008
6092441604098670412/00131
6092441604098670412/00040
6092441604098670412/00007
6092441604098670412/00068
6092441604098670412/00143
6092441604098670412/00132
6092441604098670412/00139
6092441604098670412/00054
6092441604098670412/00043
6092441604098670412/00005
6092441604098670412/00004
6092441604098670412/00028
5898692622899070396/00003
5898692622899070396/00006
5898692622899070396/00009
5898692622899070396/00001
5898692622899070396/00012
5898692622899070396/00007
5898692622899070396/00004
6117277252487848891/00009
6117277252487848891/00002
6117277252487848891/00004
6238985458741516828/00003
6238985458741516828/00009
6238985458741516828/00010
6238985458741516828/00017
6238985458741516828/00016
6238985458741516828/00011
6238985458741516828/00008
6238985458741516828/00007
6091575738691860542/00003
6091575738691860542/00006
6091575738691860542/00009
6091575738691860542/00002
6091575738691860542/00007
6091575738691860542/00004
5983702051595260731/00003
5983702051595260731/00006
5983702051595260731/00009
5983702051595260731/00033
5983702051595260731/00024
5983702051595260731/00030
5983702051595260731/00013
5983702051595260731/00001
5983702051595260731/00023
5983702051595260731/00018
5983702051595260731/00012
5983702051595260731/00020
5983702051595260731/00035
5983702051595260731/00016
5983702051595260731/00019
5983702051595260731/00034
5983702051595260731/00002
5983702051595260731/00008
5983702051595260731/00007
5983702051595260731/00005
5983702051595260731/00004
5983702051595260731/00028
6209650402613110582/00003
6209650402613110582/00032
6209650402613110582/00009
6209650402613110582/00010
6209650402613110582/00033
6209650402613110582/00024
6209650402613110582/00030
6209650402613110582/00013
6209650402613110582/00022
6209650402613110582/00044
6209650402613110582/00036
6209650402613110582/00045
6209650402613110582/00023
6209650402613110582/00038
6209650402613110582/00018
6209650402613110582/00042
6209650402613110582/00041
6209650402613110582/00037
6209650402613110582/00039
6209650402613110582/00020
6209650402613110582/00035
6209650402613110582/00017
6209650402613110582/00016
6209650402613110582/00019
6209650402613110582/00011
6209650402613110582/00027
6209650402613110582/00034
6209650402613110582/00002
6209650402613110582/00008
6209650402613110582/00040
6209650402613110582/00007
6209650402613110582/00025
6209650402613110582/00043
5930153687838935272/00006
5930153687838935272/00010
5930153687838935272/00008
5930153687838935272/00005
5930153687838935272/00004
6345850258020171127/00006
6117922786072437795/00003
6117922786072437795/00010
6117922786072437795/00001
6117922786072437795/00011
6117922786072437795/00008
6117922786072437795/00004
6233055826892590338/00015
6233055826892590338/00003
6233055826892590338/00006
6233055826892590338/00010
6233055826892590338/00013
6233055826892590338/00001
6233055826892590338/00012
6233055826892590338/00014
6233055826892590338/00017
6233055826892590338/00016
6233055826892590338/00019
6233055826892590338/00011
6233055826892590338/00002
6233055826892590338/00008
6233055826892590338/00007
6233055826892590338/00005
5943860646467464220/00015
5943860646467464220/00006
5943860646467464220/00010
5943860646467464220/00013
5943860646467464220/00012
5943860646467464220/00016
5943860646467464220/00011
5943860646467464220/00004
6365141533126941637/00060
6365141533126941637/00056
6365141533126941637/00057
6365141533126941637/00013
6365141533126941637/00012
6365141533126941637/00051
6188213791342693604/00002
5687503926993979657/00003
5687503926993979657/00006
5687503926993979657/00001
5687503926993979657/00002
5687503926993979657/00004
6122793278986045329/00015
6122793278986045329/00006
6122793278986045329/00013
6122793278986045329/00023
6122793278986045329/00012
6122793278986045329/00017
6122793278986045329/00008
6122793278986045329/00007
6106395952843454846/00003
6106395952843454846/00009
6106395952843454846/00010
6106395952843454846/00005
5977512144728289991/00006
5977512144728289991/00009
5977512144728289991/00010
5977512144728289991/00001
5977512144728289991/00012
5977512144728289991/00011
5977512144728289991/00002
5977512144728289991/00005
5977512144728289991/00004
6284219195309469824/00015
6284219195309469824/00003
6284219195309469824/00006
6284219195309469824/00009
6284219195309469824/00013
6284219195309469824/00001
6284219195309469824/00014
6284219195309469824/00016
6284219195309469824/00011
6284219195309469824/00002
6284219195309469824/00008
6284219195309469824/00007
6284219195309469824/00005
6086751631424989624/00015
6086751631424989624/00006
6086751631424989624/00009
6086751631424989624/00010
6086751631424989624/00013
6086751631424989624/00001
6086751631424989624/00018
6086751631424989624/00014
6086751631424989624/00020
6086751631424989624/00017
6086751631424989624/00016
6086751631424989624/00019
6086751631424989624/00002
6086751631424989624/00008
6086751631424989624/00007
6086751631424989624/00005
6086751631424989624/00004
6259402874273258680/00003
6259402874273258680/00006
6259402874273258680/00001
6259402874273258680/00005
5987277611869252470/00015
5987277611869252470/00006
5987277611869252470/00009
5987277611869252470/00013
5987277611869252470/00012
5987277611869252470/00021
5987277611869252470/00014
5987277611869252470/00016
5987277611869252470/00011
5987277611869252470/00007
5971320949371164880/00024
5971320949371164880/00013
5971320949371164880/00026
5971320949371164880/00018
5971320949371164880/00012
5971320949371164880/00021
5971320949371164880/00014
5971320949371164880/00019
5971320949371164880/00011
5971320949371164880/00002
5971320949371164880/00008
5971320949371164880/00007
5971320949371164880/00025
5971320949371164880/00005
5971320949371164880/00004
5995503333234560715/00029
5995503333234560715/00015
5995503333234560715/00003
5995503333234560715/00006
5995503333234560715/00009
5995503333234560715/00010
5995503333234560715/00024
5995503333234560715/00001
5995503333234560715/00021
5995503333234560715/00017
5995503333234560715/00016
5995503333234560715/00027
5995503333234560715/00002
5995503333234560715/00008
5995503333234560715/00007
5995503333234560715/00025
5995503333234560715/00004
5568055732531473539/00029
5568055732531473539/00015
5568055732531473539/00003
5568055732531473539/00032
5568055732531473539/00009
5568055732531473539/00024
5568055732531473539/00030
5568055732531473539/00013
5568055732531473539/00022
5568055732531473539/00026
5568055732531473539/00044
5568055732531473539/00050
5568055732531473539/00036
5568055732531473539/00001
5568055732531473539/00023
5568055732531473539/00038
5568055732531473539/00048
5568055732531473539/00018
5568055732531473539/00042
5568055732531473539/00041
5568055732531473539/00021
5568055732531473539/00037
5568055732531473539/00047
5568055732531473539/00020
5568055732531473539/00035
5568055732531473539/00016
5568055732531473539/00019
5568055732531473539/00011
5568055732531473539/00027
5568055732531473539/00034
5568055732531473539/00002
5568055732531473539/00008
5568055732531473539/00040
5568055732531473539/00007
5568055732531473539/00025
5568055732531473539/00043
5568055732531473539/00005
5568055732531473539/00004
5568055732531473539/00028
6079283542290698524/00015
6079283542290698524/00001
6079283542290698524/00012
6079283542290698524/00014
6079283542290698524/00027
6079283542290698524/00002
6079283542290698524/00025
6079283542290698524/00004
6140640156591170371/00004
6253625284266609126/00015
6253625284266609126/00009
6253625284266609126/00026
6253625284266609126/00001
6253625284266609126/00023
6253625284266609126/00018
6253625284266609126/00014
6253625284266609126/00020
6253625284266609126/00016
6253625284266609126/00002
6253625284266609126/00007
6253625284266609126/00025
6253625284266609126/00004
6121278014524018920/00024
6121278014524018920/00031
6121278014524018920/00013
6121278014524018920/00012
6121278014524018920/00014
6121278014524018920/00017
6121278014524018920/00027
5586505623544917789/00003
5586505623544917789/00006
5586505623544917789/00011
5586505623544917789/00002
5586505623544917789/00007
5586505623544917789/00004
6260918138735224811/00003
6260918138735224811/00006
6260918138735224811/00009
6260918138735224811/00010
6260918138735224811/00012
6260918138735224811/00011
6260918138735224811/00002
6260918138735224811/00008
6260918138735224811/00007
6260918138735224811/00005
6260918138735224811/00004
6125383144265536003/00005
6125383144265536003/00004
5541186846624433836/00003
5541186846624433836/00006
5541186846624433836/00032
5541186846624433836/00010
5541186846624433836/00033
5541186846624433836/00030
5541186846624433836/00031
5541186846624433836/00013
5541186846624433836/00022
5541186846624433836/00026
5541186846624433836/00001
5541186846624433836/00023
5541186846624433836/00012
5541186846624433836/00014
5541186846624433836/00020
5541186846624433836/00017
5541186846624433836/00016
5541186846624433836/00019
5541186846624433836/00011
5541186846624433836/00027
5541186846624433836/00002
5541186846624433836/00025
5541186846624433836/00005
5541186846624433836/00004
5541186846624433836/00028
6123214615277782961/00001
6123214615277782961/00002
6123214615277782961/00005
6064123166729284457/00003
6064123166729284457/00006
6064123166729284457/00009
6064123166729284457/00013
6064123166729284457/00005
6064123166729284457/00004
6075674481271828339/00029
6075674481271828339/00046
6075674481271828339/00009
6075674481271828339/00024
6075674481271828339/00030
6075674481271828339/00026
6075674481271828339/00036
6075674481271828339/00045
6075674481271828339/00023
6075674481271828339/00037
6075674481271828339/00020
6075674481271828339/00035
6075674481271828339/00017
6075674481271828339/00016
6075674481271828339/00019
6075674481271828339/00027
6075674481271828339/00034
6075674481271828339/00043
6075674481271828339/00004
6215243738522697556/00006
6215243738522697556/00002
6215243738522697556/00008
5551623617153720258/00015
5551623617153720258/00003
5551623617153720258/00001
5551623617153720258/00012
5551623617153720258/00014
5551623617153720258/00017
5551623617153720258/00016
5551623617153720258/00019
5551623617153720258/00002
5551623617153720258/00008
5551623617153720258/00007
5551623617153720258/00005
6130667242529814803/00003
6130667242529814803/00002
6102070491279641383/00006
6102070491279641383/00012
6102070491279641383/00011
6324230681142278360/00006
6324230681142278360/00009
6324230681142278360/00010
6324230681142278360/00013
6324230681142278360/00016
6324230681142278360/00002
6324230681142278360/00007
6077772143299170798/00015
6077772143299170798/00003
6077772143299170798/00009
6077772143299170798/00013
6077772143299170798/00001
6077772143299170798/00012
6077772143299170798/00014
6077772143299170798/00016
6077772143299170798/00019
6077772143299170798/00011
6077772143299170798/00002
6077772143299170798/00008
5860790395505332421/00003
5860790395505332421/00009
5860790395505332421/00010
5860790395505332421/00001
5860790395505332421/00023
5860790395505332421/00020
5860790395505332421/00011
5860790395505332421/00002
5860790395505332421/00007
5860790395505332421/00025
5860790395505332421/00005
6139542362950251255/00006
6139542362950251255/00009
6139542362950251255/00010
6139542362950251255/00007
6139542362950251255/00005
6139542362950251255/00004
6041181598917637143/00009
6041181598917637143/00010
6041181598917637143/00013
6041181598917637143/00018
6041181598917637143/00021
6041181598917637143/00014
6041181598917637143/00020
6041181598917637143/00016
6041181598917637143/00011
6041181598917637143/00002
6041181598917637143/00008
6041181598917637143/00007
6041181598917637143/00004
6085394851386601911/00006
6085394851386601911/00046
6085394851386601911/00032
6085394851386601911/00010
6085394851386601911/00033
6085394851386601911/00024
6085394851386601911/00031
6085394851386601911/00013
6085394851386601911/00022
6085394851386601911/00050
6085394851386601911/00001
6085394851386601911/00045
6085394851386601911/00023
6085394851386601911/00048
6085394851386601911/00018
6085394851386601911/00012
6085394851386601911/00014
6085394851386601911/00039
6085394851386601911/00017
6085394851386601911/00016
6085394851386601911/00011
6085394851386601911/00027
6085394851386601911/00007
5687545158680046865/00003
5687545158680046865/00024
5687545158680046865/00013
5687545158680046865/00022
5687545158680046865/00001
5687545158680046865/00023
5687545158680046865/00018
5687545158680046865/00021
5687545158680046865/00014
5687545158680046865/00020
5687545158680046865/00017
5687545158680046865/00008
5687545158680046865/00007
5687545158680046865/00025
5687545158680046865/00005
5687545158680046865/00004
5687545158680046865/00028
5960558190824112934/00029
5960558190824112934/00015
5960558190824112934/00009
5960558190824112934/00031
5960558190824112934/00013
5960558190824112934/00019
5960558190824112934/00008
5538635636050605931/00003
5538635636050605931/00005
5538635636050605931/00004
6212592025714146595/00029
6212592025714146595/00003
6212592025714146595/00046
6212592025714146595/00010
6212592025714146595/00024
6212592025714146595/00030
6212592025714146595/00031
6212592025714146595/00013
6212592025714146595/00026
6212592025714146595/00044
6212592025714146595/00036
6212592025714146595/00001
6212592025714146595/00023
6212592025714146595/00018
6212592025714146595/00042
6212592025714146595/00021
6212592025714146595/00014
6212592025714146595/00037
6212592025714146595/00017
6212592025714146595/00011
6212592025714146595/00027
6212592025714146595/00034
6212592025714146595/00002
6212592025714146595/00040
6212592025714146595/00025
6212592025714146595/00005
6212592025714146595/00004
6212592025714146595/00028
6151015079591335531/00003
6151015079591335531/00006
6151015079591335531/00010
6151015079591335531/00001
6151015079591335531/00012
6151015079591335531/00011
6151015079591335531/00002
6151015079591335531/00007
6151015079591335531/00005
6151015079591335531/00004
6247520417752134275/00006
6247520417752134275/00009
6247520417752134275/00010
6247520417752134275/00018
6247520417752134275/00012
6247520417752134275/00019
6247520417752134275/00011
6247520417752134275/00002
6247520417752134275/00008
6247520417752134275/00007
6247520417752134275/00004
6212920590712289083/00003
6212920590712289083/00006
6212920590712289083/00002
6212920590712289083/00007
6212920590712289083/00005
6030412397919638139/00003
6030412397919638139/00006
6030412397919638139/00008
6030412397919638139/00007
6030412397919638139/00004
6244462830534109905/00003
6244462830534109905/00006
6244462830534109905/00009
6244462830534109905/00010
6244462830534109905/00001
6244462830534109905/00012
6244462830534109905/00014
6244462830534109905/00002
6244462830534109905/00007
6244462830534109905/00004
6096747738309648778/00010
6096747738309648778/00001
6096747738309648778/00012
6096747738309648778/00008
6096747738309648778/00007
6096747738309648778/00005
5984459683827001604/00015
5984459683827001604/00003
5984459683827001604/00006
5984459683827001604/00013
5984459683827001604/00001
5984459683827001604/00012
5984459683827001604/00014
5984459683827001604/00017
5984459683827001604/00008
5984459683827001604/00007
5984459683827001604/00005
6203140950179222464/00125
6203140950179222464/00060
6203140950179222464/00003
6203140950179222464/00009
6203140950179222464/00090
6203140950179222464/00010
6203140950179222464/00057
6203140950179222464/00078
6203140950179222464/00030
6203140950179222464/00071
6203140950179222464/00058
6203140950179222464/00142
6203140950179222464/00086
6203140950179222464/00013
6203140950179222464/00069
6203140950179222464/00026
6203140950179222464/00044
6203140950179222464/00077
6203140950179222464/00093
6203140950179222464/00103
6203140950179222464/00117
6203140950179222464/00001
6203140950179222464/00064
6203140950179222464/00072
6203140950179222464/00038
6203140950179222464/00088
6203140950179222464/00135
6203140950179222464/00012
6203140950179222464/00126
6203140950179222464/00065
6203140950179222464/00144
6203140950179222464/00021
6203140950179222464/00104
6203140950179222464/00037
6203140950179222464/00101
6203140950179222464/00124
6203140950179222464/00020
6203140950179222464/00052
6203140950179222464/00091
6203140950179222464/00095
6203140950179222464/00084
6203140950179222464/00082
6203140950179222464/00106
6203140950179222464/00089
6203140950179222464/00092
6203140950179222464/00040
6203140950179222464/00102
6203140950179222464/00068
6203140950179222464/00132
6203140950179222464/00054
6203140950179222464/00043
6203140950179222464/00004
6098013035675050600/00009
6098013035675050600/00010
6098013035675050600/00001
6098013035675050600/00007
6115801931221609643/00015
6115801931221609643/00001
6115801931221609643/00018
6115801931221609643/00017
6115801931221609643/00004
6211490366602652865/00015
6211490366602652865/00003
6211490366602652865/00033
6211490366602652865/00024
6211490366602652865/00030
6211490366602652865/00013
6211490366602652865/00022
6211490366602652865/00026
6211490366602652865/00049
6211490366602652865/00038
6211490366602652865/00041
6211490366602652865/00021
6211490366602652865/00014
6211490366602652865/00039
6211490366602652865/00020
6211490366602652865/00027
6211490366602652865/00034
6211490366602652865/00002
6211490366602652865/00040
6211490366602652865/00025
6037519709801129291/00003
6037519709801129291/00010
6037519709801129291/00012
6037519709801129291/00014
6037519709801129291/00017
6037519709801129291/00016
6037519709801129291/00019
6037519709801129291/00011
6037519709801129291/00002
6037519709801129291/00008
5996956750167467138/00009
5996956750167467138/00010
5996956750167467138/00007
5996956750167467138/00004
6249375843624008338/00013
6249375843624008338/00001
6249375843624008338/00012
6249375843624008338/00021
6249375843624008338/00011
5951349351444679853/00003
5951349351444679853/00009
5951349351444679853/00010
5951349351444679853/00013
5951349351444679853/00011
5951349351444679853/00002
5951349351444679853/00007
5951349351444679853/00005
5951349351444679853/00004
5918933515274853764/00029
5918933515274853764/00015
5918933515274853764/00056
5918933515274853764/00032
5918933515274853764/00057
5918933515274853764/00030
5918933515274853764/00044
5918933515274853764/00050
5918933515274853764/00049
5918933515274853764/00023
5918933515274853764/00018
5918933515274853764/00042
5918933515274853764/00012
5918933515274853764/00014
5918933515274853764/00047
5918933515274853764/00035
5918933515274853764/00017
5918933515274853764/00016
5918933515274853764/00019
5918933515274853764/00011
5918933515274853764/00034
5918933515274853764/00008
5918933515274853764/00040
5918933515274853764/00054
5687877589148733737/00029
5687877589148733737/00015
5687877589148733737/00003
5687877589148733737/00006
5687877589148733737/00010
5687877589148733737/00033
5687877589148733737/00024
5687877589148733737/00018
5687877589148733737/00012
5687877589148733737/00020
5687877589148733737/00017
5687877589148733737/00016
5687877589148733737/00019
5687877589148733737/00011
5687877589148733737/00027
5687877589148733737/00002
5687877589148733737/00040
5687877589148733737/00025
5687877589148733737/00005
5687877589148733737/00028
5891394614469626449/00015
5891394614469626449/00003
5891394614469626449/00006
5891394614469626449/00046
5891394614469626449/00024
5891394614469626449/00031
5891394614469626449/00013
5891394614469626449/00022
5891394614469626449/00026
5891394614469626449/00050
5891394614469626449/00036
5891394614469626449/00049
5891394614469626449/00023
5891394614469626449/00018
5891394614469626449/00012
5891394614469626449/00021
5891394614469626449/00014
5891394614469626449/00037
5891394614469626449/00047
5891394614469626449/00020
5891394614469626449/00035
5891394614469626449/00017
5891394614469626449/00011
5891394614469626449/00034
5891394614469626449/00008
Download .txt
gitextract_hwq25lj6/

├── GFPGAN.py
├── Gen_hyperlipsbase_videos.py
├── HYPERLIPS.py
├── Inference_hyperlips.py
├── README.md
├── Train_data/
│   └── video_clips/
│       └── MEAD/
│           └── readme.txt
├── Train_hyperlipsBase.py
├── Train_hyperlipsHR.py
├── audio.py
├── checkpoint
├── checkpoints/
│   └── readme.txt
├── color_syncnet_trainv3.py
├── conv.py
├── datasets/
│   └── MEAD/
│       └── readme.txt
├── environment.yml
├── face_detection/
│   ├── README.md
│   ├── __init__.py
│   ├── api.py
│   ├── detection/
│   │   ├── __init__.py
│   │   ├── core.py
│   │   └── sfd/
│   │       ├── __init__.py
│   │       ├── bbox.py
│   │       ├── detect.py
│   │       ├── net_s3fd.py
│   │       ├── s3fd.pth
│   │       └── sfd_detector.py
│   ├── models.py
│   └── utils.py
├── face_parsing/
│   ├── README.md
│   ├── __init__.py
│   ├── model.py
│   ├── resnet.py
│   └── swap.py
├── filelists/
│   ├── train.txt
│   └── val.txt
├── filelists_lrs2/
│   ├── README.md
│   ├── test.txt
│   ├── train.txt
│   └── val.txt
├── filelists_mead/
│   ├── README.md
│   ├── test.txt
│   ├── train.txt
│   └── val.txt
├── gfpgan/
│   ├── gfpganv1_clean_arch.py
│   └── stylegan2_clean_arch.py
├── hparams.py
├── hparams_Base.py
├── hparams_HR.py
├── inference.py
├── models/
│   ├── __init__.py
│   ├── audio_v.py
│   ├── conv.py
│   ├── decoder.py
│   ├── deep_guided_filter.py
│   ├── gfpganv1_clean_arch.py
│   ├── guided_filter_pytorch/
│   │   ├── __init__.py
│   │   ├── box_filter.py
│   │   └── guided_filter.py
│   ├── hyperlayers.py
│   ├── hypernetwork.py
│   ├── layers.py
│   ├── lraspp.py
│   ├── memory.py
│   ├── mobilenetv3.py
│   ├── model.py
│   ├── model_hyperlips.py
│   ├── resnet.py
│   └── syncnet.py
├── preprocess.py
└── requirements.txt
Download .txt
SYMBOL INDEX (529 symbols across 44 files)

FILE: GFPGAN.py
  function tensor2img (line 13) | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
  class GFPGANer (line 58) | class GFPGANer():
    method __init__ (line 74) | def __init__(self, device,model_path, upscale=2, arch='clean', channel...
    method enhance_allimg (line 114) | def enhance_allimg(self, img, has_aligned=False, only_center_face=Fals...
    method enhance (line 174) | def enhance(self, img, has_aligned=False, only_center_face=False, past...
  function GFPGANInit (line 225) | def GFPGANInit(device,face_enhancement_path):
  function GFPGANInfer (line 247) | def GFPGANInfer(img, restorer, aligned):

FILE: Gen_hyperlipsbase_videos.py
  function inference_list (line 33) | def inference_list():

FILE: HYPERLIPS.py
  function get_smoothened_boxes (line 11) | def get_smoothened_boxes(boxes, T):
  function face_detect (line 20) | def face_detect(images, detector,pad):
  function datagen (line 59) | def datagen(mels, detector,frames,img_size,hyper_batch_size,pads):
  function load_HyperLips (line 101) | def load_HyperLips(window,rescaling,path,path_hr,device):
  function main (line 107) | def main():
  class Hyperlips (line 112) | class Hyperlips():
    method __init__ (line 113) | def __init__(self,checkpoint_path_BASE=None,
    method _HyperlipsLoadModels (line 145) | def _HyperlipsLoadModels(self):
    method _HyperlipsInference (line 158) | def _HyperlipsInference(self,face_path,audio_path,outfile_path):

FILE: Inference_hyperlips.py
  function get_smoothened_mels (line 45) | def get_smoothened_mels(mel_chunks, T):
  function face_detect (line 55) | def face_detect(images, detector,pad):
  function datagen (line 95) | def datagen(mels, detector,face_path, resize_factor):
  function _load (line 138) | def _load(checkpoint_path, device):
  function load_HyperLipsHR (line 143) | def load_HyperLipsHR(path,path_hr,device):
  function load_HyperLipsBase (line 149) | def load_HyperLipsBase(path, device):
  function read_frames (line 158) | def read_frames(face_path, resize_factor):
  function main (line 172) | def main():
  class Hyperlips (line 177) | class Hyperlips():
    method __init__ (line 178) | def __init__(self):
    method _HyperlipsLoadModels (line 185) | def _HyperlipsLoadModels(self):
    method _HyperlipsInference (line 200) | def _HyperlipsInference(self):

FILE: Train_hyperlipsBase.py
  class Dataset (line 36) | class Dataset(object):
    method __init__ (line 37) | def __init__(self, split):
    method get_frame_id (line 40) | def get_frame_id(self, frame):
    method get_window (line 43) | def get_window(self, start_frame):
    method read_window (line 55) | def read_window(self, window_fnames):
    method crop_audio_window (line 71) | def crop_audio_window(self, spec, start_frame):
    method get_segmented_mels (line 82) | def get_segmented_mels(self, spec, start_frame):
    method prepare_window (line 97) | def prepare_window(self, window):
    method __len__ (line 104) | def __len__(self):
    method __getitem__ (line 107) | def __getitem__(self, idx):
  function save_sample_images (line 162) | def save_sample_images(x, g, gt, global_step, checkpoint_dir):
  function cosine_loss (line 176) | def cosine_loss(a, v, y):
  function get_sync_loss (line 188) | def get_sync_loss(mel, g):
  function train (line 199) | def train(device, model, disc, train_data_loader, test_data_loader, opti...
  function eval_model (line 297) | def eval_model(test_data_loader, global_step, device, model, disc):
  function save_checkpoint (line 352) | def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch, prefi...
  function _load (line 364) | def _load(checkpoint_path):
  function load_checkpoint (line 373) | def load_checkpoint(path, model, optimizer, reset_optimizer=False, overw...

FILE: Train_hyperlipsHR.py
  class Dataset (line 45) | class Dataset(object):
    method __init__ (line 46) | def __init__(self, split):
    method get_frame_id (line 53) | def get_frame_id(self, frame):
    method get_window (line 56) | def get_window(self, start_frame):
    method read_window (line 68) | def read_window(self, window_fnames):
    method read_window_base (line 85) | def read_window_base(self, window_fnames):
    method read_window_sketch (line 102) | def read_window_sketch(self, window_fnames):
    method read_window_sketch_base (line 129) | def read_window_sketch_base(self, window_fnames):
    method read_coord (line 155) | def read_coord(self,window_fnames):
    method prepare_window (line 174) | def prepare_window(self, window):
    method __len__ (line 181) | def __len__(self):
    method __getitem__ (line 184) | def __getitem__(self, idx):
  function save_sample_images (line 247) | def save_sample_images(x, g, gt,m, global_step, checkpoint_dir):
  class PerceptualLoss (line 259) | class PerceptualLoss(nn.Module):
    method __init__ (line 260) | def __init__(self):
    method forward (line 270) | def forward(self, high_resolution, fake_high_resolution):
  function cosine_loss (line 275) | def cosine_loss(a, v, y):
  function train (line 287) | def train(device, model, disc,train_data_loader, test_data_loader, optim...
  function eval_model (line 429) | def eval_model(test_data_loader, global_step, device, model):
  function save_checkpoint (line 470) | def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch, prefi...
  function _load (line 482) | def _load(checkpoint_path):
  function load_checkpoint (line 491) | def load_checkpoint(path, model, reset_optimizer=False, overwrite_global...

FILE: audio.py
  function load_wav (line 9) | def load_wav(path, sr):
  function save_wav (line 12) | def save_wav(wav, path, sr):
  function save_wavenet_wav (line 17) | def save_wavenet_wav(wav, path, sr):
  function preemphasis (line 20) | def preemphasis(wav, k, preemphasize=True):
  function inv_preemphasis (line 25) | def inv_preemphasis(wav, k, inv_preemphasize=True):
  function get_hop_size (line 30) | def get_hop_size():
  function linearspectrogram (line 37) | def linearspectrogram(wav):
  function melspectrogram (line 45) | def melspectrogram(wav):
  function _lws_processor (line 53) | def _lws_processor():
  function _stft (line 57) | def _stft(y):
  function num_frames (line 66) | def num_frames(length, fsize, fshift):
  function pad_lr (line 77) | def pad_lr(x, fsize, fshift):
  function librosa_pad_lr (line 87) | def librosa_pad_lr(x, fsize, fshift):
  function _linear_to_mel (line 93) | def _linear_to_mel(spectogram):
  function _build_mel_basis (line 99) | def _build_mel_basis():
  function _amp_to_db (line 105) | def _amp_to_db(x):
  function _db_to_amp (line 109) | def _db_to_amp(x):
  function _normalize (line 112) | def _normalize(S):
  function _denormalize (line 126) | def _denormalize(D):

FILE: color_syncnet_trainv3.py
  class Dataset (line 36) | class Dataset(object):
    method __init__ (line 37) | def __init__(self, split):
    method get_frame_id (line 40) | def get_frame_id(self, frame):
    method get_window (line 43) | def get_window(self, start_frame):
    method crop_audio_window (line 57) | def crop_audio_window(self, spec, start_frame):
    method read_window (line 71) | def read_window(self, window_fnames, flip_flag=False):
    method __len__ (line 90) | def __len__(self):
    method __getitem__ (line 93) | def __getitem__(self, idx):
  function cosine_loss (line 143) | def cosine_loss(a, v, y):
  function train (line 149) | def train(device, model, train_data_loader, test_data_loader, optimizer,
  function eval_model (line 202) | def eval_model(test_data_loader, global_step, device, model, checkpoint_...
  function save_checkpoint (line 233) | def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch):
  function _load (line 246) | def _load(checkpoint_path):
  function load_checkpoint (line 254) | def load_checkpoint(path, model, optimizer, reset_optimizer=False):

FILE: conv.py
  class Conv2d (line 5) | class Conv2d(nn.Module):
    method __init__ (line 6) | def __init__(self, cin, cout, kernel_size, stride, padding, residual=F...
    method forward (line 15) | def forward(self, x):
  class nonorm_Conv2d (line 21) | class nonorm_Conv2d(nn.Module):
    method __init__ (line 22) | def __init__(self, cin, cout, kernel_size, stride, padding, residual=F...
    method forward (line 29) | def forward(self, x):
  class Conv2dTranspose (line 33) | class Conv2dTranspose(nn.Module):
    method __init__ (line 34) | def __init__(self, cin, cout, kernel_size, stride, padding, output_pad...
    method forward (line 42) | def forward(self, x):

FILE: face_detection/api.py
  class LandmarksType (line 17) | class LandmarksType(Enum):
  class NetworkSize (line 30) | class NetworkSize(Enum):
    method __new__ (line 36) | def __new__(cls, value):
    method __int__ (line 41) | def __int__(self):
  class FaceAlignment (line 46) | class FaceAlignment:
    method __init__ (line 47) | def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
    method get_detections_for_batch (line 64) | def get_detections_for_batch(self, images):

FILE: face_detection/detection/core.py
  class FaceDetector (line 9) | class FaceDetector(object):
    method __init__ (line 18) | def __init__(self, device, verbose):
    method detect_from_image (line 32) | def detect_from_image(self, tensor_or_path):
    method detect_from_directory (line 54) | def detect_from_directory(self, path, extensions=['.jpg', '.png'], rec...
    method reference_scale (line 104) | def reference_scale(self):
    method reference_x_shift (line 108) | def reference_x_shift(self):
    method reference_y_shift (line 112) | def reference_y_shift(self):
    method tensor_or_path_to_ndarray (line 116) | def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):

FILE: face_detection/detection/sfd/bbox.py
  function IOU (line 17) | def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
  function bboxlog (line 30) | def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
  function bboxloginv (line 37) | def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
  function nms (line 44) | def nms(dets, thresh):
  function encode (line 67) | def encode(matched, priors, variances):
  function decode (line 91) | def decode(loc, priors, variances):
  function batch_decode (line 111) | def batch_decode(loc, priors, variances):

FILE: face_detection/detection/sfd/detect.py
  function detect (line 19) | def detect(net, img, device):
  function batch_detect (line 58) | def batch_detect(net, imgs, device):
  function flip_detect (line 96) | def flip_detect(net, img, device):
  function pts_to_bb (line 109) | def pts_to_bb(pts):

FILE: face_detection/detection/sfd/net_s3fd.py
  class L2Norm (line 6) | class L2Norm(nn.Module):
    method __init__ (line 7) | def __init__(self, n_channels, scale=1.0):
    method forward (line 16) | def forward(self, x):
  class s3fd (line 22) | class s3fd(nn.Module):
    method __init__ (line 23) | def __init__(self):
    method forward (line 70) | def forward(self, x):

FILE: face_detection/detection/sfd/sfd_detector.py
  class SFDDetector (line 16) | class SFDDetector(FaceDetector):
    method __init__ (line 17) | def __init__(self, device, path_to_detector=os.path.join(os.path.dirna...
    method detect_from_image (line 31) | def detect_from_image(self, tensor_or_path):
    method detect_from_batch (line 41) | def detect_from_batch(self, images):
    method reference_scale (line 50) | def reference_scale(self):
    method reference_x_shift (line 54) | def reference_x_shift(self):
    method reference_y_shift (line 58) | def reference_y_shift(self):

FILE: face_detection/models.py
  function conv3x3 (line 7) | def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
  class ConvBlock (line 13) | class ConvBlock(nn.Module):
    method __init__ (line 14) | def __init__(self, in_planes, out_planes):
    method forward (line 33) | def forward(self, x):
  class Bottleneck (line 58) | class Bottleneck(nn.Module):
    method __init__ (line 62) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 75) | def forward(self, x):
  class HourGlass (line 98) | class HourGlass(nn.Module):
    method __init__ (line 99) | def __init__(self, num_modules, depth, num_features):
    method _generate_network (line 107) | def _generate_network(self, level):
    method _forward (line 119) | def _forward(self, level, inp):
    method forward (line 141) | def forward(self, x):
  class FAN (line 145) | class FAN(nn.Module):
    method __init__ (line 147) | def __init__(self, num_modules=1):
    method forward (line 174) | def forward(self, x):
  class ResNetDepth (line 204) | class ResNetDepth(nn.Module):
    method __init__ (line 206) | def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes...
    method _make_layer (line 229) | def _make_layer(self, block, planes, blocks, stride=1):
    method forward (line 246) | def forward(self, x):

FILE: face_detection/utils.py
  function _gaussian (line 11) | def _gaussian(
  function draw_gaussian (line 37) | def draw_gaussian(image, point, sigma):
  function transform (line 56) | def transform(point, center, scale, resolution, invert=False):
  function crop (line 92) | def crop(image, center, scale, resolution=256.0):
  function get_preds_fromhm (line 132) | def get_preds_fromhm(hm, center=None, scale=None):
  function get_preds_fromhm_batch (line 172) | def get_preds_fromhm_batch(hm, centers=None, scales=None):
  function shuffle_lr (line 212) | def shuffle_lr(parts, pairs=None):
  function flip (line 237) | def flip(tensor, is_label=False):
  function appdata_dir (line 259) | def appdata_dir(appname=None, roaming=False):

FILE: face_parsing/model.py
  class ConvBNReLU (line 14) | class ConvBNReLU(nn.Module):
    method __init__ (line 15) | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args...
    method forward (line 26) | def forward(self, x):
    method init_weight (line 31) | def init_weight(self):
  class BiSeNetOutput (line 37) | class BiSeNetOutput(nn.Module):
    method __init__ (line 38) | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
    method forward (line 44) | def forward(self, x):
    method init_weight (line 49) | def init_weight(self):
    method get_params (line 55) | def get_params(self):
  class AttentionRefinementModule (line 67) | class AttentionRefinementModule(nn.Module):
    method __init__ (line 68) | def __init__(self, in_chan, out_chan, *args, **kwargs):
    method forward (line 76) | def forward(self, x):
    method init_weight (line 85) | def init_weight(self):
  class ContextPath (line 92) | class ContextPath(nn.Module):
    method __init__ (line 93) | def __init__(self, device,*args, **kwargs):
    method forward (line 104) | def forward(self, x):
    method init_weight (line 127) | def init_weight(self):
    method get_params (line 133) | def get_params(self):
  class SpatialPath (line 146) | class SpatialPath(nn.Module):
    method __init__ (line 147) | def __init__(self, *args, **kwargs):
    method forward (line 155) | def forward(self, x):
    method init_weight (line 162) | def init_weight(self):
    method get_params (line 168) | def get_params(self):
  class FeatureFusionModule (line 180) | class FeatureFusionModule(nn.Module):
    method __init__ (line 181) | def __init__(self, in_chan, out_chan, *args, **kwargs):
    method forward (line 200) | def forward(self, fsp, fcp):
    method init_weight (line 212) | def init_weight(self):
    method get_params (line 218) | def get_params(self):
  class BiSeNet (line 230) | class BiSeNet(nn.Module):
    method __init__ (line 231) | def __init__(self,device, n_classes, *args, **kwargs):
    method forward (line 241) | def forward(self, x):
    method init_weight (line 256) | def init_weight(self):
    method get_params (line 262) | def get_params(self):

FILE: face_parsing/resnet.py
  function conv3x3 (line 14) | def conv3x3(in_planes, out_planes, stride=1):
  class BasicBlock (line 20) | class BasicBlock(nn.Module):
    method __init__ (line 21) | def __init__(self, in_chan, out_chan, stride=1):
    method forward (line 36) | def forward(self, x):
  function create_layer_basic (line 51) | def create_layer_basic(in_chan, out_chan, bnum, stride=1):
  class Resnet18 (line 58) | class Resnet18(nn.Module):
    method __init__ (line 59) | def __init__(self,device):
    method forward (line 71) | def forward(self, x):
    method init_weight (line 82) | def init_weight(self,device):
    method get_params (line 93) | def get_params(self):

FILE: face_parsing/swap.py
  function init_parser (line 9) | def init_parser(pth_path, device):
  function image_to_parsing_img (line 20) | def image_to_parsing_img(img, net):
  function image_to_parsing (line 46) | def image_to_parsing(img, net):
  function image_to_parsing2 (line 66) | def image_to_parsing2(img, net):
  function get_mask (line 77) | def get_mask(parsing, classes):
  function swap_regions_img (line 83) | def swap_regions_img(source, target, net):
  function swap_regions (line 97) | def swap_regions(source, target, net):

FILE: gfpgan/gfpganv1_clean_arch.py
  class StyleGAN2GeneratorCSFT (line 12) | class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
    method __init__ (line 26) | def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_mu...
    method forward (line 35) | def forward(self,
  class ResBlock (line 121) | class ResBlock(nn.Module):
    method __init__ (line 130) | def __init__(self, in_channels, out_channels, mode='down'):
    method forward (line 141) | def forward(self, x):
  class GFPGANv1Clean (line 154) | class GFPGANv1Clean(nn.Module):
    method __init__ (line 175) | def __init__(
    method forward (line 278) | def forward(self, x, return_latents=False, return_rgb=True, randomize_...

FILE: gfpgan/stylegan2_clean_arch.py
  class NormStyleCode (line 10) | class NormStyleCode(nn.Module):
    method forward (line 12) | def forward(self, x):
  class ModulatedConv2d (line 24) | class ModulatedConv2d(nn.Module):
    method __init__ (line 39) | def __init__(self,
    method forward (line 65) | def forward(self, x, style):
    method __repr__ (line 101) | def __repr__(self):
  class StyleConv (line 106) | class StyleConv(nn.Module):
    method __init__ (line 118) | def __init__(self, in_channels, out_channels, kernel_size, num_style_f...
    method forward (line 126) | def forward(self, x, style, noise=None):
  class ToRGB (line 141) | class ToRGB(nn.Module):
    method __init__ (line 150) | def __init__(self, in_channels, num_style_feat, upsample=True):
    method forward (line 157) | def forward(self, x, style, skip=None):
  class ConstantInput (line 177) | class ConstantInput(nn.Module):
    method __init__ (line 185) | def __init__(self, num_channel, size):
    method forward (line 189) | def forward(self, batch):
  class StyleGAN2GeneratorClean (line 195) | class StyleGAN2GeneratorClean(nn.Module):
    method __init__ (line 206) | def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_mu...
    method make_noise (line 279) | def make_noise(self):
    method get_latent (line 290) | def get_latent(self, x):
    method mean_latent (line 293) | def mean_latent(self, num_latent):
    method forward (line 298) | def forward(self,

FILE: hparams.py
  function get_image_list (line 3) | def get_image_list(data_root, split):
  class HParams (line 15) | class HParams:
    method __init__ (line 16) | def __init__(self, **kwargs):
    method __getattr__ (line 22) | def __getattr__(self, key):
    method set_hparam (line 27) | def set_hparam(self, key, value):
  function hparams_debug_string (line 98) | def hparams_debug_string():

FILE: hparams_Base.py
  function get_image_list (line 3) | def get_image_list(data_root, split):
  class HParams (line 15) | class HParams:
    method __init__ (line 16) | def __init__(self, **kwargs):
    method __getattr__ (line 22) | def __getattr__(self, key):
    method set_hparam (line 27) | def set_hparam(self, key, value):
  function hparams_debug_string (line 99) | def hparams_debug_string():

FILE: hparams_HR.py
  function get_image_list (line 3) | def get_image_list(data_root, split):
  class HParams (line 15) | class HParams:
    method __init__ (line 16) | def __init__(self, **kwargs):
    method __getattr__ (line 22) | def __getattr__(self, key):
    method set_hparam (line 27) | def set_hparam(self, key, value):
  function hparams_debug_string (line 100) | def hparams_debug_string():

FILE: inference.py
  function inference_single (line 35) | def inference_single():

FILE: models/audio_v.py
  function load_wav (line 8) | def load_wav(path, sr):
  function save_wav (line 12) | def save_wav(wav, path, sr):
  function save_wavenet_wav (line 19) | def save_wavenet_wav(wav, path, sr):
  function preemphasis (line 23) | def preemphasis(wav, k, preemphasize=True):
  function inv_preemphasis (line 29) | def inv_preemphasis(wav, k, inv_preemphasize=True):
  function start_and_end_indices (line 36) | def start_and_end_indices(quantized, silence_threshold=2):
  function get_hop_size (line 50) | def get_hop_size(hparams):
  function linearspectrogram (line 58) | def linearspectrogram(wav, hparams):
  function melspectrogram (line 67) | def melspectrogram(wav, hparams):
  function inv_linear_spectrogram (line 76) | def inv_linear_spectrogram(linear_spectrogram, hparams):
  function inv_mel_spectrogram (line 94) | def inv_mel_spectrogram(mel_spectrogram, hparams):
  function _lws_processor (line 112) | def _lws_processor(hparams):
  function _griffin_lim (line 117) | def _griffin_lim(S, hparams):
  function _stft (line 130) | def _stft(y, hparams):
  function _istft (line 137) | def _istft(y, hparams):
  function num_frames (line 144) | def num_frames(length, fsize, fshift):
  function pad_lr (line 155) | def pad_lr(x, fsize, fshift):
  function librosa_pad_lr (line 167) | def librosa_pad_lr(x, fsize, fshift):
  function _linear_to_mel (line 176) | def _linear_to_mel(spectogram, hparams):
  function _mel_to_linear (line 183) | def _mel_to_linear(mel_spectrogram, hparams):
  function _build_mel_basis (line 190) | def _build_mel_basis(hparams):
  function _amp_to_db (line 196) | def _amp_to_db(x, hparams):
  function _db_to_amp (line 201) | def _db_to_amp(x):
  function _normalize (line 205) | def _normalize(S, hparams):
  function _denormalize (line 223) | def _denormalize(D, hparams):

FILE: models/conv.py
  class Conv2d (line 5) | class Conv2d(nn.Module):
    method __init__ (line 6) | def __init__(self, cin, cout, kernel_size, stride, padding, residual=F...
    method forward (line 15) | def forward(self, x):
  class nonorm_Conv2d (line 21) | class nonorm_Conv2d(nn.Module):
    method __init__ (line 22) | def __init__(self, cin, cout, kernel_size, stride, padding, residual=F...
    method forward (line 29) | def forward(self, x):
  class Conv2dTranspose (line 33) | class Conv2dTranspose(nn.Module):
    method __init__ (line 34) | def __init__(self, cin, cout, kernel_size, stride, padding, output_pad...
    method forward (line 42) | def forward(self, x):

FILE: models/decoder.py
  class RecurrentDecoder (line 7) | class RecurrentDecoder(nn.Module):
    method __init__ (line 8) | def __init__(self, feature_channels, decoder_channels):
    method forward (line 17) | def forward(self,
  class AvgPool (line 30) | class AvgPool(nn.Module):
    method __init__ (line 31) | def __init__(self):
    method forward_single_frame (line 35) | def forward_single_frame(self, s0):
    method forward_time_series (line 41) | def forward_time_series(self, s0):
    method forward (line 50) | def forward(self, s0):
  class BottleneckBlock (line 57) | class BottleneckBlock(nn.Module):
    method __init__ (line 58) | def __init__(self, channels):
    method forward (line 63) | def forward(self, x, r: Optional[Tensor]):
  class UpsamplingBlock (line 70) | class UpsamplingBlock(nn.Module):
    method __init__ (line 71) | def __init__(self, in_channels, skip_channels, src_channels, out_chann...
    method forward_single_frame (line 82) | def forward_single_frame(self, x, f, s, r: Optional[Tensor]):
    method forward_time_series (line 92) | def forward_time_series(self, x, f, s, r: Optional[Tensor]):
    method forward (line 107) | def forward(self, x, f, s, r: Optional[Tensor]):
  class OutputBlock (line 114) | class OutputBlock(nn.Module):
    method __init__ (line 115) | def __init__(self, in_channels, src_channels, out_channels):
    method forward_single_frame (line 127) | def forward_single_frame(self, x, s):
    method forward_time_series (line 134) | def forward_time_series(self, x, s):
    method forward (line 145) | def forward(self, x, s):
  class ConvGRU (line 152) | class ConvGRU(nn.Module):
    method __init__ (line 153) | def __init__(self,
    method forward_single_frame (line 168) | def forward_single_frame(self, x, h):
    method forward_time_series (line 174) | def forward_time_series(self, x, h):
    method forward (line 182) | def forward(self, x, h: Optional[Tensor]):
  class Projection (line 193) | class Projection(nn.Module):
    method __init__ (line 194) | def __init__(self, in_channels, out_channels):
    method forward_single_frame (line 198) | def forward_single_frame(self, x):
    method forward_time_series (line 201) | def forward_time_series(self, x):
    method forward (line 205) | def forward(self, x):

FILE: models/deep_guided_filter.py
  class DeepGuidedFilterRefiner (line 9) | class DeepGuidedFilterRefiner(nn.Module):
    method __init__ (line 10) | def __init__(self, in_channels=4,hid_channels=16):
    method forward_single_frame (line 24) | def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha,...
    method forward_time_series (line 45) | def forward_time_series(self, fine_src, base_src, base_fgr, base_pha, ...
    method forward (line 57) | def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):

FILE: models/gfpganv1_clean_arch.py
  class StyleGAN2GeneratorCSFT (line 12) | class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
    method __init__ (line 26) | def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_mu...
    method forward (line 35) | def forward(self,
  class ResBlock (line 121) | class ResBlock(nn.Module):
    method __init__ (line 130) | def __init__(self, in_channels, out_channels, mode='down'):
    method forward (line 141) | def forward(self, x):
  class GFPGANv1Clean (line 154) | class GFPGANv1Clean(nn.Module):
    method __init__ (line 175) | def __init__(
    method forward (line 261) | def forward(self, x, return_latents=False, return_rgb=True, randomize_...

FILE: models/guided_filter_pytorch/box_filter.py
  function diff_x (line 4) | def diff_x(input, r):
  function diff_y (line 15) | def diff_y(input, r):
  class BoxFilter (line 26) | class BoxFilter(nn.Module):
    method __init__ (line 27) | def __init__(self, r):
    method forward (line 32) | def forward(self, x):

FILE: models/guided_filter_pytorch/guided_filter.py
  class FastGuidedFilter (line 8) | class FastGuidedFilter(nn.Module):
    method __init__ (line 9) | def __init__(self, r, eps=1e-8):
    method forward (line 17) | def forward(self, lr_x, lr_y, hr_x):
  class GuidedFilter (line 51) | class GuidedFilter(nn.Module):
    method __init__ (line 52) | def __init__(self, r, eps=1e-8):
    method forward (line 60) | def forward(self, x, y):
  class ConvGuidedFilter (line 92) | class ConvGuidedFilter(nn.Module):
    method __init__ (line 93) | def __init__(self, radius=1, norm=nn.BatchNorm2d):
    method forward (line 106) | def forward(self, x_lr, y_lr, x_hr):

FILE: models/hyperlayers.py
  class FCLayer (line 8) | class FCLayer(nn.Module):
    method __init__ (line 9) | def __init__(self, in_features, out_features):
    method forward (line 16) | def forward(self, input):
  class FCBlock (line 19) | class FCBlock(nn.Module):
    method __init__ (line 20) | def __init__(self,
    method __getitem__ (line 42) | def __getitem__(self,item):
    method init_weights (line 45) | def init_weights(self, m):
    method forward (line 49) | def forward(self, input):
  function partialclass (line 53) | def partialclass(cls, *args, **kwds):
  class HyperLayer (line 60) | class HyperLayer(nn.Module):
    method __init__ (line 62) | def __init__(self,
    method forward (line 80) | def forward(self, hyper_input):
  class HyperFC (line 88) | class HyperFC(nn.Module):
    method __init__ (line 91) | def __init__(self,
    method forward (line 120) | def forward(self, hyper_input):
  class BatchLinear (line 132) | class BatchLinear(nn.Module):
    method __init__ (line 133) | def __init__(self,
    method __repr__ (line 146) | def __repr__(self):
    method forward (line 149) | def forward(self, input):
  function last_hyper_layer_init (line 155) | def last_hyper_layer_init(m):
  class HyperLinear (line 161) | class HyperLinear(nn.Module):
    method __init__ (line 163) | def __init__(self,
    method forward (line 181) | def forward(self, hyper_input):#([1, 131072])
  class HyperConv (line 193) | class HyperConv(nn.Module):
    method __init__ (line 200) | def __init__(self,
    method forward (line 233) | def forward(self, x):

FILE: models/hypernetwork.py
  class HyperNetwork (line 3) | class HyperNetwork(nn.Module):
    method __init__ (line 5) | def __init__(self, in_dim=1, h_dim=32):
    method forward (line 21) | def forward(self, x):

FILE: models/layers.py
  class Upsample (line 14) | class Upsample(nn.Module):
    method __init__ (line 16) | def __init__(self, scale_factor, mode, align_corners):
    method forward (line 21) | def forward(self, x):
  class MultiSequential (line 24) | class MultiSequential(nn.Sequential):
    method forward (line 25) | def forward(self, *inputs):
  class Conv2d (line 35) | class Conv2d(nn.Module):
    method __init__ (line 36) | def __init__(self, in_channels, out_channels, kernel_size=3, padding=0):
    method forward (line 39) | def forward(self, x, hyp_out=None):
  class BatchConv2d (line 42) | class BatchConv2d(nn.Module):
    method __init__ (line 50) | def __init__(self, in_channels, out_channels, hyp_out_units, stride=1,
    method forward (line 66) | def forward(self, x, hyp_out, include_bias=True):
    method get_kernel (line 95) | def get_kernel(self):
    method get_bias (line 97) | def get_bias(self):
    method get_kernel_shape (line 99) | def get_kernel_shape(self):
    method get_bias_shape (line 101) | def get_bias_shape(self):
  class ClipByPercentile (line 104) | class ClipByPercentile(object):
    method __init__ (line 106) | def __init__(self, perc=99):
    method __call__ (line 109) | def __call__(self, img):
  class ZeroPad (line 117) | class ZeroPad(object):
    method __init__ (line 118) | def __init__(self, final_size):
    method __call__ (line 121) | def __call__(self, img):

FILE: models/lraspp.py
  class LRASPP (line 3) | class LRASPP(nn.Module):
    method __init__ (line 4) | def __init__(self, in_channels, out_channels):
    method forward_single_frame (line 17) | def forward_single_frame(self, x):
    method forward_time_series (line 20) | def forward_time_series(self, x):
    method forward (line 25) | def forward(self, x):

FILE: models/memory.py
  class Memory (line 7) | class Memory(nn.Module):
    method __init__ (line 8) | def __init__(self, radius=16.0, n_slot=96):
    method forward (line 21) | def forward(self, query, value=None, inference=False):

FILE: models/mobilenetv3.py
  function _make_divisible (line 17) | def _make_divisible(v: float, divisor: int, min_value: Optional[int] = N...
  function _log_api_usage_once (line 33) | def _log_api_usage_once(obj: str) -> None:  # type: ignore
  class Conv2d (line 47) | class Conv2d(torch.nn.Conv2d):
    method __init__ (line 48) | def __init__(self, *args, **kwargs):
  class ConvTranspose2d (line 56) | class ConvTranspose2d(torch.nn.ConvTranspose2d):
    method __init__ (line 57) | def __init__(self, *args, **kwargs):
  class BatchNorm2d (line 65) | class BatchNorm2d(torch.nn.BatchNorm2d):
    method __init__ (line 66) | def __init__(self, *args, **kwargs):
  class FrozenBatchNorm2d (line 78) | class FrozenBatchNorm2d(torch.nn.Module):
    method __init__ (line 85) | def __init__(
    method _load_from_state_dict (line 103) | def _load_from_state_dict(
    method forward (line 124) | def forward(self, x: Tensor) -> Tensor:
    method __repr__ (line 136) | def __repr__(self) -> str:
  class ConvNormActivation (line 142) | class ConvNormActivation(torch.nn.Sequential):
    method __init__ (line 157) | def __init__(
  class SElayer (line 197) | class SElayer(torch.nn.Module):
    method __init__ (line 207) | def __init__(
    method _scale (line 223) | def _scale(self, input: Tensor) -> Tensor:
    method forward (line 231) | def forward(self, input: Tensor) -> Tensor:
  class InvertedResidualConfig (line 235) | class InvertedResidualConfig:
    method __init__ (line 237) | def __init__(
    method adjust_channels (line 259) | def adjust_channels(channels: int, width_mult: float):
  class InvertedResidual (line 262) | class InvertedResidual(nn.Module):
    method __init__ (line 264) | def __init__(
    method forward (line 320) | def forward(self, input: Tensor) -> Tensor:
  class MobileNetV3 (line 328) | class MobileNetV3(nn.Module):
    method __init__ (line 329) | def __init__(
    method _forward_impl (line 419) | def _forward_impl(self, x: Tensor) -> Tensor:
    method forward (line 429) | def forward(self, x: Tensor) -> Tensor:
  class MobileNetV3LargeEncoder (line 435) | class MobileNetV3LargeEncoder(MobileNetV3):
    method __init__ (line 436) | def __init__(self, pretrained: bool = False,in_ch: int = 3):
    method forward_single_frame (line 474) | def forward_single_frame(self, x):
    method forward_time_series (line 501) | def forward_time_series(self, x):
    method forward (line 507) | def forward(self, x):

FILE: models/model.py
  class MattingNetwork (line 26) | class MattingNetwork(nn.Module):
    method __init__ (line 27) | def __init__(self,
    method forward (line 52) | def forward(self,
    method _interpolate (line 83) | def _interpolate(self, x: Tensor, scale_factor: float):

FILE: models/model_hyperlips.py
  function preprocess_sketch (line 66) | def preprocess_sketch(skecth,hr_size):
  function get_smoothened_landmarks (line 81) | def get_smoothened_landmarks(all_landmarks,windows_T,hr_size,base_size,):
  class FastGuidedFilterRefiner (line 155) | class FastGuidedFilterRefiner(nn.Module):
    method __init__ (line 156) | def __init__(self, *args, **kwargs):
    method forward_single_frame (line 160) | def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha):
    method forward_time_series (line 171) | def forward_time_series(self, fine_src, base_src, base_fgr, base_pha):
    method forward (line 182) | def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
  class FastGuidedFilter (line 189) | class FastGuidedFilter(nn.Module):
    method __init__ (line 190) | def __init__(self, r: int, eps: float = 1e-5):
    method forward (line 196) | def forward(self, lr_x, lr_y, hr_x):
  class BoxFilter (line 208) | class BoxFilter(nn.Module):
    method __init__ (line 209) | def __init__(self, r):
    method forward (line 213) | def forward(self, x):
  class Conv2d (line 226) | class Conv2d(nn.Module):
    method __init__ (line 227) | def __init__(self, cin, cout, kernel_size, stride, padding, residual=F...
    method forward (line 236) | def forward(self, x):
  class nonorm_Conv2d (line 242) | class nonorm_Conv2d(nn.Module):
    method __init__ (line 243) | def __init__(self, cin, cout, kernel_size, stride, padding, residual=F...
    method forward (line 250) | def forward(self, x):
  class Conv2dTranspose (line 254) | class Conv2dTranspose(nn.Module):
    method __init__ (line 255) | def __init__(self, cin, cout, kernel_size, stride, padding, output_pad...
    method forward (line 263) | def forward(self, x):
  class HyperFCNet (line 268) | class HyperFCNet(nn.Module):
    method __init__ (line 271) | def __init__(self,
    method forward (line 295) | def forward(self, x,f1, f2, f3, f4):#([1, 512])
    method double_conv (line 324) | def double_conv(self, in_channels, out_channels,hnet_hdim):
  class HyperLipsBase (line 369) | class HyperLipsBase(nn.Module):
    method __init__ (line 370) | def __init__(self):
    method forward (line 388) | def forward(self,audio_sequences: Tensor,face_sequences: Tensor):
    method _interpolate (line 415) | def _interpolate(self, x: Tensor, scale_factor: float):
  class HRDecoder (line 426) | class HRDecoder(nn.Module):
    method __init__ (line 427) | def __init__(self,rescaling=1):
    method forward (line 458) | def forward(self,x):
  class HRDecoder_disc_qual (line 466) | class HRDecoder_disc_qual(nn.Module):
    method __init__ (line 467) | def __init__(self):
    method forward (line 490) | def forward(self, face_sequences):
  class HyperLips_inference (line 501) | class HyperLips_inference(nn.Module):
    method __init__ (line 502) | def __init__(self,window_T,rescaling=1,base_model_checkpoint="",HRDeco...
    method forward (line 537) | def forward(self,
  class HyperCtrolDiscriminator (line 581) | class HyperCtrolDiscriminator(nn.Module):
    method __init__ (line 582) | def __init__(self):
    method get_lower_half (line 611) | def get_lower_half(self, face_sequences):
    method to_2d (line 614) | def to_2d(self, face_sequences):
    method perceptual_forward (line 619) | def perceptual_forward(self, false_face_sequences):
    method forward (line 632) | def forward(self, face_sequences):

FILE: models/resnet.py
  class ResNet50Encoder (line 5) | class ResNet50Encoder(ResNet):
    method __init__ (line 6) | def __init__(self, pretrained: bool = False,in_ch: int = 3):
    method forward_single_frame (line 22) | def forward_single_frame(self, x):
    method forward_time_series (line 37) | def forward_time_series(self, x):
    method forward (line 43) | def forward(self, x):

FILE: models/syncnet.py
  class SyncNet_color2 (line 7) | class SyncNet_color2(nn.Module):
    method __init__ (line 8) | def __init__(self):
    method forward (line 55) | def forward(self, audio_sequences, face_sequences): # audio_sequences ...
  class SyncNet_color (line 71) | class SyncNet_color(nn.Module):
    method __init__ (line 72) | def __init__(self):
    method forward (line 122) | def forward(self, audio_sequences, face_sequences):

FILE: preprocess.py
  function split_video_5s (line 94) | def split_video_5s(args):
  function get_sketch (line 153) | def get_sketch(hight,width,image,savepath):
  function get_landmarks (line 175) | def get_landmarks(image, face_mesh,hight,width):
  function get_mask (line 191) | def get_mask(hight,width,image,savepath):
  function data_process_hyper_base (line 214) | def data_process_hyper_base(args):
  function split_train_test_text (line 257) | def split_train_test_text(args):
  function data_process_hyper_hq_module (line 292) | def data_process_hyper_hq_module(args):
Condensed preview — 70 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,786K chars).
[
  {
    "path": "GFPGAN.py",
    "chars": 10566,
    "preview": "import cv2\r\nimport os\r\nimport torch\r\nfrom basicsr.utils import img2tensor\r\nfrom facexlib.utils.face_restoration_helper i"
  },
  {
    "path": "Gen_hyperlipsbase_videos.py",
    "chars": 3167,
    "preview": "from HYPERLIPS import Hyperlips\r\nimport argparse\r\nimport os\r\n\r\n\r\n\r\nparser = argparse.ArgumentParser(description='Inferen"
  },
  {
    "path": "HYPERLIPS.py",
    "chars": 12250,
    "preview": "import cv2, os, sys,audio\r\nimport subprocess, random, string\r\nfrom tqdm import tqdm\r\nimport torch, face_detection\r\nfrom "
  },
  {
    "path": "Inference_hyperlips.py",
    "chars": 14213,
    "preview": "import cv2, os, sys, argparse, audio\nimport subprocess, random, string\nfrom tqdm import tqdm\nimport torch, face_detectio"
  },
  {
    "path": "README.md",
    "chars": 4374,
    "preview": "# HyperLips: Hyper Control Lips with High Resolution Decoder for Talking Face Generation\r\nPytorch official implementatio"
  },
  {
    "path": "Train_data/video_clips/MEAD/readme.txt",
    "chars": 15,
    "preview": "Put video here."
  },
  {
    "path": "Train_hyperlipsBase.py",
    "chars": 17774,
    "preview": "from os.path import dirname, join, basename, isfile\r\nfrom tqdm import tqdm\r\nfrom models import SyncNet_color as SyncNet\r"
  },
  {
    "path": "Train_hyperlipsHR.py",
    "chars": 24854,
    "preview": "from os.path import dirname, join, basename, isfile\r\nfrom tqdm import tqdm\r\n\r\nfrom models import SyncNet_color as SyncNe"
  },
  {
    "path": "audio.py",
    "chars": 4887,
    "preview": "import librosa\r\nimport librosa.filters\r\nimport numpy as np\r\n# import tensorflow as tf\r\nfrom scipy import signal\r\nfrom sc"
  },
  {
    "path": "checkpoint",
    "chars": 2,
    "preview": "\r\n"
  },
  {
    "path": "checkpoints/readme.txt",
    "chars": 22,
    "preview": "Put checkpoint here.\r\n"
  },
  {
    "path": "color_syncnet_trainv3.py",
    "chars": 9831,
    "preview": "from os.path import dirname, join, basename, isfile\r\nfrom tqdm import tqdm\r\nfrom models import SyncNet_color as SyncNet\r"
  },
  {
    "path": "conv.py",
    "chars": 1664,
    "preview": "import torch\r\nfrom torch import nn\r\nfrom torch.nn import functional as F\r\n\r\nclass Conv2d(nn.Module):\r\n    def __init__(s"
  },
  {
    "path": "datasets/MEAD/readme.txt",
    "chars": 40,
    "preview": "Put all the traning Mead .mp4 file here."
  },
  {
    "path": "environment.yml",
    "chars": 2671,
    "preview": "name: hyperlips\r\nchannels:\r\n  - http://mirrors.ustc.edu.cn/anaconda/pkgs/free/\r\n  - http://mirrors.ustc.edu.cn/anaconda/"
  },
  {
    "path": "face_detection/README.md",
    "chars": 209,
    "preview": "The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrian"
  },
  {
    "path": "face_detection/__init__.py",
    "chars": 190,
    "preview": "# -*- coding: utf-8 -*-\r\n\r\n__author__ = \"\"\"Adrian Bulat\"\"\"\r\n__email__ = 'adrian.bulat@nottingham.ac.uk'\r\n__version__ = '"
  },
  {
    "path": "face_detection/api.py",
    "chars": 2344,
    "preview": "from __future__ import print_function\r\nimport os\r\nimport torch\r\nfrom torch.utils.model_zoo import load_url\r\nfrom enum im"
  },
  {
    "path": "face_detection/detection/__init__.py",
    "chars": 30,
    "preview": "from .core import FaceDetector"
  },
  {
    "path": "face_detection/detection/core.py",
    "chars": 4998,
    "preview": "import logging\r\nimport glob\r\nfrom tqdm import tqdm\r\nimport numpy as np\r\nimport torch\r\nimport cv2\r\n\r\n\r\nclass FaceDetector"
  },
  {
    "path": "face_detection/detection/sfd/__init__.py",
    "chars": 53,
    "preview": "from .sfd_detector import SFDDetector as FaceDetector"
  },
  {
    "path": "face_detection/detection/sfd/bbox.py",
    "chars": 4408,
    "preview": "from __future__ import print_function\r\nimport os\r\nimport sys\r\nimport cv2\r\nimport random\r\nimport datetime\r\nimport time\r\ni"
  },
  {
    "path": "face_detection/detection/sfd/detect.py",
    "chars": 3881,
    "preview": "import torch\r\nimport torch.nn.functional as F\r\n\r\nimport os\r\nimport sys\r\nimport cv2\r\nimport random\r\nimport datetime\r\nimpo"
  },
  {
    "path": "face_detection/detection/sfd/net_s3fd.py",
    "chars": 5420,
    "preview": "import torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\n\r\n\r\nclass L2Norm(nn.Module):\r\n    def __init__(sel"
  },
  {
    "path": "face_detection/detection/sfd/sfd_detector.py",
    "chars": 1868,
    "preview": "import os\r\nimport cv2\r\nfrom torch.utils.model_zoo import load_url\r\n\r\nfrom ..core import FaceDetector\r\n\r\nfrom .net_s3fd i"
  },
  {
    "path": "face_detection/models.py",
    "chars": 8880,
    "preview": "import torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nimport math\r\n\r\n\r\ndef conv3x3(in_planes, out_planes"
  },
  {
    "path": "face_detection/utils.py",
    "chars": 12121,
    "preview": "from __future__ import print_function\r\nimport os\r\nimport sys\r\nimport time\r\nimport torch\r\nimport math\r\nimport numpy as np"
  },
  {
    "path": "face_parsing/README.md",
    "chars": 141,
    "preview": "Most of the code in this folder was taken from the awesome [face parsing](https://github.com/zllrunning/face-parsing.PyT"
  },
  {
    "path": "face_parsing/__init__.py",
    "chars": 42,
    "preview": "from .swap import init_parser,swap_regions"
  },
  {
    "path": "face_parsing/model.py",
    "chars": 10689,
    "preview": "#!/usr/bin/python\r\n# -*- encoding: utf-8 -*-\r\n\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\n"
  },
  {
    "path": "face_parsing/resnet.py",
    "chars": 3940,
    "preview": "#!/usr/bin/python\r\n# -*- encoding: utf-8 -*-\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nim"
  },
  {
    "path": "face_parsing/swap.py",
    "chars": 3402,
    "preview": "import torch\r\nimport torchvision.transforms as transforms\r\nimport cv2\r\nimport numpy as np\r\nimport torch.nn.functional as"
  },
  {
    "path": "filelists/train.txt",
    "chars": 90,
    "preview": "MEAD/M003-002\r\nMEAD/M003-004\r\nMEAD/M003-003\r\nMEAD/M005-001\r\nMEAD/M003-005\r\nMEAD/M003-001\r\n"
  },
  {
    "path": "filelists/val.txt",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "filelists_lrs2/README.md",
    "chars": 55,
    "preview": "Place LRS2 (and any other) filelists here for training."
  },
  {
    "path": "filelists_lrs2/test.txt",
    "chars": 11,
    "preview": "00001\\00002"
  },
  {
    "path": "filelists_lrs2/train.txt",
    "chars": 1176093,
    "preview": "6251513448847229551/00029\r\n6251513448847229551/00015\r\n6251513448847229551/00003\r\n6251513448847229551/00032\r\n625151344884"
  },
  {
    "path": "filelists_lrs2/val.txt",
    "chars": 124362,
    "preview": "6364798794736712994/00010\r\n6364798794736712994/00012\r\n6364798794736712994/00008\r\n6055669382600582283/00001\r\n610387566603"
  },
  {
    "path": "filelists_mead/README.md",
    "chars": 55,
    "preview": "Place MEAD (and any other) filelists here for training."
  },
  {
    "path": "filelists_mead/test.txt",
    "chars": 11,
    "preview": "00001\\00002"
  },
  {
    "path": "filelists_mead/train.txt",
    "chars": 20998,
    "preview": "mead/W009-023\r\nmead/M028-012\r\nmead/M013-039\r\nmead/M009-025\r\nmead/M028-028\r\nmead/W014-031\r\nmead/M035-016\r\nmead/W019-014\r\n"
  },
  {
    "path": "filelists_mead/val.txt",
    "chars": 3148,
    "preview": "mead/M005-025\r\nmead/W016-009\r\nmead/M034-024\r\nmead/W037-039\r\nmead/M012-017\r\nmead/W015-019\r\nmead/W024-077\r\nmead/M030-014\r\n"
  },
  {
    "path": "gfpgan/gfpganv1_clean_arch.py",
    "chars": 13954,
    "preview": "import math\r\nimport random\r\nimport torch\r\nfrom basicsr.utils.registry import ARCH_REGISTRY\r\nfrom torch import nn\r\nfrom t"
  },
  {
    "path": "gfpgan/stylegan2_clean_arch.py",
    "chars": 14685,
    "preview": "import math\r\nimport random\r\nimport torch\r\nfrom basicsr.archs.arch_util import default_init_weights\r\nfrom basicsr.utils.r"
  },
  {
    "path": "hparams.py",
    "chars": 2198,
    "preview": "from glob import glob\r\nimport os\r\ndef get_image_list(data_root, split):\r\n\tfilelist = []\r\n\r\n\twith open('filelists/{}.txt'"
  },
  {
    "path": "hparams_Base.py",
    "chars": 2283,
    "preview": "from glob import glob\r\nimport os\r\ndef get_image_list(data_root, split):\r\n\tfilelist = []\r\n\r\n\twith open('filelists/{}.txt'"
  },
  {
    "path": "hparams_HR.py",
    "chars": 2395,
    "preview": "from glob import glob\r\nimport os\r\ndef get_image_list(data_root, split):\r\n\tfilelist = []\r\n\r\n\twith open('filelists/{}.txt'"
  },
  {
    "path": "inference.py",
    "chars": 3265,
    "preview": "from HYPERLIPS import Hyperlips\r\nimport argparse\r\nimport os\r\nos.environ[\"CUDA_VISIBLE_DEVICES\"] = '1'\r\n\r\n\r\nparser = argp"
  },
  {
    "path": "models/__init__.py",
    "chars": 34,
    "preview": "from .syncnet import SyncNet_color"
  },
  {
    "path": "models/audio_v.py",
    "chars": 8233,
    "preview": "import librosa\r\nimport librosa.filters\r\nimport numpy as np\r\nfrom scipy import signal\r\nfrom scipy.io import wavfile\r\n\r\n\r\n"
  },
  {
    "path": "models/conv.py",
    "chars": 1664,
    "preview": "import torch\r\nfrom torch import nn\r\nfrom torch.nn import functional as F\r\n\r\nclass Conv2d(nn.Module):\r\n    def __init__(s"
  },
  {
    "path": "models/decoder.py",
    "chars": 7379,
    "preview": "import torch\r\nfrom torch import Tensor\r\nfrom torch import nn\r\nfrom torch.nn import functional as F\r\nfrom typing import T"
  },
  {
    "path": "models/deep_guided_filter.py",
    "chars": 2583,
    "preview": "import torch\r\nfrom torch import nn\r\nfrom torch.nn import functional as F\r\n\r\n\"\"\"\r\nAdopted from <https://github.com/wuhuik"
  },
  {
    "path": "models/gfpganv1_clean_arch.py",
    "chars": 12623,
    "preview": "import math\r\nimport random\r\nimport torch\r\nfrom basicsr.utils.registry import ARCH_REGISTRY\r\nfrom torch import nn\r\nfrom t"
  },
  {
    "path": "models/guided_filter_pytorch/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "models/guided_filter_pytorch/box_filter.py",
    "chars": 1007,
    "preview": "import torch\r\nfrom torch import nn\r\n\r\ndef diff_x(input, r):\r\n    assert input.dim() == 4\r\n\r\n    left   = input[:, :,    "
  },
  {
    "path": "models/guided_filter_pytorch/guided_filter.py",
    "chars": 4301,
    "preview": "import torch\r\nfrom torch import nn\r\nfrom torch.nn import functional as F\r\nfrom torch.autograd import Variable\r\n\r\nfrom .b"
  },
  {
    "path": "models/hyperlayers.py",
    "chars": 9895,
    "preview": "'''Pytorch implementations of hyper-network modules.'''\r\nimport torch\r\nimport torch.nn as nn\r\nimport functools\r\nimport t"
  },
  {
    "path": "models/hypernetwork.py",
    "chars": 648,
    "preview": "import torch.nn as nn\r\n\r\nclass HyperNetwork(nn.Module):\r\n  \"\"\"Hypernetwork architecture.\"\"\"\r\n  def __init__(self, in_dim"
  },
  {
    "path": "models/layers.py",
    "chars": 4288,
    "preview": "\"\"\"\r\nLayers for HyperRecon\r\nFor more details, please read:\r\n  Alan Q. Wang, Adrian V. Dalca, and Mert R. Sabuncu. \r\n  \"R"
  },
  {
    "path": "models/lraspp.py",
    "chars": 924,
    "preview": "from torch import nn\r\n\r\nclass LRASPP(nn.Module):\r\n    def __init__(self, in_channels, out_channels):\r\n        super().__"
  },
  {
    "path": "models/memory.py",
    "chars": 2525,
    "preview": "import torch\r\nfrom torch import nn\r\nfrom torch.nn import functional as F\r\n\r\n\r\n\r\nclass Memory(nn.Module):\r\n    def __init"
  },
  {
    "path": "models/mobilenetv3.py",
    "chars": 19298,
    "preview": "from torch import nn\r\n\r\nfrom torch.hub import load_state_dict_from_url\r\nfrom torchvision.transforms.functional import no"
  },
  {
    "path": "models/model.py",
    "chars": 3808,
    "preview": "import torch\r\nfrom torch import Tensor\r\nfrom torch import nn\r\nfrom torch.nn import functional as F\r\nfrom typing import O"
  },
  {
    "path": "models/model_hyperlips.py",
    "chars": 28600,
    "preview": "import os, random, cv2, argparse\r\n# os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\r\nimport torch\r\nfrom torch import Tensor\r\nfr"
  },
  {
    "path": "models/resnet.py",
    "chars": 1640,
    "preview": "from torch import nn\r\nfrom torchvision.models.resnet import ResNet, Bottleneck\r\n# from torchvision.models.utils import l"
  },
  {
    "path": "models/syncnet.py",
    "chars": 6454,
    "preview": "import torch\r\nfrom torch import nn\r\nfrom torch.nn import functional as F\r\n\r\nfrom .conv import Conv2d\r\n\r\nclass SyncNet_co"
  },
  {
    "path": "preprocess.py",
    "chars": 14875,
    "preview": "import sys\r\nif sys.version_info[0] < 3 and sys.version_info[1] < 2:\r\n\traise Exception(\"Must be using >= Python 3.2\")\r\nfr"
  },
  {
    "path": "requirements.txt",
    "chars": 186,
    "preview": "librosa==0.9.2\r\nnumpy==1.21.5 \r\nopencv-contrib-python==4.7.0.72\r\nopencv-python==4.7.0.72\r\ntqdm==4.65.0\r\nnumba==0.56.4\r\nm"
  }
]

// ... and 1 more files (download for full content)

About this extraction

This page contains the full source code of the semchan/HyperLips GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 70 files (87.3 MB), approximately 630.3k tokens, and a symbol index with 529 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!