[
  {
    "path": ".gitignore",
    "content": ".idea/\n__pycache__\n**/__pycache__\ncache\nsubmit*\npretrain_models/*\n*.mp3\n*.mp4\ntmp*\noutput/*"
  },
  {
    "path": "DAWN_256.yaml",
    "content": "input_size: 256\nmax_n_frames: 200\nrandom_seed: 1234\nmean: [0.0, 0.0, 0.0]\nwin_width: 40\nsampling_step: 20\nddim_sampling_eta: 1.0\ncond_scale: 1.0\n\nmodel_config:\n  is_train: true\n  pose_dim: 6\n  config_pth: './config/hdtf256.yaml'\n  ae_pretrained_pth: './pretrain_models/LFG_256_400ep.pth'\n  diffusion_pretrained_pth: './pretrain_models/DAWN_256.pth'"
  },
  {
    "path": "DM_3/datasets_hdtf_wpose_lmk_block_lmk.py",
    "content": "# dataset for HDTF, stage 1\nfrom os import name\nimport sys\nsys.path.append('your_path')\n\nimport os\nimport random\nimport torch\n\nimport numpy as np\nimport torch.utils.data as data\nimport torch.nn.functional as Ft\nimport imageio.v2 as imageio\n\nimport cv2\nimport torchvision.transforms.functional as F\nimport matplotlib.pyplot as plt\nfrom PIL import Image\nfrom scipy.interpolate import interp1d\nimport decord\nfrom torchvision.transforms.functional import to_pil_image\nfrom torchvision import transforms\nimport time\nimport pickle as pkl\n\ndecord.bridge.set_bridge('torch')\n\n\ndef resize(im, desired_size, interpolation):\n    old_size = im.shape[:2]\n    ratio = float(desired_size)/max(old_size)\n    new_size = tuple(int(x*ratio) for x in old_size)\n\n    im = cv2.resize(im, (new_size[1], new_size[0]), interpolation=interpolation)\n    delta_w = desired_size - new_size[1]\n    delta_h = desired_size - new_size[0]\n    top, bottom = delta_h//2, delta_h-(delta_h//2)\n    left, right = delta_w//2, delta_w-(delta_w//2)\n\n    color = [0, 0, 0]\n    new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)\n\n    return new_im\n\nclass HDTF(data.Dataset):\n    def __init__(self, data_dir, pose_dir, eye_blink_dir, max_num_frames=80, image_size=128,mode='train',\n                 mean=(128, 128, 128), color_jitter=True):\n\n        super(HDTF, self).__init__()\n        self.mean = torch.tensor(mean)[None,:,None,None]\n        self.data_dir = data_dir\n        self.pose_dir = pose_dir\n        self.eye_blink_dir = eye_blink_dir\n        self.is_jitter = color_jitter\n        self.max_num_frames = max_num_frames\n        self.image_size = image_size\n        self.mode = mode\n\n        vid_list = []\n        # # crema\n        # self.hubert_dir = '/train20/intern/permanent/lmlin2/data/crema_wav_hubert'\n        # if mode == 'train':\n        #     for id_name in os.listdir(data_dir):\n        #         if id_name in ['s64','s76','s88','s90','s91']:\n        #             continue\n        #         vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{data_dir}/{id_name}') ])\n        # if mode == 'test':\n        #     for id_name in ['s64','s76','s88','s90','s91']:\n        #         vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{data_dir}/{id_name}') ])\n        # self.videos = vid_list\n\n        # hdtf  \n        vid_id_name_list = ['RD_Radio14_000','RD_Radio30_000','RD_Radio47_000','RD_Radio56_000','WDA_AmyKlobuchar1_001',\\\n                            'WDA_BarbaraLee0_000','WDA_BobCasey0_000','WDA_CatherineCortezMasto_000','WDA_DebbieDingell1_000','WDA_DonaldMcEachin_000',\\\n                            'WDA_EricSwalwell_000','WDA_HenryWaxman_000','WDA_JanSchakowsky1_000','WDA_JoeDonnelly_000','WDA_JohnSarbanes1_000',\\\n                            'WDA_JoeNeguse_001','WDA_KatieHill_000','WDA_LucyMcBath_000','WDA_MazieHirono0_000','WDA_NancyPelosi1_000',\\\n                            'WDA_PattyMurray0_000','WDA_RaulRuiz_000','WDA_SeanPatrickMaloney_000','WDA_TammyBaldwin0_000','WDA_TerriSewell0_000',\\\n                            'WDA_TomCarper_000','WDA_WhipJimClyburn_000','WRA_AdamKinzinger0_000','WRA_AnnWagner_000','WRA_BobCorker_000',\\\n                            'WRA_CandiceMiller0_000','WRA_CathyMcMorrisRodgers2_000','WRA_CoryGardner1_000','WRA_DebFischer1_000','WRA_DianeBlack1_000',\\\n                            'WRA_ErikPaulsen_000','WRA_GeorgeLeMieux_000','WRA_JebHensarling0_001','WRA_JoeHeck1_000','WRA_JohnKasich1_001',\\\n                            'WRA_MarcoRubio_000']\n\n        bad_id_name = ['WDA_DanKildee_000', 'WDA_PatrickLeahy1_000', 'WRA_KristiNoem2_000','RD_Radio39_000']\n\n        # vid_id_name_list = [item + '.mp4' for item in vid_id_name_list]\n        # bad_id_name = [item + '.mp4' for item in bad_id_name]\n        # hdtf  \n        self.hubert_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate_chunk'  \n        self.mouth_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/mouth_ratio_bar'\n        self.lmk_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/lmk_25hz_chunk'\n        with open('/train20/intern/permanent/hbcheng2/data/HDTF/length_dict.pkl', 'rb') as f:\n            self.len_dict = pkl.load(f)\n        # vid_id_name_list = ['RD_Radio47_000','WDA_CatherineCortezMasto_000','WDA_JoeNeguse_001','WDA_MichelleLujanGrisham_000','WRA_ErikPaulsen_002', \\\n        #                     'WDA_ZoeLofgren_000','WRA_JebHensarling2_003','WRA_MichaelSteele_000', 'WRA_ToddYoung_000', 'WRA_VickyHartzler_000']\n        if mode == 'train':\n            for id_name in os.listdir(data_dir):\n                # id_name = id_name[:-4]\n                if id_name in vid_id_name_list or id_name in bad_id_name:\n                    continue\n                vid_list.append(id_name)\n            self.videos = vid_list\n        if mode == 'test':\n            self.videos = vid_id_name_list\n\n    def check_head(self, frame_list, video_name, start, end):\n        '''\n        Check if the desired pose address exists.\n        '''\n        start_path = self.get_pose_path(frame_list, video_name, start)\n        end_path = self.get_pose_path(frame_list, video_name, end)\n\n        if os.path.exists(start_path) and os.path.exists(end_path):\n            return True\n        else:\n            return False\n\n\n    def get_block_data_for_two(self, path, start, end):\n        # TODO： id function\n        '''\n        input: \n            start: start id\n            end:  end id\n        output:\n            the data from block\n        '''\n\n        block_st = start//25\n        block_ed = end//25\n\n        st_pos = block_st % 25\n        ed_pos = block_ed % 25\n\n        block_st_name = 'chunk_%04d.npy' % (block_st)\n        block_ed_name = 'chunk_%04d.npy' % (block_ed)\n\n        if block_st != block_ed:\n            block_st_path = os.path.join(path, block_st_name)\n            block_ed_path = os.path.join(path, block_ed_name)\n            block_st = np.load(block_st_path)\n            block_ed = np.load(block_ed_path)\n\n            return np.concatenate((block_st[st_pos:], block_ed[:ed_pos]))\n        else:\n            block_st_path = os.path.join(path, block_st_name)\n            block_st = np.load(block_st_path)\n            return block_st[st_pos, ed_pos]\n\n    def get_block_data(self, path, start, end):\n        # TODO： id function\n        '''\n        input: \n            start: start id\n            end:  end id\n        output:\n            the data from block\n        '''\n\n        block_st = start//25\n        block_ed = end//25\n\n        st_pos = start % 25\n        ed_pos = end % 25\n\n        block_list = [os.path.join(path,'chunk_%04d.npy' % (i)) for i in range(block_st, block_ed+1)]\n\n        if block_st != block_ed:\n            arr_list = []\n            block_st = np.load(block_list[0])\n            arr_list.append(block_st[st_pos:])\n            for path in block_list[1:-1]:\n                arr_list.append(np.load(path))\n\n            block_ed = np.load(block_list[-1])\n            arr_list.append(block_ed[:ed_pos])\n\n            return np.concatenate(arr_list)\n        else:\n            block_st_path = os.path.join(path, block_list[0])\n            block_st = np.load(block_st_path)\n            return block_st[st_pos: ed_pos]\n            \n\n    def check_len(self, name):\n        \n        return self.len_dict[name]\n\n\n    def __len__(self):\n        return len(self.videos)\n\n    def __getitem__(self, idx):\n        video_name = self.videos[idx]\n        path = os.path.join(self.data_dir, video_name)\n        hubert_path = os.path.join(self.hubert_dir, video_name)\n        lmk_path = os.path.join(self.lmk_dir, video_name)\n        pose_path = os.path.join(self.pose_dir, video_name)\n        eye_blink_path = os.path.join(self.eye_blink_dir, video_name)\n\n        total_num_frames = self.check_len(video_name)\n\n        \n        \n\n        if total_num_frames <= self.max_num_frames:\n            sample_frames = total_num_frames\n            start = 0\n        else:\n            sample_frames = self.max_num_frames\n            start = np.random.randint(total_num_frames-self.max_num_frames)\n        start=start\n        stop=sample_frames+start\n\n        sample_frame_npy = self.get_block_data(path = path, start = start, end = stop)\n        sample_hubert_feature_npy = self.get_block_data(path = hubert_path, start = start, end = stop).astype(np.float32)\n        sample_pose_list_npy = self.get_block_data(path = pose_path, start = start, end = stop).astype(np.float32)\n        sample_eye_blink_list_npy = self.get_block_data(path = eye_blink_path, start = start, end = stop).astype(np.float32)\n\n\n        sample_frame_list = torch.tensor(sample_frame_npy).permute(0,3,1,2)\n        sample_hubert_feature_tensor = torch.tensor(sample_hubert_feature_npy)\n        sample_frame_list = sample_frame_list - self.mean # 20, 3, 128, 128\n        # sample_frame_list = [np.transpose(x, (2, 0, 1)) for x in sample_frame_list]\n        # sample_frame_list_npy = np.stack(sample_frame_list, axis=1) \n        # sample_pose_list_npy = np.stack(sample_pose_list, axis = 1)\n        # sample_eye_blink_list_npy = np.stack(sample_eye_blink_list, axis = 1)\n        # change to float32\n        sample_frame_list = sample_frame_list.permute(1, 0, 2, 3)\n        # sample_frame_list = np.array(sample_frame_list/255.0, dtype=np.float32)  #3, 40, 128, 128\n        # sample_frame_list = sample_frame_list/255.  # put to mode l forward\n        # added to change the video_name of crema\n        video_name = video_name.replace('/','_')\n\n        sample_pose_list_npy = sample_pose_list_npy.transpose(1,0)  # for compatibility\n        sample_eye_blink_list_npy = sample_eye_blink_list_npy.transpose(1,0)\n        \n        # if __debug__:\n        #     end_time = time.time()  # end\n        #     print(f'process time {end_time- start_time}')  # spend lot of time\n        #     start_time = end_time\n        if self.mode == 'test':\n            return sample_frame_list, sample_hubert_feature_tensor, sample_pose_list_npy, sample_eye_blink_list_npy, video_name, start\n        return sample_frame_list, sample_hubert_feature_tensor, sample_pose_list_npy, sample_eye_blink_list_npy, video_name, total_num_frames\n\n\n\n\nif __name__ == \"__main__\":\n    # hdtf\n    data_dir = \"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\"\n    pose_dir = \"/train20/intern/permanent/hbcheng2/data/HDTF/pose\"\n    # crema\n    # data_dir='/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images'\n    dataset = HDTF(data_dir=data_dir, pose_dir=pose_dir ,mode='train')\n    for i in range(10):\n        dataset.__getitem__(i)\n        print('------')    \n\n    test_dataset = data.DataLoader(dataset=dataset,\n                                    batch_size=10,\n                                    num_workers=8,\n                                    shuffle=False)\n    for i, batch in enumerate(test_dataset):\n        print(i)\n"
  },
  {
    "path": "DM_3/datasets_hdtf_wpose_lmk_block_lmk_rand.py",
    "content": "# dataset for HDTF, stage 2\nfrom os import name\nimport sys\nsys.path.append('your_path')\n\nimport os\nimport random\nimport torch\n\nimport numpy as np\nimport torch.utils.data as data\nimport torch.nn.functional as Ft\nimport imageio.v2 as imageio\n\nimport cv2\nimport torchvision.transforms.functional as F\nimport matplotlib.pyplot as plt\nfrom PIL import Image\nfrom scipy.interpolate import interp1d\nimport decord\nfrom torchvision.transforms.functional import to_pil_image\nfrom torchvision import transforms\nimport time\nimport pickle as pkl\n\ndecord.bridge.set_bridge('torch')\n\n\ndef resize(im, desired_size, interpolation):\n    old_size = im.shape[:2]\n    ratio = float(desired_size)/max(old_size)\n    new_size = tuple(int(x*ratio) for x in old_size)\n\n    im = cv2.resize(im, (new_size[1], new_size[0]), interpolation=interpolation)\n    delta_w = desired_size - new_size[1]\n    delta_h = desired_size - new_size[0]\n    top, bottom = delta_h//2, delta_h-(delta_h//2)\n    left, right = delta_w//2, delta_w-(delta_w//2)\n\n    color = [0, 0, 0]\n    new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)\n\n    return new_im\n\nclass HDTF(data.Dataset):\n    def __init__(self, data_dir, pose_dir, eye_blink_dir, max_num_frames=80, image_size=128, audio_dir=None, ref_id = None, mode='train',\n                 mean=(128, 128, 128), color_jitter=True):\n\n        super(HDTF, self).__init__()\n        self.mean = torch.tensor(mean)[None,:,None,None]\n        self.data_dir = data_dir\n        self.pose_dir = pose_dir\n        self.eye_blink_dir = eye_blink_dir\n        self.is_jitter = color_jitter\n        self.max_num_frames = max_num_frames\n        self.image_size = image_size\n        self.mode = mode\n\n        vid_list = []\n        # # crema\n        # self.hubert_dir = '/train20/intern/permanent/lmlin2/data/crema_wav_hubert'\n        # if mode == 'train':\n        #     for id_name in os.listdir(data_dir):\n        #         if id_name in ['s64','s76','s88','s90','s91']:\n        #             continue\n        #         vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{data_dir}/{id_name}') ])\n        # if mode == 'test':\n        #     for id_name in ['s64','s76','s88','s90','s91']:\n        #         vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{data_dir}/{id_name}') ])\n        # self.videos = vid_list\n\n        # hdtf  \n        vid_id_name_list = ['RD_Radio14_000','RD_Radio30_000','RD_Radio47_000','RD_Radio56_000','WDA_AmyKlobuchar1_001',\\\n                            'WDA_BarbaraLee0_000','WDA_BobCasey0_000','WDA_CatherineCortezMasto_000','WDA_DebbieDingell1_000','WDA_DonaldMcEachin_000',\\\n                            'WDA_EricSwalwell_000','WDA_HenryWaxman_000','WDA_JanSchakowsky1_000','WDA_JoeDonnelly_000','WDA_JohnSarbanes1_000',\\\n                            'WDA_JoeNeguse_001','WDA_KatieHill_000','WDA_LucyMcBath_000','WDA_MazieHirono0_000','WDA_NancyPelosi1_000',\\\n                            'WDA_PattyMurray0_000','WDA_RaulRuiz_000','WDA_SeanPatrickMaloney_000','WDA_TammyBaldwin0_000','WDA_TerriSewell0_000',\\\n                            'WDA_TomCarper_000','WDA_WhipJimClyburn_000','WRA_AdamKinzinger0_000','WRA_AnnWagner_000','WRA_BobCorker_000',\\\n                            'WRA_CandiceMiller0_000','WRA_CathyMcMorrisRodgers2_000','WRA_CoryGardner1_000','WRA_DebFischer1_000','WRA_DianeBlack1_000',\\\n                            'WRA_ErikPaulsen_000','WRA_GeorgeLeMieux_000','WRA_JebHensarling0_001','WRA_JoeHeck1_000','WRA_JohnKasich1_001',\\\n                            'WRA_MarcoRubio_000']\n\n        bad_id_name = ['WDA_DanKildee_000', 'WDA_PatrickLeahy1_000', 'WRA_KristiNoem2_000']\n\n        # vid_id_name_list = [item + '.mp4' for item in vid_id_name_list]\n        # bad_id_name = [item + '.mp4' for item in bad_id_name]\n        # hdtf  \n        if audio_dir == None:\n            self.hubert_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate_chunk'  \n        else:\n            self.hubert_dir = audio_dir\n\n        self.ref_id = ref_id\n        self.mouth_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/mouth_ratio_bar'\n        self.lmk_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/lmk_25hz_chunk'\n        with open('/train20/intern/permanent/hbcheng2/data/HDTF/length_dict.pkl', 'rb') as f:\n            self.len_dict = pkl.load(f)\n        # vid_id_name_list = ['RD_Radio47_000','WDA_CatherineCortezMasto_000','WDA_JoeNeguse_001','WDA_MichelleLujanGrisham_000','WRA_ErikPaulsen_002', \\\n        #                     'WDA_ZoeLofgren_000','WRA_JebHensarling2_003','WRA_MichaelSteele_000', 'WRA_ToddYoung_000', 'WRA_VickyHartzler_000']\n        if mode == 'train':\n            for id_name in os.listdir(data_dir):\n                # id_name = id_name[:-4]\n                if id_name in vid_id_name_list or id_name in bad_id_name:\n                    continue\n                vid_list.append(id_name)\n            self.videos = vid_list\n        if mode == 'test':\n            self.videos = vid_id_name_list\n\n    def check_head(self, frame_list, video_name, start, end):\n        '''\n        Check if the desired pose address exists.\n        '''\n        start_path = self.get_pose_path(frame_list, video_name, start)\n        end_path = self.get_pose_path(frame_list, video_name, end)\n\n        if os.path.exists(start_path) and os.path.exists(end_path):\n            return True\n        else:\n            return False\n\n\n    def get_block_data_for_two(self, path, start, end):\n        # TODO： id function\n        '''\n        input: \n            start: start id\n            end:  end id\n        output:\n            the data from block\n        '''\n\n        block_st = start//25\n        block_ed = end//25\n\n        st_pos = block_st % 25\n        ed_pos = block_ed % 25\n\n        block_st_name = 'chunk_%04d.npy' % (block_st)\n        block_ed_name = 'chunk_%04d.npy' % (block_ed)\n\n        if block_st != block_ed:\n            block_st_path = os.path.join(path, block_st_name)\n            block_ed_path = os.path.join(path, block_ed_name)\n            block_st = np.load(block_st_path)\n            block_ed = np.load(block_ed_path)\n\n            return np.concatenate((block_st[st_pos:], block_ed[:ed_pos]))\n        else:\n            block_st_path = os.path.join(path, block_st_name)\n            block_st = np.load(block_st_path)\n            return block_st[st_pos, ed_pos]\n\n    def get_block_data(self, path, start, end):\n        # TODO： id function\n        '''\n        input: \n            start: start id\n            end:  end id\n        output:\n            the data from block\n        '''\n\n        block_st = start//25\n        block_ed = end//25\n\n        st_pos = start % 25\n        ed_pos = end % 25\n\n        block_list = [os.path.join(path,'chunk_%04d.npy' % (i)) for i in range(block_st, block_ed+1)]\n\n        if block_st != block_ed:\n            arr_list = []\n            block_st = np.load(block_list[0])\n            arr_list.append(block_st[st_pos:])\n            for path in block_list[1:-1]:\n                arr_list.append(np.load(path))\n\n            block_ed = np.load(block_list[-1])\n            arr_list.append(block_ed[:ed_pos])\n\n            return np.concatenate(arr_list)\n        else:\n            block_st_path = os.path.join(path, block_list[0])\n            block_st = np.load(block_st_path)\n            return block_st[st_pos: ed_pos]\n            \n\n    def check_len(self, name):\n        \n        return self.len_dict[name]\n\n\n    def __len__(self):\n        return len(self.videos)\n\n    def __getitem__(self, idx):\n        video_name = self.videos[idx]\n        path = os.path.join(self.data_dir, video_name)\n        hubert_path = os.path.join(self.hubert_dir, video_name)\n        lmk_path = os.path.join(self.lmk_dir, video_name)\n        pose_path = os.path.join(self.pose_dir, video_name)\n        eye_blink_path = os.path.join(self.eye_blink_dir, video_name)\n\n        total_num_frames = self.check_len(video_name) \n        \n\n        if total_num_frames <= self.max_num_frames:\n            sample_frames = total_num_frames\n            start = 0\n        else:\n            sample_frames = self.max_num_frames\n            start = np.random.randint(total_num_frames-self.max_num_frames)\n        start=start\n        stop=sample_frames+start\n        if self.ref_id == None:\n            ref_id = np.random.randint(total_num_frames)\n        elif self.ref_id == \"clip\":\n            ref_id = np.random.randint(sample_frames) + start\n        else:\n            ref_id = 0\n        \n\n        sample_frame_npy = self.get_block_data(path = path, start = start, end = stop)\n        sample_hubert_feature_npy = self.get_block_data(path = hubert_path, start = start, end = stop).astype(np.float32)\n        sample_lmk_npy = self.get_block_data(path = lmk_path, start = start, end = stop).astype(np.float32)\n        sample_pose_list_npy = self.get_block_data(path = pose_path, start = start, end = stop).astype(np.float32)\n        sample_eye_blink_list_npy = self.get_block_data(path = eye_blink_path, start = start, end = stop).astype(np.float32)\n\n        ref_frame_npy = self.get_block_data(path = path, start = ref_id, end = ref_id + 1)\n        ref_hubert_feature_npy = self.get_block_data(path = hubert_path, start = ref_id, end = ref_id + 1).astype(np.float32)\n        ref_pose_list_npy = self.get_block_data(path = pose_path, start = ref_id, end = ref_id + 1).astype(np.float32)\n        ref_eye_blink_list_npy = self.get_block_data(path = eye_blink_path, start = ref_id, end =  ref_id + 1).astype(np.float32)\n\n        # mouth_path = os.path.join(self.mouth_dir, video_name+'.npy')\n        # mouth_seq = np.load(mouth_path).astype(np.float32)\n        # ref_mouth = mouth_seq[ref_id]\n        # mouth_seq = mouth_seq[start:stop]\n        mouth_lmk_tensor = torch.tensor(sample_lmk_npy[:,48:67])\n        \n\n        sample_frame_list = torch.tensor(sample_frame_npy).permute(0,3,1,2)\n        sample_hubert_feature_tensor = torch.tensor(sample_hubert_feature_npy)\n        sample_frame_list = sample_frame_list - self.mean # 20, 3, 128, 128\n\n        ref_frame_npy = torch.tensor(ref_frame_npy).permute(0,3,1,2)\n        ref_hubert_feature_npy = torch.tensor(ref_hubert_feature_npy)\n        ref_frame_npy = ref_frame_npy - self.mean # 20, 3, 128, 128\n\n        sample_hubert_feature_tensor = torch.concat([ref_hubert_feature_npy, sample_hubert_feature_tensor], dim = 0)\n        sample_frame_list = torch.concat([ref_frame_npy, sample_frame_list], dim = 0)\n        # sample_frame_list = [np.transpose(x, (2, 0, 1)) for x in sample_frame_list]\n        # sample_frame_list_npy = np.stack(sample_frame_list, axis=1) \n        # sample_pose_list_npy = np.stack(sample_pose_list, axis = 1)\n        # sample_eye_blink_list_npy = np.stack(sample_eye_blink_list, axis = 1)\n        # change to float32\n        sample_frame_list = sample_frame_list.permute(1, 0, 2, 3)\n        # sample_frame_list = np.array(sample_frame_list/255.0, dtype=np.float32)  #3, 40, 128, 128\n        # sample_frame_list = sample_frame_list/255.  # put to mode l forward\n        # added to change the video_name of crema\n        video_name = video_name.replace('/','_')\n\n        sample_pose_list_npy = np.concatenate([ref_pose_list_npy, sample_pose_list_npy], axis = 0)\n        sample_eye_blink_list_npy = np.concatenate([ref_eye_blink_list_npy, sample_eye_blink_list_npy], axis = 0)\n        sample_pose_list_npy = sample_pose_list_npy.transpose(1,0)  # for compatibility\n        sample_eye_blink_list_npy = sample_eye_blink_list_npy.transpose(1,0)\n\n        # mouth_seq = np.concatenate([ref_mouth[None], mouth_seq], axis = 0)\n        # mouth_seq_npy = mouth_seq.transpose(1,0)\n        \n        # if __debug__:\n        #     end_time = time.time()  # end\n        #     print(f'load data time {end_time- start_time}')  # spend lot of time\n        #     start_time = end_time\n        if self.mode == 'test':\n            return sample_frame_list, sample_hubert_feature_tensor, sample_pose_list_npy, sample_eye_blink_list_npy, mouth_lmk_tensor, video_name, start\n        return sample_frame_list, sample_hubert_feature_tensor, sample_pose_list_npy, sample_eye_blink_list_npy, mouth_lmk_tensor, video_name, total_num_frames\n\n\n\n\nif __name__ == \"__main__\":\n    # hdtf\n    data_dir = \"/train20/intern/permanent/hbcheng2/data/HDTF/images_25hz_128_chunk\"\n    pose_dir = \"/train20/intern/permanent/hbcheng2/data/HDTF/pose_bar_chunk\"\n    eye_blink_dir = \"/train20/intern/permanent/hbcheng2/data/HDTF/eye_blink_bbox_from_xpc_bar_2_chunk\"\n    # crema\n    # data_dir='/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images'\n    dataset = HDTF(data_dir=data_dir,\n                                       pose_dir=pose_dir,\n                                       eye_blink_dir = eye_blink_dir,\n                                       image_size=128,\n                                       max_num_frames=30,\n                                       color_jitter=True)\n    for i in range(10):\n        dataset.__getitem__(i)\n        print('------')    \n\n    test_dataset = data.DataLoader(dataset=dataset,\n                                    batch_size=10,\n                                    num_workers=8,\n                                    shuffle=False)\n    for i, batch in enumerate(test_dataset):\n        print(i)\n"
  },
  {
    "path": "DM_3/modules/local_attention.py",
    "content": "import sys\n# sys.path.append('your/path/DAWN-pytorch')\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport time\nfrom einops import rearrange\nfrom rotary_embedding_torch import RotaryEmbedding\nfrom einops_exts import rearrange_many\nimport math\nimport concurrent.futures\n# from local_attn_cuda_pkg.test_cuda import attn_forward\n# from local_attn_cuda_pkg.test_cuda import attn_forward, compute_res_forward\n# def attn_forward(x, y, batch_size, hw, seq_len, k_size, head, d, device):\n#     attn = torch.zeros(batch_size, hw, seq_len, k_size, head, device=device)\n#     local_attn_res.attn_cuda(x, y, attn, batch_size, hw, seq_len, k_size, head, d)\n#     return attn\n\n# def compute_res_forward(attn, z, batch_size, hw, seq_len, head, head_dim, k_size, device):\n#     res = torch.zeros(batch_size, hw, seq_len, head, head_dim, device=device)\n#     local_attn_res.compute_res_cuda(attn, z, res, batch_size, hw, seq_len, head, head_dim, k_size)\n#     return res\n\ndef exists(x):\n    return x is not None\n\ndef to_mask(x, mask, mode='mul'):\n    if mask is None:\n        return x\n    else:\n        while x.dim() > mask.dim():\n            mask = mask.unsqueeze(-1)\n        if mode == 'mul':\n            return x * mask\n        else:\n            return x + mask\n\n# def extract_seq_patches(x, kernel_size, rate):\n#     \"\"\"x.shape = [batch_size, seq_len, seq_dim]\"\"\"\n#     seq_len = x.size(1)\n#     seq_dim = x.size(2)\n#     k_size = kernel_size + (rate - 1) * (kernel_size - 1)\n#     p_right = (k_size - 1) // 2\n#     p_left = k_size - 1 - p_right\n#     x = F.pad(x, (0, 0, p_left, p_right), mode='constant', value=0)\n#     xs = [x[:, i: i + seq_len] for i in range(0, k_size, rate)]\n#     x = torch.cat(xs, dim=2)\n#     return x.reshape(-1, seq_len, kernel_size, seq_dim)\n\ndef extract_seq_patches(x, kernel_size, rate):\n    \"\"\"x.shape = [batch_size, hw, seq_len, seq_dim]\"\"\"\n    # batch_size, hw, seq_len, seq_dim = x.size()\n\n    # Calculate the size of the expanded kernel and the number of padding to be added on both sides.\n    k_size = kernel_size + (rate - 1) * (kernel_size - 1)\n    p_right = (k_size - 1) // 2\n    p_left = k_size - 1 - p_right\n\n    # padding\n    x = F.pad(x, (0, 0, p_left, p_right), mode='constant', value=0)  # pad only the second dimension\n\n    # Use the unfold method to extract sliding windows.\n    x_unfold = x.unfold(dimension=2, size=k_size, step=rate)  # x, window, k_size, step, rate\n    x_unfold = x_unfold.transpose(-1, -2)\n    \n    #  reshape (batch_size, hw, seq_len, kernel_size, seq_dim)\n    x_patches = x_unfold[:, :, :, ::rate]\n\n    return x_patches\n\ndef window_attn(x, y, z, kernel_size, mask, rate):\n    \"\"\"y.shape x.shape = [batch_size, hw, seq_len, self.heads, dim_head]\"\"\"\n    batch_size, hw, seq_len, head, head_dim = x.size()\n    device = x.device\n\n    # Calculate the size of the expanded kernel and the number of padding to be added on both sides.\n    k_size = kernel_size + (rate - 1) * (kernel_size - 1)\n    p_right = (k_size - 1) // 2\n    p_left = k_size - 1 - p_right\n\n    # padding\n    y = F.pad(y, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0)  # pad only the second dimension\n    z = F.pad(z, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0)\n\n    attn = torch.zeros(batch_size, hw, seq_len, k_size, head).to(device)\n    for i in range(seq_len):\n        # torch.matmul(x[:,:,i].unsqueeze(2), y[:,:,i:i + k_size].transpose()) # b, hw, 1, d ; b, hw, w, d\n        attn[:,:, i] = torch.einsum('b n h d, b n w h d -> b n w h', x[:,:,i], y[:,:,i:i + k_size]) \n    # reshape (batch_size, hw, seq_len, kernel_size, seq_dim)\n    # res = rearrange(res, 'b n l w h -> b n h l w')\n    attn = to_mask(attn, mask.unsqueeze(0), 'add')\n    attn = attn - attn.amax(dim=-2, keepdim=True).detach()\n    attn = F.softmax(attn, dim=-2)\n    res = torch.zeros(batch_size, hw, seq_len, head, head_dim).to(device)\n\n    for i in range(seq_len):\n        res[:,:,i] = torch.einsum('b n w h, b n w h d -> b n h d', attn[:,:,i], z[:,:,i : i +k_size])  # attn[:,:,i] * z[:,:,i : i +k_size]\n    res = res.view(batch_size, hw, seq_len, -1)\n    return res\n\n\ndef window_attn_2(x, y, z, kernel_size, mask, rate):  # bad optimization\n    \"\"\"\n    The optimized window_attn function eliminates two explicit for loops and utilizes tensor operations for parallel computation.\n    \n    param:\n        x (Tensor): [batch_size, hw, seq_len, heads, dim_head]\n        y (Tensor): [batch_size, hw, seq_len, heads, dim_head]\n        z (Tensor): [batch_size, hw, seq_len, heads, dim_head]\n        kernel_size (int): window size\n        mask (Tensor)\n        rate (int)\n    \n    return:\n        Tensor: [batch_size, hw, seq_len, heads * dim_head]\n    \"\"\"\n    batch_size, hw, seq_len, head, head_dim = x.size()\n\n    k_size = kernel_size + (rate - 1) * (kernel_size - 1)\n    p_right = (k_size - 1) // 2\n    p_left = k_size - 1 - p_right\n\n    y_padded = F.pad(y, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0)  # [batch_size, hw, seq_len + p_left + p_right, heads, dim_head]\n    z_padded = F.pad(z, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0)  \n\n    # y_windows  z_windows  [batch_size, hw, seq_len, k_size, heads, dim_head]\n    y_windows = y_padded.as_strided(\n        size=(batch_size, hw, seq_len, k_size, head, head_dim),\n        stride=(\n            y_padded.stride(0), \n            y_padded.stride(1), \n            y_padded.stride(2), \n            y_padded.stride(2),  \n            y_padded.stride(3), \n            y_padded.stride(4)\n        )\n    )\n    \n    z_windows = z_padded.as_strided(\n        size=(batch_size, hw, seq_len, k_size, head, head_dim),\n        stride=(\n            z_padded.stride(0), \n            z_padded.stride(1), \n            z_padded.stride(2), \n            z_padded.stride(2), \n            z_padded.stride(3), \n            z_padded.stride(4)\n        )\n    )\n\n    # x: [batch_size, hw, seq_len, heads, dim_head] -> [batch_size, hw, seq_len, 1, heads, dim_head]\n    x_expanded = x #.unsqueeze(3)  #  [batch_size, hw, seq_len, 1, heads, dim_head]\n\n    attn_scores = torch.einsum('b n l h d, b n l s h d -> b n l s h', x_expanded, y_windows)  \n\n    attn = to_mask(attn_scores, mask.unsqueeze(0), 'add')  \n\n    attn = attn - attn.amax(dim=-2, keepdim=True).detach()  \n    attn = F.softmax(attn, dim=-2)  #    Softmax on k_size\n\n    res = (attn.unsqueeze(-1) * z_windows).sum(dim=-3)  \n\n    res = res.view(batch_size, hw, seq_len, head * head_dim)\n\n    return res\n\ndef window_attn_stream(x, y, z, kernel_size, mask, rate):  # bad optimization\n    \"\"\"y.shape x.shape = [batch_size, hw, seq_len, self.heads, dim_head]\"\"\"\n    batch_size, hw, seq_len, head, head_dim = x.size()\n\n    # Calculate the size of the expanded kernel and the number of padding to be added on both sides.\n    k_size = kernel_size + (rate - 1) * (kernel_size - 1)\n    p_right = (k_size - 1) // 2\n    p_left = k_size - 1 - p_right\n\n    # padding\n    y = F.pad(y, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0)  # pad only the second dimension\n    z = F.pad(z, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0)\n\n    attn = torch.zeros(batch_size, hw, seq_len, k_size, head, device=x.device)\n    res = torch.zeros(batch_size, hw, seq_len, head, head_dim, device=x.device)\n\n    streams = [torch.cuda.Stream() for _ in range(seq_len)]\n\n    def compute_attn(i):\n        with torch.cuda.stream(streams[i]):\n            attn[:, :, i] = torch.einsum('b n h d, b n w h d -> b n w h', x[:, :, i], y[:, :, i:i + k_size])\n\n    def compute_res(i):\n        with torch.cuda.stream(streams[i]):\n            res[:, :, i] = torch.einsum('b n w h, b n w h d -> b n h d', attn[:, :, i], z[:, :, i:i + k_size])\n\n    for i in range(seq_len):\n        compute_attn(i)\n\n    for stream in streams:\n        stream.synchronize()\n\n    attn = to_mask(attn, mask.unsqueeze(0), 'add')\n    attn = attn - attn.amax(dim=-2, keepdim=True).detach()\n    attn = F.softmax(attn, dim=-2)\n\n    for i in range(seq_len):\n        compute_res(i)\n\n    for stream in streams:\n        stream.synchronize()\n\n    res = res.view(batch_size, hw, seq_len, -1)\n    return res\n\ndef create_sliding_window_mask(x, win_size, rate):\n    #  mask (len, len, head)\n    # assert mask.dim() == 3, \"The input mask must be of shape (len, len, head)\"\n\n    k_size = win_size + (rate - 1) * (win_size - 1)\n    p_right = (k_size - 1) // 2\n    p_left = k_size - 1 - p_right\n\n    # padding\n    x = F.pad(x, (p_left, p_right), mode='constant', value=-1e10)  # pad only the second dimension\n    res = []\n    for i in range(x.shape[1]):\n        res.append(x[:, i , i :i +k_size])\n\n    return torch.stack(res, dim = 1) # len k_size, head\n\nclass OurLayer(nn.Module):\n\n    def reuse(self, layer, *args, **kwargs):\n        outputs = layer(*args, **kwargs)\n        return outputs\n    \n \ndef heavy_computation(x, y, attn, k_size, i):\n        attn[:,:, i] = torch.einsum('b n h d, b n w h d -> b n w h', x[:,:,i], y[:,:,i:i + k_size]) \n\ndef heavy_computation2(res, z, attn, k_size, i):\n        res[:,:,i] = torch.einsum('b n w h, b n w h d -> b n h d', attn[:,:,i], z[:,:,i : i +k_size])  # attn[:,:,i] * z[:,:,i : i +k_size]\n\nfrom functools import partial\ndef window_attn_mp(x, y, z, kernel_size, mask, rate):\n    \"\"\"y.shape x.shape = [batch_size, hw, seq_len, self.heads, dim_head]\"\"\"\n    batch_size, hw, seq_len, head, head_dim = x.size()\n    device = x.device\n    # Calculate the size of the expanded kernel and the number of padding to be added on both sides.\n    k_size = kernel_size + (rate - 1) * (kernel_size - 1)\n    p_right = (k_size - 1) // 2\n    p_left = k_size - 1 - p_right\n    \n    # padding\n    y = F.pad(y, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0)  # pad only the second dimension\n    z = F.pad(z, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0)\n\n    attn =  torch.zeros(batch_size, hw, seq_len, k_size, head).to(device)\n\n    \n\n    unary = partial(heavy_computation, x,y,attn, k_size)\n    with concurrent.futures.ProcessPoolExecutor() as executor:\n        executor.map(unary, list(range(seq_len)))\n    # reshape (batch_size, hw, seq_len, kernel_size, seq_dim)\n    # res = rearrange(res, 'b n l w h -> b n h l w')\n    attn = to_mask(attn, mask.unsqueeze(0), 'add')\n    attn = attn - attn.amax(dim=-2, keepdim=True).detach()\n    attn = F.softmax(attn, dim=-2)\n    res = torch.zeros(batch_size, hw, seq_len, head, head_dim).to(device)\n\n    unary2 = partial(heavy_computation2, res,z,attn, k_size)\n    with concurrent.futures.ProcessPoolExecutor() as executor:\n        executor.map(unary2, list(range(seq_len)))\n    res = res.view(batch_size, hw, seq_len, -1)\n    return res\n\nclass LocalSelfAttention_opt(OurLayer):\n\n    def __init__(self, d_model, heads, size_per_head, neighbors=3, rate=1, rotary_emb=None,\n                 key_size=None, mask_right=False):\n        super(LocalSelfAttention_opt, self).__init__()\n        self.heads = heads\n        self.size_per_head = size_per_head\n        self.out_dim = heads * size_per_head\n        self.key_size = key_size if key_size else size_per_head\n        self.neighbors = neighbors\n        self.rate = rate\n        self.mask_right = mask_right\n\n        self.rotary_emb = rotary_emb\n        # self.q_dense = nn.Linear(self.key_size * self.heads, self.key_size * self.heads, bias=False)\n        # self.k_dense = nn.Linear(self.key_size * self.heads, self.key_size * self.heads, bias=False)\n        # self.v_dense = nn.Linear(self.key_size * self.heads, self.key_size * self.heads, bias=False)\n        # self.q_dense.weight.data.fill_(1)\n        # self.k_dense.weight.data.fill_(1)\n        # self.v_dense.weight.data.fill_(1)\n        self.to_qkv = nn.Linear(d_model, self.key_size * self.heads * 3, bias=False)\n        self.to_out = nn.Linear(self.key_size * self.heads, d_model, bias=False)\n        # self.to_qkv.weight.data.fill_(1)\n        # self.to_out.weight.data.fill_(1)\n\n    def forward(self, inputs, pos_bias,  focus_present_mask=None,):\n        # if isinstance(inputs, list):\n        #     x, x_mask = inputs\n        # else:\n        #     x, x_mask = inputs, None\n        x = inputs\n        x_mask = pos_bias\n\n        kernel_size = 1 + 2 * self.neighbors\n\n        # if x_mask is not None:\n        #     xp_mask = create_sliding_window_mask(x_mask, kernel_size, self.rate) # b, hw, seq, d_model -> b, hw, seq, win, d_model\n\n        batch_size, hw, seq_len, seq_dim = x.size()\n\n        if x_mask is not None:\n            xp_mask = x_mask.unsqueeze(0) # b, hw, seq, win, 1\n            v_mask = xp_mask\n        else:\n            v_mask = None\n\n        # k = self.k_dense(x)\n        # v = self.v_dense(x)\n        qw, k, v = self.to_qkv(x).chunk(3, dim=-1) # qw: b, hw, seq_len, d_model\n        \n        qw = qw/ (self.key_size ** 0.5)\n        qw = qw.view(batch_size, hw, seq_len, self.heads, self.key_size)\n        k = k.view(batch_size, hw, seq_len, self.heads, self.key_size) # b, hw,  seq_len,h, d_head\n        v = v.view(batch_size, hw, seq_len, self.heads, self.key_size)\n        st = time.time()\n        if exists(self.rotary_emb):\n            qw = self.rotary_emb.rotate_queries_or_keys(qw)\n            k = self.rotary_emb.rotate_queries_or_keys(k)\n        ed = time.time()\n        # print(\"rope local: \", ed - st)\n        st = time.time()\n        # qw = qw.view(batch_size * hw, seq_len, seq_dim) # b * hw, seq, d_model\n        # k = k.view(batch_size, hw, seq_len, self.key_size * self.heads)\n        \n        res = window_attn(qw, k, v, kernel_size, v_mask.permute(0, 2, 3, 1), rate = 1)\n        ed = time.time()\n        # print(\"rope local: \", ed - st)\n        return self.to_out(res)\n    \n\nclass MultiHeadLocalAttention(nn.Module):\n    def __init__(self, d_model, num_heads, window_size):\n        super(MultiHeadLocalAttention, self).__init__()\n        self.d_model = d_model\n        self.num_heads = num_heads\n        self.window_size = window_size\n\n        assert d_model % num_heads == 0\n\n        self.depth = d_model // num_heads\n\n        self.query = nn.Linear(d_model, d_model)\n        self.key = nn.Linear(d_model, d_model)\n        self.value = nn.Linear(d_model, d_model)\n        # self.out = nn.Linear(d_model, d_model)\n\n        self.query.weight.data.fill_(1)\n        self.key.weight.data.fill_(1)\n        self.value.weight.data.fill_(1)\n        self.query.bias.data.fill_(0)\n        self.key.bias.data.fill_(0)\n        self.value.bias.data.fill_(0)\n\n\n    def split_heads(self, x, batch_size):\n        \"\"\"Split the last dimension into (num_heads, depth).\"\"\"\n        x = x.reshape(batch_size, -1, self.num_heads, self.depth)\n        return x.permute(0, 2, 1, 3)  # (batch_size, num_heads, seq_len, depth)\n\n    def forward(self, x):\n        batch_size, seq_len, d_model = x.size()\n        assert d_model == self.d_model\n\n        Q = self.split_heads(self.query(x), batch_size)\n        K = self.split_heads(self.key(x), batch_size)\n        V = self.split_heads(self.value(x), batch_size)\n\n        # Create the attention scores\n        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.depth ** 0.5)  # (batch_size, num_heads, seq_len, seq_len)\n\n        # Create the mask\n        mask = (torch.arange(seq_len).unsqueeze(0) - torch.arange(seq_len).unsqueeze(1)).abs()\n        mask = (mask > self.window_size).unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, seq_len)\n        mask = mask.to(x.device)\n\n        # Apply the mask to the attention scores\n        attn_scores = attn_scores.masked_fill(mask, float('-inf'))\n\n        # Compute the attention weights\n        attn_weights = F.softmax(attn_scores, dim=-1)  # (batch_size, num_heads, seq_len, seq_len)\n\n        # Compute the output\n        output = torch.matmul(attn_weights, V)  # (batch_size, num_heads, seq_len, depth)\n        output = output.permute(0, 2, 1, 3) # (batch_size, seq_len, num_heads, depth)\n        output = output.reshape(batch_size, seq_len, d_model)\n\n        return output\n\n\nclass Attention(nn.Module):\n    def __init__(\n            self,\n            dim,\n            heads=4,\n            dim_head=32,\n            rotary_emb=None\n    ):\n        super().__init__()\n        self.scale = dim_head ** -0.5\n        self.heads = heads\n        hidden_dim = dim_head * heads\n\n        self.rotary_emb = rotary_emb\n        self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False)\n        self.to_out = nn.Linear(hidden_dim, dim, bias=False)\n\n        self.to_qkv.weight.data.fill_(1)\n        self.to_out.weight.data.fill_(1)\n\n    def forward(\n            self,\n            x,\n            pos_bias=None,\n            focus_present_mask=None\n    ):\n        n, device = x.shape[-2], x.device\n\n        qkv = self.to_qkv(x).chunk(3, dim=-1)\n\n        if exists(focus_present_mask) and focus_present_mask.all():\n            # if all batch samples are focusing on present\n            # it would be equivalent to passing that token's values through to the output\n            values = qkv[-1]\n            return self.to_out(values)\n\n        # split out heads\n\n        q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads)\n\n        # scale\n\n        q = q * self.scale\n\n        # rotate positions into queries and keys for time attention\n\n        st = time.time()\n        if exists(self.rotary_emb):\n            q = self.rotary_emb.rotate_queries_or_keys(q)\n            k = self.rotary_emb.rotate_queries_or_keys(k)\n\n        ed = time.time()\n        print(\"rope normal: \", ed - st)\n        # similarity\n\n        sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k)\n\n        # relative positional bias\n\n        if exists(pos_bias):\n            sim = sim + pos_bias\n\n        if exists(focus_present_mask) and not (~focus_present_mask).all():\n            attend_all_mask = torch.ones((n, n), device=device, dtype=torch.bool)\n            attend_self_mask = torch.eye(n, device=device, dtype=torch.bool)\n\n            mask = torch.where(\n                rearrange(focus_present_mask, 'b -> b 1 1 1 1'),\n                rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),\n                rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),\n            )\n\n            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)\n\n        # numerical stability\n\n        sim = sim - sim.amax(dim=-1, keepdim=True).detach()\n        attn = sim.softmax(dim=-1)\n\n        # aggregate values\n\n        out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v)\n        out = rearrange(out, '... h n d -> ... n (h d)')\n        # return self.to_out(out)\n        return out\n    \nclass RelativePositionBias(nn.Module):\n    def __init__(\n            self,\n            heads=8,\n            num_buckets=32,\n            max_distance=128\n    ):\n        super().__init__()\n        self.num_buckets = num_buckets\n        self.max_distance = max_distance\n        self.relative_attention_bias = nn.Embedding(num_buckets, heads)\n\n    @staticmethod\n    def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):\n        ret = 0\n        n = -relative_position\n\n        num_buckets //= 2\n        ret += (n < 0).long() * num_buckets\n        n = torch.abs(n)\n\n        max_exact = num_buckets // 2\n        is_small = n < max_exact\n\n        val_if_large = max_exact + (\n                torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)\n        ).long()\n        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))\n\n        ret += torch.where(is_small, n, val_if_large)\n        return ret\n\n    def forward(self, n, device):\n        q_pos = torch.arange(n, dtype=torch.long, device=device)\n        k_pos = torch.arange(n, dtype=torch.long, device=device)\n        rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')\n        rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets,\n                                                   max_distance=self.max_distance)\n        mask = - (((rel_pos > 20) + (rel_pos < - 20)) * (1e10))\n        values = self.relative_attention_bias(rp_bucket)\n        return mask +  rearrange(values, 'i j h -> h i j')\n    \nif __name__ == \"__main__\":\n    # Example usage:\n    d_model = 256\n    window_size = 20\n    seq_len = 200\n    batch_size = 1\n    head = 4\n    res_pos = RelativePositionBias(heads=4, max_distance=32)\n    rope = RotaryEmbedding(min(64, d_model//head), seq_before_head_dim = True)\n    rope2 = RotaryEmbedding(min(64, d_model//head))\n    model = LocalSelfAttention_opt(d_model, head, d_model//head, window_size, rotary_emb=rope)\n\n    model_2 = Attention(d_model, head, dim_head= d_model//head, rotary_emb = rope2)\n    \n    \n    rp = res_pos(200, 'cpu')\n    xp_mask = create_sliding_window_mask(rp, 2 * window_size + 1, 1)\n    for i in range(5):\n        x = torch.randn(batch_size, 9, seq_len, d_model)\n        st = time.time()\n        output = model([x, xp_mask])\n        ed = time.time()\n        print(\"optimized: \", ed - st)\n        st = time.time()\n        output_2 = model_2(x , pos_bias = rp)\n        ed = time.time()\n        print(\"origin: \", ed - st)\n        print(((output - output_2)**2).mean())"
  },
  {
    "path": "DM_3/modules/text.py",
    "content": "# the code from https://github.com/lucidrains/video-diffusion-pytorch\nimport torch\nfrom einops import rearrange\n\n\ndef exists(val):\n    return val is not None\n\n\n# singleton globals\n\nMODEL = None\nTOKENIZER = None\nHUBERT_MODEL_DIM = 20*1024\n# BERT_MODEL_DIM = 768\n\n\ndef get_tokenizer():\n    global TOKENIZER\n    if not exists(TOKENIZER):\n        TOKENIZER = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased')\n    return TOKENIZER\n\n\ndef get_bert():\n    global MODEL\n    if not exists(MODEL):\n        MODEL = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-cased')\n        if torch.cuda.is_available():\n            MODEL = MODEL.cuda()\n\n    return MODEL\n\n\n# tokenize\n\ndef tokenize(texts, add_special_tokens=True):\n    if not isinstance(texts, (list, tuple)):\n        texts = [texts]\n\n    tokenizer = get_tokenizer()\n\n    encoding = tokenizer.batch_encode_plus(\n        texts,\n        add_special_tokens=add_special_tokens,\n        padding=True,\n        return_tensors='pt'\n    )\n\n    token_ids = encoding.input_ids\n    return token_ids\n\n\n# embedding function\n\n@torch.no_grad()\ndef bert_embed(\n        token_ids,\n        return_cls_repr=False,\n        eps=1e-8,\n        pad_id=0.\n):\n    model = get_bert()\n    mask = token_ids != pad_id\n\n    if torch.cuda.is_available():\n        token_ids = token_ids.cuda()\n        mask = mask.cuda()\n\n    outputs = model(\n        input_ids=token_ids,\n        attention_mask=mask,\n        output_hidden_states=True\n    )\n\n    hidden_state = outputs.hidden_states[-1]\n\n    if return_cls_repr:\n        return hidden_state[:, 0]  # return [cls] as representation\n\n    if not exists(mask):\n        return hidden_state.mean(dim=1)\n\n    mask = mask[:, 1:]  # mean all tokens excluding [cls], accounting for length\n    mask = rearrange(mask, 'b n -> b n 1')\n\n    numer = (hidden_state[:, 1:] * mask).sum(dim=1)\n    denom = mask.sum(dim=1)\n    masked_mean = numer / (denom + eps)\n    return masked_mean\n"
  },
  {
    "path": "DM_3/modules/video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_mouth_mask_6D.py",
    "content": "'''\nstage 1: using 0th as the reference, short and fixed clip\n\nwith lip loss, 6D pose, conditioned by cross attention\n\n'''\nimport os\nimport torch\nimport torch.nn as nn\nimport sys\nsys.path.append('your/path')\nfrom LFG.modules.generator import Generator\nfrom LFG.modules.bg_motion_predictor import BGMotionPredictor\nfrom LFG.modules.region_predictor import RegionPredictor\nfrom DM_3.modules.video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi import DynamicNfUnet3D, DynamicNfGaussianDiffusion\nimport yaml\nfrom sync_batchnorm import DataParallelWithCallback\nfrom filter_fourier import *\n\nfrom DM_v0_loss_exp.modules.util import AntiAliasInterpolation2d\nfrom torchvision import models\nimport numpy as np\nimport time\nfrom einops import rearrange\n\nclass Attention(nn.Module):\n    def __init__(self, params):\n        super(Attention, self).__init__()\n        self.fc_query = nn.Linear(params['n'], params['dim_attention'], bias=False)\n        self.fc_attention = nn.Linear(params['dim_attention'], 1)\n    \n    def forward(self, ctx_val, ctx_key, ctx_mask, ht_query):\n\n        ht_query = self.fc_query(ht_query)\n\n        attention_score = torch.tanh(ctx_key + ht_query[:, None, None, :])\n        attention_score = self.fc_attention(attention_score).squeeze(3)\n        \n        attention_score = attention_score - attention_score.max()\n        attention_score = torch.exp(attention_score) * ctx_mask\n        attention_score = attention_score / (attention_score.sum(2).sum(1)[:, None, None] + 1e-10)\n\n        ct = (ctx_val * attention_score[:, None, :, :]).sum(3).sum(2)\n\n        return ct, attention_score\n        \nclass Face_loc_Encoder(nn.Module):\n    def __init__(self, dim = 1):\n        super(Face_loc_Encoder, self).__init__()\n        self.conv1 = nn.Conv2d(dim, 8, kernel_size=3, stride=2, padding=1)\n        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = nn.functional.relu(x)\n        x = self.conv2(x)\n        x = nn.functional.relu(x)\n        return x\n\nclass Vgg19(torch.nn.Module):\n    \"\"\"\n    Vgg19 network for perceptual loss.\n    \"\"\"\n\n    def __init__(self, requires_grad=False):\n        super(Vgg19, self).__init__()\n        vgg_pretrained_features = models.vgg19(pretrained=True).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        for x in range(2):\n            self.slice1.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(2, 7):\n            self.slice2.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(7, 12):\n            self.slice3.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(12, 21):\n            self.slice4.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(21, 30):\n            self.slice5.add_module(str(x), vgg_pretrained_features[x])\n\n        self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),\n                                       requires_grad=False)\n        self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),\n                                      requires_grad=False)\n\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, x):\n        x = (x - self.mean) / self.std\n        h_relu1 = self.slice1(x)\n        h_relu2 = self.slice2(h_relu1)\n        h_relu3 = self.slice3(h_relu2)\n        h_relu4 = self.slice4(h_relu3)\n        h_relu5 = self.slice5(h_relu4)\n        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]\n        return out\n\nclass ImagePyramide(torch.nn.Module):\n    \"\"\"\n    Create image pyramide for computing pyramide perceptual loss.\n    \"\"\"\n\n    def __init__(self, scales, num_channels):\n        super(ImagePyramide, self).__init__()\n        downs = {}\n        for scale in scales:\n            downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)\n        self.downs = nn.ModuleDict(downs)\n\n    def forward(self, x):\n        out_dict = {}\n        for scale, down_module in self.downs.items():\n            out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)\n        return out_dict\n\nclass FlowDiffusion(nn.Module):\n    def __init__(self, img_size=32, num_frames=40, sampling_timesteps=250,\n                 null_cond_prob=0.1,\n                 ddim_sampling_eta=1.,\n                 dim_mults=(1, 2, 4, 8),\n                 is_train=True,\n                 use_residual_flow=False,\n                 learn_null_cond=False,\n                 use_deconv=True,\n                 padding_mode=\"zeros\",\n                 pretrained_pth=\"\",\n                 config_pth=\"\"):\n        super(FlowDiffusion, self).__init__()\n        self.use_residual_flow = use_residual_flow\n\n        checkpoint = torch.load(pretrained_pth)\n        with open(config_pth) as f:\n            config = yaml.safe_load(f)\n\n        self.generator = Generator(num_regions=config['model_params']['num_regions'],\n                                   num_channels=config['model_params']['num_channels'],\n                                   revert_axis_swap=config['model_params']['revert_axis_swap'],\n                                   **config['model_params']['generator_params']).cuda()\n        self.generator.load_state_dict(checkpoint['generator'])\n        self.generator.eval()\n        self.set_requires_grad(self.generator, False)\n\n        self.region_predictor = RegionPredictor(num_regions=config['model_params']['num_regions'],\n                                                num_channels=config['model_params']['num_channels'],\n                                                estimate_affine=config['model_params']['estimate_affine'],\n                                                **config['model_params']['region_predictor_params']).cuda()\n        self.region_predictor.load_state_dict(checkpoint['region_predictor'])\n        self.region_predictor.eval()\n        self.set_requires_grad(self.region_predictor, False)\n\n        self.bg_predictor = BGMotionPredictor(num_channels=config['model_params']['num_channels'],\n                                              **config['model_params']['bg_predictor_params'])\n        self.bg_predictor.load_state_dict(checkpoint['bg_predictor'])\n        self.bg_predictor.eval()\n        self.set_requires_grad(self.bg_predictor, False)\n\n        self.scales = config['train_params']['scales']\n        self.pyramid = ImagePyramide(self.scales, self.generator.num_channels)\n        if torch.cuda.is_available():\n            self.pyramid = self.pyramid.cuda()\n\n        # self.vgg_loss_weights = config['train_params']['loss_weights']['rec_vgg']\n        # if sum(self.vgg_loss_weights) != 0:\n        #     self.vgg = Vgg19()\n        #     if torch.cuda.is_available():\n        #         self.vgg = self.vgg.cuda()\n\n        self.unet = DynamicNfUnet3D(dim=64,\n                           cond_dim=1024 + 6 + 2,\n                           cond_aud=1024,\n                           cond_pose=6,\n                           cond_eye=2,\n                           num_frames=num_frames,\n                           channels=3 + 256 + 16,\n                           out_grid_dim=2,\n                           out_conf_dim=1,\n                           dim_mults=dim_mults,\n                           use_hubert_audio_cond=True,\n                           learn_null_cond=learn_null_cond,\n                           use_final_activation=False,\n                           use_deconv=use_deconv,\n                           padding_mode=padding_mode)\n\n        self.diffusion = DynamicNfGaussianDiffusion(\n            denoise_fn = self.unet,\n            num_frames=num_frames,\n            image_size=img_size,\n            sampling_timesteps=sampling_timesteps,\n            timesteps=1000,  # number of steps\n            loss_type='l2',  # L1 or L2\n            use_dynamic_thres=True,\n            null_cond_prob=null_cond_prob,\n            ddim_sampling_eta=ddim_sampling_eta\n        )\n\n        self.face_loc_emb = Face_loc_Encoder()\n\n        # training\n        self.is_train = is_train\n        if self.is_train:\n            self.unet.train()\n            self.diffusion.train()\n\n    def update_num_frames(self, new_num_frames):\n        # to update num_frames of Unet3D and GaussianDiffusion\n        self.unet.update_num_frames(new_num_frames)\n        self.diffusion.update_num_frames(new_num_frames)\n\n    def generate_bbox_mask(self, bbox, size = 32):\n        # b = bbox.shape[0]\n\n        b, c, fn = bbox.size()\n        bbox = bbox[:,:,0]  # b, c, fn\n        bbox[:, :2] = (bbox[:, :2]/bbox[:, 4].unsqueeze(1)) * size  # rescale to 32* 32 for 128, and 64 * 64 for 256\n        bbox[:,2:4] = (bbox[:, 2:4]/bbox[:, 5].unsqueeze(1) )* size\n\n        bbox_left_top = bbox[:, :4:2].to(torch.int32)  # left up\n        bbox_right_bottom = (bbox[:, 1:4:2] +1).to(torch.int32)  # right down\n\n        # generating 2D index\n        row_indices = torch.arange(size).view(1, size, 1).expand(b, size, size).to(torch.uint8).cuda()\n        col_indices = torch.arange(size).view(1, 1, size).expand(b, size, size).to(torch.uint8).cuda()\n\n        # set the face bbox as 1, the first channel is y， the second is x\n        mask = (row_indices >= bbox_left_top[:, 1].view(b, 1, 1)) & (row_indices <= bbox_right_bottom[:, 1].view(b, 1, 1)) & \\\n            (col_indices >= bbox_left_top[:, 0].view(b, 1, 1)) & (col_indices <= bbox_right_bottom[:, 0].view(b, 1, 1))\n        # mask : b,32,32\n\n        bbox_mask = mask.unsqueeze(1).float()  # b, 1, 32, 32\n\n        return bbox_mask\n\n    def generate_mouth_mask(self, mouth_lmk, origin_size, size = 32):\n        b, fn, pn, c = mouth_lmk.size()  # b, fn  12,  2\n\n        origin_size = origin_size.unsqueeze(1)\n        mouth_lmk = (mouth_lmk/origin_size) * size\n\n        ld_coner = mouth_lmk.max(dim=-2)[0]\n        ru_coner = mouth_lmk.min(dim=-2)[0]\n\n        row_indices = torch.arange(size).view(1, size, 1).expand(b, fn, size, size).to(torch.uint8).cuda()\n        col_indices = torch.arange(size).view(1, 1, size).expand(b, fn, size, size).to(torch.uint8).cuda()\n\n        mask = (row_indices >= ru_coner[:,:, 1].view(b, fn, 1, 1)) & (row_indices <= ld_coner[:,:, 1].view(b, fn, 1, 1)) & \\\n            (col_indices >= ru_coner[:,:, 0].view(b, fn, 1, 1)) & (col_indices <= ld_coner[:,:, 0].view(b, fn, 1, 1))\n\n        bbox_mask = (mask).float()  # b, 1, 32, 32\n\n        return bbox_mask\n\n    def forward(self, real_vid, ref_img, ref_text, ref_pose, ref_eye_blink, bbox, mouth_lmk, is_eval=False, ref_id = 0):  \n        if True:\n            b,c,f,h,w = real_vid.size()\n            real_vid = rearrange(real_vid, 'b c f h w -> (b f) c h w')\n            bright = 64. / 255\n            contrast = 0.25\n            sat = 0.25\n            hue = 0.04\n\n            color_jitters = transforms.ColorJitter(hue = (-hue, hue), \\\n                                                   contrast = (max(0, 1 - contrast), 1 + contrast), \n                                                   saturation = (max(0, 1 - sat), 1 + sat), \n                                                   brightness = (max(0, 1 - bright), 1 + bright))\n\n            # mast have shape :  [..., 1 or 3, H, W]\n            real_vid = real_vid/255.  # because the img are floats, so need to scale to 0-1\n            real_vid = color_jitters(real_vid)  # shape need be checked\n            real_vid = rearrange(real_vid, '(b f) c h w -> b c f h w', b = b, f = f)\n            ref_img = real_vid[:,:,ref_id,:,:].clone().detach()\n\n        b, _, nf, H, W = real_vid.size()\n\n        \n\n        ref_pose = ref_pose.squeeze(1).permute(0, 2, 1)[:, :, :-1]\n        ref_eye_blink = ref_eye_blink.squeeze(1).permute(0, 2, 1)\n\n        init_pose = ref_pose[:, ref_id].unsqueeze(1).repeat(1, nf, 1)     # b, fn, 7  init state\n        init_eye = ref_eye_blink[:, ref_id].unsqueeze(1).repeat(1, nf, 1) # b, fn, 2\n\n        ref_text = torch.concat([ref_text, (ref_pose-init_pose), (ref_eye_blink-init_eye)], dim=-1)\n\n        bbox_mask = self.generate_bbox_mask(bbox, size = real_vid.shape[-1])    # b, 1, 32, 32\n\n        bbox_mask = self.face_loc_emb(bbox_mask)  # conv encoder for face mask\n\n        mouth_mask = self.generate_mouth_mask(mouth_lmk, bbox[:,None,-2:,0], size = real_vid.shape[-1]//4)\n\n        real_grid_list = []\n        real_conf_list = []\n        real_out_img_list = []\n        real_warped_img_list = []\n        output_dict = {}\n        with torch.no_grad():\n\n            b,c,f,h,w = real_vid.size()\n            real_vid_tmp = rearrange(real_vid, 'b c f h w -> (b f) c h w') # real_vid.reshape(b * f, c, h,  w) \n            ref_img_tmp = ref_img.unsqueeze(1).repeat(1,f,1,1,1).reshape(-1, 3, h, w)\n            source_region_params = self.region_predictor(ref_img_tmp)\n           \n            driving_region_params = self.region_predictor(real_vid_tmp)\n            bg_params = self.bg_predictor(ref_img_tmp, real_vid_tmp)\n            generated = self.generator(ref_img_tmp, source_region_params=source_region_params,\n                                        driving_region_params=driving_region_params, bg_params=bg_params)\n            output_dict[\"real_vid_grid\"] = rearrange(generated[\"optical_flow\"], '(b f) h w c -> b c f h w', b = b, f = f)\n            output_dict[\"real_vid_conf\"] = rearrange(generated[\"occlusion_map\"], '(b f) c h w -> b c f h w', b = b, f = f)\n            output_dict[\"real_out_vid\"] = rearrange(generated[\"prediction\"], '(b f) c h w -> b c f h w', b = b, f = f)\n            output_dict[\"real_warped_vid\"] = rearrange(generated[\"deformed\"], '(b f) c h w -> b c f h w', b = b, f = f)\n        ref_img_fea = generated[\"bottle_neck_feat\"][::f].clone().detach()       #bs, 256, 32, 32\n        del real_vid_tmp, ref_img_tmp\n        del generated\n\n\n        if self.is_train:\n            if self.use_residual_flow:\n                h, w, = H // 4, W // 4\n                identity_grid = self.get_grid(b, nf, h, w, normalize=True).cuda()\n                output_dict[\"loss\"], output_dict[\"null_cond_mask\"] = self.diffusion(\n                    torch.cat((output_dict[\"real_vid_grid\"] - identity_grid,\n                               output_dict[\"real_vid_conf\"] * 2 - 1), dim=1),\n                    ref_img_fea,\n                    bbox_mask,\n                    ref_text)\n            else:\n                output_dict[\"loss\"], output_dict[\"null_cond_mask\"] = self.diffusion(\n                    torch.cat((output_dict[\"real_vid_grid\"],\n                               output_dict[\"real_vid_conf\"] * 2 - 1), dim=1),\n                    ref_img_fea,\n                    bbox_mask,\n                    ref_text)\n            \n            pred = self.diffusion.pred_x0\n            pred_flow = pred[:, :2, :, :, :]\n            pred_conf = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5\n            # loss_high_freq = hf_loss(fea = pred_flow, mask = self.gaussian_mask.cuda(), dim = 2)\n            loss_high_freq = nn.MSELoss(reduce = False)(pred_flow, output_dict[\"real_vid_grid\"]) + nn.MSELoss(reduce = False)(pred_conf, output_dict[\"real_vid_conf\"])\n            output_dict['mouth_loss'] = ((output_dict[\"loss\"] * mouth_mask.unsqueeze(1)).sum())/(mouth_mask.sum())\n            output_dict[\"loss\"] = output_dict[\"loss\"].mean(1)\n            output_dict[\"floss\"] = loss_high_freq.mean(1)\n\n            \n            if(is_eval):\n                with torch.no_grad():\n                    fake_out_img_list = []\n                    fake_warped_img_list = []\n                    pred = self.diffusion.pred_x0  # bs, 3, nf, 32, 32\n                    if self.use_residual_flow:\n                        output_dict[\"fake_vid_grid\"] = pred[:, :2, :, :, :] + identity_grid\n                    else:\n                        output_dict[\"fake_vid_grid\"] = pred[:, :2, :, :, :]  # optical flow predicted by DM_2   bs, 2, nf, 32, 32\n                    output_dict[\"fake_vid_conf\"] = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5 # occlusion map  predicted by DM_2   bs, 1, nf, 32, 32\n                    for idx in range(nf):\n                        fake_grid = output_dict[\"fake_vid_grid\"][:, :, idx, :, :].permute(0, 2, 3, 1) #bs, 32, 32, 2\n                        fake_conf = output_dict[\"fake_vid_conf\"][:, :, idx, :, :]                     #bs, 1, 32, 32\n                        # predict fake out image and fake warped image\n                        generated = self.generator.forward_with_flow(source_image=ref_img,\n                                                                        optical_flow=fake_grid,\n                                                                        occlusion_map=fake_conf)\n                        fake_out_img_list.append(generated[\"prediction\"])\n                        fake_warped_img_list.append(generated[\"deformed\"].detach())\n                        del generated\n                    output_dict[\"fake_out_vid\"] = torch.stack(fake_out_img_list, dim=2)\n                    output_dict[\"fake_warped_vid\"] = torch.stack(fake_warped_img_list, dim=2).detach()\n                    # output_dict[\"rec_loss\"] = nn.L1Loss(reduce=False)(real_vid, output_dict[\"fake_out_vid\"])\n                    # output_dict[\"rec_warp_loss\"] = nn.L1Loss(reduce=False)(real_vid, output_dict[\"fake_warped_vid\"])\n\n                    # b,c,f,h,w = real_vid.size()\n                    # real_vid_tensor = real_vid.permute(0,2,1,3,4).reshape(b*f,c,h,w).detach()\n                    # fake_out_vid_tensor  = output_dict[\"fake_out_vid\"].permute(0,2,1,3,4).reshape(b*f,c,h,w)\n\n                    \n\n                    # if sum(self.vgg_loss_weights) != 0:\n                    #     pyramide_real = self.pyramid(real_vid_tensor)\n                    #     pyramide_generated = self.pyramid(fake_out_vid_tensor)\n                    #     value_total = 0\n                    #     for scale in self.scales:\n                    #         x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])\n                    #         y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])\n\n                    #         for i, weight in enumerate(self.vgg_loss_weights):\n                    #             value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()\n                    #             value_total += self.vgg_loss_weights[i] * value\n                    #     output_dict['rec_vgg_loss'] = value_total\n                    # if __debug__:\n                    #     end_time = time.time()  # end\n                    #     # print(f'forward for eval part time {end_time- start_time}')\n                    #     start_time = end_time\n\n\n        return output_dict\n\n    def sample_one_video(self, real_vid, sample_img, sample_audio_hubert, sample_pose, sample_eye, sample_bbox, cond_scale):\n        output_dict = {} \n        sample_img_fea = self.generator.compute_fea(sample_img) # sample_img: bs,3,128,128 sample_img_fea: 1,256,32,32\n        bbox_mask = self.generate_bbox_mask(sample_bbox, size = sample_img.shape[-1])\n\n        bbox_mask = self.face_loc_emb(bbox_mask)  # conv encoder for face mask\n\n\n        ref_pose = sample_pose.permute(0, 2, 1)[:,:,:-1]\n        ref_eye_blink = sample_eye.permute(0, 2, 1)\n\n        init_pose = ref_pose[:, 0].unsqueeze(1).repeat(1,ref_pose.shape[1], 1)\n        init_eye = ref_eye_blink[:, 0].unsqueeze(1).repeat(1,ref_eye_blink.shape[1], 1) # b, fn, 2\n\n        ref_text = torch.concat([sample_audio_hubert, (ref_pose - init_pose), (ref_eye_blink - init_eye)], dim=-1)\n\n        bs = sample_img_fea.size(0)\n        # if cond_scale = 1.0, not using unconditional model\n        # pred bs, 3, nf, 32, 32\n        pred = self.diffusion.sample(sample_img_fea, bbox_mask, cond=ref_text,\n                                     batch_size=bs, cond_scale=cond_scale)\n        if self.use_residual_flow:\n            b, _, nf, h, w = pred[:, :2, :, :, :].size()\n            identity_grid = self.get_grid(b, nf, h, w, normalize=True).cuda()\n            output_dict[\"sample_vid_grid\"] = pred[:, :2, :, :, :] + identity_grid\n        else:\n            output_dict[\"sample_vid_grid\"] = pred[:, :2, :, :, :]  # bs, 2, nf, 32, 32\n        output_dict[\"sample_vid_conf\"] = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5  # bs, 1, nf, 32, 32\n        nf = output_dict[\"sample_vid_grid\"].size(2)\n        with torch.no_grad():\n            sample_out_img_list = []\n            sample_warped_img_list = []\n            for idx in range(nf):\n                sample_grid = output_dict[\"sample_vid_grid\"][:, :, idx, :, :].permute(0, 2, 3, 1)\n                sample_conf = output_dict[\"sample_vid_conf\"][:, :, idx, :, :]\n                # predict fake out image and fake warped image\n                generated = self.generator.forward_with_flow(source_image=sample_img,\n                                                             optical_flow=sample_grid,\n                                                             occlusion_map=sample_conf)\n                sample_out_img_list.append(generated[\"prediction\"])\n                sample_warped_img_list.append(generated[\"deformed\"])\n        output_dict[\"sample_out_vid\"] = torch.stack(sample_out_img_list, dim=2)\n        output_dict[\"sample_warped_vid\"] = torch.stack(sample_warped_img_list, dim=2)\n\n        output_dict[\"rec_loss\"] = nn.L1Loss(reduce=False)(real_vid, output_dict[\"sample_out_vid\"])\n        output_dict[\"rec_warp_loss\"] = nn.L1Loss(reduce=False)(real_vid, output_dict[\"sample_warped_vid\"])\n        \n        # b,c,f,h,w = real_vid[0].unsqueeze(dim=0).size()\n        # real_vid_tensor = real_vid[0].unsqueeze(dim=0).permute(0,2,1,3,4).reshape(b*f,c,h,w)\n        # fake_out_vid_tensor  = output_dict[\"sample_out_vid\"].permute(0,2,1,3,4).reshape(b*f,c,h,w)\n\n        # pyramide_real = self.pyramid(real_vid_tensor)\n        # pyramide_generated = self.pyramid(fake_out_vid_tensor)\n\n        # if sum(self.vgg_loss_weights) != 0:\n        #     value_total = 0\n        #     for scale in self.scales:\n        #         x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])\n        #         y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])\n\n        #         for i, weight in enumerate(self.vgg_loss_weights):\n        #             value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()\n        #             value_total += self.vgg_loss_weights[i] * value\n        #     output_dict['rec_vgg_loss'] = value_total\n\n        return output_dict\n\n    def get_grid(self, b, nf, H, W, normalize=True):\n        if normalize:\n            h_range = torch.linspace(-1, 1, H)\n            w_range = torch.linspace(-1, 1, W)\n        else:\n            h_range = torch.arange(0, H)\n            w_range = torch.arange(0, W)\n        grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).repeat(b, 1, 1, 1).flip(3).float()  # flip h,w to x,y\n        return grid.permute(0, 3, 1, 2).unsqueeze(dim=2).repeat(1, 1, nf, 1, 1)\n\n    def set_requires_grad(self, nets, requires_grad=False):\n        \"\"\"Set requies_grad=Fasle for all the networks to avoid unnecessary computations\n        Parameters:\n            nets (network list)   -- a list of networks\n            requires_grad (bool)  -- whether the networks require gradients or not\n        \"\"\"\n        if not isinstance(nets, list):\n            nets = [nets]\n        for net in nets:\n            if net is not None:\n                for param in net.parameters():\n\n                    param.requires_grad = requires_grad\n\n\nif __name__ == \"__main__\":\n    os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n    bs = 5\n    img_size = 128\n    num_frames = 10\n    ref_text = [\"play basketball\"] * bs\n    ref_img = torch.rand((bs, 3, img_size, img_size), dtype=torch.float32).cuda()\n    real_vid = torch.rand((bs, 3, num_frames, img_size, img_size), dtype=torch.float32).cuda()\n    model = FlowDiffusion(num_frames=num_frames, use_residual_flow=False, sampling_timesteps=10, dim_mults=(1, 2, 4, 8, 16))\n    model.cuda()\n    # embedding ref_text\n    # cond = bert_embed(tokenize(ref_text), return_cls_repr=model.diffusion.text_use_bert_cls).cuda()\n\n    # to simulate the situation of hubert embedding\n    cond = torch.rand((bs,10,1024), dtype=torch.float32).cuda()\n    model = DataParallelWithCallback(model)\n    output_dict = model.forward(real_vid=real_vid, ref_img=ref_img, ref_text=cond)\n    model.module.sample_one_video(sample_img=ref_img[0].unsqueeze(dim=0),\n                                  sample_audio_hubert=cond[0].unsqueeze(dim=0),\n                                  cond_scale=1.0)\n"
  },
  {
    "path": "DM_3/modules/video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_mouth_mask_rand_6D.py",
    "content": "'''\nstage 2: using random reference, long and dynamic clip\n\nwith lip loss, 6D pose, conditioned by cross attention\n\n'''\nimport os\nimport torch\nimport torch.nn as nn\nimport sys\nsys.path.append('your/path')\nfrom LFG.modules.generator import Generator\nfrom LFG.modules.bg_motion_predictor import BGMotionPredictor\nfrom LFG.modules.region_predictor import RegionPredictor\nfrom DM_3.modules.video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi import DynamicNfUnet3D, DynamicNfGaussianDiffusion\nimport yaml\nfrom sync_batchnorm import DataParallelWithCallback\nfrom filter_fourier import *\n\nfrom torchvision import models\nimport numpy as np\nimport time\nfrom einops import rearrange\n\nclass Attention(nn.Module):\n    def __init__(self, params):\n        super(Attention, self).__init__()\n        self.fc_query = nn.Linear(params['n'], params['dim_attention'], bias=False)\n        self.fc_attention = nn.Linear(params['dim_attention'], 1)\n    \n    def forward(self, ctx_val, ctx_key, ctx_mask, ht_query):\n\n        ht_query = self.fc_query(ht_query)\n\n        attention_score = torch.tanh(ctx_key + ht_query[:, None, None, :])\n        attention_score = self.fc_attention(attention_score).squeeze(3)\n        \n        attention_score = attention_score - attention_score.max()\n        attention_score = torch.exp(attention_score) * ctx_mask\n        attention_score = attention_score / (attention_score.sum(2).sum(1)[:, None, None] + 1e-10)\n\n        ct = (ctx_val * attention_score[:, None, :, :]).sum(3).sum(2)\n\n        return ct, attention_score\n        \nclass Face_loc_Encoder(nn.Module):\n    def __init__(self, dim = 1):\n        super(Face_loc_Encoder, self).__init__()\n        self.conv1 = nn.Conv2d(dim, 8, kernel_size=3, stride=2, padding=1)\n        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = nn.functional.relu(x)\n        x = self.conv2(x)\n        x = nn.functional.relu(x)\n        return x\n\nclass Vgg19(torch.nn.Module):\n    \"\"\"\n    Vgg19 network for perceptual loss.\n    \"\"\"\n\n    def __init__(self, requires_grad=False):\n        super(Vgg19, self).__init__()\n        vgg_pretrained_features = models.vgg19(pretrained=True).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        for x in range(2):\n            self.slice1.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(2, 7):\n            self.slice2.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(7, 12):\n            self.slice3.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(12, 21):\n            self.slice4.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(21, 30):\n            self.slice5.add_module(str(x), vgg_pretrained_features[x])\n\n        self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),\n                                       requires_grad=False)\n        self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),\n                                      requires_grad=False)\n\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, x):\n        x = (x - self.mean) / self.std\n        h_relu1 = self.slice1(x)\n        h_relu2 = self.slice2(h_relu1)\n        h_relu3 = self.slice3(h_relu2)\n        h_relu4 = self.slice4(h_relu3)\n        h_relu5 = self.slice5(h_relu4)\n        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]\n        return out\n\n\nclass FlowDiffusion(nn.Module):\n    def __init__(self, img_size=32, num_frames=40, sampling_timesteps=250,\n                 null_cond_prob=0.1,\n                 ddim_sampling_eta=1.,\n                 dim_mults=(1, 2, 4, 8),\n                 is_train=True,\n                 use_residual_flow=False,\n                 learn_null_cond=False,\n                 use_deconv=True,\n                 padding_mode=\"zeros\",\n                 pretrained_pth=\"your_path/data/log-hdtf/hdtf128_2023-11-17_20:13/snapshots/RegionMM.pth\",\n                 config_pth=\"your/path/DAWN-pytorch/config/hdtf128.yaml\"):\n        super(FlowDiffusion, self).__init__()\n        self.use_residual_flow = use_residual_flow\n\n        checkpoint = torch.load(pretrained_pth)\n        with open(config_pth) as f:\n            config = yaml.safe_load(f)\n\n        self.generator = Generator(num_regions=config['model_params']['num_regions'],\n                                   num_channels=config['model_params']['num_channels'],\n                                   revert_axis_swap=config['model_params']['revert_axis_swap'],\n                                   **config['model_params']['generator_params']).cuda()\n        self.generator.load_state_dict(checkpoint['generator'])\n        self.generator.eval()\n        self.set_requires_grad(self.generator, False)\n\n        self.region_predictor = RegionPredictor(num_regions=config['model_params']['num_regions'],\n                                                num_channels=config['model_params']['num_channels'],\n                                                estimate_affine=config['model_params']['estimate_affine'],\n                                                **config['model_params']['region_predictor_params']).cuda()\n        self.region_predictor.load_state_dict(checkpoint['region_predictor'])\n        self.region_predictor.eval()\n        self.set_requires_grad(self.region_predictor, False)\n\n        self.bg_predictor = BGMotionPredictor(num_channels=config['model_params']['num_channels'],\n                                              **config['model_params']['bg_predictor_params'])\n        self.bg_predictor.load_state_dict(checkpoint['bg_predictor'])\n        self.bg_predictor.eval()\n        self.set_requires_grad(self.bg_predictor, False)\n\n        self.scales = config['train_params']['scales']\n\n        self.unet = DynamicNfUnet3D(dim=64,\n                           cond_dim=1024 + 6 + 2,\n                           cond_aud=1024,\n                           cond_pose=6,\n                           cond_eye=2,\n                           num_frames=num_frames,\n                           channels=3 + 256 + 16,\n                           out_grid_dim=2,\n                           out_conf_dim=1,\n                           dim_mults=dim_mults,\n                           use_hubert_audio_cond=True,\n                           learn_null_cond=learn_null_cond,\n                           use_final_activation=False,\n                           use_deconv=use_deconv,\n                           padding_mode=padding_mode)\n\n        self.diffusion = DynamicNfGaussianDiffusion(\n            denoise_fn = self.unet,\n            num_frames=num_frames,\n            image_size=img_size,\n            sampling_timesteps=sampling_timesteps,\n            timesteps=1000,  # number of steps\n            loss_type='l2',  # L1 or L2\n            use_dynamic_thres=True,\n            null_cond_prob=null_cond_prob,\n            ddim_sampling_eta=ddim_sampling_eta\n        )\n\n        self.face_loc_emb = Face_loc_Encoder()\n\n        # training\n        self.is_train = is_train\n        if self.is_train:\n            self.unet.train()\n            self.diffusion.train()\n\n    def update_num_frames(self, new_num_frames):\n        # to update num_frames of Unet3D and GaussianDiffusion\n        self.unet.update_num_frames(new_num_frames)\n        self.diffusion.update_num_frames(new_num_frames)\n\n    def generate_bbox_mask(self, bbox, size = 32):\n        # b = bbox.shape[0]\n\n        b, c, fn = bbox.size()\n        bbox = bbox[:,:,0]  # b, c, fn\n        bbox[:, :2] = (bbox[:, :2]/bbox[:, 4].unsqueeze(1)) * size  # rescale 32* 32\n        bbox[:,2:4] = (bbox[:, 2:4]/bbox[:, 5].unsqueeze(1) )* size\n\n        bbox_left_top = bbox[:, :4:2].to(torch.int32)  \n        bbox_right_bottom = (bbox[:, 1:4:2] +1).to(torch.int32)  \n\n        row_indices = torch.arange(size).view(1, size, 1).expand(b, size, size).to(torch.uint8).cuda()\n        col_indices = torch.arange(size).view(1, 1, size).expand(b, size, size).to(torch.uint8).cuda()\n\n        mask = (row_indices >= bbox_left_top[:, 1].view(b, 1, 1)) & (row_indices <= bbox_right_bottom[:, 1].view(b, 1, 1)) & \\\n            (col_indices >= bbox_left_top[:, 0].view(b, 1, 1)) & (col_indices <= bbox_right_bottom[:, 0].view(b, 1, 1))\n\n        bbox_mask = mask.unsqueeze(1).float()  # b, 1, 32, 32\n\n        return bbox_mask\n\n    def generate_mouth_mask(self, mouth_lmk, origin_size, size = 32):\n        b, fn, pn, c = mouth_lmk.size()  # b, fn  12,  2\n\n        origin_size = origin_size.unsqueeze(1)\n        mouth_lmk = (mouth_lmk/origin_size) * size\n\n        ld_coner = mouth_lmk.max(dim=-2)[0]\n        ru_coner = mouth_lmk.min(dim=-2)[0]\n\n        row_indices = torch.arange(size).view(1, size, 1).expand(b, fn, size, size).to(torch.uint8).cuda()\n        col_indices = torch.arange(size).view(1, 1, size).expand(b, fn, size, size).to(torch.uint8).cuda()\n\n        mask = (row_indices >= ru_coner[:,:, 1].view(b, fn, 1, 1)) & (row_indices <= ld_coner[:,:, 1].view(b, fn, 1, 1)) & \\\n            (col_indices >= ru_coner[:,:, 0].view(b, fn, 1, 1)) & (col_indices <= ld_coner[:,:, 0].view(b, fn, 1, 1))\n\n        bbox_mask = (mask).float()  # b, 1, 32, 32\n\n        return bbox_mask\n\n    def forward(self, real_vid, ref_img, ref_text, ref_pose, ref_eye_blink, bbox, mouth_lmk, is_eval=False):\n        if True:\n            b,c,f,h,w = real_vid.size()\n            real_vid = rearrange(real_vid, 'b c f h w -> (b f) c h w')\n            bright = 64. / 255\n            contrast = 0.25\n            sat = 0.25\n            hue = 0.04\n            # bright_f = random.uniform(max(0, 1 - bright), 1 + bright)\n            # contrast_f = random.uniform(max(0, 1 - contrast), 1 + contrast)\n            # sat_f = random.uniform(max(0, 1 - sat), 1 + sat)\n            # hue_f = random.uniform(-hue, hue)\n\n            color_jitters = transforms.ColorJitter(hue = (-hue, hue), \\\n                                                   contrast = (max(0, 1 - contrast), 1 + contrast), \n                                                   saturation = (max(0, 1 - sat), 1 + sat), \n                                                   brightness = (max(0, 1 - bright), 1 + bright))\n\n            # mast have shape :  [..., 1 or 3, H, W]\n            real_vid = real_vid/255.  # because the img are floats, so need to scale to 0-1\n            real_vid = color_jitters(real_vid)  # shape need be checked\n            real_vid = rearrange(real_vid, '(b f) c h w -> b c f h w', b = b, f = f)\n            ref_img = real_vid[:,:,0,:,:].clone().detach()\n            real_vid = real_vid[:,:,1:,:,:]\n\n            # if __debug__:\n            #     end_time = time.time()  # end\n            #     print(f'data augment time 1 {end_time- start_time}')\n            #     start_time = end_time\n\n            # sample_frame_list = color_jitters(sample_frame_list)\n            # if __debug__:\n            #     end_time = time.time()  # end\n            #     print(f'data augment time 2 {end_time- start_time}')\n            #     start_time = end_time\n        # else:\n        #     real_vid = rearrange(real_vid, 'b f c h w -> b c f h w')\n        # else:\n        #     real_vid = real_vid/255.\n        #     ref_img = ref_img/255.\n        b, _, _, H, W = real_vid.size()\n        _, nf, _ = ref_text.size()\n        ref_pose = ref_pose.squeeze(1).permute(0, 2, 1)[:, :, :-1]\n        ref_eye_blink = ref_eye_blink.squeeze(1).permute(0, 2, 1)\n\n        init_pose = ref_pose[:, 0].unsqueeze(1).repeat(1, nf, 1)     # b, fn, 7  init state\n        init_eye = ref_eye_blink[:, 0].unsqueeze(1).repeat(1, nf, 1) # b, fn, 2\n\n        ref_text = torch.concat([ref_text, (ref_pose-init_pose), (ref_eye_blink-init_eye)], dim=-1)\n        ref_text = ref_text[:, 1:]\n        \n        bbox_mask = self.generate_bbox_mask(bbox, size = real_vid.shape[-1])    # b, 1, 32, 32\n\n        bbox_mask = self.face_loc_emb(bbox_mask)  # conv encoder for face mask\n\n        mouth_mask = self.generate_mouth_mask(mouth_lmk, bbox[:,None,-2:,0], size = real_vid.shape[-1]//4)\n\n        real_grid_list = []\n        real_conf_list = []\n        real_out_img_list = []\n        real_warped_img_list = []\n        output_dict = {}\n\n        # if __debug__:\n        #     end_time = time.time()  # end\n        #     # print(f'forward process time {end_time- start_time}')\n        #     start_time = end_time\n        with torch.no_grad():\n            \n            # for idx in range(nf):\n                # driving_region_params = self.region_predictor(real_vid[:, :, idx, :, :])\n                # bg_params = self.bg_predictor(ref_img, real_vid[:, :, idx, :, :])\n                # generated = self.generator(ref_img, source_region_params=source_region_params,\n                #                            driving_region_params=driving_region_params, bg_params=bg_params)\n                # generated.update({'source_region_params': source_region_params,\n                #                   'driving_region_params': driving_region_params})\n                # real_grid_list.append(generated[\"optical_flow\"].permute(0, 3, 1, 2))\n                # # normalized occlusion map\n                # real_conf_list.append(generated[\"occlusion_map\"])\n                # real_out_img_list.append(generated[\"prediction\"])\n                # real_warped_img_list.append(generated[\"deformed\"])\n            b,c,f,h,w = real_vid.size()\n            real_vid_tmp = rearrange(real_vid, 'b c f h w -> (b f) c h w')# real_vid.reshape(b * f, c, h,  w) \n            ref_img_tmp = ref_img.unsqueeze(1).repeat(1,f,1,1,1).reshape(-1, 3, h, w)\n            source_region_params = self.region_predictor(ref_img_tmp)\n           \n            driving_region_params = self.region_predictor(real_vid_tmp)\n            bg_params = self.bg_predictor(ref_img_tmp, real_vid_tmp)\n            generated = self.generator(ref_img_tmp, source_region_params=source_region_params,\n                                        driving_region_params=driving_region_params, bg_params=bg_params)\n            output_dict[\"real_vid_grid\"] = rearrange(generated[\"optical_flow\"], '(b f) h w c -> b c f h w', b = b, f = f) # .permute(0,3,1,2).reshape(b, 2, f, 32, 32)\n            output_dict[\"real_vid_conf\"] = rearrange(generated[\"occlusion_map\"], '(b f) c h w -> b c f h w', b = b, f = f) # generated[\"occlusion_map\"].reshape(b, 1, f, 32, 32)\n            output_dict[\"real_out_vid\"] = rearrange(generated[\"prediction\"], '(b f) c h w -> b c f h w', b = b, f = f) # generated[\"prediction\"].reshape(b, 3, f, h, w)\n            output_dict[\"real_warped_vid\"] = rearrange(generated[\"deformed\"], '(b f) c h w -> b c f h w', b = b, f = f) # generated[\"deformed\"].reshape(b, 3, f, h, w)\n\n        # output_dict[\"real_vid_grid\"] = torch.stack(real_grid_list, dim=2)          # bs,2,num_frames,32,32\n        # output_dict[\"real_vid_conf\"] = torch.stack(real_conf_list, dim=2)          # bs,1,num_frames,32,32\n        # output_dict[\"real_out_vid\"] = torch.stack(real_out_img_list, dim=2)        # bs,3,num_frames,128,128\n        # output_dict[\"real_warped_vid\"] = torch.stack(real_warped_img_list, dim=2)  # bs,3,num_frames,128,128\n        # reference images are the same for different time steps, just pick the final one\n        # ref_img_fea = generated[\"bottle_neck_feat\"].clone().detach()       #bs, 256, 32, 32\n        ref_img_fea = generated[\"bottle_neck_feat\"][::f].clone().detach()       #bs, 256, 32, 32\n        del real_vid_tmp, ref_img_tmp\n        del generated\n\n        # if __debug__:\n        #     end_time = time.time()  # end\n        #     # print(f'generate gt flow time {end_time- start_time}')\n        #     start_time = end_time\n\n        if self.is_train:\n            if self.use_residual_flow:\n                h, w, = H // 4, W // 4\n                identity_grid = self.get_grid(b, nf, h, w, normalize=True).cuda()\n                output_dict[\"loss\"], output_dict[\"null_cond_mask\"] = self.diffusion(\n                    torch.cat((output_dict[\"real_vid_grid\"] - identity_grid,\n                               output_dict[\"real_vid_conf\"] * 2 - 1), dim=1),\n                    ref_img_fea,\n                    bbox_mask,\n                    ref_text)\n            else:\n                output_dict[\"loss\"], output_dict[\"null_cond_mask\"] = self.diffusion(\n                    torch.cat((output_dict[\"real_vid_grid\"],\n                               output_dict[\"real_vid_conf\"] * 2 - 1), dim=1),\n                    ref_img_fea,\n                    bbox_mask,\n                    ref_text)\n            \n            pred = self.diffusion.pred_x0\n            pred_flow = pred[:, :2, :, :, :]\n            pred_conf = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5\n            # loss_high_freq = hf_loss(fea = pred_flow, mask = self.gaussian_mask.cuda(), dim = 2)\n            # loss_high_freq = hf_loss_2(pred_flow, output_dict[\"real_vid_grid\"], dim=2)\n            loss_high_freq = nn.MSELoss(reduce = False)(pred_flow, output_dict[\"real_vid_grid\"]) + nn.MSELoss(reduce = False)(pred_conf, output_dict[\"real_vid_conf\"])\n            output_dict['mouth_loss'] = ((output_dict[\"loss\"] * mouth_mask.unsqueeze(1)).sum())/(mouth_mask.sum())\n            output_dict[\"loss\"] = output_dict[\"loss\"].mean(1)\n            output_dict[\"floss\"] = loss_high_freq.mean(1)\n\n            # if __debug__:\n            #     end_time = time.time()  # end\n            #     # print(f'forward diffusion time {end_time- start_time}')\n            #     start_time = end_time\n            \n            if(is_eval):\n                with torch.no_grad():\n                    fake_out_img_list = []\n                    fake_warped_img_list = []\n                    pred = self.diffusion.pred_x0  # bs, 3, nf, 32, 32\n                    if self.use_residual_flow:\n                        output_dict[\"fake_vid_grid\"] = pred[:, :2, :, :, :] + identity_grid\n                    else:\n                        output_dict[\"fake_vid_grid\"] = pred[:, :2, :, :, :]  # optical flow predicted by DM_2   bs, 2, nf, 32, 32\n                    output_dict[\"fake_vid_conf\"] = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5 # occlusion map  predicted by DM_2   bs, 1, nf, 32, 32\n                    for idx in range(nf - 1):\n                        fake_grid = output_dict[\"fake_vid_grid\"][:, :, idx, :, :].permute(0, 2, 3, 1) #bs, 32, 32, 2\n                        fake_conf = output_dict[\"fake_vid_conf\"][:, :, idx, :, :]                     #bs, 1, 32, 32\n                        # predict fake out image and fake warped image\n                        generated = self.generator.forward_with_flow(source_image=ref_img,\n                                                                        optical_flow=fake_grid,\n                                                                        occlusion_map=fake_conf)\n                        fake_out_img_list.append(generated[\"prediction\"])\n                        fake_warped_img_list.append(generated[\"deformed\"].detach())\n                        del generated\n                    output_dict[\"fake_out_vid\"] = torch.stack(fake_out_img_list, dim=2)\n                    output_dict[\"fake_warped_vid\"] = torch.stack(fake_warped_img_list, dim=2).detach()\n                    # output_dict[\"rec_loss\"] = nn.L1Loss(reduce=False)(real_vid, output_dict[\"fake_out_vid\"])\n                    # output_dict[\"rec_warp_loss\"] = nn.L1Loss(reduce=False)(real_vid, output_dict[\"fake_warped_vid\"])\n\n                    # b,c,f,h,w = real_vid.size()\n                    # if __debug__:\n                    #     end_time = time.time()  # end\n                    #     # print(f'forward for eval part time {end_time- start_time}')\n                    #     start_time = end_time\n\n\n        return output_dict\n\n    def sample_one_video(self, real_vid, sample_img, sample_audio_hubert, sample_pose, sample_eye, sample_bbox, cond_scale):\n        output_dict = {} \n        sample_img_fea = self.generator.compute_fea(sample_img) # sample_img: bs,3,128,128 sample_img_fea: 1,256,32,32\n        bbox_mask = self.generate_bbox_mask(sample_bbox, size = sample_img.shape[-1])\n\n        bbox_mask = self.face_loc_emb(bbox_mask)  # conv encoder for face mask\n\n\n        ref_pose = sample_pose.permute(0, 2, 1)[:,:,:-1]\n        ref_eye_blink = sample_eye.permute(0, 2, 1)\n\n        init_pose = ref_pose[:, 0].unsqueeze(1).repeat(1,ref_pose.shape[1], 1)\n        init_eye = ref_eye_blink[:, 0].unsqueeze(1).repeat(1,ref_eye_blink.shape[1], 1) # b, fn, 2\n\n        ref_text = torch.concat([sample_audio_hubert, (ref_pose - init_pose), (ref_eye_blink - init_eye)], dim=-1)\n        ref_text = ref_text[:, 1:]\n        bs = sample_img_fea.size(0)\n        # if cond_scale = 1.0, not using unconditional model\n        # pred bs, 3, nf, 32, 32\n        pred = self.diffusion.sample(sample_img_fea, bbox_mask, cond=ref_text,\n                                     batch_size=bs, cond_scale=cond_scale)\n        if self.use_residual_flow:\n            b, _, nf, h, w = pred[:, :2, :, :, :].size()\n            identity_grid = self.get_grid(b, nf, h, w, normalize=True).cuda()\n            output_dict[\"sample_vid_grid\"] = pred[:, :2, :, :, :] + identity_grid\n        else:\n            output_dict[\"sample_vid_grid\"] = pred[:, :2, :, :, :]  # bs, 2, nf, 32, 32\n        output_dict[\"sample_vid_conf\"] = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5  # bs, 1, nf, 32, 32\n        nf = output_dict[\"sample_vid_grid\"].size(2)\n        with torch.no_grad():\n            sample_out_img_list = []\n            sample_warped_img_list = []\n            for idx in range(nf):\n                sample_grid = output_dict[\"sample_vid_grid\"][:, :, idx, :, :].permute(0, 2, 3, 1)\n                sample_conf = output_dict[\"sample_vid_conf\"][:, :, idx, :, :]\n                # predict fake out image and fake warped image\n                generated = self.generator.forward_with_flow(source_image=sample_img,\n                                                             optical_flow=sample_grid,\n                                                             occlusion_map=sample_conf)\n                sample_out_img_list.append(generated[\"prediction\"])\n                sample_warped_img_list.append(generated[\"deformed\"])\n        output_dict[\"sample_out_vid\"] = torch.stack(sample_out_img_list, dim=2)\n        output_dict[\"sample_warped_vid\"] = torch.stack(sample_warped_img_list, dim=2)\n\n        output_dict[\"rec_loss\"] = nn.L1Loss(reduce=False)(real_vid, output_dict[\"sample_out_vid\"])\n        output_dict[\"rec_warp_loss\"] = nn.L1Loss(reduce=False)(real_vid, output_dict[\"sample_warped_vid\"])\n        \n\n        return output_dict\n\n    def get_grid(self, b, nf, H, W, normalize=True):\n        if normalize:\n            h_range = torch.linspace(-1, 1, H)\n            w_range = torch.linspace(-1, 1, W)\n        else:\n            h_range = torch.arange(0, H)\n            w_range = torch.arange(0, W)\n        grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).repeat(b, 1, 1, 1).flip(3).float()  # flip h,w to x,y\n        return grid.permute(0, 3, 1, 2).unsqueeze(dim=2).repeat(1, 1, nf, 1, 1)\n\n    def set_requires_grad(self, nets, requires_grad=False):\n        \"\"\"Set requies_grad=Fasle for all the networks to avoid unnecessary computations\n        Parameters:\n            nets (network list)   -- a list of networks\n            requires_grad (bool)  -- whether the networks require gradients or not\n        \"\"\"\n        if not isinstance(nets, list):\n            nets = [nets]\n        for net in nets:\n            if net is not None:\n                for param in net.parameters():\n\n                    param.requires_grad = requires_grad\n\n\nif __name__ == \"__main__\":\n    os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n    bs = 5\n    img_size = 128\n    num_frames = 10\n    ref_text = [\"play basketball\"] * bs\n    ref_img = torch.rand((bs, 3, img_size, img_size), dtype=torch.float32).cuda()\n    real_vid = torch.rand((bs, 3, num_frames, img_size, img_size), dtype=torch.float32).cuda()\n    model = FlowDiffusion(num_frames=num_frames, use_residual_flow=False, sampling_timesteps=10, dim_mults=(1, 2, 4, 8, 16))\n    model.cuda()\n    # embedding ref_text\n    # cond = bert_embed(tokenize(ref_text), return_cls_repr=model.diffusion.text_use_bert_cls).cuda()\n\n    # to simulate the situation of hubert embedding\n    cond = torch.rand((bs,10,1024), dtype=torch.float32).cuda()\n    model = DataParallelWithCallback(model)\n    output_dict = model.forward(real_vid=real_vid, ref_img=ref_img, ref_text=cond)\n    model.module.sample_one_video(sample_img=ref_img[0].unsqueeze(dim=0),\n                                  sample_audio_hubert=cond[0].unsqueeze(dim=0),\n                                  cond_scale=1.0)\n"
  },
  {
    "path": "DM_3/modules/video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_test.py",
    "content": "import os\nimport torch\nimport torch.nn as nn\nimport sys\nfrom LFG.modules.generator import Generator\nfrom LFG.modules.bg_motion_predictor import BGMotionPredictor\nfrom LFG.modules.region_predictor import RegionPredictor\nfrom DM_3.modules.video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test import DynamicNfUnet3D, DynamicNfGaussianDiffusion\nimport yaml\nfrom sync_batchnorm import DataParallelWithCallback\nfrom filter_fourier import *\n\nfrom torchvision import models\nimport numpy as np\nimport time\nfrom einops import rearrange\n\nclass Attention(nn.Module):\n    def __init__(self, params):\n        super(Attention, self).__init__()\n        self.fc_query = nn.Linear(params['n'], params['dim_attention'], bias=False)\n        self.fc_attention = nn.Linear(params['dim_attention'], 1)\n    \n    def forward(self, ctx_val, ctx_key, ctx_mask, ht_query):\n\n        ht_query = self.fc_query(ht_query)\n\n        attention_score = torch.tanh(ctx_key + ht_query[:, None, None, :])\n        attention_score = self.fc_attention(attention_score).squeeze(3)\n        \n        attention_score = attention_score - attention_score.max()\n        attention_score = torch.exp(attention_score) * ctx_mask\n        attention_score = attention_score / (attention_score.sum(2).sum(1)[:, None, None] + 1e-10)\n\n        ct = (ctx_val * attention_score[:, None, :, :]).sum(3).sum(2)\n\n        return ct, attention_score\n        \nclass Face_loc_Encoder(nn.Module):\n    def __init__(self, dim = 1):\n        super(Face_loc_Encoder, self).__init__()\n        self.conv1 = nn.Conv2d(dim, 8, kernel_size=3, stride=2, padding=1)\n        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = nn.functional.relu(x)\n        x = self.conv2(x)\n        x = nn.functional.relu(x)\n        return x\n\nclass Vgg19(torch.nn.Module):\n    \"\"\"\n    Vgg19 network for perceptual loss.\n    \"\"\"\n\n    def __init__(self, requires_grad=False):\n        super(Vgg19, self).__init__()\n        vgg_pretrained_features = models.vgg19(pretrained=True).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        for x in range(2):\n            self.slice1.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(2, 7):\n            self.slice2.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(7, 12):\n            self.slice3.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(12, 21):\n            self.slice4.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(21, 30):\n            self.slice5.add_module(str(x), vgg_pretrained_features[x])\n\n        self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),\n                                       requires_grad=False)\n        self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),\n                                      requires_grad=False)\n\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, x):\n        x = (x - self.mean) / self.std\n        h_relu1 = self.slice1(x)\n        h_relu2 = self.slice2(h_relu1)\n        h_relu3 = self.slice3(h_relu2)\n        h_relu4 = self.slice4(h_relu3)\n        h_relu5 = self.slice5(h_relu4)\n        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]\n        return out\n\n\nclass FlowDiffusion(nn.Module):\n    def __init__(self, img_size=32, num_frames=40, sampling_timesteps=250, win_width = 40,\n                 null_cond_prob=0.1,\n                 ddim_sampling_eta=1.,\n                 pose_dim = 7,\n                 dim_mults=(1, 2, 4, 8),\n                 is_train=True,\n                 use_residual_flow=False,\n                 learn_null_cond=False,\n                 use_deconv=True,\n                 padding_mode=\"zeros\",\n                 pretrained_pth=\"your_path/data/log-hdtf/hdtf128_2023-11-17_20:13/snapshots/RegionMM.pth\",\n                 config_pth=\"your/path/DAWN-pytorch/config/hdtf128.yaml\"):\n        super(FlowDiffusion, self).__init__()\n        self.use_residual_flow = use_residual_flow\n\n        checkpoint = torch.load(pretrained_pth)\n        with open(config_pth) as f:\n            config = yaml.safe_load(f)\n\n        self.generator = Generator(num_regions=config['model_params']['num_regions'],\n                                   num_channels=config['model_params']['num_channels'],\n                                   revert_axis_swap=config['model_params']['revert_axis_swap'],\n                                   **config['model_params']['generator_params']).cuda()\n        self.generator.load_state_dict(checkpoint['generator'])\n        self.generator.eval()\n        self.set_requires_grad(self.generator, False)\n\n        self.region_predictor = RegionPredictor(num_regions=config['model_params']['num_regions'],\n                                                num_channels=config['model_params']['num_channels'],\n                                                estimate_affine=config['model_params']['estimate_affine'],\n                                                **config['model_params']['region_predictor_params']).cuda()\n        self.region_predictor.load_state_dict(checkpoint['region_predictor'])\n        self.region_predictor.eval()\n        self.set_requires_grad(self.region_predictor, False)\n\n        self.bg_predictor = BGMotionPredictor(num_channels=config['model_params']['num_channels'],\n                                              **config['model_params']['bg_predictor_params'])\n        self.bg_predictor.load_state_dict(checkpoint['bg_predictor'])\n        self.bg_predictor.eval()\n        self.set_requires_grad(self.bg_predictor, False)\n\n        self.scales = config['train_params']['scales']\n        self.pose_dim = pose_dim\n        self.unet = DynamicNfUnet3D(dim=64,\n                           cond_dim=1024 + self.pose_dim + 2,\n                           cond_aud=1024,\n                           cond_pose=self.pose_dim,\n                           cond_eye=2,\n                           num_frames=num_frames,\n                           channels=3 + 256 + 16,\n                           out_grid_dim=2,\n                           out_conf_dim=1,\n                           dim_mults=dim_mults,\n                           use_hubert_audio_cond=True,\n                           learn_null_cond=learn_null_cond,\n                           use_final_activation=False,\n                           use_deconv=use_deconv,\n                           padding_mode=padding_mode, \n                           win_width = win_width)\n\n        self.diffusion = DynamicNfGaussianDiffusion(\n            denoise_fn = self.unet,\n            num_frames=num_frames,\n            image_size=img_size,\n            sampling_timesteps=sampling_timesteps,\n            timesteps=1000,  # number of steps\n            loss_type='l2',  # L1 or L2\n            use_dynamic_thres=True,\n            null_cond_prob=null_cond_prob,\n            ddim_sampling_eta=ddim_sampling_eta\n        )\n\n        self.face_loc_emb = Face_loc_Encoder()\n\n        # training\n        self.is_train = is_train\n        if self.is_train:\n            self.unet.train()\n            self.diffusion.train()\n\n    def update_num_frames(self, new_num_frames):\n        # to update num_frames of Unet3D and GaussianDiffusion\n        self.unet.update_num_frames(new_num_frames)\n        self.diffusion.update_num_frames(new_num_frames)\n\n    def generate_bbox_mask(self, bbox, size = 32):\n        # b = bbox.shape[0]\n\n        b, c, fn = bbox.size()\n        bbox = bbox[:,:,0]  # b, c, fn\n        bbox[:, :2] = (bbox[:, :2]/bbox[:, 4].unsqueeze(1)) * size  \n        bbox[:,2:4] = (bbox[:, 2:4]/bbox[:, 5].unsqueeze(1) )* size\n\n        bbox_left_top = bbox[:, :4:2].to(torch.int32)  \n        bbox_right_bottom = (bbox[:, 1:4:2] +1).to(torch.int32) \n\n        row_indices = torch.arange(size).view(1, size, 1).expand(b, size, size).to(torch.uint8).cuda()\n        col_indices = torch.arange(size).view(1, 1, size).expand(b, size, size).to(torch.uint8).cuda()\n\n        mask = (row_indices >= bbox_left_top[:, 1].view(b, 1, 1)) & (row_indices <= bbox_right_bottom[:, 1].view(b, 1, 1)) & \\\n            (col_indices >= bbox_left_top[:, 0].view(b, 1, 1)) & (col_indices <= bbox_right_bottom[:, 0].view(b, 1, 1))\n      \n        bbox_mask = mask.unsqueeze(1).float()  # b, 1, 32, 32\n\n        return bbox_mask\n\n    def forward(self, real_vid, ref_img, ref_text, ref_pose, ref_eye_blink, bbox, is_eval=False, ref_id = 0):\n        if True:\n            b,c,f,h,w = real_vid.size()\n            real_vid = rearrange(real_vid, 'b c f h w -> (b f) c h w')\n            bright = 64. / 255\n            contrast = 0.25\n            sat = 0.25\n            hue = 0.04\n\n            color_jitters = transforms.ColorJitter(hue = (-hue, hue), \\\n                                                   contrast = (max(0, 1 - contrast), 1 + contrast), \n                                                   saturation = (max(0, 1 - sat), 1 + sat), \n                                                   brightness = (max(0, 1 - bright), 1 + bright))\n\n            # mast have shape :  [..., 1 or 3, H, W]\n            real_vid = real_vid/255.  # because the img are floats, so need to scale to 0-1\n            real_vid = color_jitters(real_vid)  # shape need be checked\n            real_vid = rearrange(real_vid, '(b f) c h w -> b c f h w', b = b, f = f)\n            ref_img = real_vid[:,:,ref_id,:,:].clone().detach()\n\n\n        b, _, nf, H, W = real_vid.size()\n\n        \n\n        ref_pose = ref_pose.squeeze(1).permute(0, 2, 1)\n        ref_eye_blink = ref_eye_blink.squeeze(1).permute(0, 2, 1)\n\n        init_pose = ref_pose[:, ref_id].unsqueeze(1).repeat(1, nf, 1)     # b, fn, 7  init state\n        init_eye = ref_eye_blink[:, ref_id].unsqueeze(1).repeat(1, nf, 1) # b, fn, 2\n\n        ref_text = torch.concat([ref_text, (ref_pose-init_pose), (ref_eye_blink-init_eye)], dim=-1)\n\n        bbox_mask = self.generate_bbox_mask(bbox, size = H)    # b, 1, 32, 32\n\n        bbox_mask = self.face_loc_emb(bbox_mask)  # conv encoder for face mask\n\n\n        real_grid_list = []\n        real_conf_list = []\n        real_out_img_list = []\n        real_warped_img_list = []\n        output_dict = {}\n\n     \n        with torch.no_grad():\n            \n            b,c,f,h,w = real_vid.size()\n            real_vid_tmp = rearrange(real_vid, 'b c f h w -> (b f) c h w')# real_vid.reshape(b * f, c, h,  w) \n            ref_img_tmp = ref_img.unsqueeze(1).repeat(1,f,1,1,1).reshape(-1, 3, 128, 128)\n            source_region_params = self.region_predictor(ref_img_tmp)\n           \n            driving_region_params = self.region_predictor(real_vid_tmp)\n            bg_params = self.bg_predictor(ref_img_tmp, real_vid_tmp)\n            generated = self.generator(ref_img_tmp, source_region_params=source_region_params,\n                                        driving_region_params=driving_region_params, bg_params=bg_params)\n            output_dict[\"real_vid_grid\"] = rearrange(generated[\"optical_flow\"], '(b f) h w c -> b c f h w', b = b, f = f) # .permute(0,3,1,2).reshape(b, 2, f, 32, 32)\n            output_dict[\"real_vid_conf\"] = rearrange(generated[\"occlusion_map\"], '(b f) c h w -> b c f h w', b = b, f = f) # generated[\"occlusion_map\"].reshape(b, 1, f, 32, 32)\n            output_dict[\"real_out_vid\"] = rearrange(generated[\"prediction\"], '(b f) c h w -> b c f h w', b = b, f = f) # generated[\"prediction\"].reshape(b, 3, f, h, w)\n            output_dict[\"real_warped_vid\"] = rearrange(generated[\"deformed\"], '(b f) c h w -> b c f h w', b = b, f = f) # generated[\"deformed\"].reshape(b, 3, f, h, w)\n\n        ref_img_fea = generated[\"bottle_neck_feat\"][::f].clone().detach()       #bs, 256, 32, 32\n        del real_vid_tmp, ref_img_tmp\n        del generated\n\n\n        if self.is_train:\n            if self.use_residual_flow:\n                h, w, = H // 4, W // 4\n                identity_grid = self.get_grid(b, nf, h, w, normalize=True).cuda()\n                output_dict[\"loss\"], output_dict[\"null_cond_mask\"] = self.diffusion(\n                    torch.cat((output_dict[\"real_vid_grid\"] - identity_grid,\n                               output_dict[\"real_vid_conf\"] * 2 - 1), dim=1),\n                    ref_img_fea,\n                    bbox_mask,\n                    ref_text)\n            else:\n                output_dict[\"loss\"], output_dict[\"null_cond_mask\"] = self.diffusion(\n                    torch.cat((output_dict[\"real_vid_grid\"],\n                               output_dict[\"real_vid_conf\"] * 2 - 1), dim=1),\n                    ref_img_fea,\n                    bbox_mask,\n                    ref_text)\n            \n            pred = self.diffusion.pred_x0\n            pred_flow = pred[:, :2, :, :, :]\n            # loss_high_freq = hf_loss(fea = pred_flow, mask = self.gaussian_mask.cuda(), dim = 2)\n            loss_high_freq = hf_loss_2(pred_flow, output_dict[\"real_vid_grid\"], dim=2)\n            output_dict[\"loss\"] = output_dict[\"loss\"].mean(1)\n            output_dict[\"floss\"] = loss_high_freq.mean(1)\n\n            # if __debug__:\n            #     end_time = time.time()  # end\n            #     # print(f'forward diffusion time {end_time- start_time}')\n            #     start_time = end_time\n            \n            if(is_eval):\n                with torch.no_grad():\n                    fake_out_img_list = []\n                    fake_warped_img_list = []\n                    pred = self.diffusion.pred_x0  # bs, 3, nf, 32, 32\n                    if self.use_residual_flow:\n                        output_dict[\"fake_vid_grid\"] = pred[:, :2, :, :, :] + identity_grid\n                    else:\n                        output_dict[\"fake_vid_grid\"] = pred[:, :2, :, :, :]  # optical flow predicted by DM_2   bs, 2, nf, 32, 32\n                    output_dict[\"fake_vid_conf\"] = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5 # occlusion map  predicted by DM_2   bs, 1, nf, 32, 32\n                    for idx in range(nf):\n                        fake_grid = output_dict[\"fake_vid_grid\"][:, :, idx, :, :].permute(0, 2, 3, 1) #bs, 32, 32, 2\n                        fake_conf = output_dict[\"fake_vid_conf\"][:, :, idx, :, :]                     #bs, 1, 32, 32\n                        # predict fake out image and fake warped image\n                        generated = self.generator.forward_with_flow(source_image=ref_img,\n                                                                        optical_flow=fake_grid,\n                                                                        occlusion_map=fake_conf)\n                        fake_out_img_list.append(generated[\"prediction\"])\n                        fake_warped_img_list.append(generated[\"deformed\"].detach())\n                        del generated\n                    output_dict[\"fake_out_vid\"] = torch.stack(fake_out_img_list, dim=2)\n                    output_dict[\"fake_warped_vid\"] = torch.stack(fake_warped_img_list, dim=2).detach()\n\n\n        return output_dict\n\n    def sample_one_video(self, sample_img, sample_audio_hubert, sample_pose, sample_eye, sample_bbox, cond_scale, init_pose = None, init_eye = None, real_vid = None):\n        output_dict = {} \n        sample_img_fea = self.generator.compute_fea(sample_img) # sample_img: bs,3,128,128 sample_img_fea: 1,256,32,32\n        bbox_mask = self.generate_bbox_mask(sample_bbox, size = sample_img.shape[-1])\n\n        bbox_mask = self.face_loc_emb(bbox_mask)  # conv encoder for face mask\n\n        sample_pose = sample_pose[:,:self.pose_dim]\n        \n        ref_pose = sample_pose.permute(0, 2, 1)\n        ref_eye_blink = sample_eye.permute(0, 2, 1)\n\n        if init_pose == None:\n            init_pose = ref_pose[:, 0].unsqueeze(1).repeat(1,ref_pose.shape[1], 1)\n        else:\n            init_pose = init_pose.unsqueeze(1).repeat(1,ref_pose.shape[1], 1)\n        \n        init_pose = init_pose[:,:,:self.pose_dim]\n        if init_eye == None:\n            init_eye = ref_eye_blink[:, 0].unsqueeze(1).repeat(1,ref_eye_blink.shape[1], 1)\n        else:\n            init_eye = init_eye.unsqueeze(1).repeat(1,ref_eye_blink.shape[1], 1)\n\n        if ref_pose.shape[-1] != init_pose.shape[-1]:\n            ref_pose = torch.concat([ref_pose, init_pose[:,:,-1].unsqueeze(-1)], dim = -1)\n        ref_text = torch.concat([sample_audio_hubert, (ref_pose - init_pose), (ref_eye_blink - init_eye)], dim=-1)\n\n        bs = sample_img_fea.size(0)\n        # if cond_scale = 1.0, not using unconditional model\n        # pred bs, 3, nf, 32, 32\n        start_time = time.time()  # end\n        start_time_total = time.time()  # end\n\n        pred = self.diffusion.sample(sample_img_fea, bbox_mask, cond=ref_text,\n                                     batch_size=bs, cond_scale=cond_scale)\n        if self.use_residual_flow:\n            b, _, nf, h, w = pred[:, :2, :, :, :].size()\n            identity_grid = self.get_grid(b, nf, h, w, normalize=True).cuda()\n            output_dict[\"sample_vid_grid\"] = pred[:, :2, :, :, :] + identity_grid\n        else:\n            output_dict[\"sample_vid_grid\"] = pred[:, :2, :, :, :]  # bs, 2, nf, 32, 32\n        output_dict[\"sample_vid_conf\"] = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5  # bs, 1, nf, 32, 32\n        nf = output_dict[\"sample_vid_grid\"].size(2)\n\n        end_time = time.time()  # end\n        print(f'DDIM time {end_time- start_time}')\n        start_time = end_time\n        with torch.no_grad():\n            sample_out_img_list = []\n            sample_warped_img_list = []\n            for idx in range(nf):\n                sample_grid = output_dict[\"sample_vid_grid\"][:, :, idx, :, :].permute(0, 2, 3, 1)\n                sample_conf = output_dict[\"sample_vid_conf\"][:, :, idx, :, :]\n                # predict fake out image and fake warped image\n                generated = self.generator.forward_with_flow(source_image=sample_img,\n                                                             optical_flow=sample_grid,\n                                                             occlusion_map=sample_conf)\n                sample_out_img_list.append(generated[\"prediction\"])\n                sample_warped_img_list.append(generated[\"deformed\"])\n        output_dict[\"sample_out_vid\"] = torch.stack(sample_out_img_list, dim=2)\n        output_dict[\"sample_warped_vid\"] = torch.stack(sample_warped_img_list, dim=2)\n\n        # real_vid_tmp = rearrange(real_vids, 'b c f h w -> (b f) c h w')# real_vid.reshape(b * f, c, h,  w)\n        # with torch.no_grad():\n        #     sample_grid = output_dict[\"sample_vid_grid\"]\n        #     sample_grid = rearrange(sample_grid, 'b c f h w -> (b f) h w c')\n        #     sample_conf = output_dict[\"sample_vid_conf\"]\n        #     sample_conf = rearrange(sample_conf, 'b c f h w -> (b f) c h w')\n        #     sample_img = sample_img.repeat(nf, 1, 1, 1)\n        #     generated = self.generator.forward_with_flow(source_image=sample_img,\n        #                                                      optical_flow=sample_grid,\n        #                                                      occlusion_map=sample_conf)\n        #     output_dict[\"sample_out_vid\"] =  rearrange(generated[\"prediction\"], '(b f) c h w -> b c f h w', b = 1, f =nf)\n        end_time = time.time()  # end\n        # with open('your/path/DAWN-pytorch/speed_test.txt', 'a') as f:\n        #     f.write(f'AE time {end_time- start_time}\\n')\n        #     f.write(f'Total time {end_time- start_time_total}')\n        #     print(f'AE time {end_time- start_time}')\n        #     print(f'Total time {end_time- start_time_total}')\n        start_time = end_time\n\n        return output_dict\n\n    def get_grid(self, b, nf, H, W, normalize=True):\n        if normalize:\n            h_range = torch.linspace(-1, 1, H)\n            w_range = torch.linspace(-1, 1, W)\n        else:\n            h_range = torch.arange(0, H)\n            w_range = torch.arange(0, W)\n        grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).repeat(b, 1, 1, 1).flip(3).float()  # flip h,w to x,y\n        return grid.permute(0, 3, 1, 2).unsqueeze(dim=2).repeat(1, 1, nf, 1, 1)\n\n    def set_requires_grad(self, nets, requires_grad=False):\n        \"\"\"Set requies_grad=Fasle for all the networks to avoid unnecessary computations\n        Parameters:\n            nets (network list)   -- a list of networks\n            requires_grad (bool)  -- whether the networks require gradients or not\n        \"\"\"\n        if not isinstance(nets, list):\n            nets = [nets]\n        for net in nets:\n            if net is not None:\n                for param in net.parameters():\n\n                    param.requires_grad = requires_grad\n\n\nif __name__ == \"__main__\":\n    os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n    bs = 5\n    img_size = 128\n    num_frames = 10\n    ref_text = [\"play basketball\"] * bs\n    ref_img = torch.rand((bs, 3, img_size, img_size), dtype=torch.float32).cuda()\n    real_vid = torch.rand((bs, 3, num_frames, img_size, img_size), dtype=torch.float32).cuda()\n    model = FlowDiffusion(num_frames=num_frames, use_residual_flow=False, sampling_timesteps=10, dim_mults=(1, 2, 4, 8, 16))\n    model.cuda()\n    # embedding ref_text\n    # cond = bert_embed(tokenize(ref_text), return_cls_repr=model.diffusion.text_use_bert_cls).cuda()\n\n    # to simulate the situation of hubert embedding\n    cond = torch.rand((bs,10,1024), dtype=torch.float32).cuda()\n    model = DataParallelWithCallback(model)\n    output_dict = model.forward(real_vid=real_vid, ref_img=ref_img, ref_text=cond)\n    model.module.sample_one_video(sample_img=ref_img[0].unsqueeze(dim=0),\n                                  sample_audio_hubert=cond[0].unsqueeze(dim=0),\n                                  cond_scale=1.0)\n"
  },
  {
    "path": "DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi.py",
    "content": "'''\nadding pose condtioning on baseline\nusing cross attention to add different condition\n\nfor training\n'''\nimport math\nimport torch\nfrom torch import nn, einsum\nimport torch.nn.functional as F\nfrom functools import partial\n\nfrom torchvision import transforms as T\nfrom PIL import Image\n\nfrom tqdm import tqdm\nfrom einops import rearrange, repeat, reduce, pack, unpack\nfrom einops_exts import rearrange_many\n\nfrom rotary_embedding_torch import RotaryEmbedding\n\n# from DM.modules.text import tokenize, bert_embed, HUBERT_MODEL_DIM\n\n\n# helpers functions\n\ndef exists(x):\n    return x is not None\n\n\ndef noop(*args, **kwargs):\n    pass\n\n\ndef is_odd(n):\n    return (n % 2) == 1\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if callable(d) else d\n\n\ndef cycle(dl):\n    while True:\n        for data in dl:\n            yield data\n\n\ndef num_to_groups(num, divisor):\n    groups = num // divisor\n    remainder = num % divisor\n    arr = [divisor] * groups\n    if remainder > 0:\n        arr.append(remainder)\n    return arr\n\n\ndef prob_mask_like(shape, prob, device):\n    if prob == 1:\n        return torch.ones(shape, device=device, dtype=torch.bool)\n    elif prob == 0:\n        return torch.zeros(shape, device=device, dtype=torch.bool)\n    else:\n        return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob\n\n\ndef is_list_str(x):\n    if not isinstance(x, (list, tuple)):\n        return False\n    return all([type(el) == str for el in x])\n\n\n# relative positional bias\n\nclass RelativePositionBias(nn.Module):\n    def __init__(\n            self,\n            heads=8,\n            num_buckets=32,\n            max_distance=128\n    ):\n        super().__init__()\n        self.num_buckets = num_buckets\n        self.max_distance = max_distance\n        self.relative_attention_bias = nn.Embedding(num_buckets, heads)\n\n    @staticmethod\n    def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):\n        ret = 0\n        n = -relative_position\n\n        num_buckets //= 2\n        ret += (n < 0).long() * num_buckets\n        n = torch.abs(n)\n\n        max_exact = num_buckets // 2\n        is_small = n < max_exact\n\n        val_if_large = max_exact + (\n                torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)\n        ).long()\n        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))\n\n        ret += torch.where(is_small, n, val_if_large)\n        return ret\n\n    def forward(self, n, device):\n        q_pos = torch.arange(n, dtype=torch.long, device=device)\n        k_pos = torch.arange(n, dtype=torch.long, device=device)\n        rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')\n        rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets,\n                                                   max_distance=self.max_distance)\n        # mask = -(((rel_pos > 35) + (rel_pos < -35)) * (1e8)) # -(((rp_bucket ==15) + (rp_bucket >= 30)) * (1e8))\n        values = self.relative_attention_bias(rp_bucket)\n        return rearrange(values, 'i j h -> h i j') # + mask\n\n\n\n# small helper modules\n\nclass EMA():\n    def __init__(self, beta):\n        super().__init__()\n        self.beta = beta\n\n    def update_model_average(self, ma_model, current_model):\n        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):\n            old_weight, up_weight = ma_params.data, current_params.data\n            ma_params.data = self.update_average(old_weight, up_weight)\n\n    def update_average(self, old, new):\n        if old is None:\n            return new\n        return old * self.beta + (1 - self.beta) * new\n\n\nclass Residual(nn.Module):\n    def __init__(self, fn):\n        super().__init__()\n        self.fn = fn\n\n    def forward(self, x, *args, **kwargs):\n        return self.fn(x, *args, **kwargs) + x\n\n\nclass SinusoidalPosEmb(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        self.dim = dim\n\n    def forward(self, x):\n        device = x.device\n        half_dim = self.dim // 2\n        emb = math.log(10000) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)\n        emb = x[:, None] * emb[None, :]\n        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)\n        return emb\n\n\ndef Upsample(dim, use_deconv=True, padding_mode=\"reflect\"):\n    if use_deconv:\n        return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))\n    else:\n        return nn.Sequential(\n            nn.Upsample(scale_factor=(1, 2, 2), mode='nearest'),\n            nn.Conv3d(dim, dim, (1, 3, 3), (1, 1, 1), (0, 1, 1), padding_mode=padding_mode)\n        )\n\n\ndef Downsample(dim):\n    return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))\n\n\nclass LayerNorm(nn.Module):\n    def __init__(self, dim, eps=1e-5):\n        super().__init__()\n        self.eps = eps\n        self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1))\n\n    def forward(self, x):\n        var = torch.var(x, dim=1, unbiased=False, keepdim=True)\n        mean = torch.mean(x, dim=1, keepdim=True)\n        return (x - mean) / (var + self.eps).sqrt() * self.gamma\n\nclass LayerNorm_img(nn.Module):\n    def __init__(self, dim, stable = False):\n        super().__init__()\n        self.stable = stable\n        self.g = nn.Parameter(torch.ones(dim))\n\n    def forward(self, x):\n        if self.stable:\n            x = x / x.amax(dim = -1, keepdim = True).detach()\n\n        eps = 1e-5 if x.dtype == torch.float32 else 1e-3\n        var = torch.var(x, dim = -1, unbiased = False, keepdim = True)\n        mean = torch.mean(x, dim = -1, keepdim = True)\n        return (x - mean) * (var + eps).rsqrt() * self.g\n\nclass PreNorm(nn.Module):\n    def __init__(self, dim, fn):\n        super().__init__()\n        self.fn = fn\n        self.norm = LayerNorm(dim)\n\n    def forward(self, x, **kwargs):\n        x = self.norm(x)\n        return self.fn(x, **kwargs)\n\nclass Identity(nn.Module):\n    def __init__(self, *args, **kwargs):\n        super().__init__()\n\n    def forward(self, x, *args, **kwargs):\n        return x\n\ndef l2norm(t):\n    return F.normalize(t, dim = -1)\n# building block modules\n\nclass Block(nn.Module):\n    def __init__(self, dim, dim_out, groups=8):\n        super().__init__()\n        self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1))\n        self.norm = nn.GroupNorm(groups, dim_out)\n        self.act = nn.SiLU()\n\n    def forward(self, x, time_scale_shift=None, audio_scale_shift=None):\n        x = self.proj(x)\n        x = self.norm(x)\n\n        if exists(time_scale_shift):\n            time_scale, time_shift = time_scale_shift\n            x = x * (time_scale + 1) + time_shift\n        \n        # added by lml to change the control method of audio embedding, inspired by diffusedhead\n        # if exists(audio_scale_shift):\n        #     # audio_scale and audio_shift:(bs, 64, nf, 1, 1) \n        #     # x:(bs, 64, nf, 32, 32)\n        #     audio_scale, audio_shift = audio_scale_shift\n        #     x = x * (audio_scale + 1) + audio_shift\n\n        return self.act(x)\n\n\nclass ResnetBlock(nn.Module):\n    def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, groups=8):\n        super().__init__()\n        self.time_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(time_emb_dim, dim_out * 2)\n        ) if exists(time_emb_dim) else None\n\n        self.audio_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(audio_emb_dim, dim_out * 2)\n        ) if exists(audio_emb_dim) else None\n\n        self.block1 = Block(dim, dim_out, groups=groups)\n        self.block2 = Block(dim_out, dim_out, groups=groups)\n        self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n\n    def forward(self, x, time_emb=None, audio_emb=None):\n        time_scale_shift = None\n        audio_scale_shift = None\n        if exists(self.time_mlp):\n            assert exists(time_emb), 'time emb must be passed in'\n            time_emb = self.time_mlp(time_emb)\n            time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1 \n            time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1 \n\n        # added by lml to get audio embedding\n        if exists(self.audio_mlp):\n            assert exists(audio_emb), 'audio emb must be passed in'\n            audio_emb = self.audio_mlp(audio_emb)\n            audio_emb = rearrange(audio_emb, 'b n c -> b c n 1 1') # bs, 128, nf, 1, 1 \n            audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1\n\n        h = self.block1(x, time_scale_shift=time_scale_shift, audio_scale_shift=audio_scale_shift)\n\n        h = self.block2(h)\n        return h + self.res_conv(x)\n\nclass ResnetBlock_ca(nn.Module):\n    def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, groups=8):\n        super().__init__()\n        self.time_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(time_emb_dim, dim_out * 2)\n        ) if exists(time_emb_dim) else None\n\n        self.audio_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(audio_emb_dim, dim_out * 2)\n        ) if exists(audio_emb_dim) else None\n\n        # self.audio_mlp_2 = nn.Sequential(\n        #     nn.SiLU(),\n        #     nn.Linear(dim_out, dim_out * 2)\n        # ) if exists(audio_emb_dim) else None\n\n        attn_klass = CrossAttention\n\n        self.cross_attn = attn_klass(\n                dim = dim,\n                context_dim = dim_out * 2\n            )\n\n\n        self.block1 = Block(dim, dim_out, groups=groups)\n        self.block2 = Block(dim_out, dim_out, groups=groups)\n        self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n\n    def forward(self, x, time_emb=None, audio_emb=None):\n        time_scale_shift = None\n        audio_scale_shift = None\n        b, c, f, H, W = x.size()\n        if exists(self.time_mlp):\n            assert exists(time_emb), 'time emb must be passed in'\n            time_emb = self.time_mlp(time_emb)\n            time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1 \n            time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1 \n\n        # added by lml to get audio embedding\n        if exists(self.audio_mlp):\n            assert exists(audio_emb), 'audio emb must be passed in'\n            audio_emb = self.audio_mlp(audio_emb)\n\n            if exists(self.cross_attn):\n                # h = rearrange(x, 'b c f ... -> (b f) ... c')\n                # h, ps = pack([h], 'b * c')\n                # audio_emb = rearrange(audio_emb, 'b f ... -> (b f) ...')\n                # audio_emb = self.cross_attn(h, context = audio_emb)\n\n                # # h, = unpack(h, ps, 'b * c')\n                # # h = rearrange(h, '(b f) ... c -> b c f ...', b = b, f = f, c = c)\n                # # audio_emb = self.audio_mlp_2(audio_emb)\n                # audio_emb = rearrange(audio_emb, '(b f) ... -> b f ...', b = b, f = f)\n                assert exists(audio_emb)\n                h = rearrange(x, 'b c f ... -> (b f) ... c')\n                # h = rearrange(x, 'b c ... -> b ... c')\n                h, ps = pack([h], 'b * c')\n\n                h = self.cross_attn(h, context = audio_emb) + h\n\n                h, = unpack(h, ps, 'b * c')\n                # h = rearrange(h, 'b ... c -> b c ...')\n                h = rearrange(h, '(b f) ... c -> b f c ...', b = b, f = f)\n\n            # audio_emb = rearrange(audio_emb, 'b f (h w) c -> b c f h w', w = W, h = H) # bs, 128, nf, 1, 1 \n            # audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1\n\n        h = self.block1(x, time_scale_shift=time_scale_shift)\n\n        h = self.block2(h)\n        return h + self.res_conv(x)\n\nclass ResnetBlock_ca_mul(nn.Module):\n    def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, pose_emb_dim=None, eye_emb_dim=None, groups=8):\n        super().__init__()\n        self.time_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(time_emb_dim, dim_out * 2)\n        ) if exists(time_emb_dim) else None\n\n        self.audio_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(audio_emb_dim, dim_out * 2)\n        ) if exists(audio_emb_dim) else None\n\n        self.pose_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(pose_emb_dim, dim_out * 2)\n        ) if exists(pose_emb_dim) else None\n\n        self.eye_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(eye_emb_dim, dim_out * 2)\n        ) if exists(eye_emb_dim) else None\n\n        self.audio_emb_dim = audio_emb_dim\n        self.pose_emb_dim = pose_emb_dim\n        self.eye_emb_dim = eye_emb_dim\n        # self.audio_mlp_2 = nn.Sequential(\n        #     nn.SiLU(),\n        #     nn.Linear(dim_out, dim_out * 2)\n        # ) if exists(audio_emb_dim) else None\n\n        attn_klass = CrossAttention\n\n        self.cross_attn_aud = attn_klass(\n                dim = dim,\n                context_dim = dim_out * 2,\n                out_dim = dim_out\n            )\n\n        self.cross_attn_pose = attn_klass(\n                dim = dim,\n                context_dim = dim_out * 2,\n                out_dim = dim_out\n            )\n        \n        self.cross_attn_eye = attn_klass(\n                dim = dim,\n                context_dim = dim_out * 2,\n                out_dim = dim_out\n            )\n\n\n        self.block1 = Block(dim, dim_out, groups=groups)\n        self.block2 = Block(dim_out, dim_out, groups=groups)\n        self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n\n    def forward(self, x, time_emb=None, audio_emb=None):\n        time_scale_shift = None\n        audio_scale_shift = None\n        '''\n            need seperate 3 diffiserent condition\n        '''\n        if exists(audio_emb):\n            pose_emb = audio_emb[:,:,self.audio_emb_dim:self.audio_emb_dim + self.pose_emb_dim]\n            eye_emb = audio_emb[:,:,self.audio_emb_dim + self.pose_emb_dim: ]\n            audio_emb = audio_emb[:,:,:self.audio_emb_dim]\n\n        b, c, f, H, W = x.size()\n        if exists(self.time_mlp):\n            assert exists(time_emb), 'time emb must be passed in'\n            time_emb = self.time_mlp(time_emb)\n            time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1 \n            time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1 \n\n        # added by lml to get audio embedding\n        if exists(self.audio_mlp):  # mouth lmk + audio emb\n            assert exists(audio_emb), 'audio emb must be passed in'\n            audio_emb = self.audio_mlp(audio_emb)\n            pose_emb = self.pose_mlp(pose_emb)  # TODO: embedding\n            eye_emb = self.eye_mlp(eye_emb)\n            if exists(self.cross_attn_aud):\n                # h = rearrange(x, 'b c f ... -> (b f) ... c')\n                # h, ps = pack([h], 'b * c')\n                # audio_emb = rearrange(audio_emb, 'b f ... -> (b f) ...')\n                # audio_emb = self.cross_attn(h, context = audio_emb)\n\n                # # h, = unpack(h, ps, 'b * c')\n                # # h = rearrange(h, '(b f) ... c -> b c f ...', b = b, f = f, c = c)\n                # # audio_emb = self.audio_mlp_2(audio_emb)\n                # audio_emb = rearrange(audio_emb, '(b f) ... -> b f ...', b = b, f = f)\n                assert exists(audio_emb)\n                h_cond = rearrange(x, 'b c f ... -> (b f) ... c')\n                # h = rearrange(x, 'b c ... -> b ... c')\n                h_cond, ps = pack([h_cond], 'b * c')\n\n\n                h_pose = self.cross_attn_pose(h_cond, context = pose_emb)\n                h_aud = self.cross_attn_aud(h_cond, context = audio_emb)\n                h_eye = self.cross_attn_eye(h_cond, context = eye_emb)\n\n                h_cond = h_pose + h_aud + h_eye\n\n\n                h_cond, = unpack(h_cond, ps, 'b * c')\n                # h = rearrange(h, 'b ... c -> b c ...')\n                h_cond = rearrange(h_cond, '(b f) ... c -> b c f ...', b = b, f = f)\n\n            # audio_emb = rearrange(audio_emb, 'b f (h w) c -> b c f h w', w = W, h = H) # bs, 128, nf, 1, 1 \n            # audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1\n\n        h = self.block1(x, time_scale_shift=time_scale_shift)\n\n        if exists(self.audio_mlp):\n            h = h_cond + h\n\n        h = self.block2(h)\n        return h + self.res_conv(x)\n\nclass CrossAttention(nn.Module):\n    def __init__(\n        self,\n        dim,\n        out_dim,\n        *,\n        context_dim = None,\n        dim_head = 8,\n        heads = 8,\n        norm_context = False,\n        scale = 8\n    ):\n        super().__init__()\n        self.scale = scale\n\n        self.heads = heads\n        inner_dim = dim_head * heads\n\n        context_dim = default(context_dim, dim)\n\n        self.norm = LayerNorm_img(dim)\n        self.norm_context = LayerNorm_img(context_dim) if norm_context else Identity()\n\n        self.null_kv = nn.Parameter(torch.randn(2, dim_head))\n        self.to_q = nn.Linear(dim, inner_dim, bias = False)\n        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)\n\n        self.q_scale = nn.Parameter(torch.ones(dim_head))\n        self.k_scale = nn.Parameter(torch.ones(dim_head))\n\n        self.to_out = nn.Sequential(\n            nn.Linear(inner_dim, out_dim, bias = False),\n            LayerNorm_img(out_dim)\n        )\n\n    def forward(self, x, context, mask = None):\n        b, n, device = *x.shape[:2], x.device\n\n        x = self.norm(x)  # bn * fn ?\n        # context: b, fn, c\n        context = rearrange(context, 'b f c -> (b f) c')\n        context = self.norm_context(context)\n\n        q, k, v = (self.to_q(x), *self.to_kv(context[:, None, :]).chunk(2, dim = -1))\n\n        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))\n\n        # add null key / value for classifier free guidance in prior net\n\n        nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads,  b = b), self.null_kv.unbind(dim = -2))\n\n        k = torch.cat((nk, k), dim = -2)\n        v = torch.cat((nv, v), dim = -2)\n\n        # cosine sim attention\n\n        q, k = map(l2norm, (q, k))\n        q = q * self.q_scale\n        k = k * self.k_scale\n\n        # similarities\n\n        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale\n\n        # masking\n\n        max_neg_value = -torch.finfo(sim.dtype).max\n\n        if exists(mask):\n            mask = F.pad(mask, (1, 0), value = True)\n            mask = rearrange(mask, 'b j -> b 1 1 j')\n            sim = sim.masked_fill(~mask, max_neg_value)\n\n        attn = sim.softmax(dim = -1, dtype = torch.float32)\n        attn = attn.to(sim.dtype)\n\n        out = einsum('b h i j, b h j d -> b h i d', attn, v)\n        out = rearrange(out, 'b h n d -> b n (h d)')\n        return self.to_out(out)\n\nclass LinearCrossAttention(CrossAttention):\n    def forward(self, x, context, mask = None):\n        b, n, device = *x.shape[:2], x.device\n\n        x = self.norm(x)\n        context = rearrange(context, 'b f c -> (b f) c')\n        context = self.norm_context(context)\n\n        q, k, v = (self.to_q(x), *self.to_kv(context[:, None, :]).chunk(2, dim = -1))\n\n        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))\n        # add null key / value for classifier free guidance in prior net\n\n        nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads,  b = b), self.null_kv.unbind(dim = -2))\n\n        k = torch.cat((nk, k), dim = -2)  # b * nf * h, 2, c//h\n        v = torch.cat((nv, v), dim = -2)\n\n        # masking\n\n        max_neg_value = -torch.finfo(x.dtype).max\n\n        if exists(mask):\n            mask = F.pad(mask, (1, 0), value = True)\n            mask = rearrange(mask, 'b n -> b 1 n')\n            k = k.masked_fill(~mask, max_neg_value)\n            v = v.masked_fill(~mask, 0.)\n\n        # linear attention\n\n        q = q.softmax(dim = -1) # # b * nf * h, 32*32, c//h,\n        k = k.softmax(dim = -2)\n\n        q = q * self.scale\n\n        context = einsum('b n d, b n e -> b d e', k, v) # b * nf * h, 2, c//h,  b * nf * h, 2, c//h\n        out = einsum('b n d, b d e -> b n e', q, context)\n        out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)\n        return self.to_out(out)\n\nclass SpatialLinearAttention(nn.Module):\n    def __init__(self, dim, heads=4, dim_head=32):\n        super().__init__()\n        self.scale = dim_head ** -0.5\n        self.heads = heads\n        hidden_dim = dim_head * heads\n        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)\n        self.to_out = nn.Conv2d(hidden_dim, dim, 1)\n\n    def forward(self, x):\n        b, c, f, h, w = x.shape\n        x = rearrange(x, 'b c f h w -> (b f) c h w')\n\n        qkv = self.to_qkv(x).chunk(3, dim=1)\n        q, k, v = rearrange_many(qkv, 'b (h c) x y -> b h c (x y)', h=self.heads)\n\n        q = q.softmax(dim=-2)\n        k = k.softmax(dim=-1)\n\n        q = q * self.scale\n        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)\n\n        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)\n        out = rearrange(out, 'b h c (x y) -> b (h c) x y', h=self.heads, x=h, y=w)\n        out = self.to_out(out)\n        return rearrange(out, '(b f) c h w -> b c f h w', b=b)\n\n\n# attention along space and time\n\nclass EinopsToAndFrom(nn.Module):\n    def __init__(self, from_einops, to_einops, fn):\n        super().__init__()\n        self.from_einops = from_einops\n        self.to_einops = to_einops\n        self.fn = fn\n\n    def forward(self, x, **kwargs):\n        shape = x.shape\n        reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape)))\n        x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')\n        x = self.fn(x, **kwargs)\n        x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(\n            self,\n            dim,\n            heads=4,\n            dim_head=32,\n            rotary_emb=None\n    ):\n        super().__init__()\n        self.scale = dim_head ** -0.5\n        self.heads = heads\n        hidden_dim = dim_head * heads\n\n        self.rotary_emb = rotary_emb\n        self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False)\n        self.to_out = nn.Linear(hidden_dim, dim, bias=False)\n\n    def forward(\n            self,\n            x,\n            pos_bias=None,\n            focus_present_mask=None\n    ):  # temperal: 'b (h w) f c'  ; spatial :  'b f (h w) c'\n        n, device = x.shape[-2], x.device\n\n        qkv = self.to_qkv(x).chunk(3, dim=-1)\n\n        if exists(focus_present_mask) and focus_present_mask.all():\n            # if all batch samples are focusing on present\n            # it would be equivalent to passing that token's values through to the output\n            values = qkv[-1]\n            return self.to_out(values)\n\n        # split out heads\n\n        q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads)\n\n        # scale\n\n        q = q * self.scale\n\n        # rotate positions into queries and keys for time attention\n\n        if exists(self.rotary_emb):\n            q = self.rotary_emb.rotate_queries_or_keys(q)\n            k = self.rotary_emb.rotate_queries_or_keys(k)\n\n        # similarity\n\n        sim = einsum('... h i d, ... h j d -> ... h i j', q, k)\n\n        # relative positional bias\n\n        if exists(pos_bias):\n            sim = sim + pos_bias\n\n        if exists(focus_present_mask) and not (~focus_present_mask).all():\n            attend_all_mask = torch.ones((n, n), device=device, dtype=torch.bool)\n            attend_self_mask = torch.eye(n, device=device, dtype=torch.bool)\n\n            mask = torch.where(\n                rearrange(focus_present_mask, 'b -> b 1 1 1 1'),\n                rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),\n                rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),\n            )\n\n            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)\n\n        # numerical stability\n\n        sim = sim - sim.amax(dim=-1, keepdim=True).detach()\n        attn = sim.softmax(dim=-1)\n\n        # aggregate values\n\n        out = einsum('... h i j, ... h j d -> ... h i d', attn, v)\n        out = rearrange(out, '... h n d -> ... n (h d)')\n        return self.to_out(out)\n\n# model\nclass Unet3D(nn.Module):\n    def __init__(\n            self,\n            dim,\n            cond_aud=1024,\n            cond_pose=7,\n            cond_eye=2,\n            cond_dim=None,\n            out_grid_dim=2,\n            out_conf_dim=1,\n            num_frames=40,\n            dim_mults=(1, 2, 4, 8),\n            channels=3,\n            attn_heads=8,\n            attn_dim_head=32,\n            use_hubert_audio_cond=False,\n            init_dim=None,\n            init_kernel_size=7,\n            use_sparse_linear_attn=True,\n            resnet_groups=8,\n            use_final_activation=False,\n            learn_null_cond=False,\n            use_deconv=True,\n            padding_mode=\"zeros\",\n    ):\n        super().__init__()\n        self.null_cond_mask = None\n        self.channels = channels\n        self.num_frames = num_frames\n        self.HUBERT_MODEL_DIM = 1024\n        # temporal attention and its relative positional encoding\n\n        rotary_emb = RotaryEmbedding(min(32, attn_dim_head))\n\n        temporal_attn = lambda dim: EinopsToAndFrom('b c f h w', 'b (h w) f c',\n                                                    Attention(dim, heads=attn_heads, dim_head=attn_dim_head,\n                                                              rotary_emb=rotary_emb))\n\n        self.time_rel_pos_bias = RelativePositionBias(heads=attn_heads,\n                                                      max_distance=32)  # realistically will not be able to generate that many frames of video... yet\n\n        # initial conv\n\n        init_dim = default(init_dim, dim)\n        assert is_odd(init_kernel_size)\n\n        init_padding = init_kernel_size // 2\n        self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size, init_kernel_size),\n                                   padding=(0, init_padding, init_padding))\n\n        self.init_temporal_attn = Residual(PreNorm(init_dim, temporal_attn(init_dim)))\n\n        # dimensions\n\n        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]\n        in_out = list(zip(dims[:-1], dims[1:]))\n\n        # time conditioning\n\n        time_dim = dim * 4\n        self.time_mlp = nn.Sequential(\n            SinusoidalPosEmb(dim),\n            nn.Linear(dim, time_dim),\n            nn.GELU(),\n            nn.Linear(time_dim, time_dim)\n        )\n\n        # audio conditioning\n\n        self.has_cond = exists(cond_dim) or use_hubert_audio_cond\n        self.cond_dim = cond_dim\n        self.cond_aud_dim = cond_aud\n        self.cond_pose_dim = cond_pose\n        self.cond_eye_dim = cond_eye\n\n        # modified by lml\n        self.learn_null_cond = learn_null_cond\n\n\n        # cat(t,cond) is not suitable\n        # cond_dim = time_dim + int(cond_dim or 0)\n\n        # layers\n\n        self.downs = nn.ModuleList([])\n        self.ups = nn.ModuleList([])\n\n        num_resolutions = len(in_out)\n\n        # block type\n\n        block_klass = partial(ResnetBlock_ca_mul, groups=resnet_groups)\n        block_klass_cond = partial(block_klass, time_emb_dim=time_dim, audio_emb_dim=self.cond_aud_dim, pose_emb_dim=self.cond_pose_dim, eye_emb_dim=self.cond_eye_dim)\n        # block_klass_cond = partial(block_klass, time_emb_dim=cond_dim) # cat embedding\n\n        # modules for all layers\n\n        for ind, (dim_in, dim_out) in enumerate(in_out):\n            is_last = ind >= (num_resolutions - 1)\n\n            self.downs.append(nn.ModuleList([\n                block_klass_cond(dim_in, dim_out),\n                block_klass_cond(dim_out, dim_out),\n                Residual(PreNorm(dim_out, SpatialLinearAttention(dim_out,\n                                                                 heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(),\n                Residual(PreNorm(dim_out, temporal_attn(dim_out))),\n                Downsample(dim_out) if not is_last else nn.Identity()\n            ]))\n\n        mid_dim = dims[-1]\n        self.mid_block1 = block_klass_cond(mid_dim, mid_dim)\n\n        spatial_attn = EinopsToAndFrom('b c f h w', 'b f (h w) c', Attention(mid_dim, heads=attn_heads))\n\n        self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn))\n        self.mid_temporal_attn = Residual(PreNorm(mid_dim, temporal_attn(mid_dim)))\n\n        self.mid_block2 = block_klass_cond(mid_dim, mid_dim)\n\n        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):\n            is_last = ind >= (num_resolutions - 1)\n\n            self.ups.append(nn.ModuleList([\n                block_klass_cond(dim_out * 2, dim_in),\n                block_klass_cond(dim_in, dim_in),\n                Residual(PreNorm(dim_in, SpatialLinearAttention(dim_in,\n                                                                heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(),\n                Residual(PreNorm(dim_in, temporal_attn(dim_in))),\n                Upsample(dim_in, use_deconv, padding_mode) if not is_last else nn.Identity()\n            ]))\n\n        # out_dim = default(out_grid_dim, channels)\n        self.final_conv = nn.Sequential(\n            block_klass(dim * 2, dim),\n            nn.Conv3d(dim, out_grid_dim, 1)\n        )\n\n        # added by nhm\n        self.use_final_activation = use_final_activation\n        if self.use_final_activation:\n            self.final_activation = nn.Tanh()\n        else:\n            self.final_activation = nn.Identity()\n\n        # added by nhm for predicting occlusion mask\n        self.occlusion_map = nn.Sequential(\n            block_klass(dim * 2, dim),\n            nn.Conv3d(dim, out_conf_dim, 1)\n        )\n\n    def forward_with_cond_scale(\n            self,\n            *args,\n            cond_scale=2.,\n            **kwargs\n    ):\n        logits = self.forward(*args, null_cond_prob=0., **kwargs)\n        if cond_scale == 1 or not self.has_cond:\n            return logits\n\n        null_logits = self.forward(*args, null_cond_prob=1., **kwargs)\n        return null_logits + (logits - null_logits) * cond_scale\n\n    def forward(\n            self,\n            x,\n            time,\n            cond=None,\n            null_cond_prob=0.,\n            focus_present_mask=None,\n            prob_focus_present=0.\n            # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time)\n    ):\n        assert not (self.has_cond and not exists(cond)), 'cond must be passed in if cond_dim specified'\n        batch, device = x.shape[0], x.device\n\n        focus_present_mask = default(focus_present_mask,\n                                     lambda: prob_mask_like((batch,), prob_focus_present, device=device))\n\n        time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device)\n\n        x = self.init_conv(x)\n        r = x.clone()\n\n        x = self.init_temporal_attn(x, pos_bias=time_rel_pos_bias)\n\n        t = self.time_mlp(time) if exists(self.time_mlp) else None\n\n        if self.learn_null_cond:\n            self.null_cond_emb = nn.Parameter(torch.randn(1, self.num_frames, self.cond_dim)) if self.has_cond else None\n        else:\n            self.null_cond_emb = torch.zeros(1, self.num_frames, self.cond_dim) if self.has_cond else None\n        # classifier free guidance\n\n        if self.has_cond:\n            batch, device = x.shape[0], x.device\n            self.null_cond_mask = prob_mask_like((batch, self.num_frames,), null_cond_prob, device=device)\n            cond = torch.where(rearrange(self.null_cond_mask, 'b n -> b n 1'), self.null_cond_emb.to(cond.device), cond) \n            # t (bs, 256)  cond (bs, nf*1024)->(bs, nf, 1024) in this version\n            \n            # it's the original cond embedding method used in LFDM\n            # t = torch.cat((t, cond), dim=-1)\n\n        h = []\n\n        for block1, block2, spatial_attn, temporal_attn, downsample in self.downs:\n            x = block1(x, t, cond)\n            x = block2(x, t, cond)\n            x = spatial_attn(x)\n            x = temporal_attn(x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask)\n            h.append(x)\n            x = downsample(x)\n\n        x = self.mid_block1(x, t, cond)\n        x = self.mid_spatial_attn(x)\n        x = self.mid_temporal_attn(x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask)\n        x = self.mid_block2(x, t, cond)\n\n        for block1, block2, spatial_attn, temporal_attn, upsample in self.ups:\n            x = torch.cat((x, h.pop()), dim=1)\n            x = block1(x, t, cond)\n            x = block2(x, t, cond)\n            x = spatial_attn(x)\n            x = temporal_attn(x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask)\n            x = upsample(x)\n\n        x = torch.cat((x, r), dim=1)\n        return torch.cat((self.final_conv(x), self.occlusion_map(x)), dim=1)\n\n# to dynamically change num_frames of Unet3D\nclass DynamicNfUnet3D(Unet3D):\n    def __init__(self, default_num_frames=20, *args, **kwargs):\n        super(DynamicNfUnet3D, self).__init__(*args, **kwargs)\n        self.default_num_frames = default_num_frames\n        self.num_frames = default_num_frames\n    def update_num_frames(self, new_num_frames):\n        self.num_frames = new_num_frames\n\n# gaussian diffusion trainer class\n\ndef extract(a, t, x_shape):\n    b, *_ = t.shape\n    out = a.gather(-1, t)\n    return out.reshape(b, *((1,) * (len(x_shape) - 1)))\n\n\ndef cosine_beta_schedule(timesteps, s=0.008):\n    \"\"\"\n    cosine schedule\n    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ\n    \"\"\"\n    steps = timesteps + 1\n    x = torch.linspace(0, timesteps, steps, dtype=torch.float64)\n    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2\n    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]\n    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])\n    return torch.clip(betas, 0, 0.9999)\n\n\nclass GaussianDiffusion(nn.Module):\n    def __init__(\n            self,\n            denoise_fn,\n            *,\n            image_size,\n            num_frames,\n            text_use_bert_cls=False,\n            channels=3,\n            timesteps=1000,\n            sampling_timesteps=250,\n            ddim_sampling_eta=1.,\n            loss_type='l1',\n            use_dynamic_thres=False,  # from the Imagen paper\n            dynamic_thres_percentile=0.9,\n            null_cond_prob=0.1\n    ):\n        super().__init__()\n        self.null_cond_prob = null_cond_prob\n        self.channels = channels\n        self.image_size = image_size\n        self.num_frames = num_frames\n        self.denoise_fn = denoise_fn\n\n        betas = cosine_beta_schedule(timesteps)\n\n        alphas = 1. - betas\n        alphas_cumprod = torch.cumprod(alphas, axis=0)\n        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)\n\n        timesteps, = betas.shape\n        self.num_timesteps = int(timesteps)\n        self.loss_type = loss_type\n\n        self.sampling_timesteps = default(sampling_timesteps,\n                                          timesteps)\n        self.is_ddim_sampling = self.sampling_timesteps < timesteps\n        self.ddim_sampling_eta = ddim_sampling_eta\n\n        # register buffer helper function that casts float64 to float32\n\n        register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))\n\n        register_buffer('betas', betas)\n        register_buffer('alphas_cumprod', alphas_cumprod)\n        register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n\n        register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))\n        register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))\n        register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))\n        register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))\n        register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))\n\n        # calculations for posterior q(x_{t-1} | x_t, x_0)\n\n        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)\n\n        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)\n\n        register_buffer('posterior_variance', posterior_variance)\n\n        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain\n\n        register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20)))\n        register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))\n        register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))\n\n        # text conditioning parameters\n\n        self.text_use_bert_cls = text_use_bert_cls\n\n        # dynamic thresholding when sampling\n\n        self.use_dynamic_thres = use_dynamic_thres\n        self.dynamic_thres_percentile = dynamic_thres_percentile\n\n    def q_mean_variance(self, x_start, t):\n        mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start\n        variance = extract(1. - self.alphas_cumprod, t, x_start.shape)\n        log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)\n        return mean, variance, log_variance\n\n    def predict_start_from_noise(self, x_t, t, noise):\n        return (\n                extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -\n                extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise\n        )\n\n    def q_posterior(self, x_start, x_t, t):\n        posterior_mean = (\n                extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +\n                extract(self.posterior_mean_coef2, t, x_t.shape) * x_t\n        )\n        posterior_variance = extract(self.posterior_variance, t, x_t.shape)\n        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)\n        return posterior_mean, posterior_variance, posterior_log_variance_clipped\n\n    def p_mean_variance(self, x, t, fea, clip_denoised: bool, cond=None, cond_scale=1.):\n        fea = fea.unsqueeze(dim=2).repeat(1, 1, x.size(2), 1, 1)\n        x_recon = self.predict_start_from_noise(x, t=t, noise=self.denoise_fn.forward_with_cond_scale(torch.cat([x, fea], dim=1),\n                                                                                                      t,\n                                                                                                      cond=cond,\n                                                                                                      cond_scale=cond_scale))\n\n        if clip_denoised:\n            s = 1.\n            if self.use_dynamic_thres:\n                s = torch.quantile(\n                    rearrange(x_recon, 'b ... -> b (...)').abs(),\n                    self.dynamic_thres_percentile,\n                    dim=-1\n                )\n\n                s.clamp_(min=1.)\n                s = s.view(-1, *((1,) * (x_recon.ndim - 1)))\n\n            # clip by threshold, depending on whether static or dynamic\n            x_recon = x_recon.clamp(-s, s) / s\n\n        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)\n        return model_mean, posterior_variance, posterior_log_variance\n\n    @torch.inference_mode()\n    def p_sample(self, x, t, fea, cond=None, cond_scale=1., clip_denoised=True):\n        b, *_, device = *x.shape, x.device\n        model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, fea=fea,\n                                                                 clip_denoised=clip_denoised, cond=cond,\n                                                                 cond_scale=cond_scale)\n        noise = torch.randn_like(x)\n        # no noise when t == 0\n        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))\n        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise\n\n    @torch.inference_mode()\n    def p_sample_loop(self, fea, shape, cond=None, cond_scale=1.):\n        device = self.betas.device\n\n        b = shape[0]\n        img = torch.randn(shape, device=device)\n\n        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):\n            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), fea, cond=cond,\n                                cond_scale=cond_scale)\n\n        return img\n        # return unnormalize_img(img)\n\n    @torch.inference_mode()\n    def sample(self, fea, bbox_mask, cond=None, cond_scale=1., batch_size=16):\n        # text bert: cond 1,768\n        # device = next(self.denoise_fn.parameters()).device\n        # if is_list_str(cond):\n        #     cond = torch.rand((1 ,768), dtype=torch.float32).cuda()  #used to debug\n            # cond = bert_embed(tokenize(cond), return_cls_repr=self.text_use_bert_cls).to(device)\n\n        batch_size = cond.shape[0] if exists(cond) else batch_size\n        # batch_size = 1 if exists(cond) else batch_size\n        image_size = self.image_size\n        channels = self.channels\n        num_frames = self.num_frames\n        sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample\n        fea = torch.cat([fea, bbox_mask], dim=1)\n        return sample_fn(fea, (batch_size, channels, num_frames, image_size, image_size), cond=cond,\n                         cond_scale=cond_scale)\n\n    # add by nhm\n    @torch.no_grad()\n    def ddim_sample(self, fea, shape, cond=None, cond_scale=1., clip_denoised=True):\n\n        batch, device, total_timesteps, sampling_timesteps, eta = \\\n            shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta\n\n        times = torch.linspace(0., total_timesteps, steps=sampling_timesteps + 2)[:-1]\n        times = list(reversed(times.int().tolist()))\n        time_pairs = list(zip(times[:-1], times[1:]))\n\n        img = torch.randn(shape, device=device) # bs, 3, nf, 32, 32\n        fea = fea.unsqueeze(dim=2).repeat(1, 1, img.size(2), 1, 1) #bs, 256, nf, 32, 32\n\n        for time, time_next in tqdm(time_pairs, desc='sampling loop time step'):\n            alpha = self.alphas_cumprod_prev[time]\n            alpha_next = self.alphas_cumprod_prev[time_next]\n\n            time_cond = torch.full((batch,), time, device=device, dtype=torch.long)\n\n            # pred_noise, x_start, *_ = self.model_predictions(img, time_cond, fea)\n            pred_noise = self.denoise_fn.forward_with_cond_scale(\n                torch.cat([img, fea], dim=1),\n                time_cond,\n                cond=cond,\n                cond_scale=cond_scale)\n            x_start = self.predict_start_from_noise(img, t=time_cond, noise=pred_noise)\n\n            if clip_denoised:\n                s = 1.\n                if self.use_dynamic_thres:\n                    s = torch.quantile(\n                        rearrange(x_start, 'b ... -> b (...)').abs(),\n                        self.dynamic_thres_percentile,\n                        dim=-1\n                    )\n\n                    s.clamp_(min=1.)\n                    s = s.view(-1, *((1,) * (x_start.ndim - 1)))\n\n                # clip by threshold, depending on whether static or dynamic\n                x_start = x_start.clamp(-s, s) / s\n\n            sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()\n            c = ((1 - alpha_next) - sigma ** 2).sqrt()\n\n            noise = torch.randn_like(img) if time_next > 0 else 0.\n\n            img = x_start * alpha_next.sqrt() + \\\n                  c * pred_noise + \\\n                  sigma * noise\n\n        # img = unnormalize_to_zero_to_one(img)\n        return img\n\n    @torch.inference_mode()\n    def interpolate(self, x1, x2, t=None, lam=0.5):\n        b, *_, device = *x1.shape, x1.device\n        t = default(t, self.num_timesteps - 1)\n\n        assert x1.shape == x2.shape\n\n        t_batched = torch.stack([torch.tensor(t, device=device)] * b)\n        xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))\n\n        img = (1 - lam) * xt1 + lam * xt2\n        for i in tqdm(reversed(range(0, t)), desc='interpolation sample time step', total=t):\n            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))\n\n        return img\n\n    def q_sample(self, x_start, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x_start))\n\n        return (\n                extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +\n                extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise\n        )\n\n    def p_losses(self, x_start, t, fea, bbox_mask, cond=None, noise=None, clip_denoised=True, **kwargs):\n        # x_start: bs, 3, num_frame, 32, 32\n        # t: bs\n        # fea: bs, 256, num_frame, 32, 32\n        # cond: bs, 768\n        b, c, f, h, w, device = *x_start.shape, x_start.device\n        noise = default(noise, lambda: torch.randn_like(x_start)) # bs, 3, nf, 32, 32\n\n        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)# bs, 3, nf, 32, 32\n\n        pred_noise = self.denoise_fn.forward(torch.cat([x_noisy, fea, bbox_mask], dim=1), t, cond=cond,\n                                             null_cond_prob=self.null_cond_prob,\n                                             **kwargs)\n\n        if self.loss_type == 'l1':\n            loss = F.l1_loss(noise, pred_noise, reduce=False)\n        elif self.loss_type == 'l2':\n            loss = F.mse_loss(noise, pred_noise, reduce=False)\n        else:\n            raise NotImplementedError()\n  \n        pred_x0 = self.predict_start_from_noise(x_noisy, t, pred_noise)\n\n        if clip_denoised:\n            s = 1.\n            if self.use_dynamic_thres:\n                s = torch.quantile(\n                    rearrange(pred_x0, 'b ... -> b (...)').abs(),\n                    self.dynamic_thres_percentile,\n                    dim=-1\n                )\n\n                s.clamp_(min=1.)\n                s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))\n\n            # clip by threshold, depending on whether static or dynamic\n            self.pred_x0 = pred_x0.clamp(-s, s) / s\n\n        return loss, self.denoise_fn.null_cond_mask\n\n    def forward(self, x, fea, bbox_mask, cond, *args, **kwargs):\n        b, device, img_size, = x.shape[0], x.device, self.image_size\n        # check_shape(x, 'b c f h w', c=self.channels, f=self.num_frames, h=img_size, w=img_size)\n        t = torch.randint(0, self.num_timesteps, (b,), device=device).long()\n        fea = fea.unsqueeze(dim=2).repeat(1, 1, x.size(2), 1, 1)\n        bbox_mask = bbox_mask.unsqueeze(dim=2).repeat(1, 1, x.size(2), 1, 1)\n\n        return self.p_losses(x, t, fea, bbox_mask, cond, *args, **kwargs)\n\n\n# trainer class\n\nCHANNELS_TO_MODE = {\n    1: 'L',\n    3: 'RGB',\n    4: 'RGBA'\n}\n\n\ndef seek_all_images(img, channels=3):\n    assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid'\n    mode = CHANNELS_TO_MODE[channels]\n\n    i = 0\n    while True:\n        try:\n            img.seek(i)\n            yield img.convert(mode)\n        except EOFError:\n            break\n        i += 1\n\n# to dynamically change num_frames of GaussianDiffusion\nclass DynamicNfGaussianDiffusion(GaussianDiffusion):\n    def __init__(self, default_num_frames=20, *args, **kwargs):\n        super(DynamicNfGaussianDiffusion, self).__init__(*args, **kwargs)\n        self.default_num_frames = default_num_frames\n        self.num_frames = default_num_frames\n    def update_num_frames(self, new_num_frames):\n        self.num_frames = new_num_frames\n\n# tensor of shape (channels, frames, height, width) -> gif\n\ndef video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True):\n    images = map(T.ToPILImage(), tensor.unbind(dim=1))\n    first_img, *rest_imgs = images\n    first_img.save(path, save_all=True, append_images=rest_imgs, duration=duration, loop=loop, optimize=optimize)\n    return images\n\n\n# gif -> (channels, frame, height, width) tensor\n\ndef gif_to_tensor(path, channels=3, transform=T.ToTensor()):\n    img = Image.open(path)\n    tensors = tuple(map(transform, seek_all_images(img, channels=channels)))\n    return torch.stack(tensors, dim=1)\n\n\ndef identity(t, *args, **kwargs):\n    return t\n\n\ndef normalize_img(t):\n    return t * 2 - 1\n\n\n# def unnormalize_img(t):\n#     return (t + 1) * 0.5\n\n\ndef cast_num_frames(t, *, frames):\n    f = t.shape[1]\n\n    if f == frames:\n        return t\n\n    if f > frames:\n        return t[:, :frames]\n\n    return F.pad(t, (0, 0, 0, 0, 0, frames - f))\n\n\n"
  },
  {
    "path": "DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test.py",
    "content": "'''\nadding pose condtioning on baseline\nusing cross attention to add different condition\n\nusing local attention, for inference, faster cost more ram\n'''\nimport math\nimport torch\nfrom torch import nn, einsum\nimport torch.nn.functional as F\nfrom functools import partial\n\nfrom torchvision import transforms as T\nfrom PIL import Image\n\nfrom tqdm import tqdm\nfrom einops import rearrange, repeat, reduce, pack, unpack\nfrom einops_exts import rearrange_many\n\nfrom rotary_embedding_torch import RotaryEmbedding\n\n# from DM.modules.text import tokenize, bert_embed, HUBERT_MODEL_DIM\n\n\n# helpers functions\n\ndef exists(x):\n    return x is not None\n\n\ndef noop(*args, **kwargs):\n    pass\n\n\ndef is_odd(n):\n    return (n % 2) == 1\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if callable(d) else d\n\n\ndef cycle(dl):\n    while True:\n        for data in dl:\n            yield data\n\n\ndef num_to_groups(num, divisor):\n    groups = num // divisor\n    remainder = num % divisor\n    arr = [divisor] * groups\n    if remainder > 0:\n        arr.append(remainder)\n    return arr\n\n\ndef prob_mask_like(shape, prob, device):\n    if prob == 1:\n        return torch.ones(shape, device=device, dtype=torch.bool)\n    elif prob == 0:\n        return torch.zeros(shape, device=device, dtype=torch.bool)\n    else:\n        return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob\n\n\ndef is_list_str(x):\n    if not isinstance(x, (list, tuple)):\n        return False\n    return all([type(el) == str for el in x])\n\n\n# relative positional bias\n\nclass RelativePositionBias(nn.Module):\n    def __init__(\n            self,\n            heads=8,\n            num_buckets=32,\n            max_distance=128,\n            window_width = 20\n    ):\n        super().__init__()\n        self.num_buckets = num_buckets\n        self.max_distance = max_distance\n        self.relative_attention_bias = nn.Embedding(num_buckets, heads)\n        self.window_width = window_width\n\n    @staticmethod\n    def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):\n        ret = 0\n        n = -relative_position\n\n        num_buckets //= 2\n        ret += (n < 0).long() * num_buckets\n        n = torch.abs(n)\n\n        max_exact = num_buckets // 2\n        is_small = n < max_exact\n\n        val_if_large = max_exact + (\n                torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)\n        ).long()\n        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))\n\n        ret += torch.where(is_small, n, val_if_large)\n        return ret\n\n    def forward(self, n, device):\n        q_pos = torch.arange(n, dtype=torch.long, device=device)\n        k_pos = torch.arange(n, dtype=torch.long, device=device)\n        rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')\n        rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets,\n                                                   max_distance=self.max_distance)\n        mask = -(((rel_pos > self.window_width) + (rel_pos < -self.window_width)) * (1e8)) # -(((rp_bucket ==15) + (rp_bucket >= 30)) * (1e8))\n        values = self.relative_attention_bias(rp_bucket)\n        return rearrange(values, 'i j h -> h i j') + mask\n\n\n\n# small helper modules\n\nclass EMA():\n    def __init__(self, beta):\n        super().__init__()\n        self.beta = beta\n\n    def update_model_average(self, ma_model, current_model):\n        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):\n            old_weight, up_weight = ma_params.data, current_params.data\n            ma_params.data = self.update_average(old_weight, up_weight)\n\n    def update_average(self, old, new):\n        if old is None:\n            return new\n        return old * self.beta + (1 - self.beta) * new\n\n\nclass Residual(nn.Module):\n    def __init__(self, fn):\n        super().__init__()\n        self.fn = fn\n\n    def forward(self, x, *args, **kwargs):\n        return self.fn(x, *args, **kwargs) + x\n\n\nclass SinusoidalPosEmb(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        self.dim = dim\n\n    def forward(self, x):\n        device = x.device\n        half_dim = self.dim // 2\n        emb = math.log(10000) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)\n        emb = x[:, None] * emb[None, :]\n        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)\n        return emb\n\n\ndef Upsample(dim, use_deconv=True, padding_mode=\"reflect\"):\n    if use_deconv:\n        return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))\n    else:\n        return nn.Sequential(\n            nn.Upsample(scale_factor=(1, 2, 2), mode='nearest'),\n            nn.Conv3d(dim, dim, (1, 3, 3), (1, 1, 1), (0, 1, 1), padding_mode=padding_mode)\n        )\n\n\ndef Downsample(dim):\n    return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))\n\n\nclass LayerNorm(nn.Module):\n    def __init__(self, dim, eps=1e-5):\n        super().__init__()\n        self.eps = eps\n        self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1))\n\n    def forward(self, x):\n        var = torch.var(x, dim=1, unbiased=False, keepdim=True)\n        mean = torch.mean(x, dim=1, keepdim=True)\n        return (x - mean) / (var + self.eps).sqrt() * self.gamma\n\nclass LayerNorm_img(nn.Module):\n    def __init__(self, dim, stable = False):\n        super().__init__()\n        self.stable = stable\n        self.g = nn.Parameter(torch.ones(dim))\n\n    def forward(self, x):\n        if self.stable:\n            x = x / x.amax(dim = -1, keepdim = True).detach()\n\n        eps = 1e-5 if x.dtype == torch.float32 else 1e-3\n        var = torch.var(x, dim = -1, unbiased = False, keepdim = True)\n        mean = torch.mean(x, dim = -1, keepdim = True)\n        return (x - mean) * (var + eps).rsqrt() * self.g\n\nclass PreNorm(nn.Module):\n    def __init__(self, dim, fn):\n        super().__init__()\n        self.fn = fn\n        self.norm = LayerNorm(dim)\n\n    def forward(self, x, **kwargs):\n        x = self.norm(x)\n        return self.fn(x, **kwargs)\n\nclass Identity(nn.Module):\n    def __init__(self, *args, **kwargs):\n        super().__init__()\n\n    def forward(self, x, *args, **kwargs):\n        return x\n\ndef l2norm(t):\n    return F.normalize(t, dim = -1)\n# building block modules\n\nclass Block(nn.Module):\n    def __init__(self, dim, dim_out, groups=8):\n        super().__init__()\n        self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1))\n        self.norm = nn.GroupNorm(groups, dim_out)\n        self.act = nn.SiLU()\n\n    def forward(self, x, time_scale_shift=None, audio_scale_shift=None):\n        x = self.proj(x)\n        x = self.norm(x)\n\n        if exists(time_scale_shift):\n            time_scale, time_shift = time_scale_shift\n            x = x * (time_scale + 1) + time_shift\n        \n        # added by lml to change the control method of audio embedding, inspired by diffusedhead\n        # if exists(audio_scale_shift):\n        #     # audio_scale and audio_shift:(bs, 64, nf, 1, 1) \n        #     # x:(bs, 64, nf, 32, 32)\n        #     audio_scale, audio_shift = audio_scale_shift\n        #     x = x * (audio_scale + 1) + audio_shift\n\n        return self.act(x)\n\n\nclass ResnetBlock(nn.Module):\n    def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, groups=8):\n        super().__init__()\n        self.time_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(time_emb_dim, dim_out * 2)\n        ) if exists(time_emb_dim) else None\n\n        self.audio_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(audio_emb_dim, dim_out * 2)\n        ) if exists(audio_emb_dim) else None\n\n        self.block1 = Block(dim, dim_out, groups=groups)\n        self.block2 = Block(dim_out, dim_out, groups=groups)\n        self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n\n    def forward(self, x, time_emb=None, audio_emb=None):\n        time_scale_shift = None\n        audio_scale_shift = None\n        if exists(self.time_mlp):\n            assert exists(time_emb), 'time emb must be passed in'\n            time_emb = self.time_mlp(time_emb)\n            time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1 \n            time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1 \n\n        # added by lml to get audio embedding\n        if exists(self.audio_mlp):\n            assert exists(audio_emb), 'audio emb must be passed in'\n            audio_emb = self.audio_mlp(audio_emb)\n            audio_emb = rearrange(audio_emb, 'b n c -> b c n 1 1') # bs, 128, nf, 1, 1 \n            audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1\n\n        h = self.block1(x, time_scale_shift=time_scale_shift, audio_scale_shift=audio_scale_shift)\n\n        h = self.block2(h)\n        return h + self.res_conv(x)\n\nclass ResnetBlock_ca(nn.Module):\n    def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, groups=8):\n        super().__init__()\n        self.time_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(time_emb_dim, dim_out * 2)\n        ) if exists(time_emb_dim) else None\n\n        self.audio_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(audio_emb_dim, dim_out * 2)\n        ) if exists(audio_emb_dim) else None\n\n        # self.audio_mlp_2 = nn.Sequential(\n        #     nn.SiLU(),\n        #     nn.Linear(dim_out, dim_out * 2)\n        # ) if exists(audio_emb_dim) else None\n\n        attn_klass = CrossAttention\n\n        self.cross_attn = attn_klass(\n                dim = dim,\n                context_dim = dim_out * 2\n            )\n\n\n        self.block1 = Block(dim, dim_out, groups=groups)\n        self.block2 = Block(dim_out, dim_out, groups=groups)\n        self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n\n    def forward(self, x, time_emb=None, audio_emb=None):\n        time_scale_shift = None\n        audio_scale_shift = None\n        b, c, f, H, W = x.size()\n        if exists(self.time_mlp):\n            assert exists(time_emb), 'time emb must be passed in'\n            time_emb = self.time_mlp(time_emb)\n            time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1 \n            time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1 \n\n        # added by lml to get audio embedding\n        if exists(self.audio_mlp):\n            assert exists(audio_emb), 'audio emb must be passed in'\n            audio_emb = self.audio_mlp(audio_emb)\n\n            if exists(self.cross_attn):\n                # h = rearrange(x, 'b c f ... -> (b f) ... c')\n                # h, ps = pack([h], 'b * c')\n                # audio_emb = rearrange(audio_emb, 'b f ... -> (b f) ...')\n                # audio_emb = self.cross_attn(h, context = audio_emb)\n\n                # # h, = unpack(h, ps, 'b * c')\n                # # h = rearrange(h, '(b f) ... c -> b c f ...', b = b, f = f, c = c)\n                # # audio_emb = self.audio_mlp_2(audio_emb)\n                # audio_emb = rearrange(audio_emb, '(b f) ... -> b f ...', b = b, f = f)\n                assert exists(audio_emb)\n                h = rearrange(x, 'b c f ... -> (b f) ... c')\n                # h = rearrange(x, 'b c ... -> b ... c')\n                h, ps = pack([h], 'b * c')\n\n                h = self.cross_attn(h, context = audio_emb) + h\n\n                h, = unpack(h, ps, 'b * c')\n                # h = rearrange(h, 'b ... c -> b c ...')\n                h = rearrange(h, '(b f) ... c -> b f c ...', b = b, f = f)\n\n            # audio_emb = rearrange(audio_emb, 'b f (h w) c -> b c f h w', w = W, h = H) # bs, 128, nf, 1, 1 \n            # audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1\n\n        h = self.block1(x, time_scale_shift=time_scale_shift)\n\n        h = self.block2(h)\n        return h + self.res_conv(x)\n\nclass ResnetBlock_ca_mul(nn.Module):\n    def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, pose_emb_dim=None, eye_emb_dim=None, groups=8):\n        super().__init__()\n        self.time_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(time_emb_dim, dim_out * 2)\n        ) if exists(time_emb_dim) else None\n\n        self.audio_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(audio_emb_dim, dim_out * 2)\n        ) if exists(audio_emb_dim) else None\n\n        self.pose_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(pose_emb_dim, dim_out * 2)\n        ) if exists(pose_emb_dim) else None\n\n        self.eye_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(eye_emb_dim, dim_out * 2)\n        ) if exists(eye_emb_dim) else None\n\n        self.audio_emb_dim = audio_emb_dim\n        self.pose_emb_dim = pose_emb_dim\n        self.eye_emb_dim = eye_emb_dim\n        # self.audio_mlp_2 = nn.Sequential(\n        #     nn.SiLU(),\n        #     nn.Linear(dim_out, dim_out * 2)\n        # ) if exists(audio_emb_dim) else None\n\n        attn_klass = CrossAttention\n\n        self.cross_attn_aud = attn_klass(\n                dim = dim,\n                context_dim = dim_out * 2,\n                out_dim = dim_out\n            )\n\n        self.cross_attn_pose = attn_klass(\n                dim = dim,\n                context_dim = dim_out * 2,\n                out_dim = dim_out\n            )\n        \n        self.cross_attn_eye = attn_klass(\n                dim = dim,\n                context_dim = dim_out * 2,\n                out_dim = dim_out\n            )\n\n\n        self.block1 = Block(dim, dim_out, groups=groups)\n        self.block2 = Block(dim_out, dim_out, groups=groups)\n        self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n\n    def forward(self, x, time_emb=None, audio_emb=None):\n        time_scale_shift = None\n        audio_scale_shift = None\n        '''\n            need seperate 3 diffiserent condition\n        '''\n        if exists(audio_emb):\n            pose_emb = audio_emb[:,:,self.audio_emb_dim:self.audio_emb_dim + self.pose_emb_dim]\n            eye_emb = audio_emb[:,:,self.audio_emb_dim + self.pose_emb_dim: ]\n            audio_emb = audio_emb[:,:,:self.audio_emb_dim]\n\n        b, c, f, H, W = x.size()\n        if exists(self.time_mlp):\n            assert exists(time_emb), 'time emb must be passed in'\n            time_emb = self.time_mlp(time_emb)\n            time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1 \n            time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1 \n\n        # added by lml to get audio embedding\n        if exists(self.audio_mlp):  # mouth lmk + audio emb\n            assert exists(audio_emb), 'audio emb must be passed in'\n            audio_emb = self.audio_mlp(audio_emb)\n            pose_emb = self.pose_mlp(pose_emb)  # TODO: embedding\n            eye_emb = self.eye_mlp(eye_emb)\n            if exists(self.cross_attn_aud):\n                # h = rearrange(x, 'b c f ... -> (b f) ... c')\n                # h, ps = pack([h], 'b * c')\n                # audio_emb = rearrange(audio_emb, 'b f ... -> (b f) ...')\n                # audio_emb = self.cross_attn(h, context = audio_emb)\n\n                # # h, = unpack(h, ps, 'b * c')\n                # # h = rearrange(h, '(b f) ... c -> b c f ...', b = b, f = f, c = c)\n                # # audio_emb = self.audio_mlp_2(audio_emb)\n                # audio_emb = rearrange(audio_emb, '(b f) ... -> b f ...', b = b, f = f)\n                assert exists(audio_emb)\n                h_cond = rearrange(x, 'b c f ... -> (b f) ... c')\n                # h = rearrange(x, 'b c ... -> b ... c')\n                h_cond, ps = pack([h_cond], 'b * c')\n\n\n                h_pose = self.cross_attn_pose(h_cond, context = pose_emb)\n                h_aud = self.cross_attn_aud(h_cond, context = audio_emb)\n                h_eye = self.cross_attn_eye(h_cond, context = eye_emb)\n\n                h_cond = h_pose + h_aud + h_eye\n\n\n                h_cond, = unpack(h_cond, ps, 'b * c')\n                # h = rearrange(h, 'b ... c -> b c ...')\n                h_cond = rearrange(h_cond, '(b f) ... c -> b c f ...', b = b, f = f)\n\n            # audio_emb = rearrange(audio_emb, 'b f (h w) c -> b c f h w', w = W, h = H) # bs, 128, nf, 1, 1 \n            # audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1\n\n        h = self.block1(x, time_scale_shift=time_scale_shift)\n\n        if exists(self.audio_mlp):\n            h = h_cond + h\n\n        h = self.block2(h)\n        return h + self.res_conv(x)\n\nclass CrossAttention(nn.Module):\n    def __init__(\n        self,\n        dim,\n        out_dim,\n        *,\n        context_dim = None,\n        dim_head = 8,\n        heads = 8,\n        norm_context = False,\n        scale = 8\n    ):\n        super().__init__()\n        self.scale = scale\n\n        self.heads = heads\n        inner_dim = dim_head * heads\n\n        context_dim = default(context_dim, dim)\n\n        self.norm = LayerNorm_img(dim)\n        self.norm_context = LayerNorm_img(context_dim) if norm_context else Identity()\n\n        self.null_kv = nn.Parameter(torch.randn(2, dim_head))\n        self.to_q = nn.Linear(dim, inner_dim, bias = False)\n        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)\n\n        self.q_scale = nn.Parameter(torch.ones(dim_head))\n        self.k_scale = nn.Parameter(torch.ones(dim_head))\n\n        self.to_out = nn.Sequential(\n            nn.Linear(inner_dim, out_dim, bias = False),\n            LayerNorm_img(out_dim)\n        )\n\n    def forward(self, x, context, mask = None):\n        b, n, device = *x.shape[:2], x.device\n\n        x = self.norm(x)  # bn * fn ?\n        # context: b, fn, c\n        context = rearrange(context, 'b f c -> (b f) c')\n        context = self.norm_context(context)\n\n        q, k, v = (self.to_q(x), *self.to_kv(context[:, None, :]).chunk(2, dim = -1))\n\n        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))\n\n        # add null key / value for classifier free guidance in prior net\n\n        nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads,  b = b), self.null_kv.unbind(dim = -2))\n\n        k = torch.cat((nk, k), dim = -2)\n        v = torch.cat((nv, v), dim = -2)\n\n        # cosine sim attention\n\n        q, k = map(l2norm, (q, k))\n        q = q * self.q_scale\n        k = k * self.k_scale\n\n        # similarities\n\n        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale\n\n        # masking\n\n        max_neg_value = -torch.finfo(sim.dtype).max\n\n        if exists(mask):\n            mask = F.pad(mask, (1, 0), value = True)\n            mask = rearrange(mask, 'b j -> b 1 1 j')\n            sim = sim.masked_fill(~mask, max_neg_value)\n\n        attn = sim.softmax(dim = -1, dtype = torch.float32)\n        attn = attn.to(sim.dtype)\n\n        out = einsum('b h i j, b h j d -> b h i d', attn, v)\n        out = rearrange(out, 'b h n d -> b n (h d)')\n        return self.to_out(out)\n\nclass LinearCrossAttention(CrossAttention):\n    def forward(self, x, context, mask = None):\n        b, n, c = x.size()\n        b, n, device = *x.shape[:2], x.device   # x : b * fn, 32*32, c\n\n        x = self.norm(x)\n        context = self.norm_context(context)\n\n        q, k, v = (self.to_q(x), *self.to_kv(context[:, None, :]).chunk(2, dim = -1))  # b*fn, 32*32, c, b*fn, 1, c * 2, \n\n        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads, d = c//self.heads), (q, k, v)) # head * b*fn, n, c//head\n\n        # add null key / value for classifier free guidance in prior net\n\n        nk, nv = map(lambda t: repeat(t, 'd -> (b h) 1 d', h = self.heads,  b = b), self.null_kv.unbind(dim = -2))\n\n        k = torch.cat((nk, k), dim = -2)  # b * nf * h, 2, c//h\n        v = torch.cat((nv, v), dim = -2)\n\n        # masking\n\n        max_neg_value = -torch.finfo(x.dtype).max\n\n        if exists(mask):\n            mask = F.pad(mask, (1, 0), value = True)\n            mask = rearrange(mask, 'b n -> b n 1')\n            k = k.masked_fill(~mask, max_neg_value)\n            v = v.masked_fill(~mask, 0.)\n\n        # linear attention\n\n        q = q.softmax(dim = -1) # # b * nf * h, 32*32, c//h,\n        k = k.softmax(dim = -2)\n\n        q = q * self.scale\n\n        context = einsum('b n d, b n e -> b d e', k, v) # b * nf * h, 2, c//h,  b * nf * h, 2, c//h\n        out = einsum('b n d, b d e -> b n e', q, context)\n        out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)\n        return self.to_out(out)\n\nclass SpatialLinearAttention(nn.Module):\n    def __init__(self, dim, heads=4, dim_head=32):\n        super().__init__()\n        self.scale = dim_head ** -0.5\n        self.heads = heads\n        hidden_dim = dim_head * heads\n        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)\n        self.to_out = nn.Conv2d(hidden_dim, dim, 1)\n\n    def forward(self, x):\n        b, c, f, h, w = x.shape\n        x = rearrange(x, 'b c f h w -> (b f) c h w')\n\n        qkv = self.to_qkv(x).chunk(3, dim=1)\n        q, k, v = rearrange_many(qkv, 'b (h c) x y -> b h c (x y)', h=self.heads)\n\n        q = q.softmax(dim=-2)\n        k = k.softmax(dim=-1)\n\n        q = q * self.scale\n        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)\n\n        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)\n        out = rearrange(out, 'b h c (x y) -> b (h c) x y', h=self.heads, x=h, y=w)\n        out = self.to_out(out)\n        return rearrange(out, '(b f) c h w -> b c f h w', b=b)\n\n\n# attention along space and time\n\nclass EinopsToAndFrom(nn.Module):\n    def __init__(self, from_einops, to_einops, fn):\n        super().__init__()\n        self.from_einops = from_einops\n        self.to_einops = to_einops\n        self.fn = fn\n\n    def forward(self, x, **kwargs):\n        shape = x.shape\n        reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape)))\n        x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')\n        x = self.fn(x, **kwargs)\n        x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(\n            self,\n            dim,\n            heads=4,\n            dim_head=32,\n            rotary_emb=None\n    ):\n        super().__init__()\n        self.scale = dim_head ** -0.5\n        self.heads = heads\n        hidden_dim = dim_head * heads\n\n        self.rotary_emb = rotary_emb\n        self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False)\n        self.to_out = nn.Linear(hidden_dim, dim, bias=False)\n\n    def forward(\n            self,\n            x,\n            pos_bias=None,\n            focus_present_mask=None\n    ):  # temperal: 'b (h w) f c'  ; spatial :  'b f (h w) c'\n        n, device = x.shape[-2], x.device\n\n        qkv = self.to_qkv(x).chunk(3, dim=-1)\n\n        if exists(focus_present_mask) and focus_present_mask.all():\n            # if all batch samples are focusing on present\n            # it would be equivalent to passing that token's values through to the output\n            values = qkv[-1]\n            return self.to_out(values)\n\n        # split out heads\n\n        q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads)\n\n        # scale\n\n        q = q * self.scale\n\n        # rotate positions into queries and keys for time attention\n\n        if exists(self.rotary_emb):\n            q = self.rotary_emb.rotate_queries_or_keys(q)\n            k = self.rotary_emb.rotate_queries_or_keys(k)\n\n        # similarity\n\n        sim = einsum('... h i d, ... h j d -> ... h i j', q, k)\n\n        # relative positional bias\n\n        if exists(pos_bias):\n            sim = sim + pos_bias\n\n        if exists(focus_present_mask) and not (~focus_present_mask).all():\n            attend_all_mask = torch.ones((n, n), device=device, dtype=torch.bool)\n            attend_self_mask = torch.eye(n, device=device, dtype=torch.bool)\n\n            mask = torch.where(\n                rearrange(focus_present_mask, 'b -> b 1 1 1 1'),\n                rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),\n                rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),\n            )\n\n            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)\n\n        # numerical stability\n\n        sim = sim - sim.amax(dim=-1, keepdim=True).detach()\n        attn = sim.softmax(dim=-1)\n\n        # aggregate values\n\n        out = einsum('... h i j, ... h j d -> ... h i d', attn, v)\n        out = rearrange(out, '... h n d -> ... n (h d)')\n        return self.to_out(out)\n\n# model\nclass Unet3D(nn.Module):\n    def __init__(\n            self,\n            dim,\n            cond_aud=1024,\n            cond_pose=7,\n            cond_eye=2,\n            cond_dim=None,\n            out_grid_dim=2,\n            out_conf_dim=1,\n            num_frames=40,\n            dim_mults=(1, 2, 4, 8),\n            channels=3,\n            attn_heads=8,\n            attn_dim_head=32,\n            use_hubert_audio_cond=False,\n            init_dim=None,\n            init_kernel_size=7,\n            use_sparse_linear_attn=True,\n            resnet_groups=8,\n            use_final_activation=False,\n            learn_null_cond=False,\n            use_deconv=True,\n            padding_mode=\"zeros\",\n            win_width = 20\n    ):\n        super().__init__()\n        self.null_cond_mask = None\n        self.channels = channels\n        self.num_frames = num_frames\n        self.HUBERT_MODEL_DIM = 1024\n        # temporal attention and its relative positional encoding\n\n        rotary_emb = RotaryEmbedding(min(32, attn_dim_head))\n\n        temporal_attn = lambda dim: EinopsToAndFrom('b c f h w', 'b (h w) f c',\n                                                    Attention(dim, heads=attn_heads, dim_head=attn_dim_head,\n                                                              rotary_emb=rotary_emb))\n\n        self.time_rel_pos_bias = RelativePositionBias(heads=attn_heads,\n                                                      max_distance=32, window_width = win_width)  # realistically will not be able to generate that many frames of video... yet\n\n        # initial conv\n\n        init_dim = default(init_dim, dim)\n        assert is_odd(init_kernel_size)\n\n        init_padding = init_kernel_size // 2\n        self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size, init_kernel_size),\n                                   padding=(0, init_padding, init_padding))\n\n        self.init_temporal_attn = Residual(PreNorm(init_dim, temporal_attn(init_dim)))\n\n        # dimensions\n\n        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]\n        in_out = list(zip(dims[:-1], dims[1:]))\n\n        # time conditioning\n\n        time_dim = dim * 4\n        self.time_mlp = nn.Sequential(\n            SinusoidalPosEmb(dim),\n            nn.Linear(dim, time_dim),\n            nn.GELU(),\n            nn.Linear(time_dim, time_dim)\n        )\n\n        # audio conditioning\n\n        self.has_cond = exists(cond_dim) or use_hubert_audio_cond\n        self.cond_dim = cond_dim\n        self.cond_aud_dim = cond_aud\n        self.cond_pose_dim = cond_pose\n        self.cond_eye_dim = cond_eye\n\n        # modified by lml\n        self.learn_null_cond = learn_null_cond\n\n\n        # cat(t,cond) is not suitable\n        # cond_dim = time_dim + int(cond_dim or 0)\n\n        # layers\n\n        self.downs = nn.ModuleList([])\n        self.ups = nn.ModuleList([])\n\n        num_resolutions = len(in_out)\n\n        # block type\n\n        block_klass = partial(ResnetBlock_ca_mul, groups=resnet_groups)\n        block_klass_cond = partial(block_klass, time_emb_dim=time_dim, audio_emb_dim=self.cond_aud_dim, pose_emb_dim=self.cond_pose_dim, eye_emb_dim=self.cond_eye_dim)\n        # block_klass_cond = partial(block_klass, time_emb_dim=cond_dim) # cat embedding\n\n        # modules for all layers\n\n        for ind, (dim_in, dim_out) in enumerate(in_out):\n            is_last = ind >= (num_resolutions - 1)\n\n            self.downs.append(nn.ModuleList([\n                block_klass_cond(dim_in, dim_out),\n                block_klass_cond(dim_out, dim_out),\n                Residual(PreNorm(dim_out, SpatialLinearAttention(dim_out,\n                                                                 heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(),\n                Residual(PreNorm(dim_out, temporal_attn(dim_out))),\n                Downsample(dim_out) if not is_last else nn.Identity()\n            ]))\n\n        mid_dim = dims[-1]\n        self.mid_block1 = block_klass_cond(mid_dim, mid_dim)\n\n        spatial_attn = EinopsToAndFrom('b c f h w', 'b f (h w) c', Attention(mid_dim, heads=attn_heads))\n\n        self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn))\n        self.mid_temporal_attn = Residual(PreNorm(mid_dim, temporal_attn(mid_dim)))\n\n        self.mid_block2 = block_klass_cond(mid_dim, mid_dim)\n\n        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):\n            is_last = ind >= (num_resolutions - 1)\n\n            self.ups.append(nn.ModuleList([\n                block_klass_cond(dim_out * 2, dim_in),\n                block_klass_cond(dim_in, dim_in),\n                Residual(PreNorm(dim_in, SpatialLinearAttention(dim_in,\n                                                                heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(),\n                Residual(PreNorm(dim_in, temporal_attn(dim_in))),\n                Upsample(dim_in, use_deconv, padding_mode) if not is_last else nn.Identity()\n            ]))\n\n        # out_dim = default(out_grid_dim, channels)\n        self.final_conv = nn.Sequential(\n            block_klass(dim * 2, dim),\n            nn.Conv3d(dim, out_grid_dim, 1)\n        )\n\n        # added by nhm\n        self.use_final_activation = use_final_activation\n        if self.use_final_activation:\n            self.final_activation = nn.Tanh()\n        else:\n            self.final_activation = nn.Identity()\n\n        # added by nhm for predicting occlusion mask\n        self.occlusion_map = nn.Sequential(\n            block_klass(dim * 2, dim),\n            nn.Conv3d(dim, out_conf_dim, 1)\n        )\n\n    def forward_with_cond_scale(\n            self,\n            *args,\n            cond_scale=2.,\n            **kwargs\n    ):\n        logits = self.forward(*args, null_cond_prob=0., **kwargs)\n        if cond_scale == 1 or not self.has_cond:\n            return logits\n\n        null_logits = self.forward(*args, null_cond_prob=1., **kwargs)\n        return null_logits + (logits - null_logits) * cond_scale\n\n    def forward(\n            self,\n            x,\n            time,\n            cond=None,\n            null_cond_prob=0.,\n            focus_present_mask=None,\n            prob_focus_present=0.\n            # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time)\n    ):\n        assert not (self.has_cond and not exists(cond)), 'cond must be passed in if cond_dim specified'\n        batch, device = x.shape[0], x.device\n\n        focus_present_mask = default(focus_present_mask,\n                                     lambda: prob_mask_like((batch,), prob_focus_present, device=device))\n\n        time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device)\n\n        x = self.init_conv(x)\n        r = x.clone()\n\n        x = self.init_temporal_attn(x, pos_bias=time_rel_pos_bias)\n\n        t = self.time_mlp(time) if exists(self.time_mlp) else None\n\n        if self.learn_null_cond:\n            self.null_cond_emb = nn.Parameter(torch.randn(1, self.num_frames, self.cond_dim)) if self.has_cond else None\n        else:\n            self.null_cond_emb = torch.zeros(1, self.num_frames, self.cond_dim) if self.has_cond else None\n        # classifier free guidance\n\n        if self.has_cond:\n            batch, device = x.shape[0], x.device\n            self.null_cond_mask = prob_mask_like((batch, self.num_frames,), null_cond_prob, device=device)\n            cond = torch.where(rearrange(self.null_cond_mask, 'b n -> b n 1'), self.null_cond_emb.to(cond.device), cond) \n            # t (bs, 256)  cond (bs, nf*1024)->(bs, nf, 1024) in this version\n            \n            # it's the original cond embedding method used in LFDM\n            # t = torch.cat((t, cond), dim=-1)\n\n        h = []\n\n        for block1, block2, spatial_attn, temporal_attn, downsample in self.downs:\n            x = block1(x, t, cond)\n            x = block2(x, t, cond)\n            x = spatial_attn(x)\n            x = temporal_attn(x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask)\n            h.append(x)\n            x = downsample(x)\n\n        x = self.mid_block1(x, t, cond)\n        x = self.mid_spatial_attn(x)\n        x = self.mid_temporal_attn(x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask)\n        x = self.mid_block2(x, t, cond)\n\n        for block1, block2, spatial_attn, temporal_attn, upsample in self.ups:\n            x = torch.cat((x, h.pop()), dim=1)\n            x = block1(x, t, cond)\n            x = block2(x, t, cond)\n            x = spatial_attn(x)\n            x = temporal_attn(x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask)\n            x = upsample(x)\n\n        x = torch.cat((x, r), dim=1)\n        return torch.cat((self.final_conv(x), self.occlusion_map(x)), dim=1)\n\n# to dynamically change num_frames of Unet3D\nclass DynamicNfUnet3D(Unet3D):\n    def __init__(self, default_num_frames=20, *args, **kwargs):\n        super(DynamicNfUnet3D, self).__init__(*args, **kwargs)\n        self.default_num_frames = default_num_frames\n        self.num_frames = default_num_frames\n    def update_num_frames(self, new_num_frames):\n        self.num_frames = new_num_frames\n\n# gaussian diffusion trainer class\n\ndef extract(a, t, x_shape):\n    b, *_ = t.shape\n    out = a.gather(-1, t)\n    return out.reshape(b, *((1,) * (len(x_shape) - 1)))\n\n\ndef cosine_beta_schedule(timesteps, s=0.008):\n    \"\"\"\n    cosine schedule\n    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ\n    \"\"\"\n    steps = timesteps + 1\n    x = torch.linspace(0, timesteps, steps, dtype=torch.float64)\n    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2\n    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]\n    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])\n    return torch.clip(betas, 0, 0.9999)\n\n\nclass GaussianDiffusion(nn.Module):\n    def __init__(\n            self,\n            denoise_fn,\n            *,\n            image_size,\n            num_frames,\n            text_use_bert_cls=False,\n            channels=3,\n            timesteps=1000,\n            sampling_timesteps=250,\n            ddim_sampling_eta=1.,\n            loss_type='l1',\n            use_dynamic_thres=False,  # from the Imagen paper\n            dynamic_thres_percentile=0.9,\n            null_cond_prob=0.1\n    ):\n        super().__init__()\n        self.null_cond_prob = null_cond_prob\n        self.channels = channels\n        self.image_size = image_size\n        self.num_frames = num_frames\n        self.denoise_fn = denoise_fn\n\n        betas = cosine_beta_schedule(timesteps)\n\n        alphas = 1. - betas\n        alphas_cumprod = torch.cumprod(alphas, axis=0)\n        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)\n\n        timesteps, = betas.shape\n        self.num_timesteps = int(timesteps)\n        self.loss_type = loss_type\n\n        self.sampling_timesteps = default(sampling_timesteps,\n                                          timesteps)\n        self.is_ddim_sampling = self.sampling_timesteps < timesteps\n        self.ddim_sampling_eta = ddim_sampling_eta\n\n        # register buffer helper function that casts float64 to float32\n\n        register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))\n\n        register_buffer('betas', betas)\n        register_buffer('alphas_cumprod', alphas_cumprod)\n        register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n\n        register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))\n        register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))\n        register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))\n        register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))\n        register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))\n\n        # calculations for posterior q(x_{t-1} | x_t, x_0)\n\n        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)\n\n        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)\n\n        register_buffer('posterior_variance', posterior_variance)\n\n        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain\n\n        register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20)))\n        register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))\n        register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))\n\n        # text conditioning parameters\n\n        self.text_use_bert_cls = text_use_bert_cls\n\n        # dynamic thresholding when sampling\n\n        self.use_dynamic_thres = use_dynamic_thres\n        self.dynamic_thres_percentile = dynamic_thres_percentile\n\n    def q_mean_variance(self, x_start, t):\n        mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start\n        variance = extract(1. - self.alphas_cumprod, t, x_start.shape)\n        log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)\n        return mean, variance, log_variance\n\n    def predict_start_from_noise(self, x_t, t, noise):\n        return (\n                extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -\n                extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise\n        )\n\n    def q_posterior(self, x_start, x_t, t):\n        posterior_mean = (\n                extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +\n                extract(self.posterior_mean_coef2, t, x_t.shape) * x_t\n        )\n        posterior_variance = extract(self.posterior_variance, t, x_t.shape)\n        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)\n        return posterior_mean, posterior_variance, posterior_log_variance_clipped\n\n    def p_mean_variance(self, x, t, fea, clip_denoised: bool, cond=None, cond_scale=1.):\n        fea = fea.unsqueeze(dim=2).repeat(1, 1, x.size(2), 1, 1)\n        x_recon = self.predict_start_from_noise(x, t=t, noise=self.denoise_fn.forward_with_cond_scale(torch.cat([x, fea], dim=1),\n                                                                                                      t,\n                                                                                                      cond=cond,\n                                                                                                      cond_scale=cond_scale))\n\n        if clip_denoised:\n            s = 1.\n            if self.use_dynamic_thres:\n                s = torch.quantile(\n                    rearrange(x_recon, 'b ... -> b (...)').abs(),\n                    self.dynamic_thres_percentile,\n                    dim=-1\n                )\n\n                s.clamp_(min=1.)\n                s = s.view(-1, *((1,) * (x_recon.ndim - 1)))\n\n            # clip by threshold, depending on whether static or dynamic\n            x_recon = x_recon.clamp(-s, s) / s\n\n        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)\n        return model_mean, posterior_variance, posterior_log_variance\n\n    @torch.inference_mode()\n    def p_sample(self, x, t, fea, cond=None, cond_scale=1., clip_denoised=True):\n        b, *_, device = *x.shape, x.device\n        model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, fea=fea,\n                                                                 clip_denoised=clip_denoised, cond=cond,\n                                                                 cond_scale=cond_scale)\n        noise = torch.randn_like(x)\n        # no noise when t == 0\n        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))\n        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise\n\n    @torch.inference_mode()\n    def p_sample_loop(self, fea, shape, cond=None, cond_scale=1.):\n        device = self.betas.device\n\n        b = shape[0]\n        img = torch.randn(shape, device=device)\n\n        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):\n            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), fea, cond=cond,\n                                cond_scale=cond_scale)\n\n        return img\n        # return unnormalize_img(img)\n\n    @torch.inference_mode()\n    def sample(self, fea, bbox_mask, cond=None, cond_scale=1., batch_size=16):\n        # text bert: cond 1,768\n        # device = next(self.denoise_fn.parameters()).device\n        # if is_list_str(cond):\n        #     cond = torch.rand((1 ,768), dtype=torch.float32).cuda()  #used to debug\n            # cond = bert_embed(tokenize(cond), return_cls_repr=self.text_use_bert_cls).to(device)\n\n        batch_size = cond.shape[0] if exists(cond) else batch_size\n        # batch_size = 1 if exists(cond) else batch_size\n        image_size = self.image_size\n        channels = self.channels\n        num_frames = self.num_frames\n        sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample\n        fea = torch.cat([fea, bbox_mask], dim=1)\n        return sample_fn(fea, (batch_size, channels, num_frames, fea.shape[-1], fea.shape[-1]), cond=cond,\n                         cond_scale=cond_scale)\n\n    # add by nhm\n    @torch.no_grad()\n    def ddim_sample(self, fea, shape, cond=None, cond_scale=1., clip_denoised=True):\n\n        batch, device, total_timesteps, sampling_timesteps, eta = \\\n            shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta\n\n        times = torch.linspace(0., total_timesteps, steps=sampling_timesteps + 2)[:-1]\n        times = list(reversed(times.int().tolist()))\n        time_pairs = list(zip(times[:-1], times[1:]))\n\n        img = torch.randn(shape, device=device) # bs, 3, nf, 32, 32\n        fea = fea.unsqueeze(dim=2).repeat(1, 1, img.size(2), 1, 1) #bs, 256, nf, 32, 32\n\n        for time, time_next in tqdm(time_pairs, desc='sampling loop time step'):\n            alpha = self.alphas_cumprod_prev[time]\n            alpha_next = self.alphas_cumprod_prev[time_next]\n\n            time_cond = torch.full((batch,), time, device=device, dtype=torch.long)\n\n            # pred_noise, x_start, *_ = self.model_predictions(img, time_cond, fea)\n            pred_noise = self.denoise_fn.forward_with_cond_scale(\n                torch.cat([img, fea], dim=1),\n                time_cond,\n                cond=cond,\n                cond_scale=cond_scale)\n            x_start = self.predict_start_from_noise(img, t=time_cond, noise=pred_noise)\n\n            if clip_denoised:\n                s = 1.\n                if self.use_dynamic_thres:\n                    s = torch.quantile(\n                        rearrange(x_start, 'b ... -> b (...)').abs(),\n                        self.dynamic_thres_percentile,\n                        dim=-1\n                    )\n\n                    s.clamp_(min=1.)\n                    s = s.view(-1, *((1,) * (x_start.ndim - 1)))\n\n                # clip by threshold, depending on whether static or dynamic\n                x_start = x_start.clamp(-s, s) / s\n\n            sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()\n            c = ((1 - alpha_next) - sigma ** 2).sqrt()\n\n            noise = torch.randn_like(img) if time_next > 0 else 0.\n\n            img = x_start * alpha_next.sqrt() + \\\n                  c * pred_noise + \\\n                  sigma * noise\n\n        # img = unnormalize_to_zero_to_one(img)\n        return img\n\n    @torch.inference_mode()\n    def interpolate(self, x1, x2, t=None, lam=0.5):\n        b, *_, device = *x1.shape, x1.device\n        t = default(t, self.num_timesteps - 1)\n\n        assert x1.shape == x2.shape\n\n        t_batched = torch.stack([torch.tensor(t, device=device)] * b)\n        xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))\n\n        img = (1 - lam) * xt1 + lam * xt2\n        for i in tqdm(reversed(range(0, t)), desc='interpolation sample time step', total=t):\n            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))\n\n        return img\n\n    def q_sample(self, x_start, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x_start))\n\n        return (\n                extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +\n                extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise\n        )\n\n    def p_losses(self, x_start, t, fea, bbox_mask, cond=None, noise=None, clip_denoised=True, **kwargs):\n        # x_start: bs, 3, num_frame, 32, 32\n        # t: bs\n        # fea: bs, 256, num_frame, 32, 32\n        # cond: bs, 768\n        b, c, f, h, w, device = *x_start.shape, x_start.device\n        noise = default(noise, lambda: torch.randn_like(x_start)) # bs, 3, nf, 32, 32\n\n        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)# bs, 3, nf, 32, 32\n\n        pred_noise = self.denoise_fn.forward(torch.cat([x_noisy, fea, bbox_mask], dim=1), t, cond=cond,\n                                             null_cond_prob=self.null_cond_prob,\n                                             **kwargs)\n\n        if self.loss_type == 'l1':\n            loss = F.l1_loss(noise, pred_noise, reduce=False)\n        elif self.loss_type == 'l2':\n            loss = F.mse_loss(noise, pred_noise, reduce=False)\n        else:\n            raise NotImplementedError()\n  \n        pred_x0 = self.predict_start_from_noise(x_noisy, t, pred_noise)\n\n        if clip_denoised:\n            s = 1.\n            if self.use_dynamic_thres:\n                s = torch.quantile(\n                    rearrange(pred_x0, 'b ... -> b (...)').abs(),\n                    self.dynamic_thres_percentile,\n                    dim=-1\n                )\n\n                s.clamp_(min=1.)\n                s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))\n\n            # clip by threshold, depending on whether static or dynamic\n            self.pred_x0 = pred_x0.clamp(-s, s) / s\n\n        return loss, self.denoise_fn.null_cond_mask\n\n    def forward(self, x, fea, bbox_mask, cond, *args, **kwargs):\n        b, device, img_size, = x.shape[0], x.device, self.image_size\n        # check_shape(x, 'b c f h w', c=self.channels, f=self.num_frames, h=img_size, w=img_size)\n        t = torch.randint(0, self.num_timesteps, (b,), device=device).long()\n        fea = fea.unsqueeze(dim=2).repeat(1, 1, x.size(2), 1, 1)\n        bbox_mask = bbox_mask.unsqueeze(dim=2).repeat(1, 1, x.size(2), 1, 1)\n\n        return self.p_losses(x, t, fea, bbox_mask, cond, *args, **kwargs)\n\n\n# trainer class\n\nCHANNELS_TO_MODE = {\n    1: 'L',\n    3: 'RGB',\n    4: 'RGBA'\n}\n\n\ndef seek_all_images(img, channels=3):\n    assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid'\n    mode = CHANNELS_TO_MODE[channels]\n\n    i = 0\n    while True:\n        try:\n            img.seek(i)\n            yield img.convert(mode)\n        except EOFError:\n            break\n        i += 1\n\n# to dynamically change num_frames of GaussianDiffusion\nclass DynamicNfGaussianDiffusion(GaussianDiffusion):\n    def __init__(self, default_num_frames=20, *args, **kwargs):\n        super(DynamicNfGaussianDiffusion, self).__init__(*args, **kwargs)\n        self.default_num_frames = default_num_frames\n        self.num_frames = default_num_frames\n    def update_num_frames(self, new_num_frames):\n        self.num_frames = new_num_frames\n\n# tensor of shape (channels, frames, height, width) -> gif\n\ndef video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True):\n    images = map(T.ToPILImage(), tensor.unbind(dim=1))\n    first_img, *rest_imgs = images\n    first_img.save(path, save_all=True, append_images=rest_imgs, duration=duration, loop=loop, optimize=optimize)\n    return images\n\n\n# gif -> (channels, frame, height, width) tensor\n\ndef gif_to_tensor(path, channels=3, transform=T.ToTensor()):\n    img = Image.open(path)\n    tensors = tuple(map(transform, seek_all_images(img, channels=channels)))\n    return torch.stack(tensors, dim=1)\n\n\ndef identity(t, *args, **kwargs):\n    return t\n\n\ndef normalize_img(t):\n    return t * 2 - 1\n\n\n# def unnormalize_img(t):\n#     return (t + 1) * 0.5\n\n\ndef cast_num_frames(t, *, frames):\n    f = t.shape[1]\n\n    if f == frames:\n        return t\n\n    if f > frames:\n        return t[:, :frames]\n\n    return F.pad(t, (0, 0, 0, 0, 0, frames - f))\n\n\n"
  },
  {
    "path": "DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test_local_opt.py",
    "content": "'''\nadding pose condtioning on baseline\nusing cross attention to add different condition\n\nusing ram optimized local attention, for inference (slower, costing less ram)\n'''\nimport math\nimport torch\nfrom torch import nn, einsum\nimport torch.nn.functional as F\nfrom functools import partial\n\nfrom torchvision import transforms as T\nfrom PIL import Image\n\nfrom tqdm import tqdm\nfrom einops import rearrange, repeat, reduce, pack, unpack\nfrom einops_exts import rearrange_many\n\nfrom rotary_embedding_torch import RotaryEmbedding\nfrom DM_3.modules.local_attention import LocalSelfAttention_opt, create_sliding_window_mask\n\n# from DM.modules.text import tokenize, bert_embed, HUBERT_MODEL_DIM\n\n\n# helpers functions\n\ndef exists(x):\n    return x is not None\n\n\ndef noop(*args, **kwargs):\n    pass\n\n\ndef is_odd(n):\n    return (n % 2) == 1\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if callable(d) else d\n\n\ndef cycle(dl):\n    while True:\n        for data in dl:\n            yield data\n\n\ndef num_to_groups(num, divisor):\n    groups = num // divisor\n    remainder = num % divisor\n    arr = [divisor] * groups\n    if remainder > 0:\n        arr.append(remainder)\n    return arr\n\n\ndef prob_mask_like(shape, prob, device):\n    if prob == 1:\n        return torch.ones(shape, device=device, dtype=torch.bool)\n    elif prob == 0:\n        return torch.zeros(shape, device=device, dtype=torch.bool)\n    else:\n        return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob\n\n\ndef is_list_str(x):\n    if not isinstance(x, (list, tuple)):\n        return False\n    return all([type(el) == str for el in x])\n\n\n# relative positional bias\n\nclass RelativePositionBias(nn.Module):\n    def __init__(\n            self,\n            heads=8,\n            num_buckets=32,\n            max_distance=128,\n            window_width = 20\n    ):\n        super().__init__()\n        self.num_buckets = num_buckets\n        self.max_distance = max_distance\n        self.relative_attention_bias = nn.Embedding(num_buckets, heads)\n        self.window_width = window_width\n\n    @staticmethod\n    def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):\n        ret = 0\n        n = -relative_position\n\n        num_buckets //= 2\n        ret += (n < 0).long() * num_buckets\n        n = torch.abs(n)\n\n        max_exact = num_buckets // 2\n        is_small = n < max_exact\n\n        val_if_large = max_exact + (\n                torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)\n        ).long()\n        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))\n\n        ret += torch.where(is_small, n, val_if_large)\n        return ret\n\n    def forward(self, n, device):\n        q_pos = torch.arange(n, dtype=torch.long, device=device)\n        k_pos = torch.arange(n, dtype=torch.long, device=device)\n        rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')\n        rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets,\n                                                   max_distance=self.max_distance)\n        mask = -(((rel_pos > self.window_width) + (rel_pos < -self.window_width)) * (1e8)) # -(((rp_bucket ==15) + (rp_bucket >= 30)) * (1e8))\n        values = self.relative_attention_bias(rp_bucket)\n        return rearrange(values, 'i j h -> h i j') + mask\n\n\n\n# small helper modules\n\nclass EMA():\n    def __init__(self, beta):\n        super().__init__()\n        self.beta = beta\n\n    def update_model_average(self, ma_model, current_model):\n        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):\n            old_weight, up_weight = ma_params.data, current_params.data\n            ma_params.data = self.update_average(old_weight, up_weight)\n\n    def update_average(self, old, new):\n        if old is None:\n            return new\n        return old * self.beta + (1 - self.beta) * new\n\n\nclass Residual(nn.Module):\n    def __init__(self, fn):\n        super().__init__()\n        self.fn = fn\n\n    def forward(self, x, *args, **kwargs):\n        return self.fn(x, *args, **kwargs) + x\n\n\nclass SinusoidalPosEmb(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        self.dim = dim\n\n    def forward(self, x):\n        device = x.device\n        half_dim = self.dim // 2\n        emb = math.log(10000) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)\n        emb = x[:, None] * emb[None, :]\n        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)\n        return emb\n\n\ndef Upsample(dim, use_deconv=True, padding_mode=\"reflect\"):\n    if use_deconv:\n        return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))\n    else:\n        return nn.Sequential(\n            nn.Upsample(scale_factor=(1, 2, 2), mode='nearest'),\n            nn.Conv3d(dim, dim, (1, 3, 3), (1, 1, 1), (0, 1, 1), padding_mode=padding_mode)\n        )\n\n\ndef Downsample(dim):\n    return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))\n\n\nclass LayerNorm(nn.Module):\n    def __init__(self, dim, eps=1e-5):\n        super().__init__()\n        self.eps = eps\n        self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1))\n\n    def forward(self, x):\n        var = torch.var(x, dim=1, unbiased=False, keepdim=True)\n        mean = torch.mean(x, dim=1, keepdim=True)\n        return (x - mean) / (var + self.eps).sqrt() * self.gamma\n\nclass LayerNorm_img(nn.Module):\n    def __init__(self, dim, stable = False):\n        super().__init__()\n        self.stable = stable\n        self.g = nn.Parameter(torch.ones(dim))\n\n    def forward(self, x):\n        if self.stable:\n            x = x / x.amax(dim = -1, keepdim = True).detach()\n\n        eps = 1e-5 if x.dtype == torch.float32 else 1e-3\n        var = torch.var(x, dim = -1, unbiased = False, keepdim = True)\n        mean = torch.mean(x, dim = -1, keepdim = True)\n        return (x - mean) * (var + eps).rsqrt() * self.g\n\nclass PreNorm(nn.Module):\n    def __init__(self, dim, fn):\n        super().__init__()\n        self.fn = fn\n        self.norm = LayerNorm(dim)\n\n    def forward(self, x, **kwargs):\n        x = self.norm(x)\n        return self.fn(x, **kwargs)\n\nclass Identity(nn.Module):\n    def __init__(self, *args, **kwargs):\n        super().__init__()\n\n    def forward(self, x, *args, **kwargs):\n        return x\n\ndef l2norm(t):\n    return F.normalize(t, dim = -1)\n# building block modules\n\nclass Block(nn.Module):\n    def __init__(self, dim, dim_out, groups=8):\n        super().__init__()\n        self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1))\n        self.norm = nn.GroupNorm(groups, dim_out)\n        self.act = nn.SiLU()\n\n    def forward(self, x, time_scale_shift=None, audio_scale_shift=None):\n        x = self.proj(x)\n        x = self.norm(x)\n\n        if exists(time_scale_shift):\n            time_scale, time_shift = time_scale_shift\n            x = x * (time_scale + 1) + time_shift\n        \n        # added by lml to change the control method of audio embedding, inspired by diffusedhead\n        # if exists(audio_scale_shift):\n        #     # audio_scale and audio_shift:(bs, 64, nf, 1, 1) \n        #     # x:(bs, 64, nf, 32, 32)\n        #     audio_scale, audio_shift = audio_scale_shift\n        #     x = x * (audio_scale + 1) + audio_shift\n\n        return self.act(x)\n\n\nclass ResnetBlock(nn.Module):\n    def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, groups=8):\n        super().__init__()\n        self.time_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(time_emb_dim, dim_out * 2)\n        ) if exists(time_emb_dim) else None\n\n        self.audio_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(audio_emb_dim, dim_out * 2)\n        ) if exists(audio_emb_dim) else None\n\n        self.block1 = Block(dim, dim_out, groups=groups)\n        self.block2 = Block(dim_out, dim_out, groups=groups)\n        self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n\n    def forward(self, x, time_emb=None, audio_emb=None):\n        time_scale_shift = None\n        audio_scale_shift = None\n        if exists(self.time_mlp):\n            assert exists(time_emb), 'time emb must be passed in'\n            time_emb = self.time_mlp(time_emb)\n            time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1 \n            time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1 \n\n        # added by lml to get audio embedding\n        if exists(self.audio_mlp):\n            assert exists(audio_emb), 'audio emb must be passed in'\n            audio_emb = self.audio_mlp(audio_emb)\n            audio_emb = rearrange(audio_emb, 'b n c -> b c n 1 1') # bs, 128, nf, 1, 1 \n            audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1\n\n        h = self.block1(x, time_scale_shift=time_scale_shift, audio_scale_shift=audio_scale_shift)\n\n        h = self.block2(h)\n        return h + self.res_conv(x)\n\nclass ResnetBlock_ca(nn.Module):\n    def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, groups=8):\n        super().__init__()\n        self.time_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(time_emb_dim, dim_out * 2)\n        ) if exists(time_emb_dim) else None\n\n        self.audio_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(audio_emb_dim, dim_out * 2)\n        ) if exists(audio_emb_dim) else None\n\n        # self.audio_mlp_2 = nn.Sequential(\n        #     nn.SiLU(),\n        #     nn.Linear(dim_out, dim_out * 2)\n        # ) if exists(audio_emb_dim) else None\n\n        attn_klass = CrossAttention\n\n        self.cross_attn = attn_klass(\n                dim = dim,\n                context_dim = dim_out * 2\n            )\n\n\n        self.block1 = Block(dim, dim_out, groups=groups)\n        self.block2 = Block(dim_out, dim_out, groups=groups)\n        self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n\n    def forward(self, x, time_emb=None, audio_emb=None):\n        time_scale_shift = None\n        audio_scale_shift = None\n        b, c, f, H, W = x.size()\n        if exists(self.time_mlp):\n            assert exists(time_emb), 'time emb must be passed in'\n            time_emb = self.time_mlp(time_emb)\n            time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1 \n            time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1 \n\n        # added by lml to get audio embedding\n        if exists(self.audio_mlp):\n            assert exists(audio_emb), 'audio emb must be passed in'\n            audio_emb = self.audio_mlp(audio_emb)\n\n            if exists(self.cross_attn):\n                # h = rearrange(x, 'b c f ... -> (b f) ... c')\n                # h, ps = pack([h], 'b * c')\n                # audio_emb = rearrange(audio_emb, 'b f ... -> (b f) ...')\n                # audio_emb = self.cross_attn(h, context = audio_emb)\n\n                # # h, = unpack(h, ps, 'b * c')\n                # # h = rearrange(h, '(b f) ... c -> b c f ...', b = b, f = f, c = c)\n                # # audio_emb = self.audio_mlp_2(audio_emb)\n                # audio_emb = rearrange(audio_emb, '(b f) ... -> b f ...', b = b, f = f)\n                assert exists(audio_emb)\n                h = rearrange(x, 'b c f ... -> (b f) ... c')\n                # h = rearrange(x, 'b c ... -> b ... c')\n                h, ps = pack([h], 'b * c')\n\n                h = self.cross_attn(h, context = audio_emb) + h\n\n                h, = unpack(h, ps, 'b * c')\n                # h = rearrange(h, 'b ... c -> b c ...')\n                h = rearrange(h, '(b f) ... c -> b f c ...', b = b, f = f)\n\n            # audio_emb = rearrange(audio_emb, 'b f (h w) c -> b c f h w', w = W, h = H) # bs, 128, nf, 1, 1 \n            # audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1\n\n        h = self.block1(x, time_scale_shift=time_scale_shift)\n\n        h = self.block2(h)\n        return h + self.res_conv(x)\n\nclass ResnetBlock_ca_mul(nn.Module):\n    def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, pose_emb_dim=None, eye_emb_dim=None, groups=8):\n        super().__init__()\n        self.time_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(time_emb_dim, dim_out * 2)\n        ) if exists(time_emb_dim) else None\n\n        self.audio_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(audio_emb_dim, dim_out * 2)\n        ) if exists(audio_emb_dim) else None\n\n        self.pose_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(pose_emb_dim, dim_out * 2)\n        ) if exists(pose_emb_dim) else None\n\n        self.eye_mlp = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(eye_emb_dim, dim_out * 2)\n        ) if exists(eye_emb_dim) else None\n\n        self.audio_emb_dim = audio_emb_dim\n        self.pose_emb_dim = pose_emb_dim\n        self.eye_emb_dim = eye_emb_dim\n        # self.audio_mlp_2 = nn.Sequential(\n        #     nn.SiLU(),\n        #     nn.Linear(dim_out, dim_out * 2)\n        # ) if exists(audio_emb_dim) else None\n\n        attn_klass = CrossAttention\n\n        self.cross_attn_aud = attn_klass(\n                dim = dim,\n                context_dim = dim_out * 2,\n                out_dim = dim_out\n            )\n\n        self.cross_attn_pose = attn_klass(\n                dim = dim,\n                context_dim = dim_out * 2,\n                out_dim = dim_out\n            )\n        \n        self.cross_attn_eye = attn_klass(\n                dim = dim,\n                context_dim = dim_out * 2,\n                out_dim = dim_out\n            )\n\n\n        self.block1 = Block(dim, dim_out, groups=groups)\n        self.block2 = Block(dim_out, dim_out, groups=groups)\n        self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n\n    def forward(self, x, time_emb=None, audio_emb=None):\n        time_scale_shift = None\n        audio_scale_shift = None\n        '''\n            need seperate 3 diffiserent condition\n        '''\n        if exists(audio_emb):\n            pose_emb = audio_emb[:,:,self.audio_emb_dim:self.audio_emb_dim + self.pose_emb_dim]\n            eye_emb = audio_emb[:,:,self.audio_emb_dim + self.pose_emb_dim: ]\n            audio_emb = audio_emb[:,:,:self.audio_emb_dim]\n\n        b, c, f, H, W = x.size()\n        if exists(self.time_mlp):\n            assert exists(time_emb), 'time emb must be passed in'\n            time_emb = self.time_mlp(time_emb)\n            time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1 \n            time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1 \n\n        # added by lml to get audio embedding\n        if exists(self.audio_mlp):  # mouth lmk + audio emb\n            assert exists(audio_emb), 'audio emb must be passed in'\n            audio_emb = self.audio_mlp(audio_emb)\n            pose_emb = self.pose_mlp(pose_emb)  # TODO: embedding\n            eye_emb = self.eye_mlp(eye_emb)\n            if exists(self.cross_attn_aud):\n                # h = rearrange(x, 'b c f ... -> (b f) ... c')\n                # h, ps = pack([h], 'b * c')\n                # audio_emb = rearrange(audio_emb, 'b f ... -> (b f) ...')\n                # audio_emb = self.cross_attn(h, context = audio_emb)\n\n                # # h, = unpack(h, ps, 'b * c')\n                # # h = rearrange(h, '(b f) ... c -> b c f ...', b = b, f = f, c = c)\n                # # audio_emb = self.audio_mlp_2(audio_emb)\n                # audio_emb = rearrange(audio_emb, '(b f) ... -> b f ...', b = b, f = f)\n                assert exists(audio_emb)\n                h_cond = rearrange(x, 'b c f ... -> (b f) ... c')\n                # h = rearrange(x, 'b c ... -> b ... c')\n                h_cond, ps = pack([h_cond], 'b * c')\n\n\n                h_pose = self.cross_attn_pose(h_cond, context = pose_emb)\n                h_aud = self.cross_attn_aud(h_cond, context = audio_emb)\n                h_eye = self.cross_attn_eye(h_cond, context = eye_emb)\n\n                h_cond = h_pose + h_aud + h_eye\n\n\n                h_cond, = unpack(h_cond, ps, 'b * c')\n                # h = rearrange(h, 'b ... c -> b c ...')\n                h_cond = rearrange(h_cond, '(b f) ... c -> b c f ...', b = b, f = f)\n\n            # audio_emb = rearrange(audio_emb, 'b f (h w) c -> b c f h w', w = W, h = H) # bs, 128, nf, 1, 1 \n            # audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1\n\n        h = self.block1(x, time_scale_shift=time_scale_shift)\n\n        if exists(self.audio_mlp):\n            h = h_cond + h\n\n        h = self.block2(h)\n        return h + self.res_conv(x)\n\nclass CrossAttention(nn.Module):\n    def __init__(\n        self,\n        dim,\n        out_dim,\n        *,\n        context_dim = None,\n        dim_head = 8,\n        heads = 8,\n        norm_context = False,\n        scale = 8\n    ):\n        super().__init__()\n        self.scale = scale\n\n        self.heads = heads\n        inner_dim = dim_head * heads\n\n        context_dim = default(context_dim, dim)\n\n        self.norm = LayerNorm_img(dim)\n        self.norm_context = LayerNorm_img(context_dim) if norm_context else Identity()\n\n        self.null_kv = nn.Parameter(torch.randn(2, dim_head))\n        self.to_q = nn.Linear(dim, inner_dim, bias = False)\n        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)\n\n        self.q_scale = nn.Parameter(torch.ones(dim_head))\n        self.k_scale = nn.Parameter(torch.ones(dim_head))\n\n        self.to_out = nn.Sequential(\n            nn.Linear(inner_dim, out_dim, bias = False),\n            LayerNorm_img(out_dim)\n        )\n\n    def forward(self, x, context, mask = None):\n        b, n, device = *x.shape[:2], x.device\n\n        x = self.norm(x)  # bn * fn ?\n        # context: b, fn, c\n        context = rearrange(context, 'b f c -> (b f) c')\n        context = self.norm_context(context)\n\n        q, k, v = (self.to_q(x), *self.to_kv(context[:, None, :]).chunk(2, dim = -1))\n\n        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))\n\n        # add null key / value for classifier free guidance in prior net\n\n        nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads,  b = b), self.null_kv.unbind(dim = -2))\n\n        k = torch.cat((nk, k), dim = -2)\n        v = torch.cat((nv, v), dim = -2)\n\n        # cosine sim attention\n\n        q, k = map(l2norm, (q, k))\n        q = q * self.q_scale\n        k = k * self.k_scale\n\n        # similarities\n\n        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale\n\n        # masking\n\n        max_neg_value = -torch.finfo(sim.dtype).max\n\n        if exists(mask):\n            mask = F.pad(mask, (1, 0), value = True)\n            mask = rearrange(mask, 'b j -> b 1 1 j')\n            sim = sim.masked_fill(~mask, max_neg_value)\n\n        attn = sim.softmax(dim = -1, dtype = torch.float32)\n        attn = attn.to(sim.dtype)\n\n        out = einsum('b h i j, b h j d -> b h i d', attn, v)\n        out = rearrange(out, 'b h n d -> b n (h d)')\n        return self.to_out(out)\n\nclass LinearCrossAttention(CrossAttention):\n    def forward(self, x, context, mask = None):\n        b, n, c = x.size()\n        b, n, device = *x.shape[:2], x.device   # x : b * fn, 32*32, c\n\n        x = self.norm(x)\n        context = self.norm_context(context)\n\n        q, k, v = (self.to_q(x), *self.to_kv(context[:, None, :]).chunk(2, dim = -1))  # b*fn, 32*32, c, b*fn, 1, c * 2, \n\n        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads, d = c//self.heads), (q, k, v)) # head * b*fn, n, c//head\n\n        # add null key / value for classifier free guidance in prior net\n\n        nk, nv = map(lambda t: repeat(t, 'd -> (b h) 1 d', h = self.heads,  b = b), self.null_kv.unbind(dim = -2))\n\n        k = torch.cat((nk, k), dim = -2)  # b * nf * h, 2, c//h\n        v = torch.cat((nv, v), dim = -2)\n\n        # masking\n\n        max_neg_value = -torch.finfo(x.dtype).max\n\n        if exists(mask):\n            mask = F.pad(mask, (1, 0), value = True)\n            mask = rearrange(mask, 'b n -> b n 1')\n            k = k.masked_fill(~mask, max_neg_value)\n            v = v.masked_fill(~mask, 0.)\n\n        # linear attention\n\n        q = q.softmax(dim = -1) # # b * nf * h, 32*32, c//h,\n        k = k.softmax(dim = -2)\n\n        q = q * self.scale\n\n        context = einsum('b n d, b n e -> b d e', k, v) # b * nf * h, 2, c//h,  b * nf * h, 2, c//h\n        out = einsum('b n d, b d e -> b n e', q, context)\n        out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)\n        return self.to_out(out)\n\nclass SpatialLinearAttention(nn.Module):\n    def __init__(self, dim, heads=4, dim_head=32):\n        super().__init__()\n        self.scale = dim_head ** -0.5\n        self.heads = heads\n        hidden_dim = dim_head * heads\n        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)\n        self.to_out = nn.Conv2d(hidden_dim, dim, 1)\n\n    def forward(self, x):\n        b, c, f, h, w = x.shape\n        x = rearrange(x, 'b c f h w -> (b f) c h w')\n\n        qkv = self.to_qkv(x).chunk(3, dim=1)\n        q, k, v = rearrange_many(qkv, 'b (h c) x y -> b h c (x y)', h=self.heads)\n\n        q = q.softmax(dim=-2)\n        k = k.softmax(dim=-1)\n\n        q = q * self.scale\n        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)\n\n        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)\n        out = rearrange(out, 'b h c (x y) -> b (h c) x y', h=self.heads, x=h, y=w)\n        out = self.to_out(out)\n        return rearrange(out, '(b f) c h w -> b c f h w', b=b)\n\n\n# attention along space and time\n\nclass EinopsToAndFrom(nn.Module):\n    def __init__(self, from_einops, to_einops, fn):\n        super().__init__()\n        self.from_einops = from_einops\n        self.to_einops = to_einops\n        self.fn = fn\n\n    def forward(self, x, **kwargs):\n        shape = x.shape\n        reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape)))\n        x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')\n        x = self.fn(x, **kwargs)\n        x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(\n            self,\n            dim,\n            heads=4,\n            dim_head=32,\n            rotary_emb=None\n    ):\n        super().__init__()\n        self.scale = dim_head ** -0.5\n        self.heads = heads\n        hidden_dim = dim_head * heads\n\n        self.rotary_emb = rotary_emb\n        self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False)\n        self.to_out = nn.Linear(hidden_dim, dim, bias=False)\n\n    def forward(\n            self,\n            x,\n            pos_bias=None,\n            focus_present_mask=None\n    ):  # temperal: 'b (h w) f c'  ; spatial :  'b f (h w) c'\n        n, device = x.shape[-2], x.device\n\n        qkv = self.to_qkv(x).chunk(3, dim=-1)\n\n        if exists(focus_present_mask) and focus_present_mask.all():\n            # if all batch samples are focusing on present\n            # it would be equivalent to passing that token's values through to the output\n            values = qkv[-1]\n            return self.to_out(values)\n\n        # split out heads\n\n        q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads)\n\n        # scale\n\n        q = q * self.scale\n\n        # rotate positions into queries and keys for time attention\n\n        if exists(self.rotary_emb):\n            q = self.rotary_emb.rotate_queries_or_keys(q)\n            k = self.rotary_emb.rotate_queries_or_keys(k)\n\n        # similarity\n\n        sim = einsum('... h i d, ... h j d -> ... h i j', q, k)\n\n        # relative positional bias\n\n        if exists(pos_bias):\n            sim = sim + pos_bias\n\n        if exists(focus_present_mask) and not (~focus_present_mask).all():\n            attend_all_mask = torch.ones((n, n), device=device, dtype=torch.bool)\n            attend_self_mask = torch.eye(n, device=device, dtype=torch.bool)\n\n            mask = torch.where(\n                rearrange(focus_present_mask, 'b -> b 1 1 1 1'),\n                rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),\n                rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),\n            )\n\n            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)\n\n        # numerical stability\n\n        sim = sim - sim.amax(dim=-1, keepdim=True).detach()\n        attn = sim.softmax(dim=-1)\n\n        # aggregate values\n\n        out = einsum('... h i j, ... h j d -> ... h i d', attn, v)\n        out = rearrange(out, '... h n d -> ... n (h d)')\n        return self.to_out(out)\n\n# model\nclass Unet3D(nn.Module):\n    def __init__(\n            self,\n            dim,\n            cond_aud=1024,\n            cond_pose=7,\n            cond_eye=2,\n            cond_dim=None,\n            out_grid_dim=2,\n            out_conf_dim=1,\n            num_frames=40,\n            dim_mults=(1, 2, 4, 8),\n            channels=3,\n            attn_heads=8,\n            attn_dim_head=32,\n            use_hubert_audio_cond=False,\n            init_dim=None,\n            init_kernel_size=7,\n            use_sparse_linear_attn=True,\n            resnet_groups=8,\n            use_final_activation=False,\n            learn_null_cond=False,\n            use_deconv=True,\n            padding_mode=\"zeros\",\n            win_width = 20\n    ):\n        super().__init__()\n        self.null_cond_mask = None\n        self.channels = channels\n        self.num_frames = num_frames\n        self.HUBERT_MODEL_DIM = 1024\n        self.win_width = win_width\n        # temporal attention and its relative positional encoding\n\n        rotary_emb = RotaryEmbedding(min(32, attn_dim_head), seq_before_head_dim = True)\n\n        temporal_attn = lambda dim: EinopsToAndFrom('b c f h w', 'b (h w) f c',\n                                                    LocalSelfAttention_opt(dim, heads=attn_heads, size_per_head=attn_dim_head, neighbors=self.win_width,\n                                                              rotary_emb=rotary_emb))\n        \n        self.time_rel_pos_bias = RelativePositionBias(heads=attn_heads,\n                                                      max_distance=32, window_width = self.win_width )  # realistically will not be able to generate that many frames of video... yet\n\n        # initial conv\n\n        init_dim = default(init_dim, dim)\n        assert is_odd(init_kernel_size)\n\n        init_padding = init_kernel_size // 2\n        self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size, init_kernel_size),\n                                   padding=(0, init_padding, init_padding))\n\n        self.init_temporal_attn = Residual(PreNorm(init_dim, temporal_attn(init_dim)))\n\n        # dimensions\n\n        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]\n        in_out = list(zip(dims[:-1], dims[1:]))\n\n        # time conditioning\n\n        time_dim = dim * 4\n        self.time_mlp = nn.Sequential(\n            SinusoidalPosEmb(dim),\n            nn.Linear(dim, time_dim),\n            nn.GELU(),\n            nn.Linear(time_dim, time_dim)\n        )\n\n        # audio conditioning\n\n        self.has_cond = exists(cond_dim) or use_hubert_audio_cond\n        self.cond_dim = cond_dim\n        self.cond_aud_dim = cond_aud\n        self.cond_pose_dim = cond_pose\n        self.cond_eye_dim = cond_eye\n\n        # modified by lml\n        self.learn_null_cond = learn_null_cond\n\n\n        # cat(t,cond) is not suitable\n        # cond_dim = time_dim + int(cond_dim or 0)\n\n        # layers\n\n        self.downs = nn.ModuleList([])\n        self.ups = nn.ModuleList([])\n\n        num_resolutions = len(in_out)\n\n        # block type\n\n        block_klass = partial(ResnetBlock_ca_mul, groups=resnet_groups)\n        block_klass_cond = partial(block_klass, time_emb_dim=time_dim, audio_emb_dim=self.cond_aud_dim, pose_emb_dim=self.cond_pose_dim, eye_emb_dim=self.cond_eye_dim)\n        # block_klass_cond = partial(block_klass, time_emb_dim=cond_dim) # cat embedding\n\n        # modules for all layers\n\n        for ind, (dim_in, dim_out) in enumerate(in_out):\n            is_last = ind >= (num_resolutions - 1)\n\n            self.downs.append(nn.ModuleList([\n                block_klass_cond(dim_in, dim_out),\n                block_klass_cond(dim_out, dim_out),\n                Residual(PreNorm(dim_out, SpatialLinearAttention(dim_out,\n                                                                 heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(),\n                Residual(PreNorm(dim_out, temporal_attn(dim_out))),\n                Downsample(dim_out) if not is_last else nn.Identity()\n            ]))\n\n        mid_dim = dims[-1]\n        self.mid_block1 = block_klass_cond(mid_dim, mid_dim)\n\n        spatial_attn = EinopsToAndFrom('b c f h w', 'b f (h w) c', Attention(mid_dim, heads=attn_heads))\n\n        self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn))\n        self.mid_temporal_attn = Residual(PreNorm(mid_dim, temporal_attn(mid_dim)))\n\n        self.mid_block2 = block_klass_cond(mid_dim, mid_dim)\n\n        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):\n            is_last = ind >= (num_resolutions - 1)\n\n            self.ups.append(nn.ModuleList([\n                block_klass_cond(dim_out * 2, dim_in),\n                block_klass_cond(dim_in, dim_in),\n                Residual(PreNorm(dim_in, SpatialLinearAttention(dim_in,\n                                                                heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(),\n                Residual(PreNorm(dim_in, temporal_attn(dim_in))),\n                Upsample(dim_in, use_deconv, padding_mode) if not is_last else nn.Identity()\n            ]))\n\n        # out_dim = default(out_grid_dim, channels)\n        self.final_conv = nn.Sequential(\n            block_klass(dim * 2, dim),\n            nn.Conv3d(dim, out_grid_dim, 1)\n        )\n\n        # added by nhm\n        self.use_final_activation = use_final_activation\n        if self.use_final_activation:\n            self.final_activation = nn.Tanh()\n        else:\n            self.final_activation = nn.Identity()\n\n        # added by nhm for predicting occlusion mask\n        self.occlusion_map = nn.Sequential(\n            block_klass(dim * 2, dim),\n            nn.Conv3d(dim, out_conf_dim, 1)\n        )\n\n    def forward_with_cond_scale(\n            self,\n            *args,\n            cond_scale=2.,\n            **kwargs\n    ):\n        logits = self.forward(*args, null_cond_prob=0., **kwargs)\n        if cond_scale == 1 or not self.has_cond:\n            return logits\n\n        null_logits = self.forward(*args, null_cond_prob=1., **kwargs)\n        return null_logits + (logits - null_logits) * cond_scale\n\n    def forward(\n            self,\n            x,\n            time,\n            cond=None,\n            null_cond_prob=0.,\n            focus_present_mask=None,\n            prob_focus_present=0.\n            # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time)\n    ):\n        assert not (self.has_cond and not exists(cond)), 'cond must be passed in if cond_dim specified'\n        batch, device = x.shape[0], x.device\n\n        focus_present_mask = default(focus_present_mask,\n                                     lambda: prob_mask_like((batch,), prob_focus_present, device=device))\n\n        time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device)\n        time_rel_pos_bias = create_sliding_window_mask(time_rel_pos_bias, 2 * self.win_width  + 1, 1)\n        x = self.init_conv(x)\n        r = x.clone()\n\n        x = self.init_temporal_attn(x, pos_bias=time_rel_pos_bias)\n\n        t = self.time_mlp(time) if exists(self.time_mlp) else None\n\n        if self.learn_null_cond:\n            self.null_cond_emb = nn.Parameter(torch.randn(1, self.num_frames, self.cond_dim)) if self.has_cond else None\n        else:\n            self.null_cond_emb = torch.zeros(1, self.num_frames, self.cond_dim) if self.has_cond else None\n        # classifier free guidance\n\n        if self.has_cond:\n            batch, device = x.shape[0], x.device\n            self.null_cond_mask = prob_mask_like((batch, self.num_frames,), null_cond_prob, device=device)\n            cond = torch.where(rearrange(self.null_cond_mask, 'b n -> b n 1'), self.null_cond_emb.to(cond.device), cond) \n            # t (bs, 256)  cond (bs, nf*1024)->(bs, nf, 1024) in this version\n            \n            # it's the original cond embedding method used in LFDM\n            # t = torch.cat((t, cond), dim=-1)\n\n        h = []\n\n        for block1, block2, spatial_attn, temporal_attn, downsample in self.downs:\n            x = block1(x, t, cond)\n            x = block2(x, t, cond)\n            x = spatial_attn(x)\n            x = temporal_attn(x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask)\n            h.append(x)\n            x = downsample(x)\n\n        x = self.mid_block1(x, t, cond)\n        x = self.mid_spatial_attn(x)\n        x = self.mid_temporal_attn(x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask)\n        x = self.mid_block2(x, t, cond)\n\n        for block1, block2, spatial_attn, temporal_attn, upsample in self.ups:\n            x = torch.cat((x, h.pop()), dim=1)\n            x = block1(x, t, cond)\n            x = block2(x, t, cond)\n            x = spatial_attn(x)\n            x = temporal_attn(x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask)\n            x = upsample(x)\n\n        x = torch.cat((x, r), dim=1)\n        return torch.cat((self.final_conv(x), self.occlusion_map(x)), dim=1)\n\n# to dynamically change num_frames of Unet3D\nclass DynamicNfUnet3D(Unet3D):\n    def __init__(self, default_num_frames=20, *args, **kwargs):\n        super(DynamicNfUnet3D, self).__init__(*args, **kwargs)\n        self.default_num_frames = default_num_frames\n        self.num_frames = default_num_frames\n    def update_num_frames(self, new_num_frames):\n        self.num_frames = new_num_frames\n\n# gaussian diffusion trainer class\n\ndef extract(a, t, x_shape):\n    b, *_ = t.shape\n    out = a.gather(-1, t)\n    return out.reshape(b, *((1,) * (len(x_shape) - 1)))\n\n\ndef cosine_beta_schedule(timesteps, s=0.008):\n    \"\"\"\n    cosine schedule\n    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ\n    \"\"\"\n    steps = timesteps + 1\n    x = torch.linspace(0, timesteps, steps, dtype=torch.float64)\n    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2\n    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]\n    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])\n    return torch.clip(betas, 0, 0.9999)\n\n\nclass GaussianDiffusion(nn.Module):\n    def __init__(\n            self,\n            denoise_fn,\n            *,\n            image_size,\n            num_frames,\n            text_use_bert_cls=False,\n            channels=3,\n            timesteps=1000,\n            sampling_timesteps=250,\n            ddim_sampling_eta=1.,\n            loss_type='l1',\n            use_dynamic_thres=False,  # from the Imagen paper\n            dynamic_thres_percentile=0.9,\n            null_cond_prob=0.1\n    ):\n        super().__init__()\n        self.null_cond_prob = null_cond_prob\n        self.channels = channels\n        self.image_size = image_size\n        self.num_frames = num_frames\n        self.denoise_fn = denoise_fn\n\n        betas = cosine_beta_schedule(timesteps)\n\n        alphas = 1. - betas\n        alphas_cumprod = torch.cumprod(alphas, axis=0)\n        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)\n\n        timesteps, = betas.shape\n        self.num_timesteps = int(timesteps)\n        self.loss_type = loss_type\n\n        self.sampling_timesteps = default(sampling_timesteps,\n                                          timesteps)\n        self.is_ddim_sampling = self.sampling_timesteps < timesteps\n        self.ddim_sampling_eta = ddim_sampling_eta\n\n        # register buffer helper function that casts float64 to float32\n\n        register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))\n\n        register_buffer('betas', betas)\n        register_buffer('alphas_cumprod', alphas_cumprod)\n        register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n\n        register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))\n        register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))\n        register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))\n        register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))\n        register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))\n\n        # calculations for posterior q(x_{t-1} | x_t, x_0)\n\n        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)\n\n        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)\n\n        register_buffer('posterior_variance', posterior_variance)\n\n        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain\n\n        register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20)))\n        register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))\n        register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))\n\n        # text conditioning parameters\n\n        self.text_use_bert_cls = text_use_bert_cls\n\n        # dynamic thresholding when sampling\n\n        self.use_dynamic_thres = use_dynamic_thres\n        self.dynamic_thres_percentile = dynamic_thres_percentile\n\n    def q_mean_variance(self, x_start, t):\n        mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start\n        variance = extract(1. - self.alphas_cumprod, t, x_start.shape)\n        log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)\n        return mean, variance, log_variance\n\n    def predict_start_from_noise(self, x_t, t, noise):\n        return (\n                extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -\n                extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise\n        )\n\n    def q_posterior(self, x_start, x_t, t):\n        posterior_mean = (\n                extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +\n                extract(self.posterior_mean_coef2, t, x_t.shape) * x_t\n        )\n        posterior_variance = extract(self.posterior_variance, t, x_t.shape)\n        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)\n        return posterior_mean, posterior_variance, posterior_log_variance_clipped\n\n    def p_mean_variance(self, x, t, fea, clip_denoised: bool, cond=None, cond_scale=1.):\n        fea = fea.unsqueeze(dim=2).repeat(1, 1, x.size(2), 1, 1)\n        x_recon = self.predict_start_from_noise(x, t=t, noise=self.denoise_fn.forward_with_cond_scale(torch.cat([x, fea], dim=1),\n                                                                                                      t,\n                                                                                                      cond=cond,\n                                                                                                      cond_scale=cond_scale))\n\n        if clip_denoised:\n            s = 1.\n            if self.use_dynamic_thres:\n                s = torch.quantile(\n                    rearrange(x_recon, 'b ... -> b (...)').abs(),\n                    self.dynamic_thres_percentile,\n                    dim=-1\n                )\n\n                s.clamp_(min=1.)\n                s = s.view(-1, *((1,) * (x_recon.ndim - 1)))\n\n            # clip by threshold, depending on whether static or dynamic\n            x_recon = x_recon.clamp(-s, s) / s\n\n        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)\n        return model_mean, posterior_variance, posterior_log_variance\n\n    @torch.inference_mode()\n    def p_sample(self, x, t, fea, cond=None, cond_scale=1., clip_denoised=True):\n        b, *_, device = *x.shape, x.device\n        model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, fea=fea,\n                                                                 clip_denoised=clip_denoised, cond=cond,\n                                                                 cond_scale=cond_scale)\n        noise = torch.randn_like(x)\n        # no noise when t == 0\n        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))\n        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise\n\n    @torch.inference_mode()\n    def p_sample_loop(self, fea, shape, cond=None, cond_scale=1.):\n        device = self.betas.device\n\n        b = shape[0]\n        img = torch.randn(shape, device=device)\n\n        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):\n            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), fea, cond=cond,\n                                cond_scale=cond_scale)\n\n        return img\n        # return unnormalize_img(img)\n\n    @torch.inference_mode()\n    def sample(self, fea, bbox_mask, cond=None, cond_scale=1., batch_size=16):\n        # text bert: cond 1,768\n        # device = next(self.denoise_fn.parameters()).device\n        # if is_list_str(cond):\n        #     cond = torch.rand((1 ,768), dtype=torch.float32).cuda()  #used to debug\n            # cond = bert_embed(tokenize(cond), return_cls_repr=self.text_use_bert_cls).to(device)\n\n        batch_size = cond.shape[0] if exists(cond) else batch_size\n        # batch_size = 1 if exists(cond) else batch_size\n        image_size = self.image_size\n        channels = self.channels\n        num_frames = self.num_frames\n        sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample\n        fea = torch.cat([fea, bbox_mask], dim=1)\n        return sample_fn(fea, (batch_size, channels, num_frames, fea.shape[-1], fea.shape[-1]), cond=cond,\n                         cond_scale=cond_scale)\n\n    # add by nhm\n    @torch.no_grad()\n    def ddim_sample(self, fea, shape, cond=None, cond_scale=1., clip_denoised=True):\n\n        batch, device, total_timesteps, sampling_timesteps, eta = \\\n            shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta\n\n        times = torch.linspace(0., total_timesteps, steps=sampling_timesteps + 2)[:-1]\n        times = list(reversed(times.int().tolist()))\n        time_pairs = list(zip(times[:-1], times[1:]))\n\n        img = torch.randn(shape, device=device) # bs, 3, nf, 32, 32\n        fea = fea.unsqueeze(dim=2).repeat(1, 1, img.size(2), 1, 1) #bs, 256, nf, 32, 32\n\n        for time, time_next in tqdm(time_pairs, desc='sampling loop time step'):\n            alpha = self.alphas_cumprod_prev[time]\n            alpha_next = self.alphas_cumprod_prev[time_next]\n\n            time_cond = torch.full((batch,), time, device=device, dtype=torch.long)\n\n            # pred_noise, x_start, *_ = self.model_predictions(img, time_cond, fea)\n            pred_noise = self.denoise_fn.forward_with_cond_scale(\n                torch.cat([img, fea], dim=1),\n                time_cond,\n                cond=cond,\n                cond_scale=cond_scale)\n            x_start = self.predict_start_from_noise(img, t=time_cond, noise=pred_noise)\n\n            if clip_denoised:\n                s = 1.\n                if self.use_dynamic_thres:\n                    s = torch.quantile(\n                        rearrange(x_start, 'b ... -> b (...)').abs(),\n                        self.dynamic_thres_percentile,\n                        dim=-1\n                    )\n\n                    s.clamp_(min=1.)\n                    s = s.view(-1, *((1,) * (x_start.ndim - 1)))\n\n                # clip by threshold, depending on whether static or dynamic\n                x_start = x_start.clamp(-s, s) / s\n\n            sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()\n            c = ((1 - alpha_next) - sigma ** 2).sqrt()\n\n            noise = torch.randn_like(img) if time_next > 0 else 0.\n\n            img = x_start * alpha_next.sqrt() + \\\n                  c * pred_noise + \\\n                  sigma * noise\n\n        # img = unnormalize_to_zero_to_one(img)\n        return img\n\n    @torch.inference_mode()\n    def interpolate(self, x1, x2, t=None, lam=0.5):\n        b, *_, device = *x1.shape, x1.device\n        t = default(t, self.num_timesteps - 1)\n\n        assert x1.shape == x2.shape\n\n        t_batched = torch.stack([torch.tensor(t, device=device)] * b)\n        xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))\n\n        img = (1 - lam) * xt1 + lam * xt2\n        for i in tqdm(reversed(range(0, t)), desc='interpolation sample time step', total=t):\n            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))\n\n        return img\n\n    def q_sample(self, x_start, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x_start))\n\n        return (\n                extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +\n                extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise\n        )\n\n    def p_losses(self, x_start, t, fea, bbox_mask, cond=None, noise=None, clip_denoised=True, **kwargs):\n        # x_start: bs, 3, num_frame, 32, 32\n        # t: bs\n        # fea: bs, 256, num_frame, 32, 32\n        # cond: bs, 768\n        b, c, f, h, w, device = *x_start.shape, x_start.device\n        noise = default(noise, lambda: torch.randn_like(x_start)) # bs, 3, nf, 32, 32\n\n        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)# bs, 3, nf, 32, 32\n\n        pred_noise = self.denoise_fn.forward(torch.cat([x_noisy, fea, bbox_mask], dim=1), t, cond=cond,\n                                             null_cond_prob=self.null_cond_prob,\n                                             **kwargs)\n\n        if self.loss_type == 'l1':\n            loss = F.l1_loss(noise, pred_noise, reduce=False)\n        elif self.loss_type == 'l2':\n            loss = F.mse_loss(noise, pred_noise, reduce=False)\n        else:\n            raise NotImplementedError()\n  \n        pred_x0 = self.predict_start_from_noise(x_noisy, t, pred_noise)\n\n        if clip_denoised:\n            s = 1.\n            if self.use_dynamic_thres:\n                s = torch.quantile(\n                    rearrange(pred_x0, 'b ... -> b (...)').abs(),\n                    self.dynamic_thres_percentile,\n                    dim=-1\n                )\n\n                s.clamp_(min=1.)\n                s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))\n\n            # clip by threshold, depending on whether static or dynamic\n            self.pred_x0 = pred_x0.clamp(-s, s) / s\n\n        return loss, self.denoise_fn.null_cond_mask\n\n    def forward(self, x, fea, bbox_mask, cond, *args, **kwargs):\n        b, device, img_size, = x.shape[0], x.device, self.image_size\n        # check_shape(x, 'b c f h w', c=self.channels, f=self.num_frames, h=img_size, w=img_size)\n        t = torch.randint(0, self.num_timesteps, (b,), device=device).long()\n        fea = fea.unsqueeze(dim=2).repeat(1, 1, x.size(2), 1, 1)\n        bbox_mask = bbox_mask.unsqueeze(dim=2).repeat(1, 1, x.size(2), 1, 1)\n\n        return self.p_losses(x, t, fea, bbox_mask, cond, *args, **kwargs)\n\n\n# trainer class\n\nCHANNELS_TO_MODE = {\n    1: 'L',\n    3: 'RGB',\n    4: 'RGBA'\n}\n\n\ndef seek_all_images(img, channels=3):\n    assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid'\n    mode = CHANNELS_TO_MODE[channels]\n\n    i = 0\n    while True:\n        try:\n            img.seek(i)\n            yield img.convert(mode)\n        except EOFError:\n            break\n        i += 1\n\n# to dynamically change num_frames of GaussianDiffusion\nclass DynamicNfGaussianDiffusion(GaussianDiffusion):\n    def __init__(self, default_num_frames=20, *args, **kwargs):\n        super(DynamicNfGaussianDiffusion, self).__init__(*args, **kwargs)\n        self.default_num_frames = default_num_frames\n        self.num_frames = default_num_frames\n    def update_num_frames(self, new_num_frames):\n        self.num_frames = new_num_frames\n\n# tensor of shape (channels, frames, height, width) -> gif\n\ndef video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True):\n    images = map(T.ToPILImage(), tensor.unbind(dim=1))\n    first_img, *rest_imgs = images\n    first_img.save(path, save_all=True, append_images=rest_imgs, duration=duration, loop=loop, optimize=optimize)\n    return images\n\n\n# gif -> (channels, frame, height, width) tensor\n\ndef gif_to_tensor(path, channels=3, transform=T.ToTensor()):\n    img = Image.open(path)\n    tensors = tuple(map(transform, seek_all_images(img, channels=channels)))\n    return torch.stack(tensors, dim=1)\n\n\ndef identity(t, *args, **kwargs):\n    return t\n\n\ndef normalize_img(t):\n    return t * 2 - 1\n\n\n# def unnormalize_img(t):\n#     return (t + 1) * 0.5\n\n\ndef cast_num_frames(t, *, frames):\n    f = t.shape[1]\n\n    if f == frames:\n        return t\n\n    if f > frames:\n        return t[:, :frames]\n\n    return F.pad(t, (0, 0, 0, 0, 0, frames - f))\n\n\n"
  },
  {
    "path": "DM_3/test_lr.py",
    "content": "import torch\nimport torch.optim as optim\nfrom torch.optim.lr_scheduler import CosineAnnealingLR\n\nmodel = torch.nn.Linear(2,4)\noptimizer = optim.Adam(model.parameters(), lr=2e-5)\n\nscheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-5)\n\nfor epoch in range(100):\n    scheduler.step()\n    print(optimizer.param_groups[0]['lr'])"
  },
  {
    "path": "DM_3/train_vdm_hdtf_wpose_plus_faceemb_init_cond_liploss_6D.py",
    "content": "import sys\nsys.path.append('your/path/')\n\nimport argparse\nfrom datetime import datetime, time\n\nimport imageio\nimport torch\nfrom torch.utils import data\nimport numpy as np\nimport torch.backends.cudnn as cudnn\nimport os\nimport os.path as osp\nimport timeit\nimport math\nfrom PIL import Image\nfrom misc import Logger, grid2fig, conf2fig\nfrom DM_3.datasets_hdtf_wpose_lmk_block_lmk import HDTF \nimport sys\nimport random\nfrom torch.utils.tensorboard import SummaryWriter \nfrom DM_3.utils import MultiEpochsDataLoader as DataLoader\nimport time\n\nfrom DM_3.modules.video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_mouth_mask_6D import FlowDiffusion\nfrom torch.optim.lr_scheduler import MultiStepLR\nfrom sync_batchnorm import DataParallelWithCallback\nimport torch.multiprocessing as mp\n\n\nstart = timeit.default_timer()\nBATCH_SIZE = 20\n# crema settings\n# MAX_EPOCH = 300\n# epoch_milestones = [800, 1000]\n# hdtf\nMAX_EPOCH = 500 * 25\nepoch_milestones = [800000, 10000000]\nroot_dir = 'your/path'\ndata_dir = \"your/image/path\"\npose_dir = \"your/pose/path\"\neye_blink_dir = \"your/blink/path\"\nGPU = \"2\"\npostfix = \"-j-of\"\njoint = \"joint\" in postfix or \"-j\" in postfix  # allow joint training with unconditional model\nonly_use_flow = \"onlyflow\" in postfix or \"-of\" in postfix  # whether only use flow loss\nvgg_weight = 0\nfloss_weight = 0.15\nif joint:\n    null_cond_prob = 0.1\nelse:\n    null_cond_prob = 0.0\nif \"upconv\" in postfix:\n    use_deconv = False\n    padding_mode = \"reflect\"\nelse:\n    use_deconv = True\n    padding_mode = \"zeros\"\nuse_residual_flow = \"-rf\" in postfix\nlearn_null_cond = \"-lnc\" in postfix\nINPUT_SIZE = 128\nMAX_N_FRAMES = 20  \nLEARNING_RATE = 2e-4\nRANDOM_SEED = 1234\nclip_c = 2.\nprint('use grad clip, clip = ', clip_c)\nMEAN = (0.0, 0.0, 0.0)\nconfig_pth = \"./config/hdtf128.yaml\"\n# PATH of LFG checkpoint\nAE_RESTORE_FROM = 'LFG/path'   \nRESTORE_FROM = '' # use existing checkpoint\nDM_LOG_PATH = os.path.join(root_dir,'data','HDTF_wpose_faceemb_newae_6Dpose', 'ca_init_cond_liploss','stage1_0ref_1000epae_v0_lr_N='+str(MAX_N_FRAMES))\nprint(DM_LOG_PATH)\nSNAPSHOT_DIR = os.path.join(DM_LOG_PATH, 'snapshots' + postfix)\nIMGSHOT_DIR = os.path.join(DM_LOG_PATH, 'imgshots' + postfix)\nVIDSHOT_DIR = os.path.join(DM_LOG_PATH, \"vidshots\" + postfix)\nSAMPLE_DIR = os.path.join(DM_LOG_PATH, 'sample' + postfix)\nNUM_EXAMPLES_PER_EPOCH = 400\nNUM_STEPS_PER_EPOCH = math.ceil(NUM_EXAMPLES_PER_EPOCH / float(BATCH_SIZE))\nMAX_ITER = max(NUM_EXAMPLES_PER_EPOCH * MAX_EPOCH + 1,\n               NUM_STEPS_PER_EPOCH * BATCH_SIZE * MAX_EPOCH + 1)\nSAVE_MODEL_EVERY = int(250000)\nSAVE_VID_EVERY = 4000\nSAMPLE_VID_EVERY = 2000\nUPDATE_MODEL_EVERY = 500\n\nos.makedirs(SNAPSHOT_DIR, exist_ok=True)\nos.makedirs(IMGSHOT_DIR, exist_ok=True)\nos.makedirs(VIDSHOT_DIR, exist_ok=True)\nos.makedirs(SAMPLE_DIR, exist_ok=True)\n\nLOG_PATH = SNAPSHOT_DIR + \"/B\" + format(BATCH_SIZE, \"04d\") + \"E\" + format(MAX_EPOCH, \"04d\") + \".log\"\nsys.stdout = Logger(LOG_PATH, sys.stdout)\nprint(root_dir)\nprint(\"update saved model every:\", UPDATE_MODEL_EVERY)\nprint(\"save model every:\", SAVE_MODEL_EVERY)\nprint(\"save video every:\", SAVE_VID_EVERY)\nprint(\"sample video every:\", SAMPLE_VID_EVERY)\nprint(postfix)\nprint(\"RESTORE_FROM\", RESTORE_FROM)\nprint(\"num examples per epoch:\", NUM_EXAMPLES_PER_EPOCH)\nprint(\"max epoch:\", MAX_EPOCH)\nprint(\"image size\", INPUT_SIZE)\nprint(\"epoch milestones:\", epoch_milestones)\nprint(\"only use flow loss:\", only_use_flow)\nprint(\"null_cond_prob:\", null_cond_prob)\nprint(\"use residual flow:\", use_residual_flow)\nprint(\"learn null cond:\", learn_null_cond)\nprint(\"use deconv:\", use_deconv)\n\n\ndef get_arguments():\n    \"\"\"Parse all the arguments provided from the CLI.\n\n    Returns:\n      A list of parsed arguments.\n    \"\"\"\n    parser = argparse.ArgumentParser(description=\"Flow Diffusion\")\n    parser.add_argument(\"--fine-tune\", default=False)\n    parser.add_argument(\"--set-start\", default=True)\n    parser.add_argument(\"--start-step\", default=0, type=int)\n    parser.add_argument(\"--img-dir\", type=str, default=IMGSHOT_DIR,\n                        help=\"Where to save images of the model.\")\n    parser.add_argument(\"--num-workers\", default=8)\n    parser.add_argument(\"--final-step\", type=int, default=int(NUM_STEPS_PER_EPOCH * MAX_EPOCH),\n                        help=\"Number of training steps.\")\n    parser.add_argument(\"--gpu\", default=GPU,\n                        help=\"choose gpu device.\")\n    parser.add_argument('--print-freq', '-p', default=2, type=int,\n                        metavar='N', help='print frequency')\n    parser.add_argument('--save-img-freq', default=2000, type=int,\n                        metavar='N', help='save image frequency')\n    parser.add_argument('--save-vid-freq', default=SAVE_VID_EVERY, type=int)\n    parser.add_argument('--sample-vid-freq', default=SAMPLE_VID_EVERY, type=int)\n    parser.add_argument(\"--batch-size\", type=int, default=BATCH_SIZE,\n                        help=\"Number of images sent to the network in one step.\")\n    parser.add_argument(\"--input-size\", type=str, default=INPUT_SIZE,\n                        help=\"Comma-separated string with height and width of images.\")\n    parser.add_argument(\"--learning-rate\", type=float, default=LEARNING_RATE,\n                        help=\"Base learning rate for training with polynomial decay.\")\n    parser.add_argument(\"--random-seed\", type=int, default=RANDOM_SEED,\n                        help=\"Random seed to have reproducible results.\")\n    parser.add_argument(\"--restore-from\", default=RESTORE_FROM)\n    parser.add_argument(\"--save-pred-every\", type=int, default=SAVE_MODEL_EVERY,\n                        help=\"Save checkpoint every often.\")\n    parser.add_argument(\"--update-pred-every\", type=int, default=UPDATE_MODEL_EVERY)\n    parser.add_argument(\"--snapshot-dir\", type=str, default=SNAPSHOT_DIR,\n                        help=\"Where to save snapshots of the model.\")\n    parser.add_argument(\"--fp16\", default=True)\n    parser.add_argument(\"--cosin\", default=True)\n    return parser.parse_args()\n\n\nargs = get_arguments()\n\n\ndef sample_img(rec_img_batch, idx=0):\n    rec_img = rec_img_batch[idx].permute(1, 2, 0).data.cpu().numpy().copy()\n    rec_img += np.array(MEAN) / 255.0\n    rec_img[rec_img < 0] = 0\n    rec_img[rec_img > 1] = 1\n    rec_img *= 255\n    return np.array(rec_img, np.uint8)\n\n\ndef main():\n    \"\"\"Create the model and start the training.\"\"\"\n\n    os.environ[\"CUDA_VISIBLE_DEVICES\"] = args.gpu\n\n    cudnn.enabled = True\n    cudnn.benchmark = True\n    setup_seed(args.random_seed)\n    writer = SummaryWriter(os.path.join(DM_LOG_PATH, 'tensorboard'))\n\n\n    model = FlowDiffusion(is_train=True,\n                          img_size=INPUT_SIZE // 4,\n                          num_frames=MAX_N_FRAMES,\n                          null_cond_prob=null_cond_prob,\n                          sampling_timesteps=20,\n                          use_residual_flow=use_residual_flow,\n                          learn_null_cond=learn_null_cond,\n                          use_deconv=use_deconv,\n                          padding_mode=padding_mode,\n                          config_pth=config_pth,\n                          pretrained_pth=AE_RESTORE_FROM)\n    model.cuda()\n    scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)\n\n    # Not set model to be train mode! Because pretrained flow autoenc need to be eval (BatchNorm)\n\n    # create optimizer\n    optimizer_diff = torch.optim.Adam(model.diffusion.parameters(),\n                                      lr=LEARNING_RATE, betas=(0.9, 0.99))\n\n    if args.fine_tune:\n        pass\n    elif args.restore_from:\n        if os.path.isfile(args.restore_from):\n            print(\"=> loading checkpoint '{}'\".format(args.restore_from))\n            checkpoint = torch.load(args.restore_from)\n            if args.set_start:\n                args.start_step = int(math.ceil(checkpoint['example'] / args.batch_size))\n            model_ckpt = model.diffusion.state_dict()\n            for name, _ in model_ckpt.items():\n                model_ckpt[name].copy_(checkpoint['diffusion'][name])\n            model.diffusion.load_state_dict(model_ckpt)\n            print(\"=> loaded checkpoint '{}'\".format(args.restore_from))\n            if args.set_start:\n                if \"optimizer_diff\" in list(checkpoint.keys()):\n                    optimizer_diff.load_state_dict(checkpoint['optimizer_diff'])\n        else:\n            print(\"=> no checkpoint found at '{}'\".format(args.restore_from))\n    else:\n        print(\"NO checkpoint found!\")\n\n    # enable the usage of multi-GPU\n    model = DataParallelWithCallback(model)\n\n    setup_seed(args.random_seed)\n    trainloader = DataLoader(HDTF(data_dir=data_dir,\n                                       pose_dir=pose_dir,\n                                       eye_blink_dir = eye_blink_dir,\n                                       image_size=INPUT_SIZE,\n                                       max_num_frames=MAX_N_FRAMES,\n                                       color_jitter=True,\n                                       mean=MEAN),\n                                  batch_size=args.batch_size,\n                                  shuffle=True, num_workers=args.num_workers,# args.num_workers,\n                                  pin_memory=True)\n\n    batch_time = AverageMeter()\n    data_time = AverageMeter()\n\n    losses = AverageMeter()\n    losses_rec = AverageMeter()\n    losses_warp = AverageMeter()\n    losses_vgg = AverageMeter()\n\n    cnt = 0\n    actual_step = args.start_step\n    start_epoch = int(math.ceil((args.start_step * args.batch_size) / NUM_EXAMPLES_PER_EPOCH))\n    epoch_cnt = start_epoch\n\n    if(not args.cosin):\n        scheduler = MultiStepLR(optimizer_diff, epoch_milestones, gamma=0.1, last_epoch=start_epoch - 1)\n    else:\n        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_diff, T_max=MAX_EPOCH, eta_min=1e-6)\n    print(\"epoch %d, lr= %.7f\" % (epoch_cnt, optimizer_diff.param_groups[0][\"lr\"]))\n\n\n    # start = time.time()\n    torch.inverse(torch.ones((1,1), device = \"cuda:0\"))\n    while actual_step < args.final_step:\n        iter_end = timeit.default_timer()\n        start_time = time.time()  # start\n        load_sum = 0\n        calculate_sum = 0\n\n        for i_iter, batch in enumerate(trainloader):\n            \n            if __debug__:\n                end_time = time.time()  # end\n                # print(f'load time {end_time- start_time}')\n                load_sum += end_time - start_time\n                if end_time - start_time > 1:\n                    print('unnormal load: \\t',i_iter)\n                start_time = end_time\n\n            actual_step = int(args.start_step + cnt)\n            data_time.update(timeit.default_timer() - iter_end)\n\n            real_vids, ref_hubert, real_poses, real_blink_bbox, mouth_lmk_tensor, real_names, _ = batch\n            # ref_hubert, real_poses, real_blink_bbox : b, c, fn\n            # use first frame of each video as reference frame\n            ref_id = 0 # random.randint(0, real_vids.shape[2] - 1)\n            ref_imgs = real_vids[:, :, ref_id, :, :].clone().detach()\n            bs = real_vids.size(0)\n            new_num_frames = real_vids.size(2)\n            model.module.update_num_frames(new_num_frames)\n\n            # end_time = time.time()  # end\n            # print(f'preprocess time {end_time- start_time}')\n            # start_time = end_time\n\n            # encode text\n            # cond = bert_embed(tokenize(ref_texts), return_cls_repr=model.module.diffusion.text_use_bert_cls).cuda()\n            is_eval = actual_step % args.save_vid_freq == 0 or actual_step % args.sample_vid_freq == 0\n            with torch.cuda.amp.autocast(enabled=args.fp16):\n                train_output_dict = model.forward(real_vid=real_vids, ref_img=ref_imgs, ref_text=ref_hubert, ref_pose=real_poses, ref_eye_blink = real_blink_bbox[:, :2], bbox=real_blink_bbox[:, 2:], mouth_lmk = mouth_lmk_tensor, is_eval = is_eval, ref_id = ref_id)\n\n            # optimize model\n            \n            optimizer_diff.zero_grad()\n            # if only_use_flow:\n            #     scaler.scale(train_output_dict[\"loss\"].mean()).backward()\n            # else:\n            #     scaler.scale((train_output_dict[\"loss\"].mean() + train_output_dict[\"rec_loss\"].mean() +\n            #      train_output_dict[\"rec_warp_loss\"].mean())).backward()\n    \n            scaler.scale(train_output_dict[\"loss\"].mean() + floss_weight * train_output_dict['floss'].mean()  + 0.15 * train_output_dict['mouth_loss'].mean()).backward()\n            # optimizer_diff.step()\n            scaler.unscale_(optimizer_diff)\n            # loss.backward()\n            if clip_c > 0.:\n                torch.nn.utils.clip_grad_norm_(model.parameters(), clip_c)\n            scaler.step(optimizer_diff)\n            scaler.update()\n\n            batch_time.update(timeit.default_timer() - iter_end)\n            iter_end = timeit.default_timer()\n\n            losses.update(train_output_dict[\"loss\"].mean().item(), bs)\n            losses_rec.update(train_output_dict[\"floss\"].mean().item(), bs)\n            losses_warp.update(train_output_dict[\"mouth_loss\"].mean().item(), bs)\n            # losses_vgg.update(train_output_dict[\"rec_vgg_loss\"].mean().item(), bs)\n\n            writer.add_scalar('train/loss', train_output_dict[\"loss\"].mean().item(),actual_step)\n            writer.add_scalar('train/floss', train_output_dict[\"floss\"].mean().item(),actual_step)\n            writer.add_scalar('train/mouth_loss', train_output_dict[\"mouth_loss\"].mean().item(),actual_step)\n            # writer.add_scalar('train/rec_loss', train_output_dict[\"rec_loss\"].mean().item(),actual_step)\n            # writer.add_scalar('train/rec_warp_loss', train_output_dict[\"rec_warp_loss\"].mean().item(),actual_step)\n            # writer.add_scalar('train/rec_vgg_loss', train_output_dict[\"rec_vgg_loss\"].mean().item(),actual_step)\n            if __debug__:\n                end_time = time.time()  # end\n                # print(f'forward time {end_time- start_time}')\n                calculate_sum += end_time - start_time\n                start_time = end_time\n\n            # if actual_step % 100 == 0:\n            #     end = time.time()\n            #     print(\"100 iter time:{0}\".format(end-start))\n            if actual_step % args.print_freq == 0:\n                current_time = datetime.now()\n                current_time_str = current_time.strftime(\"%Y-%m-%d %H:%M:%S\")\n                print(\"Current time is:\", current_time_str)\n                print('iter: [{0}]{1}/{2}\\t'\n                      'loss {loss.val:.7f} ({loss.avg:.7f})\\t'\n                      'loss_rec {loss_rec.val:.4f} ({loss_rec.avg:.4f})\\t'\n                      'loss_warp {loss_warp.val:.4f} ({loss_warp.avg:.4f})'\n                    .format(\n                    cnt, actual_step, args.final_step,\n                    batch_time=batch_time,\n                    data_time=data_time,\n                    loss=losses,\n                    loss_rec=losses_rec,\n                    loss_warp=losses_warp,\n                ))\n\n            null_cond_mask = np.array(train_output_dict[\"null_cond_mask\"].data.cpu().numpy(),\n                                      dtype=np.uint8)\n\n            if actual_step % args.save_vid_freq == 0: # and cnt != 0:\n                print(\"saving video...\")\n                num_frames = real_vids.size(2)\n                msk_size = ref_imgs.shape[-1]\n                new_im_arr_list = []\n                save_src_img = sample_img(ref_imgs/255.)\n                for nf in range(num_frames):\n                    save_tar_img = sample_img(real_vids[:, :, nf, :, :]/255.) # adapt fast version\n                    save_real_out_img = sample_img(train_output_dict[\"real_out_vid\"][:, :, nf, :, :])\n                    save_real_warp_img = sample_img(train_output_dict[\"real_warped_vid\"][:, :, nf, :, :])\n                    save_fake_out_img = sample_img(train_output_dict[\"fake_out_vid\"][:, :, nf, :, :])\n                    save_fake_warp_img = sample_img(train_output_dict[\"fake_warped_vid\"][:, :, nf, :, :])\n                    save_real_grid = grid2fig(\n                        train_output_dict[\"real_vid_grid\"][0, :, nf].permute((1, 2, 0)).data.cpu().numpy(),\n                        grid_size=32, img_size=msk_size)\n                    save_fake_grid = grid2fig(\n                        train_output_dict[\"fake_vid_grid\"][0, :, nf].permute((1, 2, 0)).data.cpu().numpy(),\n                        grid_size=32, img_size=msk_size)\n                    save_real_conf = conf2fig(train_output_dict[\"real_vid_conf\"][0, :, nf])\n                    save_fake_conf = conf2fig(train_output_dict[\"fake_vid_conf\"][0, :, nf])\n                    new_im = Image.new('RGB', (msk_size * 5, msk_size * 2))\n                    new_im.paste(Image.fromarray(save_src_img, 'RGB'), (0, 0))\n                    new_im.paste(Image.fromarray(save_tar_img, 'RGB'), (0, msk_size))\n                    new_im.paste(Image.fromarray(save_real_out_img, 'RGB'), (msk_size, 0))\n                    new_im.paste(Image.fromarray(save_real_warp_img, 'RGB'), (msk_size, msk_size))\n                    new_im.paste(Image.fromarray(save_fake_out_img, 'RGB'), (msk_size * 2, 0))\n                    new_im.paste(Image.fromarray(save_fake_warp_img, 'RGB'), (msk_size * 2, msk_size))\n                    new_im.paste(Image.fromarray(save_real_grid, 'RGB'), (msk_size * 3, 0))\n                    new_im.paste(Image.fromarray(save_fake_grid, 'RGB'), (msk_size * 3, msk_size))\n                    new_im.paste(Image.fromarray(save_real_conf, 'L'), (msk_size * 4, 0))\n                    new_im.paste(Image.fromarray(save_fake_conf, 'L'), (msk_size * 4, msk_size))\n                    new_im_arr = np.array(new_im)\n                    new_im_arr_list.append(new_im_arr)\n                new_vid_name = 'B' + format(args.batch_size, \"04d\") + '_S' + format(actual_step, \"06d\") \\\n                               + '_' + real_names[0] + \"_%d.gif\" % (null_cond_mask[0][0])\n                new_vid_file = os.path.join(VIDSHOT_DIR, new_vid_name)\n                imageio.mimsave(new_vid_file, new_im_arr_list)\n                new_im_arr_list = None\n                new_im_arr = None\n                new_im = None\n                del new_im_arr_list, new_im_arr, new_im\n\n            # sampling\n            if actual_step % args.sample_vid_freq == 0:\n                print(\"sampling video...\")\n                with torch.no_grad():\n                    # cond = torch.concat([ref_hubert[0].unsqueeze(dim=0), real_poses[0].permute(1,0).unsqueeze(0), real_blink_bbox[0][:2].permute(1,0).unsqueeze(0)], dim=-1).cuda()\n                    sample_output_dict = model.module.sample_one_video(real_vid=real_vids.cuda()/255.,\n                                                                    sample_img=ref_imgs[0].unsqueeze(dim=0).cuda()/255.,\n                                                                    sample_audio_hubert = ref_hubert[0].unsqueeze(dim=0).cuda(),\n                                                                    sample_pose = real_poses[0].unsqueeze(0).cuda(),\n                                                                    sample_eye =  real_blink_bbox[0][:2].unsqueeze(0).cuda(),\n                                                                    sample_bbox = real_blink_bbox[0,2:].unsqueeze(0).cuda(),\n                                                                    cond_scale=1.0)\n                num_frames = real_vids.size(2)\n                msk_size = ref_imgs.shape[-1]\n                new_im_arr_list = []\n                save_src_img = sample_img(ref_imgs/255.)\n                for nf in range(num_frames):\n                    save_tar_img = sample_img(real_vids[:, :, nf, :, :]/255.)\n                    save_real_out_img = sample_img(train_output_dict[\"real_out_vid\"][:, :, nf, :, :])\n                    save_real_warp_img = sample_img(train_output_dict[\"real_warped_vid\"][:, :, nf, :, :])\n                    save_sample_out_img = sample_img(sample_output_dict[\"sample_out_vid\"][:, :, nf, :, :])\n                    save_sample_warp_img = sample_img(sample_output_dict[\"sample_warped_vid\"][:, :, nf, :, :])\n                    save_real_grid = grid2fig(\n                        train_output_dict[\"real_vid_grid\"][0, :, nf].permute((1, 2, 0)).data.cpu().numpy(),\n                        grid_size=32, img_size=msk_size)\n                    save_fake_grid = grid2fig(\n                        sample_output_dict[\"sample_vid_grid\"][0, :, nf].permute((1, 2, 0)).data.cpu().numpy(),\n                        grid_size=32, img_size=msk_size)\n                    save_real_conf = conf2fig(train_output_dict[\"real_vid_conf\"][0, :, nf])\n                    save_fake_conf = conf2fig(sample_output_dict[\"sample_vid_conf\"][0, :, nf])\n                    new_im = Image.new('RGB', (msk_size * 5, msk_size * 2))\n                    new_im.paste(Image.fromarray(save_src_img, 'RGB'), (0, 0))\n                    new_im.paste(Image.fromarray(save_tar_img, 'RGB'), (0, msk_size))\n                    new_im.paste(Image.fromarray(save_real_out_img, 'RGB'), (msk_size, 0))\n                    new_im.paste(Image.fromarray(save_real_warp_img, 'RGB'), (msk_size, msk_size))\n                    new_im.paste(Image.fromarray(save_sample_out_img, 'RGB'), (msk_size * 2, 0))\n                    new_im.paste(Image.fromarray(save_sample_warp_img, 'RGB'), (msk_size * 2, msk_size))\n                    new_im.paste(Image.fromarray(save_real_grid, 'RGB'), (msk_size * 3, 0))\n                    new_im.paste(Image.fromarray(save_fake_grid, 'RGB'), (msk_size * 3, msk_size))\n                    new_im.paste(Image.fromarray(save_real_conf, 'L'), (msk_size * 4, 0))\n                    new_im.paste(Image.fromarray(save_fake_conf, 'L'), (msk_size * 4, msk_size))\n                    new_im_arr = np.array(new_im)\n                    new_im_arr_list.append(new_im_arr)\n                new_vid_name = 'B' + format(args.batch_size, \"04d\") + '_S' + format(actual_step, \"06d\") \\\n                               + '_' + real_names[0] + \".gif\"\n                new_vid_file = os.path.join(SAMPLE_DIR, new_vid_name)\n                imageio.mimsave(new_vid_file, new_im_arr_list)\n                new_im_arr_list = None\n                new_im_arr = None\n                new_im = None\n                del new_im_arr_list, new_im_arr, new_im\n\n            # save model at i-th step\n            if actual_step % args.save_pred_every == 0 and cnt != 0:\n                print('taking snapshot ...')\n                torch.save({'example': actual_step * args.batch_size,\n                            'diffusion': model.module.diffusion.state_dict(),\n                            'optimizer_diff': optimizer_diff.state_dict()},\n                           osp.join(args.snapshot_dir,\n                                    'flowdiff_' + format(args.batch_size, \"04d\") + '_S' + format(actual_step,\n                                                                                                 \"06d\") + '.pth'))\n\n            # update saved model\n            if actual_step % args.update_pred_every == 0 and cnt != 0:\n                print('updating saved snapshot ...')\n                torch.save({'example': actual_step * args.batch_size,\n                            'diffusion': model.module.diffusion.state_dict(),\n                            'optimizer_diff': optimizer_diff.state_dict()},\n                           osp.join(args.snapshot_dir, 'flowdiff.pth'))\n\n            if actual_step >= args.final_step:\n                break\n\n            cnt += 1\n\n            # if __debug__:\n            #     end_time = time.time()  # end\n            #     print(f'orther time 1: {end_time- start_time}')\n            #     start_time = end_time\n\n            del real_vids, ref_imgs, ref_hubert, real_names, null_cond_mask\n            del train_output_dict\n            del batch\n            # torch.cuda.empty_cache()\n\n            # if __debug__:\n            #     end_time = time.time()  # end\n            #     print(f'orther time time {end_time- start_time}')\n            #     start_time = end_time\n\n        scheduler.step()\n        epoch_cnt += 1\n        print(\"epoch %d, lr= %.7f\" % (epoch_cnt, optimizer_diff.param_groups[0][\"lr\"]))\n\n        if __debug__:\n            print('load_sum: ', load_sum)\n            print('calculate_sum: ', calculate_sum)\n\n    print('save the final model ...')\n    torch.save({'example': actual_step * args.batch_size,\n                'diffusion': model.module.diffusion.state_dict(),\n                'optimizer_diff': optimizer_diff.state_dict()},\n               osp.join(args.snapshot_dir,\n                        'flowdiff_' + format(args.batch_size, \"04d\") + '_S' + format(actual_step, \"06d\") + '.pth'))\n    end = timeit.default_timer()\n    print(end - start, 'seconds')\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n\ndef setup_seed(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    np.random.seed(seed)\n    random.seed(seed)\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    # mp.set_start_method('spawn')\n    # torch.multiprocessing.set_start_method(\"spawn\")\n    main()\n"
  },
  {
    "path": "DM_3/train_vdm_hdtf_wpose_plus_faceemb_init_cond_liploss_6D_s2.py",
    "content": "import sys\nsys.path.append('your/path/')\n\nimport argparse\nfrom datetime import datetime, time\n\nimport imageio\nimport torch\nfrom torch.utils import data\nimport numpy as np\nimport torch.backends.cudnn as cudnn\nimport os\nimport os.path as osp\nimport timeit\nimport math\nfrom PIL import Image\nfrom misc import Logger, grid2fig, conf2fig\nfrom DM_3.datasets_hdtf_wpose_lmk_block_lmk_rand import HDTF \nimport sys\nimport random\nfrom torch.utils.tensorboard import SummaryWriter \nfrom DM_3.utils import MultiEpochsDataLoader as DataLoader\n\nimport time\n\nfrom DM_3.modules.video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_mouth_mask_rand_6D import FlowDiffusion\nfrom torch.optim.lr_scheduler import MultiStepLR\nfrom sync_batchnorm import DataParallelWithCallback\nimport torch.multiprocessing as mp\n\n\nstart = timeit.default_timer()\nBATCH_SIZE = 40\n# crema settings\n# MAX_EPOCH = 300\n# epoch_milestones = [800, 1000]\n# hdtf\nMAX_EPOCH = 500 * 25\nepoch_milestones = [800000, 10000000]\nroot_dir = 'your/path'\ndata_dir = \"your/image/path\"\npose_dir = \"your/pose/path\"\neye_blink_dir = \"your/blink/path\"\nGPU = \"2\"\npostfix = \"-j-of-s2\"\njoint = \"joint\" in postfix or \"-j\" in postfix  # allow joint training with unconditional model\nonly_use_flow = \"onlyflow\" in postfix or \"-of\" in postfix  # whether only use flow loss\nDYNAMIC_FRAMES = '-s2' in postfix\nvgg_weight = 0\nfloss_weight = 0.15\nif joint:\n    null_cond_prob = 0.1\nelse:\n    null_cond_prob = 0.0\nif \"upconv\" in postfix:\n    use_deconv = False\n    padding_mode = \"reflect\"\nelse:\n    use_deconv = True\n    padding_mode = \"zeros\"\nuse_residual_flow = \"-rf\" in postfix\nlearn_null_cond = \"-lnc\" in postfix\nINPUT_SIZE = 128\nif DYNAMIC_FRAMES:\n    MAX_N_FRAMES = 40\n    MIN_N_FRAMES = 30\nelse:\n    MAX_N_FRAMES = 20  \nLEARNING_RATE = 2e-4\nRANDOM_SEED = 1234\nclip_c = 2.\nprint('use grad clip, clip = ', clip_c)\nMEAN = (0.0, 0.0, 0.0)\nconfig_pth = \"./config/hdtf128.yaml\"\n# PATH of LFG checkpoint\nAE_RESTORE_FROM = 'LFG/path'   \nRESTORE_FROM = '' # use existing checkpoint\nDM_LOG_PATH = os.path.join(root_dir,'data','HDTF_wpose_faceemb_newae_6Dpose', 'ca_init_cond_liploss','fromstart_rand_df_liploss_lr_N='+str(MAX_N_FRAMES))\nprint(DM_LOG_PATH)\nSNAPSHOT_DIR = os.path.join(DM_LOG_PATH, 'snapshots' + postfix)\nIMGSHOT_DIR = os.path.join(DM_LOG_PATH, 'imgshots' + postfix)\nVIDSHOT_DIR = os.path.join(DM_LOG_PATH, \"vidshots\" + postfix)\nSAMPLE_DIR = os.path.join(DM_LOG_PATH, 'sample' + postfix)\nNUM_EXAMPLES_PER_EPOCH = 400\nNUM_STEPS_PER_EPOCH = math.ceil(NUM_EXAMPLES_PER_EPOCH / float(BATCH_SIZE))\nMAX_ITER = max(NUM_EXAMPLES_PER_EPOCH * MAX_EPOCH + 1,\n               NUM_STEPS_PER_EPOCH * BATCH_SIZE * MAX_EPOCH + 1)\nSAVE_MODEL_EVERY = int(100000/10)\nSAVE_VID_EVERY = 4000\nSAMPLE_VID_EVERY = 2000\nUPDATE_MODEL_EVERY = 500\n\nos.makedirs(SNAPSHOT_DIR, exist_ok=True)\nos.makedirs(IMGSHOT_DIR, exist_ok=True)\nos.makedirs(VIDSHOT_DIR, exist_ok=True)\nos.makedirs(SAMPLE_DIR, exist_ok=True)\n\nLOG_PATH = SNAPSHOT_DIR + \"/B\" + format(BATCH_SIZE, \"04d\") + \"E\" + format(MAX_EPOCH, \"04d\") + \".log\"\nsys.stdout = Logger(LOG_PATH, sys.stdout)\nprint(root_dir)\nprint(\"update saved model every:\", UPDATE_MODEL_EVERY)\nprint(\"save model every:\", SAVE_MODEL_EVERY)\nprint(\"save video every:\", SAVE_VID_EVERY)\nprint(\"sample video every:\", SAMPLE_VID_EVERY)\nprint(postfix)\nprint(\"RESTORE_FROM\", RESTORE_FROM)\nprint(\"num examples per epoch:\", NUM_EXAMPLES_PER_EPOCH)\nprint(\"max epoch:\", MAX_EPOCH)\nprint(\"image size\", INPUT_SIZE)\nprint(\"epoch milestones:\", epoch_milestones)\nprint(\"only use flow loss:\", only_use_flow)\nprint(\"null_cond_prob:\", null_cond_prob)\nprint(\"use residual flow:\", use_residual_flow)\nprint(\"learn null cond:\", learn_null_cond)\nprint(\"use deconv:\", use_deconv)\n\n\ndef get_arguments():\n    \"\"\"Parse all the arguments provided from the CLI.\n\n    Returns:\n      A list of parsed arguments.\n    \"\"\"\n    parser = argparse.ArgumentParser(description=\"Flow Diffusion\")\n    parser.add_argument(\"--fine-tune\", default=False)\n    parser.add_argument(\"--set-start\", default=True)\n    parser.add_argument(\"--start-step\", default=0, type=int)\n    parser.add_argument(\"--img-dir\", type=str, default=IMGSHOT_DIR,\n                        help=\"Where to save images of the model.\")\n    parser.add_argument(\"--num-workers\", default=8)\n    parser.add_argument(\"--final-step\", type=int, default=int(NUM_STEPS_PER_EPOCH * MAX_EPOCH),\n                        help=\"Number of training steps.\")\n    parser.add_argument(\"--gpu\", default=GPU,\n                        help=\"choose gpu device.\")\n    parser.add_argument('--print-freq', '-p', default=2, type=int,\n                        metavar='N', help='print frequency')\n    parser.add_argument('--save-img-freq', default=2000, type=int,\n                        metavar='N', help='save image frequency')\n    parser.add_argument('--save-vid-freq', default=SAVE_VID_EVERY, type=int)\n    parser.add_argument('--sample-vid-freq', default=SAMPLE_VID_EVERY, type=int)\n    parser.add_argument(\"--batch-size\", type=int, default=BATCH_SIZE,\n                        help=\"Number of images sent to the network in one step.\")\n    parser.add_argument(\"--input-size\", type=str, default=INPUT_SIZE,\n                        help=\"Comma-separated string with height and width of images.\")\n    parser.add_argument(\"--learning-rate\", type=float, default=LEARNING_RATE,\n                        help=\"Base learning rate for training with polynomial decay.\")\n    parser.add_argument(\"--random-seed\", type=int, default=RANDOM_SEED,\n                        help=\"Random seed to have reproducible results.\")\n    parser.add_argument(\"--restore-from\", default=RESTORE_FROM)\n    parser.add_argument(\"--save-pred-every\", type=int, default=SAVE_MODEL_EVERY,\n                        help=\"Save checkpoint every often.\")\n    parser.add_argument(\"--update-pred-every\", type=int, default=UPDATE_MODEL_EVERY)\n    parser.add_argument(\"--snapshot-dir\", type=str, default=SNAPSHOT_DIR,\n                        help=\"Where to save snapshots of the model.\")\n    parser.add_argument(\"--fp16\", default=True)\n    parser.add_argument(\"--cosin\", default=True)\n    return parser.parse_args()\n\n\nargs = get_arguments()\n\n\ndef sample_img(rec_img_batch, idx=0):\n    rec_img = rec_img_batch[idx].permute(1, 2, 0).data.cpu().numpy().copy()\n    rec_img += np.array(MEAN) / 255.0\n    rec_img[rec_img < 0] = 0\n    rec_img[rec_img > 1] = 1\n    rec_img *= 255\n    return np.array(rec_img, np.uint8)\n\n\ndef main():\n    \"\"\"Create the model and start the training.\"\"\"\n\n    os.environ[\"CUDA_VISIBLE_DEVICES\"] = args.gpu\n\n    cudnn.enabled = True\n    cudnn.benchmark = True\n    setup_seed(args.random_seed)\n    writer = SummaryWriter(os.path.join(DM_LOG_PATH, 'tensorboard'))\n\n\n    model = FlowDiffusion(is_train=True,\n                          img_size=INPUT_SIZE // 4,\n                          num_frames=MAX_N_FRAMES,\n                          null_cond_prob=null_cond_prob,\n                          sampling_timesteps=20,\n                          use_residual_flow=use_residual_flow,\n                          learn_null_cond=learn_null_cond,\n                          use_deconv=use_deconv,\n                          padding_mode=padding_mode,\n                          config_pth=config_pth,\n                          pretrained_pth=AE_RESTORE_FROM)\n    model.cuda()\n    scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)\n\n    # Not set model to be train mode! Because pretrained flow autoenc need to be eval (BatchNorm)\n\n    # create optimizer\n    optimizer_diff = torch.optim.Adam(model.diffusion.parameters(),\n                                      lr=LEARNING_RATE, betas=(0.9, 0.99))\n\n    if args.fine_tune:\n        pass\n    elif args.restore_from:\n        if os.path.isfile(args.restore_from):\n            print(\"=> loading checkpoint '{}'\".format(args.restore_from))\n            checkpoint = torch.load(args.restore_from)\n            if args.set_start:\n                args.start_step = int(math.ceil(checkpoint['example'] / args.batch_size))\n            model_ckpt = model.diffusion.state_dict()\n            for name, _ in model_ckpt.items():\n                model_ckpt[name].copy_(checkpoint['diffusion'][name])\n            model.diffusion.load_state_dict(model_ckpt)\n            print(\"=> loaded checkpoint '{}'\".format(args.restore_from))\n            if args.set_start:\n                if \"optimizer_diff\" in list(checkpoint.keys()):\n                    optimizer_diff.load_state_dict(checkpoint['optimizer_diff'])\n        else:\n            print(\"=> no checkpoint found at '{}'\".format(args.restore_from))\n    else:\n        print(\"NO checkpoint found!\")\n\n    # enable the usage of multi-GPU\n    model = DataParallelWithCallback(model)\n\n    setup_seed(args.random_seed)\n    trainloader = DataLoader(HDTF(data_dir=data_dir,\n                                       pose_dir=pose_dir,\n                                       eye_blink_dir = eye_blink_dir,\n                                       image_size=INPUT_SIZE,\n                                       max_num_frames=MAX_N_FRAMES,\n                                       color_jitter=True,\n                                       mean=MEAN),\n                                  batch_size=args.batch_size,\n                                  shuffle=True, num_workers=args.num_workers,# args.num_workers,\n                                  pin_memory=True)\n\n    batch_time = AverageMeter()\n    data_time = AverageMeter()\n\n    losses = AverageMeter()\n    losses_rec = AverageMeter()\n    losses_warp = AverageMeter()\n    losses_vgg = AverageMeter()\n\n    cnt = 0\n    actual_step = args.start_step\n    start_epoch = int(math.ceil((args.start_step * args.batch_size) / NUM_EXAMPLES_PER_EPOCH))\n    epoch_cnt = start_epoch\n\n    if(not args.cosin):\n        scheduler = MultiStepLR(optimizer_diff, epoch_milestones, gamma=0.1, last_epoch=start_epoch - 1)\n    else:\n        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_diff, T_max=MAX_EPOCH, eta_min=1e-6)\n    print(\"epoch %d, lr= %.7f\" % (epoch_cnt, optimizer_diff.param_groups[0][\"lr\"]))\n\n\n    # start = time.time()\n    torch.inverse(torch.ones((1,1), device = \"cuda:0\"))\n    while actual_step < args.final_step:\n        iter_end = timeit.default_timer()\n        start_time = time.time()  \n        load_sum = 0\n        calculate_sum = 0\n\n        for i_iter, batch in enumerate(trainloader):\n            \n            if __debug__:\n                end_time = time.time()  # end\n                # print(f'load time {end_time- start_time}')\n                load_sum += end_time - start_time\n                if end_time - start_time > 1:\n                    print('unnormal load: \\t',i_iter)\n                start_time = end_time\n\n            actual_step = int(args.start_step + cnt)\n            data_time.update(timeit.default_timer() - iter_end)\n\n            real_vids, ref_hubert, real_poses, real_blink_bbox, mouth_lmk_tensor, real_names, _ = batch\n            \n            if(DYNAMIC_FRAMES == True):\n                selct_frames = random.randint(MIN_N_FRAMES, MAX_N_FRAMES) + 1\n                selct_start = 0\n\n                real_vids = real_vids[:,:,selct_start:selct_start+selct_frames,:,:]\n                ref_hubert = ref_hubert[:,selct_start:selct_start+selct_frames,:]\n                real_poses = real_poses[:,:,selct_start:selct_start+selct_frames]\n                mouth_lmk_tensor = mouth_lmk_tensor[:,selct_start:selct_start+selct_frames -1]\n                real_blink_bbox = real_blink_bbox[:,:,selct_start:selct_start+selct_frames]\n            \n            \n            # ref_hubert, real_poses, real_blink_bbox : b, c, fn\n            # use first frame of each video as reference frame\n            ref_id = 0# random.randint(0, real_vids.shape[2] - 1)\n            ref_imgs = real_vids[:, :, ref_id, :, :].clone().detach()\n            bs = real_vids.size(0)\n            new_num_frames = real_vids.size(2) -1\n            model.module.update_num_frames(new_num_frames)\n\n            # end_time = time.time()  # end\n            # print(f'preprocess time {end_time- start_time}')\n            # start_time = end_time\n\n            # encode text\n            # cond = bert_embed(tokenize(ref_texts), return_cls_repr=model.module.diffusion.text_use_bert_cls).cuda()\n            is_eval = actual_step % args.save_vid_freq == 0 or actual_step % args.sample_vid_freq == 0\n            with torch.cuda.amp.autocast(enabled=args.fp16):\n                train_output_dict = model.forward(real_vid=real_vids, ref_img=ref_imgs, ref_text=ref_hubert, ref_pose=real_poses, ref_eye_blink = real_blink_bbox[:, :2], bbox=real_blink_bbox[:, 2:], mouth_lmk = mouth_lmk_tensor, is_eval = is_eval)\n            \n            # optimize model\n            \n            optimizer_diff.zero_grad()\n            # if only_use_flow:\n            #     scaler.scale(train_output_dict[\"loss\"].mean()).backward()\n            # else:\n            #     scaler.scale((train_output_dict[\"loss\"].mean() + train_output_dict[\"rec_loss\"].mean() +\n            #      train_output_dict[\"rec_warp_loss\"].mean())).backward()\n    \n            scaler.scale(train_output_dict[\"loss\"].mean() + floss_weight * train_output_dict['floss'].mean()  + 0.15 * train_output_dict['mouth_loss'].mean()).backward()\n            # optimizer_diff.step()\n            scaler.unscale_(optimizer_diff)\n            # loss.backward()\n            if clip_c > 0.:\n                torch.nn.utils.clip_grad_norm_(model.parameters(), clip_c)\n            \n            has_nan_grad = False\n\n            for name, param in model.named_parameters():\n                if param.grad != None and torch.isnan(param.grad).any():\n                    has_nan_grad = True\n                    print(name)\n                    break\n\n            if has_nan_grad:\n                print(\"grad NaN, dont update param\")\n                scaler.update()\n                continue\n            \n            scaler.step(optimizer_diff)\n            scaler.update()\n\n            batch_time.update(timeit.default_timer() - iter_end)\n            iter_end = timeit.default_timer()\n\n            losses.update(train_output_dict[\"loss\"].mean().item(), bs)\n            losses_rec.update(train_output_dict[\"floss\"].mean().item(), bs)\n            losses_warp.update(train_output_dict[\"mouth_loss\"].mean().item(), bs)\n            # losses_vgg.update(train_output_dict[\"rec_vgg_loss\"].mean().item(), bs)\n\n            writer.add_scalar('train/loss', train_output_dict[\"loss\"].mean().item(),actual_step)\n            writer.add_scalar('train/floss', train_output_dict[\"floss\"].mean().item(),actual_step)\n            writer.add_scalar('train/mouth_loss', train_output_dict[\"mouth_loss\"].mean().item(),actual_step)\n            # writer.add_scalar('train/rec_loss', train_output_dict[\"rec_loss\"].mean().item(),actual_step)\n            # writer.add_scalar('train/rec_warp_loss', train_output_dict[\"rec_warp_loss\"].mean().item(),actual_step)\n            # writer.add_scalar('train/rec_vgg_loss', train_output_dict[\"rec_vgg_loss\"].mean().item(),actual_step)\n            if __debug__:\n                end_time = time.time()  # end\n                # print(f'forward time {end_time- start_time}')\n                calculate_sum += end_time - start_time\n                start_time = end_time\n\n            # if actual_step % 100 == 0:\n            #     end = time.time()\n            #     print(\"100 iter time:{0}\".format(end-start))\n            if actual_step % args.print_freq == 0:\n                current_time = datetime.now()\n                current_time_str = current_time.strftime(\"%Y-%m-%d %H:%M:%S\")\n                print(\"Current time is:\", current_time_str)\n                print('iter: [{0}]{1}/{2}\\t'\n                      'loss {loss.val:.7f} ({loss.avg:.7f})\\t'\n                      'loss_mse {loss_rec.val:.4f} ({loss_rec.avg:.4f})\\t'\n                      'loss_lip {loss_warp.val:.4f} ({loss_warp.avg:.4f})'\n                    .format(\n                    cnt, actual_step, args.final_step,\n                    batch_time=batch_time,\n                    data_time=data_time,\n                    loss=losses,\n                    loss_rec=losses_rec,\n                    loss_warp=losses_warp,\n                ))\n\n            null_cond_mask = np.array(train_output_dict[\"null_cond_mask\"].data.cpu().numpy(),\n                                      dtype=np.uint8)\n\n            if actual_step % args.save_vid_freq == 0: # and cnt != 0:\n                torch.cuda.empty_cache()\n                print(\"saving video...\")\n                num_frames = real_vids.size(2) - 1\n                msk_size = ref_imgs.shape[-1]\n                new_im_arr_list = []\n                save_src_img = sample_img(ref_imgs/255.)\n                for nf in range(num_frames):\n                    save_tar_img = sample_img(real_vids[:, :, nf, :, :]/255.) # adapt fast version\n                    save_real_out_img = sample_img(train_output_dict[\"real_out_vid\"][:, :, nf, :, :])\n                    save_real_warp_img = sample_img(train_output_dict[\"real_warped_vid\"][:, :, nf, :, :])\n                    save_fake_out_img = sample_img(train_output_dict[\"fake_out_vid\"][:, :, nf, :, :])\n                    save_fake_warp_img = sample_img(train_output_dict[\"fake_warped_vid\"][:, :, nf, :, :])\n                    save_real_grid = grid2fig(\n                        train_output_dict[\"real_vid_grid\"][0, :, nf].permute((1, 2, 0)).data.cpu().numpy(),\n                        grid_size=32, img_size=msk_size)\n                    save_fake_grid = grid2fig(\n                        train_output_dict[\"fake_vid_grid\"][0, :, nf].permute((1, 2, 0)).data.cpu().numpy(),\n                        grid_size=32, img_size=msk_size)\n                    save_real_conf = conf2fig(train_output_dict[\"real_vid_conf\"][0, :, nf])\n                    save_fake_conf = conf2fig(train_output_dict[\"fake_vid_conf\"][0, :, nf])\n                    new_im = Image.new('RGB', (msk_size * 5, msk_size * 2))\n                    new_im.paste(Image.fromarray(save_src_img, 'RGB'), (0, 0))\n                    new_im.paste(Image.fromarray(save_tar_img, 'RGB'), (0, msk_size))\n                    new_im.paste(Image.fromarray(save_real_out_img, 'RGB'), (msk_size, 0))\n                    new_im.paste(Image.fromarray(save_real_warp_img, 'RGB'), (msk_size, msk_size))\n                    new_im.paste(Image.fromarray(save_fake_out_img, 'RGB'), (msk_size * 2, 0))\n                    new_im.paste(Image.fromarray(save_fake_warp_img, 'RGB'), (msk_size * 2, msk_size))\n                    new_im.paste(Image.fromarray(save_real_grid, 'RGB'), (msk_size * 3, 0))\n                    new_im.paste(Image.fromarray(save_fake_grid, 'RGB'), (msk_size * 3, msk_size))\n                    new_im.paste(Image.fromarray(save_real_conf, 'L'), (msk_size * 4, 0))\n                    new_im.paste(Image.fromarray(save_fake_conf, 'L'), (msk_size * 4, msk_size))\n                    new_im_arr = np.array(new_im)\n                    new_im_arr_list.append(new_im_arr)\n                new_vid_name = 'B' + format(args.batch_size, \"04d\") + '_S' + format(actual_step, \"06d\") \\\n                               + '_' + real_names[0] + \"_%d.gif\" % (null_cond_mask[0][0])\n                new_vid_file = os.path.join(VIDSHOT_DIR, new_vid_name)\n                imageio.mimsave(new_vid_file, new_im_arr_list)\n                new_im_arr_list = None\n                new_im_arr = None\n                new_im = None\n                del new_im_arr_list, new_im_arr, new_im\n\n            # sampling\n            if actual_step % args.sample_vid_freq == 0:\n                torch.cuda.empty_cache()\n                real_vids = real_vids[:,:,1:,:,:]\n                print(\"sampling video...\")\n                with torch.no_grad():\n                    # cond = torch.concat([ref_hubert[0].unsqueeze(dim=0), real_poses[0].permute(1,0).unsqueeze(0), real_blink_bbox[0][:2].permute(1,0).unsqueeze(0)], dim=-1).cuda()\n                    sample_output_dict = model.module.sample_one_video(real_vid=real_vids.cuda()/255.,\n                                                                    sample_img=ref_imgs[0].unsqueeze(dim=0).cuda()/255.,\n                                                                    sample_audio_hubert = ref_hubert[0].unsqueeze(dim=0).cuda(),\n                                                                    sample_pose = real_poses[0].unsqueeze(0).cuda(),\n                                                                    sample_eye =  real_blink_bbox[0][:2].unsqueeze(0).cuda(),\n                                                                    sample_bbox = real_blink_bbox[0,2:].unsqueeze(0).cuda(),\n                                                                    cond_scale=1.0)\n                num_frames = real_vids.size(2)\n                msk_size = ref_imgs.shape[-1]\n                new_im_arr_list = []\n                save_src_img = sample_img(ref_imgs/255.)\n                for nf in range(num_frames):\n                    save_tar_img = sample_img(real_vids[:, :, nf, :, :]/255.)\n                    save_real_out_img = sample_img(train_output_dict[\"real_out_vid\"][:, :, nf, :, :])\n                    save_real_warp_img = sample_img(train_output_dict[\"real_warped_vid\"][:, :, nf, :, :])\n                    save_sample_out_img = sample_img(sample_output_dict[\"sample_out_vid\"][:, :, nf, :, :])\n                    save_sample_warp_img = sample_img(sample_output_dict[\"sample_warped_vid\"][:, :, nf, :, :])\n                    save_real_grid = grid2fig(\n                        train_output_dict[\"real_vid_grid\"][0, :, nf].permute((1, 2, 0)).data.cpu().numpy(),\n                        grid_size=32, img_size=msk_size)\n                    save_fake_grid = grid2fig(\n                        sample_output_dict[\"sample_vid_grid\"][0, :, nf].permute((1, 2, 0)).data.cpu().numpy(),\n                        grid_size=32, img_size=msk_size)\n                    save_real_conf = conf2fig(train_output_dict[\"real_vid_conf\"][0, :, nf])\n                    save_fake_conf = conf2fig(sample_output_dict[\"sample_vid_conf\"][0, :, nf])\n                    new_im = Image.new('RGB', (msk_size * 5, msk_size * 2))\n                    new_im.paste(Image.fromarray(save_src_img, 'RGB'), (0, 0))\n                    new_im.paste(Image.fromarray(save_tar_img, 'RGB'), (0, msk_size))\n                    new_im.paste(Image.fromarray(save_real_out_img, 'RGB'), (msk_size, 0))\n                    new_im.paste(Image.fromarray(save_real_warp_img, 'RGB'), (msk_size, msk_size))\n                    new_im.paste(Image.fromarray(save_sample_out_img, 'RGB'), (msk_size * 2, 0))\n                    new_im.paste(Image.fromarray(save_sample_warp_img, 'RGB'), (msk_size * 2, msk_size))\n                    new_im.paste(Image.fromarray(save_real_grid, 'RGB'), (msk_size * 3, 0))\n                    new_im.paste(Image.fromarray(save_fake_grid, 'RGB'), (msk_size * 3, msk_size))\n                    new_im.paste(Image.fromarray(save_real_conf, 'L'), (msk_size * 4, 0))\n                    new_im.paste(Image.fromarray(save_fake_conf, 'L'), (msk_size * 4, msk_size))\n                    new_im_arr = np.array(new_im)\n                    new_im_arr_list.append(new_im_arr)\n                new_vid_name = 'B' + format(args.batch_size, \"04d\") + '_S' + format(actual_step, \"06d\") \\\n                               + '_' + real_names[0] + \".gif\"\n                new_vid_file = os.path.join(SAMPLE_DIR, new_vid_name)\n                imageio.mimsave(new_vid_file, new_im_arr_list)\n                new_im_arr_list = None\n                new_im_arr = None\n                new_im = None\n                del new_im_arr_list, new_im_arr, new_im\n\n            # save model at i-th step\n            if actual_step % args.save_pred_every == 0 and cnt != 0 and actual_step > args.final_step // 2:\n                print('taking snapshot ...')\n                torch.save({'example': actual_step * args.batch_size,\n                            'diffusion': model.module.diffusion.state_dict(),\n                            'optimizer_diff': optimizer_diff.state_dict()},\n                           osp.join(args.snapshot_dir,\n                                    'flowdiff_' + format(args.batch_size, \"04d\") + '_S' + format(actual_step,\n                                                                                                 \"06d\") + '.pth'))\n\n            # update saved model\n            if actual_step % args.update_pred_every == 0 and cnt != 0:\n                print('updating saved snapshot ...')\n                torch.save({'example': actual_step * args.batch_size,\n                            'diffusion': model.module.diffusion.state_dict(),\n                            'optimizer_diff': optimizer_diff.state_dict()},\n                           osp.join(args.snapshot_dir, 'flowdiff.pth'))\n\n            if actual_step >= args.final_step:\n                break\n\n            cnt += 1\n\n\n            del real_vids, ref_imgs, ref_hubert, real_names, null_cond_mask\n            del train_output_dict\n            del batch\n            # torch.cuda.empty_cache()\n\n\n        scheduler.step()\n        epoch_cnt += 1\n        print(\"epoch %d, lr= %.7f\" % (epoch_cnt, optimizer_diff.param_groups[0][\"lr\"]))\n\n        if __debug__:\n            print('load_sum: ', load_sum)\n            print('calculate_sum: ', calculate_sum)\n\n    print('save the final model ...')\n    torch.save({'example': actual_step * args.batch_size,\n                'diffusion': model.module.diffusion.state_dict(),\n                'optimizer_diff': optimizer_diff.state_dict()},\n               osp.join(args.snapshot_dir,\n                        'flowdiff_' + format(args.batch_size, \"04d\") + '_S' + format(actual_step, \"06d\") + '.pth'))\n    end = timeit.default_timer()\n    print(end - start, 'seconds')\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n\ndef setup_seed(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    np.random.seed(seed)\n    random.seed(seed)\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    # mp.set_start_method('spawn')\n    # torch.multiprocessing.set_start_method(\"spawn\")\n    main()\n"
  },
  {
    "path": "DM_3/utils.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass MultiEpochsDataLoader(torch.utils.data.DataLoader):\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self._DataLoader__initialized = False\n        self.batch_sampler = _RepeatSampler(self.batch_sampler)\n        self._DataLoader__initialized = True\n        self.iterator = super().__iter__()\n\n    def __len__(self):\n        return len(self.batch_sampler.sampler)\n\n    def __iter__(self):\n        for i in range(len(self)):\n            yield next(self.iterator)\n\n\nclass _RepeatSampler(object):\n    \"\"\" Sampler that repeats forever.\n    Args:\n        sampler (Sampler)\n    \"\"\"\n\n    def __init__(self, sampler):\n        self.sampler = sampler\n\n    def __iter__(self):\n        while True:\n            yield from iter(self.sampler)"
  },
  {
    "path": "LFG/__init__.py",
    "content": ""
  },
  {
    "path": "LFG/augmentation.py",
    "content": "\"\"\"\nCode from https://github.com/hassony2/torch_videovision\n\"\"\"\n\nimport numbers\n\nimport random\nimport numpy as np\nimport PIL\n\nfrom skimage.transform import resize, rotate\nfrom numpy import pad\nimport torchvision\n\nimport warnings\n\nfrom skimage import img_as_ubyte, img_as_float\n\n\ndef crop_clip(clip, min_h, min_w, h, w):\n    if isinstance(clip[0], np.ndarray):\n        cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]\n\n    elif isinstance(clip[0], PIL.Image.Image):\n        cropped = [\n            img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip\n        ]\n    else:\n        raise TypeError('Expected numpy.ndarray or PIL.Image' +\n                        'but got list of {0}'.format(type(clip[0])))\n    return cropped\n\n\ndef pad_clip(clip, h, w):\n    im_h, im_w = clip[0].shape[:2]\n    pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2)\n    pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2)\n\n    return pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge')\n\n\ndef resize_clip(clip, size, interpolation='bilinear'):\n    if isinstance(clip[0], np.ndarray):\n        if isinstance(size, numbers.Number):\n            im_h, im_w, im_c = clip[0].shape\n            # Min spatial dim already matches minimal size\n            if (im_w <= im_h and im_w == size) or (im_h <= im_w\n                                                   and im_h == size):\n                return clip\n            new_h, new_w = get_resize_sizes(im_h, im_w, size)\n            size = (new_w, new_h)\n        else:\n            size = size[1], size[0]\n\n        scaled = [\n            resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True,\n                   mode='constant', anti_aliasing=True) for img in clip\n        ]\n    elif isinstance(clip[0], PIL.Image.Image):\n        if isinstance(size, numbers.Number):\n            im_w, im_h = clip[0].size\n            # Min spatial dim already matches minimal size\n            if (im_w <= im_h and im_w == size) or (im_h <= im_w\n                                                   and im_h == size):\n                return clip\n            new_h, new_w = get_resize_sizes(im_h, im_w, size)\n            size = (new_w, new_h)\n        else:\n            size = size[1], size[0]\n        if interpolation == 'bilinear':\n            pil_inter = PIL.Image.NEAREST\n        else:\n            pil_inter = PIL.Image.BILINEAR\n        scaled = [img.resize(size, pil_inter) for img in clip]\n    else:\n        raise TypeError('Expected numpy.ndarray or PIL.Image' +\n                        'but got list of {0}'.format(type(clip[0])))\n    return scaled\n\n\ndef get_resize_sizes(im_h, im_w, size):\n    if im_w < im_h:\n        ow = size\n        oh = int(size * im_h / im_w)\n    else:\n        oh = size\n        ow = int(size * im_w / im_h)\n    return oh, ow\n\n\nclass RandomFlip(object):\n    def __init__(self, time_flip=False, horizontal_flip=False):\n        self.time_flip = time_flip\n        self.horizontal_flip = horizontal_flip\n\n    def __call__(self, clip):\n        if random.random() < 0.5 and self.time_flip:\n            return clip[::-1]\n        if random.random() < 0.5 and self.horizontal_flip:\n            return [np.fliplr(img) for img in clip]\n\n        return clip\n\n\nclass RandomResize(object):\n    \"\"\"Resizes a list of (H x W x C) numpy.ndarray to the final size\n    The larger the original image is, the more times it takes to\n    interpolate\n    Args:\n    interpolation (str): Can be one of 'nearest', 'bilinear'\n    defaults to nearest\n    size (tuple): (widht, height)\n    \"\"\"\n\n    def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):\n        self.ratio = ratio\n        self.interpolation = interpolation\n\n    def __call__(self, clip):\n        scaling_factor = random.uniform(self.ratio[0], self.ratio[1])\n\n        if isinstance(clip[0], np.ndarray):\n            im_h, im_w, im_c = clip[0].shape\n        elif isinstance(clip[0], PIL.Image.Image):\n            im_w, im_h = clip[0].size\n\n        new_w = int(im_w * scaling_factor)\n        new_h = int(im_h * scaling_factor)\n        new_size = (new_w, new_h)\n        resized = resize_clip(\n            clip, new_size, interpolation=self.interpolation)\n\n        return resized\n\n\nclass RandomCrop(object):\n    \"\"\"Extract random crop at the same location for a list of videos\n    Args:\n    size (sequence or int): Desired output size for the\n    crop in format (h, w)\n    \"\"\"\n\n    def __init__(self, size):\n        if isinstance(size, numbers.Number):\n            size = (size, size)\n\n        self.size = size\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n        img (PIL.Image or numpy.ndarray): List of videos to be cropped\n        in format (h, w, c) in numpy.ndarray\n        Returns:\n        PIL.Image or numpy.ndarray: Cropped list of videos\n        \"\"\"\n        h, w = self.size\n        if isinstance(clip[0], np.ndarray):\n            im_h, im_w, im_c = clip[0].shape\n        elif isinstance(clip[0], PIL.Image.Image):\n            im_w, im_h = clip[0].size\n        else:\n            raise TypeError('Expected numpy.ndarray or PIL.Image' +\n                            'but got list of {0}'.format(type(clip[0])))\n\n        clip = pad_clip(clip, h, w)\n        im_h, im_w = clip.shape[1:3]\n        x1 = 0 if h == im_h else random.randint(0, im_w - w)\n        y1 = 0 if w == im_w else random.randint(0, im_h - h)\n        cropped = crop_clip(clip, y1, x1, h, w)\n\n        return cropped\n\n\nclass RandomRotation(object):\n    \"\"\"Rotate entire clip randomly by a random angle within\n    given bounds\n    Args:\n    degrees (sequence or int): Range of degrees to select from\n    If degrees is a number instead of sequence like (min, max),\n    the range of degrees, will be (-degrees, +degrees).\n    \"\"\"\n\n    def __init__(self, degrees):\n        if isinstance(degrees, numbers.Number):\n            if degrees < 0:\n                raise ValueError('If degrees is a single number,'\n                                 'must be positive')\n            degrees = (-degrees, degrees)\n        else:\n            if len(degrees) != 2:\n                raise ValueError('If degrees is a sequence,'\n                                 'it must be of len 2.')\n\n        self.degrees = degrees\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n        img (PIL.Image or numpy.ndarray): List of videos to be cropped\n        in format (h, w, c) in numpy.ndarray\n        Returns:\n        PIL.Image or numpy.ndarray: Cropped list of videos\n        \"\"\"\n        angle = random.uniform(self.degrees[0], self.degrees[1])\n        if isinstance(clip[0], np.ndarray):\n            rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip]\n        elif isinstance(clip[0], PIL.Image.Image):\n            rotated = [img.rotate(angle) for img in clip]\n        else:\n            raise TypeError('Expected numpy.ndarray or PIL.Image' +\n                            'but got list of {0}'.format(type(clip[0])))\n\n        return rotated\n\n\nclass ColorJitter(object):\n    \"\"\"Randomly change the brightness, contrast and saturation and hue of the clip\n    Args:\n    brightness (float): How much to jitter brightness. brightness_factor\n    is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].\n    contrast (float): How much to jitter contrast. contrast_factor\n    is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].\n    saturation (float): How much to jitter saturation. saturation_factor\n    is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].\n    hue(float): How much to jitter hue. hue_factor is chosen uniformly from\n    [-hue, hue]. Should be >=0 and <= 0.5.\n    \"\"\"\n\n    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):\n        self.brightness = brightness\n        self.contrast = contrast\n        self.saturation = saturation\n        self.hue = hue\n\n    def get_params(self, brightness, contrast, saturation, hue):\n        if brightness > 0:\n            brightness_factor = random.uniform(\n                max(0, 1 - brightness), 1 + brightness)\n        else:\n            brightness_factor = None\n\n        if contrast > 0:\n            contrast_factor = random.uniform(\n                max(0, 1 - contrast), 1 + contrast)\n        else:\n            contrast_factor = None\n\n        if saturation > 0:\n            saturation_factor = random.uniform(\n                max(0, 1 - saturation), 1 + saturation)\n        else:\n            saturation_factor = None\n\n        if hue > 0:\n            hue_factor = random.uniform(-hue, hue)\n        else:\n            hue_factor = None\n        return brightness_factor, contrast_factor, saturation_factor, hue_factor\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n        clip (list): list of PIL.Image\n        Returns:\n        list PIL.Image : list of transformed PIL.Image\n        \"\"\"\n        if isinstance(clip[0], np.ndarray):\n            brightness, contrast, saturation, hue = self.get_params(\n                self.brightness, self.contrast, self.saturation, self.hue)\n\n            # Create img transform function sequence\n            img_transforms = []\n            if brightness is not None:\n                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))\n            if saturation is not None:\n                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))\n            if hue is not None:\n                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))\n            if contrast is not None:\n                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))\n            random.shuffle(img_transforms)\n            img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array,\n                                                                                                     img_as_float]\n\n            with warnings.catch_warnings():\n                warnings.simplefilter(\"ignore\")\n                jittered_clip = []\n                for img in clip:\n                    jittered_img = img\n                    for func in img_transforms:\n                        jittered_img = func(jittered_img)\n                    jittered_clip.append(jittered_img.astype('float32'))\n        elif isinstance(clip[0], PIL.Image.Image):\n            brightness, contrast, saturation, hue = self.get_params(\n                self.brightness, self.contrast, self.saturation, self.hue)\n\n            # Create img transform function sequence\n            img_transforms = []\n            if brightness is not None:\n                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))\n            if saturation is not None:\n                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))\n            if hue is not None:\n                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))\n            if contrast is not None:\n                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))\n            random.shuffle(img_transforms)\n\n            # Apply to all videos\n            jittered_clip = []\n            for img in clip:\n                for func in img_transforms:\n                    jittered_img = func(img)\n                jittered_clip.append(jittered_img)\n\n        else:\n            raise TypeError('Expected numpy.ndarray or PIL.Image' +\n                            'but got list of {0}'.format(type(clip[0])))\n        return jittered_clip\n\n\nclass AllAugmentationTransform:\n    def __init__(self, resize_param=None, rotation_param=None, flip_param=None, crop_param=None, jitter_param=None):\n        self.transforms = []\n\n        if flip_param is not None:\n            self.transforms.append(RandomFlip(**flip_param))\n\n        if rotation_param is not None:\n            self.transforms.append(RandomRotation(**rotation_param))\n\n        if resize_param is not None:\n            self.transforms.append(RandomResize(**resize_param))\n\n        if crop_param is not None:\n            self.transforms.append(RandomCrop(**crop_param))\n\n        if jitter_param is not None:\n            self.transforms.append(ColorJitter(**jitter_param))\n\n    def __call__(self, clip):\n        for t in self.transforms:\n            clip = t(clip)\n        return clip\n"
  },
  {
    "path": "LFG/frames_dataset.py",
    "content": "\"\"\"\nCopyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.\nNo license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify,\npublish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights.\nSuch code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability,\ntitle, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses.\nIn no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof.\n\"\"\"\n\nimport os\nfrom skimage import io, img_as_float32\nfrom skimage.color import gray2rgb\nfrom skimage.transform import resize\n\nfrom sklearn.model_selection import train_test_split\nfrom imageio import mimread\n\nimport numpy as np\nfrom torch.utils.data import Dataset\nimport pandas as pd\nfrom augmentation import AllAugmentationTransform\nimport glob\nfrom functools import partial\n\n\ndef read_video(name, frame_shape):\n    \"\"\"\n    Read video which can be:\n      - an image of concatenated frames\n      - '.mp4' and'.gif'\n      - folder with videos\n    \"\"\"\n\n    if os.path.isdir(name):\n        frames = sorted(os.listdir(name))\n        num_frames = len(frames)\n        video_array = [img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)]\n        if frame_shape is not None:\n            video_array = np.array([resize(frame, frame_shape) for frame in video_array])\n    elif name.lower().endswith('.png') or name.lower().endswith('.jpg'):\n        image = io.imread(name)\n\n        if frame_shape is None:\n            raise ValueError('Frame shape can not be None for stacked png format.')\n\n        frame_shape = tuple(frame_shape)\n\n        if len(image.shape) == 2 or image.shape[2] == 1:\n            image = gray2rgb(image)\n\n        if image.shape[2] == 4:\n            image = image[..., :3]\n\n        image = img_as_float32(image)\n\n        video_array = np.moveaxis(image, 1, 0)\n\n        video_array = video_array.reshape((-1,) + frame_shape + (3, ))\n        video_array = np.moveaxis(video_array, 1, 2)\n    elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'):\n        video = mimread(name)\n        if len(video[0].shape) == 2:\n            video = [gray2rgb(frame) for frame in video]\n        if frame_shape is not None:\n            video = np.array([resize(frame, frame_shape) for frame in video])\n        video = np.array(video)\n        if video.shape[-1] == 4:\n            video = video[..., :3]\n        video_array = img_as_float32(video)\n    else:\n        raise Exception(\"Unknown file extensions  %s\" % name)\n\n    return video_array\n\n\nclass FramesDataset(Dataset):\n    \"\"\"\n    Dataset of videos, each video can be represented as:\n      - an image of concatenated frames\n      - '.mp4' or '.gif'\n      - folder with all frames\n    \"\"\"\n\n    def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,\n                 random_seed=0, pairs_list=None, augmentation_params=None):\n        self.root_dir = root_dir\n        self.videos = os.listdir(root_dir)\n        self.frame_shape = frame_shape\n        self.pairs_list = pairs_list\n        self.id_sampling = id_sampling\n        if os.path.exists(os.path.join(root_dir, 'train')):\n            assert os.path.exists(os.path.join(root_dir, 'test'))\n            print(\"Use predefined train-test split.\")\n            if id_sampling:\n                train_videos = {os.path.basename(video).split('#')[0] for video in\n                                os.listdir(os.path.join(root_dir, 'train'))}\n                train_videos = list(train_videos)\n            else:\n                train_videos = os.listdir(os.path.join(root_dir, 'train'))\n            test_videos = os.listdir(os.path.join(root_dir, 'test'))\n            self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')\n        else:\n            print(\"Use random train-test split.\")\n            train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)\n\n        if is_train:\n            self.videos = train_videos\n        else:\n            self.videos = test_videos\n\n        self.is_train = is_train\n\n        if self.is_train:\n            self.transform = AllAugmentationTransform(**augmentation_params)\n        else:\n            self.transform = None\n\n    def __len__(self):\n        return len(self.videos)\n\n    def __getitem__(self, idx):\n        if self.is_train and self.id_sampling:\n            name = self.videos[idx]\n            try:\n                path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))\n            except ValueError:\n                raise ValueError(\"File formatting is not correct for id_sampling=True. \"\n                                 \"Change file formatting, or set id_sampling=False.\")\n        else:\n            name = self.videos[idx]\n            path = os.path.join(self.root_dir, name)\n\n        video_name = os.path.basename(path)\n\n        if self.is_train and os.path.isdir(path):\n            frames = os.listdir(path)\n            num_frames = len(frames)\n            frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2))\n            \n            if self.frame_shape is not None:\n                resize_fn = partial(resize, output_shape=self.frame_shape)\n            else:\n                resize_fn = img_as_float32\n\n            if type(frames[0]) is bytes:\n                video_array = [resize_fn(io.imread(os.path.join(path, frames[idx].decode('utf-8')))) for idx in\n                               frame_idx]\n            else:\n                video_array = [resize_fn(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx]\n\n        else:\n            video_array = read_video(path, frame_shape=self.frame_shape)\n            num_frames = len(video_array)\n            frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.is_train else range(\n                num_frames)\n            video_array = video_array[frame_idx][..., :3]\n\n        if self.transform is not None:\n            video_array = self.transform(video_array)\n\n        out = {}\n        if self.is_train:\n            source = np.array(video_array[0], dtype='float32')\n            driving = np.array(video_array[1], dtype='float32')\n\n            out['driving'] = driving.transpose((2, 0, 1))\n            out['source'] = source.transpose((2, 0, 1))\n        else:\n            video = np.array(video_array, dtype='float32')\n            out['video'] = video.transpose((3, 0, 1, 2))\n\n        out['name'] = video_name\n        out['id'] = idx\n        \n        return out\n\n\nclass DatasetRepeater(Dataset):\n    \"\"\"\n    Pass several times over the same dataset for better i/o performance\n    \"\"\"\n\n    def __init__(self, dataset, num_repeats=100):\n        self.dataset = dataset\n        self.num_repeats = num_repeats\n\n    def __len__(self):\n        return self.num_repeats * self.dataset.__len__()\n\n    def __getitem__(self, idx):\n        return self.dataset[idx % self.dataset.__len__()]\n\n\nclass PairedDataset(Dataset):\n    \"\"\"\n    Dataset of pairs for animation.\n    \"\"\"\n\n    def __init__(self, initial_dataset, number_of_pairs, seed=0):\n        self.initial_dataset = initial_dataset\n        pairs_list = self.initial_dataset.pairs_list\n\n        np.random.seed(seed)\n\n        if pairs_list is None:\n            max_idx = min(number_of_pairs, len(initial_dataset))\n            nx, ny = max_idx, max_idx\n            xy = np.mgrid[:nx, :ny].reshape(2, -1).T\n            number_of_pairs = min(xy.shape[0], number_of_pairs)\n            self.pairs = xy.take(np.random.choice(xy.shape[0], number_of_pairs, replace=False), axis=0)\n        else:\n            videos = self.initial_dataset.videos\n            name_to_index = {name: index for index, name in enumerate(videos)}\n            pairs = pd.read_csv(pairs_list)\n            pairs = pairs[np.logical_and(pairs['source'].isin(videos), pairs['driving'].isin(videos))]\n\n            number_of_pairs = min(pairs.shape[0], number_of_pairs)\n            self.pairs = []\n            self.start_frames = []\n            for ind in range(number_of_pairs):\n                self.pairs.append(\n                    (name_to_index[pairs['driving'].iloc[ind]], name_to_index[pairs['source'].iloc[ind]]))\n\n    def __len__(self):\n        return len(self.pairs)\n\n    def __getitem__(self, idx):\n        pair = self.pairs[idx]\n        first = self.initial_dataset[pair[0]]\n        second = self.initial_dataset[pair[1]]\n        first = {'driving_' + key: value for key, value in first.items()}\n        second = {'source_' + key: value for key, value in second.items()}\n\n        return {**first, **second}\n"
  },
  {
    "path": "LFG/hdtf_dataset.py",
    "content": "# build MUG dataset for RegionMM\n\nimport os\nimport imageio\n\nimport numpy as np\nfrom torch.utils.data import Dataset\nimport yaml\nfrom argparse import ArgumentParser\nfrom augmentation import AllAugmentationTransform\n\nfrom functools import partial\nimport cv2\nimport matplotlib.pyplot as plt\nimport imageio.v2 as imageio\n\n\ndef resize(im, desired_size, interpolation):\n    old_size = im.shape[:2]\n    ratio = float(desired_size)/max(old_size)\n    new_size = tuple(int(x*ratio) for x in old_size)\n\n    im = cv2.resize(im, (new_size[1], new_size[0]), interpolation=interpolation)\n    delta_w = desired_size - new_size[1]\n    delta_h = desired_size - new_size[0]\n    top, bottom = delta_h//2, delta_h-(delta_h//2)\n    left, right = delta_w//2, delta_w-(delta_w//2)\n\n    color = [0, 0, 0]\n    new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)\n\n    return new_im\n\n\n# this is just for training\nclass FramesDataset(Dataset):\n    \"\"\"\n    Dataset of videos, each video can be represented as:\n      - an image of concatenated frames\n      - '.mp4' or '.gif'\n      - folder with all frames\n    \"\"\"\n\n    def __init__(self, root_dir, frame_shape=256, id_sampling=False,\n                 pairs_list=None, augmentation_params=None):\n        \n        self.root_dir = root_dir\n        self.frame_shape = frame_shape\n        self.pairs_list = pairs_list\n        self.id_sampling = id_sampling\n\n\n        \n        vid_list = []\n        # crema\n        for id_name in os.listdir(root_dir):\n            vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{root_dir}/{id_name}') ])\n\n         #hdtf\n        # for id_name in os.listdir(root_dir):\n        #     vid_list.append(id_name)\n            \n        self.videos = vid_list\n\n        self.transform = AllAugmentationTransform(**augmentation_params)\n\n    def __len__(self):\n        return len(self.videos)\n\n    def __getitem__(self, idx):\n        if self.id_sampling:\n            raise NotImplementedError\n        else:\n            name = self.videos[idx]\n            path = os.path.join(self.root_dir, name)\n\n        frames = os.listdir(path)\n        frames.sort()\n\n        num_frames = len(frames)\n        frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2))\n        resize_fn = partial(resize, desired_size=self.frame_shape, interpolation=cv2.INTER_AREA)\n\n        if type(frames[0]) is bytes:\n            frame_names = [frames[idx].decode('utf-8') for idx in frame_idx]\n        else:\n            frame_names = [frames[idx] for idx in frame_idx]\n\n        video_array = [resize_fn(imageio.imread(os.path.join(path, x))) for x in frame_names]\n\n        # video_array = [img_as_float32(x) for x in video_array]\n\n        video_array = self.transform(video_array)\n\n        out = {}\n\n        source = np.array(video_array[0], dtype='float32')\n        driving = np.array(video_array[1], dtype='float32')\n\n        out['driving'] = driving.transpose((2, 0, 1))\n        out['source'] = source.transpose((2, 0, 1))\n        out['name'] = name\n        out['frame'] = frame_names\n        out['id'] = idx\n        \n        return out\n\n\nclass DatasetRepeater(Dataset):\n    \"\"\"\n    Pass several times over the same dataset for better i/o performance\n    \"\"\"\n\n    def __init__(self, dataset, num_repeats=100):\n        self.dataset = dataset\n        self.num_repeats = num_repeats\n\n    def __len__(self):\n        return self.num_repeats * self.dataset.__len__()\n\n    def __getitem__(self, idx):\n        return self.dataset[idx % self.dataset.__len__()]\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    parser.add_argument(\"--config\",\n                        default=\"/train20/intern/permanent/lmlin2/Flow/CVPR23_LFDM-main/config/hdtf128.yaml\",\n                        help=\"path to config\")\n    opt = parser.parse_args()\n    with open(opt.config) as f:\n        config = yaml.safe_load(f)\n    data = FramesDataset(**config['dataset_params'])\n\n    # data.__getitem__(0)\n\n    # print('_------')\n    # data.__getitem__(1)\n    # print('------')\n    data.__getitem__(2)\n    print('------')    \n    data.__getitem__(3)\n    print('------')"
  },
  {
    "path": "LFG/modules/avd_network.py",
    "content": "\"\"\"\nCopyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.\nNo license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify,\npublish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights.\nSuch code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability,\ntitle, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses.\nIn no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof.\n\"\"\"\nimport torch\nfrom torch import nn\n\n\nclass AVDNetwork(nn.Module):\n    \"\"\"\n    Animation via Disentanglement network\n    \"\"\"\n\n    def __init__(self, num_regions, id_bottle_size=64, pose_bottle_size=64, revert_axis_swap=True):\n        super(AVDNetwork, self).__init__()\n        input_size = (2 + 4) * num_regions\n        self.num_regions = num_regions\n        self.revert_axis_swap = revert_axis_swap\n\n        self.id_encoder = nn.Sequential(\n            nn.Linear(input_size, 256),\n            nn.BatchNorm1d(256),\n            nn.ReLU(inplace=True),\n            nn.Linear(256, 512),\n            nn.BatchNorm1d(512),\n            nn.ReLU(inplace=True),\n            nn.Linear(512, 1024),\n            nn.BatchNorm1d(1024),\n            nn.ReLU(inplace=True),\n            nn.Linear(1024, id_bottle_size)\n        )\n\n        self.pose_encoder = nn.Sequential(\n            nn.Linear(input_size, 256),\n            nn.BatchNorm1d(256),\n            nn.ReLU(inplace=True),\n            nn.Linear(256, 512),\n            nn.BatchNorm1d(512),\n            nn.ReLU(inplace=True),\n            nn.Linear(512, 1024),\n            nn.BatchNorm1d(1024),\n            nn.ReLU(inplace=True),\n            nn.Linear(1024, pose_bottle_size)\n        )\n\n        self.decoder = nn.Sequential(\n            nn.Linear(pose_bottle_size + id_bottle_size, 1024),\n            nn.BatchNorm1d(1024),\n            nn.ReLU(),\n            nn.Linear(1024, 512),\n            nn.BatchNorm1d(512),\n            nn.ReLU(),\n            nn.Linear(512, 256),\n            nn.BatchNorm1d(256),\n            nn.ReLU(),\n            nn.Linear(256, input_size)\n        )\n\n    @staticmethod\n    def region_params_to_emb(x):\n        mean = x['shift']\n        jac = x['affine']\n        emb = torch.cat([mean, jac.view(jac.shape[0], jac.shape[1], -1)], dim=-1)\n        emb = emb.view(emb.shape[0], -1)\n        return emb\n\n    def emb_to_region_params(self, emb):\n        emb = emb.view(emb.shape[0], self.num_regions, 6)\n        mean = emb[:, :, :2]\n        jac = emb[:, :, 2:].view(emb.shape[0], emb.shape[1], 2, 2)\n        return {'shift': mean, 'affine': jac}\n\n    def forward(self, x_id, x_pose, alpha=0.2):\n        if self.revert_axis_swap:\n            affine = torch.matmul(x_id['affine'], torch.inverse(x_pose['affine']))\n            sign = torch.sign(affine[:, :, 0:1, 0:1])\n            x_id = {'affine': x_id['affine'] * sign, 'shift': x_id['shift']}\n\n        pose_emb = self.pose_encoder(self.region_params_to_emb(x_pose))\n        id_emb = self.id_encoder(self.region_params_to_emb(x_id))\n\n        rec = self.decoder(torch.cat([pose_emb, id_emb], dim=1))\n\n        rec = self.emb_to_region_params(rec)\n        rec['covar'] = torch.matmul(rec['affine'], rec['affine'].permute(0, 1, 3, 2))\n        return rec\n"
  },
  {
    "path": "LFG/modules/bg_motion_predictor.py",
    "content": "\"\"\"\nCopyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.\nNo license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify,\npublish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights.\nSuch code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability,\ntitle, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses.\nIn no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof.\n\"\"\"\n\nfrom torch import nn\nimport torch\nfrom LFG.modules.util import Encoder\n\n\nclass BGMotionPredictor(nn.Module):\n    \"\"\"\n    Module for background estimation, return single transformation, parametrized as 3x3 matrix.\n    \"\"\"\n\n    def __init__(self, block_expansion, num_channels, max_features, num_blocks, bg_type='zero'):\n        super(BGMotionPredictor, self).__init__()\n        assert bg_type in ['zero', 'shift', 'affine', 'perspective']\n\n        self.bg_type = bg_type\n        if self.bg_type != 'zero':\n            self.encoder = Encoder(block_expansion, in_features=num_channels * 2, max_features=max_features,\n                                   num_blocks=num_blocks)\n            in_features = min(max_features, block_expansion * (2 ** num_blocks))\n            if self.bg_type == 'perspective':\n                self.fc = nn.Linear(in_features, 8)\n                self.fc.weight.data.zero_()\n                self.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0], dtype=torch.float))\n            elif self.bg_type == 'affine':\n                self.fc = nn.Linear(in_features, 6)\n                self.fc.weight.data.zero_()\n                self.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))\n            elif self.bg_type == 'shift':\n                self.fc = nn.Linear(in_features, 2)\n                self.fc.weight.data.zero_()\n                self.fc.bias.data.copy_(torch.tensor([0, 0], dtype=torch.float))\n\n    def forward(self, source_image, driving_image):\n        bs = source_image.shape[0]\n        out = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).type(source_image.type())\n        if self.bg_type != 'zero':\n            prediction = self.encoder(torch.cat([source_image, driving_image], dim=1))\n            prediction = prediction[-1].mean(dim=(2, 3))\n            prediction = self.fc(prediction)\n            if self.bg_type == 'shift':\n                out[:, :2, 2] = prediction\n            elif self.bg_type == 'affine':\n                out[:, :2, :] = prediction.view(bs, 2, 3)\n            elif self.bg_type == 'perspective':\n                out[:, :2, :] = prediction[:, :6].view(bs, 2, 3)\n                out[:, 2, :2] = prediction[:, 6:].view(bs, 2)\n\n        return out\n"
  },
  {
    "path": "LFG/modules/flow_autoenc.py",
    "content": "# utilize RegionMM to design a flow auto-encoder\n\nimport torch\nimport torch.nn as nn\nimport sys\nsys.path.append('/train20/intern/permanent/lmlin2/Flow/CVPR23_LFDM-main')\nfrom LFG.modules.generator import Generator\nfrom LFG.modules.bg_motion_predictor import BGMotionPredictor\nfrom LFG.modules.region_predictor import RegionPredictor\nimport yaml\n\n\n# based on RegionMM\nclass FlowAE(nn.Module):\n    def __init__(self, is_train=False,\n                 config_pth=\"/train20/intern/permanent/lmlin2/Flow/CVPR23_LFDM-main/config/mug128.yaml\"):\n        super(FlowAE, self).__init__()\n\n        with open(config_pth) as f:\n            config = yaml.safe_load(f)\n\n        self.generator = Generator(num_regions=config['model_params']['num_regions'],\n                                   num_channels=config['model_params']['num_channels'],\n                                   revert_axis_swap=config['model_params']['revert_axis_swap'],\n                                   **config['model_params']['generator_params']).cuda()\n        self.region_predictor = RegionPredictor(num_regions=config['model_params']['num_regions'],\n                                                num_channels=config['model_params']['num_channels'],\n                                                estimate_affine=config['model_params']['estimate_affine'],\n                                                **config['model_params']['region_predictor_params']).cuda()\n        self.bg_predictor = BGMotionPredictor(num_channels=config['model_params']['num_channels'],\n                                              **config['model_params']['bg_predictor_params'])\n\n        self.is_train = is_train\n\n        self.ref_img = None\n        self.dri_img = None\n        self.generated = None\n\n    def forward(self):\n        source_region_params = self.region_predictor(self.ref_img)\n        self.driving_region_params = self.region_predictor(self.dri_img)\n\n        bg_params = self.bg_predictor(self.ref_img, self.dri_img)\n        self.generated = self.generator(self.ref_img, source_region_params=source_region_params,\n                                        driving_region_params=self.driving_region_params, bg_params=bg_params)\n        self.generated.update({'source_region_params': source_region_params,\n                               'driving_region_params': self.driving_region_params})\n\n    def set_train_input(self, ref_img, dri_img):\n        self.ref_img = ref_img.cuda()\n        self.dri_img = dri_img.cuda()\n\n\nif __name__ == \"__main__\":\n    # default image size is 128\n    # import os\n    # os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n    ref_img = torch.rand((5, 3, 128, 128), dtype=torch.float32)\n    dri_img = torch.rand((5, 3, 128, 128), dtype=torch.float32)\n    model = FlowAE(is_train=True).cuda()\n    model.train()\n    model.set_train_input(ref_img=ref_img, dri_img=dri_img)\n    model.forward()\n    print(\"___finihed___\")\n\n"
  },
  {
    "path": "LFG/modules/generator.py",
    "content": "\"\"\"\nCopyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.\nNo license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify,\npublish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights.\nSuch code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability,\ntitle, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses.\nIn no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof.\n\"\"\"\nimport time\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\nimport sys\nsys.path.append(\"/train20/intern/permanent/lmlin2/Flow/CVPR23_LFDM-main\")\nfrom LFG.modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d\nfrom LFG.modules.pixelwise_flow_predictor import PixelwiseFlowPredictor\n\n\nclass Generator(nn.Module):\n    \"\"\"\n    Generator that given source image and region parameters try to transform image according to movement trajectories\n    induced by region parameters. Generator follows Johnson architecture.\n    \"\"\"\n\n    def __init__(self, num_channels, num_regions, block_expansion, max_features, num_down_blocks,\n                 num_bottleneck_blocks, pixelwise_flow_predictor_params=None, skips=False, revert_axis_swap=True):\n        super(Generator, self).__init__()\n\n        if pixelwise_flow_predictor_params is not None:\n            self.pixelwise_flow_predictor = PixelwiseFlowPredictor(num_regions=num_regions, num_channels=num_channels,\n                                                                   revert_axis_swap=revert_axis_swap,\n                                                                   **pixelwise_flow_predictor_params)\n        else:\n            self.pixelwise_flow_predictor = None\n\n        self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))\n\n        down_blocks = []\n        for i in range(num_down_blocks):\n            in_features = min(max_features, block_expansion * (2 ** i))\n            out_features = min(max_features, block_expansion * (2 ** (i + 1)))\n            down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))\n        self.down_blocks = nn.ModuleList(down_blocks)\n\n        up_blocks = []\n        for i in range(num_down_blocks):\n            in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i)))\n            out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1)))\n            up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))\n        self.up_blocks = nn.ModuleList(up_blocks)\n\n        self.bottleneck = torch.nn.Sequential()\n        in_features = min(max_features, block_expansion * (2 ** num_down_blocks))\n        for i in range(num_bottleneck_blocks):\n            self.bottleneck.add_module('r' + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1)))\n\n        self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))\n        self.num_channels = num_channels\n        self.skips = skips\n\n    @staticmethod\n    def deform_input(inp, optical_flow):\n        _, h_old, w_old, _ = optical_flow.shape\n        _, _, h, w = inp.shape\n        if h_old != h or w_old != w:\n            optical_flow = optical_flow.permute(0, 3, 1, 2)\n            optical_flow = F.interpolate(optical_flow, size=(h, w), mode='bilinear')\n            optical_flow = optical_flow.permute(0, 2, 3, 1)\n        return F.grid_sample(inp, optical_flow)\n\n    def apply_optical(self, input_previous=None, input_skip=None, motion_params=None):\n        if motion_params is not None:\n            if 'occlusion_map' in motion_params:\n                occlusion_map = motion_params['occlusion_map']\n            else:\n                occlusion_map = None\n            deformation = motion_params['optical_flow']\n            input_skip = self.deform_input(input_skip, deformation)\n\n            if occlusion_map is not None:\n                if input_skip.shape[2] != occlusion_map.shape[2] or input_skip.shape[3] != occlusion_map.shape[3]:\n                    occlusion_map = F.interpolate(occlusion_map, size=input_skip.shape[2:], mode='bilinear')\n                if input_previous is not None:\n                    input_skip = input_skip * occlusion_map + input_previous * (1 - occlusion_map)\n                else:\n                    input_skip = input_skip * occlusion_map\n            out = input_skip\n        else:\n            out = input_previous if input_previous is not None else input_skip\n        return out\n\n    def forward(self, source_image, driving_region_params, source_region_params, bg_params=None):\n        out = self.first(source_image)\n        skips = [out]\n        for i in range(len(self.down_blocks)):\n            out = self.down_blocks[i](out)\n            skips.append(out)\n\n        output_dict = {}\n        output_dict[\"bottle_neck_feat\"] = out\n        if self.pixelwise_flow_predictor is not None:\n            motion_params = self.pixelwise_flow_predictor(source_image=source_image,\n                                                          driving_region_params=driving_region_params,\n                                                          source_region_params=source_region_params,\n                                                          bg_params=bg_params)\n            output_dict[\"deformed\"] = self.deform_input(source_image, motion_params['optical_flow'])\n            output_dict[\"optical_flow\"] = motion_params['optical_flow']\n            if 'occlusion_map' in motion_params:\n                output_dict['occlusion_map'] = motion_params['occlusion_map']\n        else:\n            motion_params = None\n\n        out = self.apply_optical(input_previous=None, input_skip=out, motion_params=motion_params)\n\n        out = self.bottleneck(out)\n        for i in range(len(self.up_blocks)):\n            if self.skips:\n                out = self.apply_optical(input_skip=skips[-(i + 1)], input_previous=out, motion_params=motion_params)\n            out = self.up_blocks[i](out)\n        if self.skips:\n            out = self.apply_optical(input_skip=skips[0], input_previous=out, motion_params=motion_params)\n        out = self.final(out)\n        out = torch.sigmoid(out)\n\n        if self.skips:\n            out = self.apply_optical(input_skip=source_image, input_previous=out, motion_params=motion_params)\n\n        output_dict[\"prediction\"] = out\n\n        return output_dict\n\n    def compute_fea(self, source_image):\n        out = self.first(source_image)\n        for i in range(len(self.down_blocks)):\n            out = self.down_blocks[i](out)\n        return out\n\n    def forward_with_flow(self, source_image, optical_flow, occlusion_map):\n        start_time = time.time()  # end\n        out = self.first(source_image)\n        end_time = time.time()  # end\n        # print(f'img fea extract time surplus {end_time- start_time}')\n        skips = [out]\n        for i in range(len(self.down_blocks)):\n            out = self.down_blocks[i](out)\n            skips.append(out)\n\n        output_dict = {}\n        motion_params = {}\n        motion_params[\"optical_flow\"] = optical_flow\n        motion_params[\"occlusion_map\"] = occlusion_map\n        output_dict[\"deformed\"] = self.deform_input(source_image, motion_params['optical_flow'])\n\n        out = self.apply_optical(input_previous=None, input_skip=out, motion_params=motion_params)\n\n        out = self.bottleneck(out)\n        for i in range(len(self.up_blocks)):\n            if self.skips:\n                out = self.apply_optical(input_skip=skips[-(i + 1)], input_previous=out, motion_params=motion_params)\n            out = self.up_blocks[i](out)\n        if self.skips:\n            out = self.apply_optical(input_skip=skips[0], input_previous=out, motion_params=motion_params)\n        out = self.final(out)\n        out = torch.sigmoid(out)\n\n        if self.skips:\n            out = self.apply_optical(input_skip=source_image, input_previous=out, motion_params=motion_params)\n\n        output_dict[\"prediction\"] = out\n\n        return output_dict\n"
  },
  {
    "path": "LFG/modules/model.py",
    "content": "\"\"\"\nCopyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.\nNo license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify,\npublish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights.\nSuch code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability,\ntitle, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses.\nIn no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof.\n\"\"\"\n\nfrom torch import nn\nimport torch\nimport torch.nn.functional as F\nfrom LFG.modules.util import AntiAliasInterpolation2d, make_coordinate_grid\nfrom torchvision import models\nimport numpy as np\nfrom torch.autograd import grad\n\n\nclass Vgg19(torch.nn.Module):\n    \"\"\"\n    Vgg19 network for perceptual loss.\n    \"\"\"\n\n    def __init__(self, requires_grad=False):\n        super(Vgg19, self).__init__()\n        vgg_pretrained_features = models.vgg19(pretrained=True).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        for x in range(2):\n            self.slice1.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(2, 7):\n            self.slice2.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(7, 12):\n            self.slice3.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(12, 21):\n            self.slice4.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(21, 30):\n            self.slice5.add_module(str(x), vgg_pretrained_features[x])\n\n        self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),\n                                       requires_grad=False)\n        self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),\n                                      requires_grad=False)\n\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, x):\n        x = (x - self.mean) / self.std\n        h_relu1 = self.slice1(x)\n        h_relu2 = self.slice2(h_relu1)\n        h_relu3 = self.slice3(h_relu2)\n        h_relu4 = self.slice4(h_relu3)\n        h_relu5 = self.slice5(h_relu4)\n        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]\n        return out\n\n\nclass ImagePyramide(torch.nn.Module):\n    \"\"\"\n    Create image pyramide for computing pyramide perceptual loss.\n    \"\"\"\n\n    def __init__(self, scales, num_channels):\n        super(ImagePyramide, self).__init__()\n        downs = {}\n        for scale in scales:\n            downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)\n        self.downs = nn.ModuleDict(downs)\n\n    def forward(self, x):\n        out_dict = {}\n        for scale, down_module in self.downs.items():\n            out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)\n        return out_dict\n\n\nclass Transform:\n    \"\"\"\n    Random tps transformation for equivariance constraints.\n    \"\"\"\n\n    def __init__(self, bs, **kwargs):\n        noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))\n        self.theta = noise + torch.eye(2, 3).view(1, 2, 3)\n        self.bs = bs\n\n        if ('sigma_tps' in kwargs) and ('points_tps' in kwargs):\n            self.tps = True\n            self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']),\n                                                       type=self.theta.type())\n            self.control_points = self.control_points.unsqueeze(0)\n            self.control_params = torch.normal(mean=0,\n                                               std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))\n        else:\n            self.tps = False\n\n    def transform_frame(self, frame):\n        grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0)\n        grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)\n        grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)\n        return F.grid_sample(frame, grid, padding_mode=\"reflection\")\n\n    def warp_coordinates(self, coordinates):\n        theta = self.theta.type(coordinates.type())\n        theta = theta.unsqueeze(1)\n        transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]\n        transformed = transformed.squeeze(-1)\n\n        if self.tps:\n            control_points = self.control_points.type(coordinates.type())\n            control_params = self.control_params.type(coordinates.type())\n            distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)\n            distances = torch.abs(distances).sum(-1)\n\n            # TODO this part may have bugs\n            result = distances ** 2\n            result = result * torch.log(distances + 1e-6)\n            result = result * control_params\n            result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)\n            transformed = transformed + result\n\n        return transformed\n\n    def jacobian(self, coordinates):\n        new_coordinates = self.warp_coordinates(coordinates)\n        grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True)\n        grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True)\n        jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2)\n        return jacobian\n\n\ndef detach_kp(kp):\n    return {key: value.detach() for key, value in kp.items()}\n\n\nclass ReconstructionModel(torch.nn.Module):\n    \"\"\"\n    Merge all updates into single model for better multi-gpu usage\n    \"\"\"\n\n    def __init__(self, region_predictor, bg_predictor, generator, train_params):\n        super(ReconstructionModel, self).__init__()\n        self.region_predictor = region_predictor\n        self.bg_predictor = bg_predictor\n        self.generator = generator\n        self.train_params = train_params\n        self.scales = train_params['scales']\n        self.pyramid = ImagePyramide(self.scales, generator.num_channels)\n        if torch.cuda.is_available():\n            self.pyramid = self.pyramid.cuda()\n\n        self.loss_weights = train_params['loss_weights']\n\n        if sum(self.loss_weights['perceptual']) != 0:\n            self.vgg = Vgg19()\n            if torch.cuda.is_available():\n                self.vgg = self.vgg.cuda()\n\n    def forward(self, x):\n        source_region_params = self.region_predictor(x['source'])\n        driving_region_params = self.region_predictor(x['driving'])\n\n        bg_params = self.bg_predictor(x['source'], x['driving'])    #background\n        generated = self.generator(x['source'], source_region_params=source_region_params,\n                                   driving_region_params=driving_region_params, bg_params=bg_params)\n        generated.update({'source_region_params': source_region_params, 'driving_region_params': driving_region_params})\n\n        loss_values = {}\n\n        pyramide_real = self.pyramid(x['driving'])\n        pyramide_generated = self.pyramid(generated['prediction'])\n\n        if sum(self.loss_weights['perceptual']) != 0:\n            value_total = 0\n            for scale in self.scales:\n                x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])\n                y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])\n\n                for i, weight in enumerate(self.loss_weights['perceptual']):\n                    value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()\n                    value_total += self.loss_weights['perceptual'][i] * value\n            loss_values['perceptual'] = value_total\n\n        if (self.loss_weights['equivariance_shift'] + self.loss_weights['equivariance_affine']) != 0:\n            transform = Transform(x['driving'].shape[0], **self.train_params['transform_params'])\n            transformed_frame = transform.transform_frame(x['driving'])\n            transformed_region_params = self.region_predictor(transformed_frame)\n\n            generated['transformed_frame'] = transformed_frame\n            generated['transformed_region_params'] = transformed_region_params\n\n            if self.loss_weights['equivariance_shift'] != 0:\n                value = torch.abs(driving_region_params['shift'] -\n                                  transform.warp_coordinates(transformed_region_params['shift'])).mean()\n                loss_values['equivariance_shift'] = self.loss_weights['equivariance_shift'] * value\n\n            if self.loss_weights['equivariance_affine'] != 0:\n                affine_transformed = torch.matmul(transform.jacobian(transformed_region_params['shift']),\n                                                  transformed_region_params['affine'])\n\n                normed_driving = torch.inverse(driving_region_params['affine'])\n                normed_transformed = affine_transformed\n                value = torch.matmul(normed_driving, normed_transformed)\n                eye = torch.eye(2).view(1, 1, 2, 2).type(value.type())\n\n                if self.generator.pixelwise_flow_predictor.revert_axis_swap:\n                    value = value * torch.sign(value[:, :, 0:1, 0:1])\n\n                value = torch.abs(eye - value).mean()\n                loss_values['equivariance_affine'] = self.loss_weights['equivariance_affine'] * value\n\n        return loss_values, generated\n\n\n"
  },
  {
    "path": "LFG/modules/pixelwise_flow_predictor.py",
    "content": "\"\"\"\nCopyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.\nNo license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify,\npublish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights.\nSuch code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability,\ntitle, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses.\nIn no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof.\n\"\"\"\n\nfrom torch import nn\nimport torch.nn.functional as F\nimport torch\nfrom LFG.modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, region2gaussian\nfrom LFG.modules.util import to_homogeneous, from_homogeneous\n\n\nclass PixelwiseFlowPredictor(nn.Module):\n    \"\"\"\n    Module that predicts a pixelwise flow from sparse motion representation given by\n    source_region_params and driving_region_params\n    \"\"\"\n\n    def __init__(self, block_expansion, num_blocks, max_features, num_regions, num_channels,\n                 estimate_occlusion_map=False, scale_factor=1, region_var=0.01,\n                 use_covar_heatmap=False, use_deformed_source=True, revert_axis_swap=False):\n        super(PixelwiseFlowPredictor, self).__init__()\n        self.hourglass = Hourglass(block_expansion=block_expansion,\n                                   in_features=(num_regions + 1) * (num_channels * use_deformed_source + 1),\n                                   max_features=max_features, num_blocks=num_blocks)\n\n        self.mask = nn.Conv2d(self.hourglass.out_filters, num_regions + 1, kernel_size=(7, 7), padding=(3, 3))\n\n        if estimate_occlusion_map:\n            self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3))\n        else:\n            self.occlusion = None\n\n        self.num_regions = num_regions\n        self.scale_factor = scale_factor\n        self.region_var = region_var\n        self.use_covar_heatmap = use_covar_heatmap\n        self.use_deformed_source = use_deformed_source\n        self.revert_axis_swap = revert_axis_swap\n\n        if self.scale_factor != 1:\n            self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)\n\n    def create_heatmap_representations(self, source_image, driving_region_params, source_region_params):\n        \"\"\"\n        Eq 6. in the paper H_k(z)\n        \"\"\"\n        spatial_size = source_image.shape[2:]\n        covar = self.region_var if not self.use_covar_heatmap else driving_region_params['covar']\n        gaussian_driving = region2gaussian(driving_region_params['shift'], covar=covar, spatial_size=spatial_size)\n        covar = self.region_var if not self.use_covar_heatmap else source_region_params['covar']\n        gaussian_source = region2gaussian(source_region_params['shift'], covar=covar, spatial_size=spatial_size)\n\n        heatmap = gaussian_driving - gaussian_source\n\n        # adding background feature\n        zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1])\n        heatmap = torch.cat([zeros.type(heatmap.type()), heatmap], dim=1)\n        heatmap = heatmap.unsqueeze(2)\n        return heatmap\n\n    def create_sparse_motions(self, source_image, driving_region_params, source_region_params, bg_params=None):\n        bs, _, h, w = source_image.shape\n        identity_grid = make_coordinate_grid((h, w), type=source_region_params['shift'].type())\n        identity_grid = identity_grid.view(1, 1, h, w, 2)\n        coordinate_grid = identity_grid - driving_region_params['shift'].view(bs, self.num_regions, 1, 1, 2)\n        if 'affine' in driving_region_params:\n            affine = torch.matmul(source_region_params['affine'], torch.inverse(driving_region_params['affine'].float()))\n            if self.revert_axis_swap:\n                affine = affine * torch.sign(affine[:, :, 0:1, 0:1])\n            affine = affine.unsqueeze(-3).unsqueeze(-3)\n            affine = affine.repeat(1, 1, h, w, 1, 1)\n            coordinate_grid = torch.matmul(affine, coordinate_grid.unsqueeze(-1))\n            coordinate_grid = coordinate_grid.squeeze(-1)\n\n        driving_to_source = coordinate_grid + source_region_params['shift'].view(bs, self.num_regions, 1, 1, 2)\n\n        # adding background feature\n        if bg_params is None:\n            bg_grid = identity_grid.repeat(bs, 1, 1, 1, 1)\n        else:\n            bg_grid = identity_grid.repeat(bs, 1, 1, 1, 1)\n            bg_grid = to_homogeneous(bg_grid)\n            bg_grid = torch.matmul(bg_params.view(bs, 1, 1, 1, 3, 3), bg_grid.unsqueeze(-1)).squeeze(-1)\n            bg_grid = from_homogeneous(bg_grid)\n\n        sparse_motions = torch.cat([bg_grid, driving_to_source], dim=1)\n\n        return sparse_motions\n\n    def create_deformed_source_image(self, source_image, sparse_motions):\n        bs, _, h, w = source_image.shape\n        source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_regions + 1, 1, 1, 1, 1)\n        source_repeat = source_repeat.view(bs * (self.num_regions + 1), -1, h, w)\n        sparse_motions = sparse_motions.view((bs * (self.num_regions + 1), h, w, -1))\n        sparse_deformed = F.grid_sample(source_repeat, sparse_motions)\n        sparse_deformed = sparse_deformed.view((bs, self.num_regions + 1, -1, h, w))\n        return sparse_deformed\n\n    def forward(self, source_image, driving_region_params, source_region_params, bg_params=None):\n        if self.scale_factor != 1:\n            source_image = self.down(source_image)\n\n        bs, _, h, w = source_image.shape\n\n        out_dict = dict()\n        heatmap_representation = self.create_heatmap_representations(source_image, driving_region_params,\n                                                                     source_region_params)\n        sparse_motion = self.create_sparse_motions(source_image, driving_region_params,\n                                                   source_region_params, bg_params=bg_params)\n        deformed_source = self.create_deformed_source_image(source_image, sparse_motion)\n        if self.use_deformed_source:\n            predictor_input = torch.cat([heatmap_representation, deformed_source], dim=2)\n        else:\n            predictor_input = heatmap_representation\n        predictor_input = predictor_input.view(bs, -1, h, w)\n\n        prediction = self.hourglass(predictor_input)\n\n        mask = self.mask(prediction)\n        mask = F.softmax(mask, dim=1)\n        mask = mask.unsqueeze(2)\n        sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3)\n        deformation = (sparse_motion * mask).sum(dim=1)\n        deformation = deformation.permute(0, 2, 3, 1)\n\n        out_dict['optical_flow'] = deformation\n\n        if self.occlusion:\n            occlusion_map = torch.sigmoid(self.occlusion(prediction))\n            out_dict['occlusion_map'] = occlusion_map\n\n        return out_dict\n"
  },
  {
    "path": "LFG/modules/region_predictor.py",
    "content": "\"\"\"\nCopyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.\nNo license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify,\npublish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights.\nSuch code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability,\ntitle, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses.\nIn no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof.\n\"\"\"\n\nfrom torch import nn\nimport torch\nimport torch.nn.functional as F\nfrom LFG.modules.util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d, Encoder\n\n\ndef svd(covar, fast=False):\n    if fast:\n        from torch_batch_svd import svd as fast_svd\n        return fast_svd(covar)\n    else:\n        u, s, v = torch.svd(covar.cpu())\n        s = s.to(covar.device)\n        u = u.to(covar.device)\n        v = v.to(covar.device)\n        return u, s, v\n\n\nclass RegionPredictor(nn.Module):\n    \"\"\"\n    Region estimating. Estimate affine parameters of the region.\n    \"\"\"\n\n    def __init__(self, block_expansion, num_regions, num_channels, max_features,\n                 num_blocks, temperature, estimate_affine=False, scale_factor=1,\n                 pca_based=False, fast_svd=False, pad=3):\n        super(RegionPredictor, self).__init__()\n        self.predictor = Hourglass(block_expansion, in_features=num_channels,\n                                   max_features=max_features, num_blocks=num_blocks)\n\n        self.regions = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_regions, kernel_size=(7, 7),\n                                 padding=pad)\n\n        # FOMM-like regression based representation\n        if estimate_affine and not pca_based:\n            self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters,\n                                      out_channels=4, kernel_size=(7, 7), padding=pad)\n            self.jacobian.weight.data.zero_()\n            self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1], dtype=torch.float))\n        else:\n            self.jacobian = None\n\n        self.temperature = temperature\n        self.scale_factor = scale_factor\n        self.pca_based = pca_based\n        self.fast_svd = fast_svd\n\n        if self.scale_factor != 1:\n            self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)\n\n    def region2affine(self, region):\n        shape = region.shape\n        region = region.unsqueeze(-1)\n        grid = make_coordinate_grid(shape[2:], region.type()).unsqueeze_(0).unsqueeze_(0)\n        mean = (region * grid).sum(dim=(2, 3))\n\n        region_params = {'shift': mean}\n\n        if self.pca_based:\n            mean_sub = grid - mean.unsqueeze(-2).unsqueeze(-2)\n            covar = torch.matmul(mean_sub.unsqueeze(-1), mean_sub.unsqueeze(-2))\n            covar = covar * region.unsqueeze(-1)\n            covar = covar.sum(dim=(2, 3))\n            region_params['covar'] = covar\n\n        return region_params\n\n    def forward(self, x):\n        if self.scale_factor != 1:\n            x = self.down(x)\n\n        feature_map = self.predictor(x)\n        prediction = self.regions(feature_map)\n\n        final_shape = prediction.shape\n        region = prediction.view(final_shape[0], final_shape[1], -1)\n        region = F.softmax(region / self.temperature, dim=2)\n        region = region.view(*final_shape)\n\n        region_params = self.region2affine(region)\n        region_params['heatmap'] = region\n\n        # Regression-based estimation\n        if self.jacobian is not None:\n            jacobian_map = self.jacobian(feature_map)\n            jacobian_map = jacobian_map.reshape(final_shape[0], 1, 4, final_shape[2],\n                                                final_shape[3])\n            region = region.unsqueeze(2)\n\n            jacobian = region * jacobian_map\n            jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)\n            jacobian = jacobian.sum(dim=-1)\n            jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2)\n            region_params['affine'] = jacobian\n            region_params['covar'] = torch.matmul(jacobian, jacobian.permute(0, 1, 3, 2))\n        elif self.pca_based:\n            covar = region_params['covar']\n            shape = covar.shape\n            covar = covar.view(-1, 2, 2)\n            u, s, v = svd(covar, self.fast_svd)\n            d = torch.diag_embed(s ** 0.5)\n            sqrt = torch.matmul(u, d)\n            sqrt = sqrt.view(*shape)\n            region_params['affine'] = sqrt\n            region_params['u'] = u\n            region_params['d'] = d\n\n        return region_params\n"
  },
  {
    "path": "LFG/modules/util.py",
    "content": "\"\"\"\nCopyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.\nNo license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify,\npublish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights.\nSuch code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability,\ntitle, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses.\nIn no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof.\n\"\"\"\n\nfrom torch import nn\n\nimport torch.nn.functional as F\nimport torch\nfrom LFG.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d\n\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom skimage.draw import disk as circle\nimport math\n\n\ndef region2gaussian(center, covar, spatial_size):\n    \"\"\"\n    Transform a region parameters into gaussian like heatmap\n    \"\"\"\n    mean = center\n\n    coordinate_grid = make_coordinate_grid(spatial_size, mean.type())\n    number_of_leading_dimensions = len(mean.shape) - 1\n    shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape\n    coordinate_grid = coordinate_grid.view(*shape)\n    repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1)\n    coordinate_grid = coordinate_grid.repeat(*repeats)\n\n    # Preprocess kp shape\n    shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2)\n    mean = mean.view(*shape)\n\n    mean_sub = (coordinate_grid - mean)\n    if type(covar) == float:\n        out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / covar)\n    else:\n        shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2, 2)\n        covar_inverse = torch.inverse(covar).view(*shape)\n        under_exp = torch.matmul(torch.matmul(mean_sub.unsqueeze(-2), covar_inverse), mean_sub.unsqueeze(-1))\n        out = torch.exp(-0.5 * under_exp.sum(dim=(-1, -2)))\n\n    return out\n\n\ndef make_coordinate_grid(spatial_size, type):\n    \"\"\"\n    Create a meshgrid [-1,1] x [-1,1] of given spatial_size.\n    \"\"\"\n    h, w = spatial_size\n    x = torch.arange(w).type(type)\n    y = torch.arange(h).type(type)\n\n    x = (2 * (x / (w - 1)) - 1)\n    y = (2 * (y / (h - 1)) - 1)\n\n    yy = y.view(-1, 1).repeat(1, w)\n    xx = x.view(1, -1).repeat(h, 1)\n\n    meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)\n\n    return meshed\n\n\nclass ResBlock2d(nn.Module):\n    \"\"\"\n    Res block, preserve spatial resolution.\n    \"\"\"\n\n    def __init__(self, in_features, kernel_size, padding):\n        super(ResBlock2d, self).__init__()\n        self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,\n                               padding=padding)\n        self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,\n                               padding=padding)\n        self.norm1 = BatchNorm2d(in_features, affine=True)\n        self.norm2 = BatchNorm2d(in_features, affine=True)\n\n    def forward(self, x):\n        out = self.norm1(x)\n        out = F.relu(out)\n        out = self.conv1(out)\n        out = self.norm2(out)\n        out = F.relu(out)\n        out = self.conv2(out)\n        out += x\n        return out\n\n\nclass UpBlock2d(nn.Module):\n    \"\"\"\n    Upsampling block for use in decoder.\n    \"\"\"\n\n    def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):\n        super(UpBlock2d, self).__init__()\n\n        self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,\n                              padding=padding, groups=groups)\n        self.norm = BatchNorm2d(out_features, affine=True)\n\n    def forward(self, x):\n        out = F.interpolate(x, scale_factor=2)\n        out = self.conv(out)\n        out = self.norm(out)\n        out = F.relu(out)\n        return out\n\n\nclass DownBlock2d(nn.Module):\n    \"\"\"\n    Downsampling block for use in encoder.\n    \"\"\"\n\n    def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):\n        super(DownBlock2d, self).__init__()\n        self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,\n                              padding=padding, groups=groups)\n        self.norm = BatchNorm2d(out_features, affine=True)\n        self.pool = nn.AvgPool2d(kernel_size=(2, 2))\n\n    def forward(self, x):\n        out = self.conv(x)\n        out = self.norm(out)\n        out = F.relu(out)\n        out = self.pool(out)\n        return out\n\n\nclass SameBlock2d(nn.Module):\n    \"\"\"\n    Simple block, preserve spatial resolution.\n    \"\"\"\n\n    def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1):\n        super(SameBlock2d, self).__init__()\n        self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,\n                              kernel_size=kernel_size, padding=padding, groups=groups)\n        self.norm = BatchNorm2d(out_features, affine=True)\n\n    def forward(self, x):\n        out = self.conv(x)\n        out = self.norm(out)\n        out = F.relu(out)\n        return out\n\n\nclass Encoder(nn.Module):\n    \"\"\"\n    Hourglass Encoder\n    \"\"\"\n\n    def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):\n        super(Encoder, self).__init__()\n\n        down_blocks = []\n        for i in range(num_blocks):\n            down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),\n                                           min(max_features, block_expansion * (2 ** (i + 1))),\n                                           kernel_size=3, padding=1))\n        self.down_blocks = nn.ModuleList(down_blocks)\n\n    def forward(self, x):\n        outs = [x]\n        for down_block in self.down_blocks:\n            outs.append(down_block(outs[-1]))\n        return outs\n\n\nclass Decoder(nn.Module):\n    \"\"\"\n    Hourglass Decoder\n    \"\"\"\n\n    def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):\n        super(Decoder, self).__init__()\n\n        up_blocks = []\n\n        for i in range(num_blocks)[::-1]:\n            in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))\n            out_filters = min(max_features, block_expansion * (2 ** i))\n            up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))\n\n        self.up_blocks = nn.ModuleList(up_blocks)\n        self.out_filters = block_expansion + in_features\n\n    def forward(self, x):\n        out = x.pop()\n        for up_block in self.up_blocks:\n            out = up_block(out)\n            skip = x.pop()\n            out = torch.cat([out, skip], dim=1)\n        return out\n\n\nclass Hourglass(nn.Module):\n    \"\"\"\n    Hourglass architecture.\n    \"\"\"\n\n    def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):\n        super(Hourglass, self).__init__()\n        self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)\n        self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)\n        self.out_filters = self.decoder.out_filters\n\n    def forward(self, x):\n        return self.decoder(self.encoder(x))\n\n\nclass AntiAliasInterpolation2d(nn.Module):\n    \"\"\"\n    Band-limited downsampling, for better preservation of the input signal.\n    \"\"\"\n\n    def __init__(self, channels, scale):\n        super(AntiAliasInterpolation2d, self).__init__()\n        sigma = (1 / scale - 1) / 2\n        kernel_size = 2 * round(sigma * 4) + 1\n        self.ka = kernel_size // 2\n        self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka\n\n        kernel_size = [kernel_size, kernel_size]\n        sigma = [sigma, sigma]\n        # The gaussian kernel is the product of the\n        # gaussian function of each dimension.\n        kernel = 1\n        meshgrids = torch.meshgrid(\n            [\n                torch.arange(size, dtype=torch.float32)\n                for size in kernel_size\n            ]\n        )\n        for size, std, mgrid in zip(kernel_size, sigma, meshgrids):\n            mean = (size - 1) / 2\n            kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))\n\n        # Make sure sum of values in gaussian kernel equals 1.\n        kernel = kernel / torch.sum(kernel)\n        # Reshape to depthwise convolutional weight\n        kernel = kernel.view(1, 1, *kernel.size())\n        kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))\n\n        self.register_buffer('weight', kernel)\n        self.groups = channels\n        self.scale = scale\n        inv_scale = 1 / scale\n        self.int_inv_scale = int(inv_scale)\n\n    def forward(self, input):\n        if self.scale == 1.0:\n            return input\n\n        out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))\n        out = F.conv2d(out, weight=self.weight, groups=self.groups)\n        out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]\n\n        return out\n\n\ndef to_homogeneous(coordinates):\n    ones_shape = list(coordinates.shape)\n    ones_shape[-1] = 1\n    ones = torch.ones(ones_shape).type(coordinates.type())\n\n    return torch.cat([coordinates, ones], dim=-1)\n\n\ndef from_homogeneous(coordinates):\n    return coordinates[..., :2] / coordinates[..., 2:3]\n\n\ndef draw_colored_heatmap(heatmap, colormap, bg_color):\n    parts = []\n    weights = []\n    bg_color = np.array(bg_color).reshape((1, 1, 1, 3))\n    num_regions = heatmap.shape[-1]\n    for i in range(num_regions):\n        color = np.array(colormap(i / num_regions))[:3]\n        color = color.reshape((1, 1, 1, 3))\n        part = heatmap[:, :, :, i:(i + 1)]\n        part = part / np.max(part, axis=(1, 2), keepdims=True)\n        weights.append(part)\n\n        color_part = part * color\n        parts.append(color_part)\n\n    weight = sum(weights)\n    bg_weight = 1 - np.minimum(1, weight)\n    weight = np.maximum(1, weight)\n    result = sum(parts) / weight + bg_weight * bg_color\n    return result\n\n\nclass Visualizer:\n    def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow', region_bg_color=(0, 0, 0)):\n        self.kp_size = kp_size\n        self.draw_border = draw_border\n        self.colormap = plt.get_cmap(colormap)\n        self.region_bg_color = np.array(region_bg_color)\n\n    def draw_image_with_kp(self, image, kp_array):\n        image = np.copy(image)\n        spatial_size = np.array(image.shape[:2][::-1])[np.newaxis]\n        kp_array = spatial_size * (kp_array + 1) / 2\n        num_regions = kp_array.shape[0]\n        for kp_ind, kp in enumerate(kp_array):\n            rr, cc = circle((kp[1], kp[0]), self.kp_size, shape=image.shape[:2])\n            image[rr, cc] = np.array(self.colormap(kp_ind / num_regions))[:3]\n        return image\n\n    def create_image_column_with_kp(self, images, kp):\n        image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)])\n        return self.create_image_column(image_array)\n\n    def create_image_column(self, images):\n        if self.draw_border:\n            images = np.copy(images)\n            images[:, :, [0, -1]] = (1, 1, 1)\n            images[:, :, [0, -1]] = (1, 1, 1)\n        return np.concatenate(list(images), axis=0)\n\n    def create_image_grid(self, *args):\n        out = []\n        for arg in args:\n            if type(arg) == tuple:\n                out.append(self.create_image_column_with_kp(arg[0], arg[1]))\n            else:\n                out.append(self.create_image_column(arg))\n        return np.concatenate(out, axis=1)\n\n    @staticmethod\n    def sample(x, index):\n        return x[index].unsqueeze(dim=0).clone().detach()\n\n    def visualize(self, driving, source, out, index=0):\n        images = []\n\n        # Source image with region centers\n        source = self.sample(source, index)\n        source = source.data.cpu()\n        source_region_params = self.sample(out['source_region_params']['shift'], index)\n        source_region_params = source_region_params.data.cpu().numpy()\n        source = np.transpose(source, [0, 2, 3, 1])\n        images.append((source, source_region_params))\n\n        if 'heatmap' in out['source_region_params']:\n            source_heatmap = self.sample(out['source_region_params']['heatmap'], index)\n            source_heatmap = F.interpolate(source_heatmap, size=source.shape[1:3])\n            source_heatmap = np.transpose(source_heatmap.data.cpu().numpy(), [0, 2, 3, 1])\n            images.append(draw_colored_heatmap(source_heatmap, self.colormap, self.region_bg_color))\n\n        # Deformed image\n        if 'deformed' in out:\n            deformed = self.sample(out['deformed'], index)\n            deformed = deformed.data.cpu().numpy()\n            deformed = np.transpose(deformed, [0, 2, 3, 1])\n            images.append(deformed)\n\n        # Equivariance visualization\n        if 'transformed_frame' in out:\n            transformed = self.sample(out['transformed_frame'], index)\n            transformed = transformed.data.cpu().numpy()\n            transformed = np.transpose(transformed, [0, 2, 3, 1])\n            transformed_kp = self.sample(out['transformed_region_params']['shift'], index)\n            transformed_kp = transformed_kp.data.cpu().numpy()\n            images.append((transformed, transformed_kp))\n\n        # Driving image with region centers\n        driving_region_params = self.sample(out['driving_region_params']['shift'], index)\n        driving_region_params = driving_region_params.data.cpu().numpy()\n        driving = self.sample(driving, index)\n        driving = driving.data.cpu().numpy()\n        driving = np.transpose(driving, [0, 2, 3, 1])\n        images.append((driving, driving_region_params))\n\n        # Heatmaps visualizations\n        if 'heatmap' in out['driving_region_params']:\n            driving_heatmap = self.sample(out['driving_region_params']['heatmap'], index)\n            driving_heatmap = F.interpolate(driving_heatmap, size=source.shape[1:3])\n            driving_heatmap = np.transpose(driving_heatmap.data.cpu().numpy(), [0, 2, 3, 1])\n            images.append(draw_colored_heatmap(driving_heatmap, self.colormap, self.region_bg_color))\n\n        # Result\n        prediction = self.sample(out['prediction'], index)\n        prediction = prediction.data.cpu().numpy()\n        prediction = np.transpose(prediction, [0, 2, 3, 1])\n        images.append(prediction)\n\n        # Occlusion map\n        if 'occlusion_map' in out:\n            occlusion_map = self.sample(out['occlusion_map'], index)\n            occlusion_map = occlusion_map.data.cpu().repeat(1, 3, 1, 1)\n            occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy()\n            occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1])\n            images.append(occlusion_map)\n\n        image = self.create_image_grid(*images)\n        # reshape (1, 8) to (2, 4)\n        H, W, _ = image.shape\n        base = H\n        num_image = W // H\n        row, col = 2, math.ceil(num_image/2)\n        new_image = np.zeros((row*base, col*base, 3), dtype=np.float32)\n        cnt = 0\n        for ii in range(row):\n            for jj in range(col):\n                try:\n                    new_image[ii*base:(ii+1)*base, jj*base:(jj+1)*base, :] = image[:, cnt*base:(cnt+1)*base, :]\n                except:\n                    pass\n                cnt += 1\n        new_image = (255 * new_image).astype(np.uint8)\n        return new_image\n\n"
  },
  {
    "path": "LFG/run_hdtf.py",
    "content": "# Estimate flow and occlusion mask via RegionMM (or called MRAA) for MUG dataset\n# this code is based on RegionMM from Snap Inc.\n# https://github.com/snap-research/articulated-animation\n\nimport os\nimport sys\nsys.path.append(\"your/path/DAWN-pytorch\")  # change this path to your current work directory\nimport math\nimport yaml\nfrom argparse import ArgumentParser\nfrom shutil import copy\nfrom datetime import datetime\n\nfrom LFG.hdtf_dataset import FramesDataset\n\nfrom LFG.modules.generator import Generator\nfrom LFG.modules.bg_motion_predictor import BGMotionPredictor\nfrom LFG.modules.region_predictor import RegionPredictor\nfrom LFG.modules.avd_network import AVDNetwork\n\nimport torch\nimport torch.backends.cudnn as cudnn\nimport numpy as np\nimport random\n\nfrom LFG.train import train\n\n\nclass Logger(object):\n    def __init__(self, filename='default.log', stream=sys.stdout):\n        self.terminal = stream\n        self.log = open(filename, 'w')\n\n    def write(self, message):\n        self.terminal.write(message)\n        self.log.write(message)\n\n    def flush(self):\n        pass\n\n\ndef setup_seed(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    np.random.seed(seed)\n    random.seed(seed)\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == \"__main__\":\n    # os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n    cudnn.enabled = True\n    cudnn.benchmark = True\n\n    if sys.version_info[0] < 3:\n        raise Exception(\"You must use Python 3 or higher. Recommended version is Python 3.7\")\n\n    parser = ArgumentParser()\n    parser.add_argument(\"--postfix\", default=\"\")  # indicate different settings\n    parser.add_argument(\"--random-seed\", default=1234)\n    parser.add_argument(\"--set-start\", default=False)\n    parser.add_argument(\"--config\",\n                        default=\"your/path/DAWN-pytorch/config/hdtf128_llm.yaml\",\n                        help=\"path to config\")\n    parser.add_argument(\"--mode\", default=\"train\", choices=[\"train\"])\n    parser.add_argument(\"--log_dir\",\n                        default='your/path/DAWN-pytorch/AE/data/log-hdtf',\n                        help=\"path to log into\")\n    parser.add_argument(\"--checkpoint\",  # use the pretrained VOX model given by Snap\n                        default=\"/train20/intern/permanent/lmlin2/Flow/CVPR23_LFDM-main/data/ckp/vox256.pth\",\n                        help=\"path to checkpoint to restore\")\n    parser.add_argument(\"--device_ids\", default=\"0\", type=lambda x: list(map(int, x.split(','))),\n                        help=\"Names of the devices comma separated.\")\n    parser.add_argument(\"--verbose\", dest=\"verbose\", default=False, help=\"Print model architecture\")\n    parser.set_defaults(verbose=False)\n\n    opt = parser.parse_args()\n\n    setup_seed(opt.random_seed)\n\n    with open(opt.config) as f:\n        config = yaml.safe_load(f)\n\n    current_time = datetime.now()\n    current_time = current_time.strftime(\"%Y-%m-%d_%H:%M\")\n    log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0]+opt.postfix+'_'+current_time)\n    if not os.path.exists(log_dir):\n        os.makedirs(log_dir)\n    if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):\n        copy(opt.config, log_dir)\n\n    # the directory to save checkpoints\n    config[\"snapshots\"] = os.path.join(log_dir, 'snapshots'+opt.postfix)\n    os.makedirs(config[\"snapshots\"], exist_ok=True)\n    # the directory to save images of training results\n    config[\"imgshots\"] = os.path.join(log_dir, 'imgshots'+opt.postfix)\n    os.makedirs(config[\"imgshots\"], exist_ok=True)\n    config[\"set_start\"] = opt.set_start\n    log_txt = os.path.join(log_dir,\n                           \"B\"+format(config['train_params']['batch_size'], \"04d\")+\n                           \"E\"+format(config['train_params']['max_epochs'], \"04d\")+\".log\")\n    sys.stdout = Logger(log_txt, sys.stdout)\n\n    print(\"postfix:\", opt.postfix)\n    print(\"checkpoint:\", opt.checkpoint)\n    print(\"batch size:\", config['train_params']['batch_size'])\n\n    generator = Generator(num_regions=config['model_params']['num_regions'],\n                          num_channels=config['model_params']['num_channels'],\n                          revert_axis_swap=config['model_params']['revert_axis_swap'],\n                          **config['model_params']['generator_params'])\n\n    if torch.cuda.is_available():\n        generator.to(opt.device_ids[0])\n    if opt.verbose:\n        print(generator)\n\n    region_predictor = RegionPredictor(num_regions=config['model_params']['num_regions'],\n                                       num_channels=config['model_params']['num_channels'],\n                                       estimate_affine=config['model_params']['estimate_affine'],\n                                       **config['model_params']['region_predictor_params'])\n\n    if torch.cuda.is_available():\n        region_predictor.to(opt.device_ids[0])\n\n    if opt.verbose:\n        print(region_predictor)\n\n    bg_predictor = BGMotionPredictor(num_channels=config['model_params']['num_channels'],\n                                     **config['model_params']['bg_predictor_params'])\n    if torch.cuda.is_available():\n        bg_predictor.to(opt.device_ids[0])\n    if opt.verbose:\n        print(bg_predictor)\n\n    avd_network = AVDNetwork(num_regions=config['model_params']['num_regions'],\n                             **config['model_params']['avd_network_params'])\n    if torch.cuda.is_available():\n        avd_network.to(opt.device_ids[0])\n    if opt.verbose:\n        print(avd_network)\n\n    dataset = FramesDataset(**config['dataset_params'])\n    config[\"num_example_per_epoch\"] = config['train_params']['num_repeats'] * len(dataset)\n    config[\"num_step_per_epoch\"] = math.ceil(config[\"num_example_per_epoch\"]/float(config['train_params']['batch_size']))\n    # save 10 checkpoints in total\n    config[\"save_ckpt_freq\"] = config[\"num_step_per_epoch\"] * (config['train_params']['max_epochs'] // 10)\n    print(\"save ckpt freq:\", config[\"save_ckpt_freq\"])\n\n    print(\"Training...\")\n    train(config, generator, region_predictor, bg_predictor, opt.checkpoint, log_dir, dataset, opt.device_ids)\n\n"
  },
  {
    "path": "LFG/run_hdtf_crema.py",
    "content": "# Estimate flow and occlusion mask via RegionMM (or called MRAA) for MUG dataset\n# this code is based on RegionMM from Snap Inc.\n# https://github.com/snap-research/articulated-animation\n\nimport os\nimport sys\nsys.path.append(\"your/path/DAWN-pytorch\")  # change this path to your current work directory\nimport math\nimport yaml\nfrom argparse import ArgumentParser\nfrom shutil import copy\nfrom datetime import datetime\n\nfrom LFG.frames_dataset import FramesDataset\n\nfrom LFG.modules.generator import Generator\nfrom LFG.modules.bg_motion_predictor import BGMotionPredictor\nfrom LFG.modules.region_predictor import RegionPredictor\nfrom LFG.modules.avd_network import AVDNetwork\n\nimport torch\nimport torch.backends.cudnn as cudnn\nimport numpy as np\nimport random\n\nfrom LFG.train import train\n\n\nclass Logger(object):\n    def __init__(self, filename='default.log', stream=sys.stdout):\n        self.terminal = stream\n        self.log = open(filename, 'w')\n\n    def write(self, message):\n        self.terminal.write(message)\n        self.log.write(message)\n\n    def flush(self):\n        pass\n\n\ndef setup_seed(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    np.random.seed(seed)\n    random.seed(seed)\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == \"__main__\":\n    # os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n    cudnn.enabled = True\n    cudnn.benchmark = True\n\n    if sys.version_info[0] < 3:\n        raise Exception(\"You must use Python 3 or higher. Recommended version is Python 3.7\")\n\n    parser = ArgumentParser()\n    parser.add_argument(\"--postfix\", default=\"\")  # indicate different settings\n    parser.add_argument(\"--random-seed\", default=1234)\n    parser.add_argument(\"--set-start\", default=False)\n    parser.add_argument(\"--config\",\n                        default=\"your/path/DAWN-pytorch/config/hdtf128_llm.yaml\",\n                        help=\"path to config\")\n    parser.add_argument(\"--mode\", default=\"train\", choices=[\"train\"])\n    parser.add_argument(\"--log_dir\",\n                        default='your/path/DAWN-pytorch/AE/data/log-hdtf',\n                        help=\"path to log into\")\n    parser.add_argument(\"--checkpoint\",  # use the pretrained VOX model given by Snap\n                        default=\"/train20/intern/permanent/lmlin2/Flow/CVPR23_LFDM-main/data/ckp/vox256.pth\",\n                        help=\"path to checkpoint to restore\")\n    parser.add_argument(\"--device_ids\", default=\"0\", type=lambda x: list(map(int, x.split(','))),\n                        help=\"Names of the devices comma separated.\")\n    parser.add_argument(\"--verbose\", dest=\"verbose\", default=False, help=\"Print model architecture\")\n    parser.set_defaults(verbose=False)\n\n    opt = parser.parse_args()\n\n    setup_seed(opt.random_seed)\n\n    with open(opt.config) as f:\n        config = yaml.safe_load(f)\n\n    current_time = datetime.now()\n    current_time = current_time.strftime(\"%Y-%m-%d_%H:%M\")\n    log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0]+opt.postfix+'_'+current_time)\n    if not os.path.exists(log_dir):\n        os.makedirs(log_dir)\n    if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):\n        copy(opt.config, log_dir)\n\n    # the directory to save checkpoints\n    config[\"snapshots\"] = os.path.join(log_dir, 'snapshots'+opt.postfix)\n    os.makedirs(config[\"snapshots\"], exist_ok=True)\n    # the directory to save images of training results\n    config[\"imgshots\"] = os.path.join(log_dir, 'imgshots'+opt.postfix)\n    os.makedirs(config[\"imgshots\"], exist_ok=True)\n    config[\"set_start\"] = opt.set_start\n    log_txt = os.path.join(log_dir,\n                           \"B\"+format(config['train_params']['batch_size'], \"04d\")+\n                           \"E\"+format(config['train_params']['max_epochs'], \"04d\")+\".log\")\n    sys.stdout = Logger(log_txt, sys.stdout)\n\n    print(\"postfix:\", opt.postfix)\n    print(\"checkpoint:\", opt.checkpoint)\n    print(\"batch size:\", config['train_params']['batch_size'])\n\n    generator = Generator(num_regions=config['model_params']['num_regions'],\n                          num_channels=config['model_params']['num_channels'],\n                          revert_axis_swap=config['model_params']['revert_axis_swap'],\n                          **config['model_params']['generator_params'])\n\n    if torch.cuda.is_available():\n        generator.to(opt.device_ids[0])\n    if opt.verbose:\n        print(generator)\n\n    region_predictor = RegionPredictor(num_regions=config['model_params']['num_regions'],\n                                       num_channels=config['model_params']['num_channels'],\n                                       estimate_affine=config['model_params']['estimate_affine'],\n                                       **config['model_params']['region_predictor_params'])\n\n    if torch.cuda.is_available():\n        region_predictor.to(opt.device_ids[0])\n\n    if opt.verbose:\n        print(region_predictor)\n\n    bg_predictor = BGMotionPredictor(num_channels=config['model_params']['num_channels'],\n                                     **config['model_params']['bg_predictor_params'])\n    if torch.cuda.is_available():\n        bg_predictor.to(opt.device_ids[0])\n    if opt.verbose:\n        print(bg_predictor)\n\n    avd_network = AVDNetwork(num_regions=config['model_params']['num_regions'],\n                             **config['model_params']['avd_network_params'])\n    if torch.cuda.is_available():\n        avd_network.to(opt.device_ids[0])\n    if opt.verbose:\n        print(avd_network)\n\n    dataset = FramesDataset(**config['dataset_params'])\n    config[\"num_example_per_epoch\"] = config['train_params']['num_repeats'] * len(dataset)\n    config[\"num_step_per_epoch\"] = math.ceil(config[\"num_example_per_epoch\"]/float(config['train_params']['batch_size']))\n    # save 10 checkpoints in total\n    config[\"save_ckpt_freq\"] = config[\"num_step_per_epoch\"] * (config['train_params']['max_epochs'] // 10)\n    print(\"save ckpt freq:\", config[\"save_ckpt_freq\"])\n\n    print(\"Training...\")\n    train(config, generator, region_predictor, bg_predictor, opt.checkpoint, log_dir, dataset, opt.device_ids)\n\n"
  },
  {
    "path": "LFG/sync_batchnorm/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : __init__.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nfrom .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d\nfrom .replicate import DataParallelWithCallback, patch_replication_callback\n"
  },
  {
    "path": "LFG/sync_batchnorm/batchnorm.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : batchnorm.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport collections\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch.nn.modules.batchnorm import _BatchNorm\nfrom torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast\n\nfrom .comm import SyncMaster\n\n__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']\n\n\ndef _sum_ft(tensor):\n    \"\"\"sum over the first and last dimention\"\"\"\n    return tensor.sum(dim=0).sum(dim=-1)\n\n\ndef _unsqueeze_ft(tensor):\n    \"\"\"add new dementions at the front and the tail\"\"\"\n    return tensor.unsqueeze(0).unsqueeze(-1)\n\n\n_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])\n_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])\n\n\nclass _SynchronizedBatchNorm(_BatchNorm):\n    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):\n        super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)\n\n        self._sync_master = SyncMaster(self._data_parallel_master)\n\n        self._is_parallel = False\n        self._parallel_id = None\n        self._slave_pipe = None\n\n    def forward(self, input):\n        # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.\n        if not (self._is_parallel and self.training):\n            return F.batch_norm(\n                input, self.running_mean, self.running_var, self.weight, self.bias,\n                self.training, self.momentum, self.eps)\n\n        # Resize the input to (B, C, -1).\n        input_shape = input.size()\n        input = input.view(input.size(0), self.num_features, -1)\n\n        # Compute the sum and square-sum.\n        sum_size = input.size(0) * input.size(2)\n        input_sum = _sum_ft(input)\n        input_ssum = _sum_ft(input ** 2)\n\n        # Reduce-and-broadcast the statistics.\n        if self._parallel_id == 0:\n            mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))\n        else:\n            mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))\n\n        # Compute the output.\n        if self.affine:\n            # MJY:: Fuse the multiplication for speed.\n            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)\n        else:\n            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)\n\n        # Reshape it.\n        return output.view(input_shape)\n\n    def __data_parallel_replicate__(self, ctx, copy_id):\n        self._is_parallel = True\n        self._parallel_id = copy_id\n\n        # parallel_id == 0 means master device.\n        if self._parallel_id == 0:\n            ctx.sync_master = self._sync_master\n        else:\n            self._slave_pipe = ctx.sync_master.register_slave(copy_id)\n\n    def _data_parallel_master(self, intermediates):\n        \"\"\"Reduce the sum and square-sum, compute the statistics, and broadcast it.\"\"\"\n\n        # Always using same \"device order\" makes the ReduceAdd operation faster.\n        # Thanks to:: Tete Xiao (http://tetexiao.com/)\n        intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())\n\n        to_reduce = [i[1][:2] for i in intermediates]\n        to_reduce = [j for i in to_reduce for j in i]  # flatten\n        target_gpus = [i[1].sum.get_device() for i in intermediates]\n\n        sum_size = sum([i[1].sum_size for i in intermediates])\n        sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)\n        mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)\n\n        broadcasted = Broadcast.apply(target_gpus, mean, inv_std)\n\n        outputs = []\n        for i, rec in enumerate(intermediates):\n            outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))\n\n        return outputs\n\n    def _compute_mean_std(self, sum_, ssum, size):\n        \"\"\"Compute the mean and standard-deviation with sum and square-sum. This method\n        also maintains the moving average on the master device.\"\"\"\n        assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'\n        mean = sum_ / size\n        sumvar = ssum - sum_ * mean\n        unbias_var = sumvar / (size - 1)\n        bias_var = sumvar / size\n\n        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data\n        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data\n\n        return mean, bias_var.clamp(self.eps) ** -0.5\n\n\nclass SynchronizedBatchNorm1d(_SynchronizedBatchNorm):\n    r\"\"\"Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a\n    mini-batch.\n\n    .. math::\n\n        y = \\frac{x - mean[x]}{ \\sqrt{Var[x] + \\epsilon}} * gamma + beta\n\n    This module differs from the built-in PyTorch BatchNorm1d as the mean and\n    standard-deviation are reduced across all devices during training.\n\n    For example, when one uses `nn.DataParallel` to wrap the network during\n    training, PyTorch's implementation normalize the tensor on each device using\n    the statistics only on that device, which accelerated the computation and\n    is also easy to implement, but the statistics might be inaccurate.\n    Instead, in this synchronized version, the statistics will be computed\n    over all training samples distributed on multiple devices.\n    \n    Note that, for one-GPU or CPU-only case, this module behaves exactly same\n    as the built-in PyTorch implementation.\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and gamma and beta are learnable parameter vectors\n    of size C (where C is the input size).\n\n    During training, this layer keeps a running estimate of its computed mean\n    and variance. The running sum is kept with a default momentum of 0.1.\n\n    During evaluation, this running mean/variance is used for normalization.\n\n    Because the BatchNorm is done over the `C` dimension, computing statistics\n    on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm\n\n    Args:\n        num_features: num_features from an expected input of size\n            `batch_size x num_features [x width]`\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Default: 0.1\n        affine: a boolean value that when set to ``True``, gives the layer learnable\n            affine parameters. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C)` or :math:`(N, C, L)`\n        - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)\n\n    Examples:\n        >>> # With Learnable Parameters\n        >>> m = SynchronizedBatchNorm1d(100)\n        >>> # Without Learnable Parameters\n        >>> m = SynchronizedBatchNorm1d(100, affine=False)\n        >>> input = torch.autograd.Variable(torch.randn(20, 100))\n        >>> output = m(input)\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() != 2 and input.dim() != 3:\n            raise ValueError('expected 2D or 3D input (got {}D input)'\n                             .format(input.dim()))\n        super(SynchronizedBatchNorm1d, self)._check_input_dim(input)\n\n\nclass SynchronizedBatchNorm2d(_SynchronizedBatchNorm):\n    r\"\"\"Applies Batch Normalization over a 4d input that is seen as a mini-batch\n    of 3d inputs\n\n    .. math::\n\n        y = \\frac{x - mean[x]}{ \\sqrt{Var[x] + \\epsilon}} * gamma + beta\n\n    This module differs from the built-in PyTorch BatchNorm2d as the mean and\n    standard-deviation are reduced across all devices during training.\n\n    For example, when one uses `nn.DataParallel` to wrap the network during\n    training, PyTorch's implementation normalize the tensor on each device using\n    the statistics only on that device, which accelerated the computation and\n    is also easy to implement, but the statistics might be inaccurate.\n    Instead, in this synchronized version, the statistics will be computed\n    over all training samples distributed on multiple devices.\n    \n    Note that, for one-GPU or CPU-only case, this module behaves exactly same\n    as the built-in PyTorch implementation.\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and gamma and beta are learnable parameter vectors\n    of size C (where C is the input size).\n\n    During training, this layer keeps a running estimate of its computed mean\n    and variance. The running sum is kept with a default momentum of 0.1.\n\n    During evaluation, this running mean/variance is used for normalization.\n\n    Because the BatchNorm is done over the `C` dimension, computing statistics\n    on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm\n\n    Args:\n        num_features: num_features from an expected input of\n            size batch_size x num_features x height x width\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Default: 0.1\n        affine: a boolean value that when set to ``True``, gives the layer learnable\n            affine parameters. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C, H, W)`\n        - Output: :math:`(N, C, H, W)` (same shape as input)\n\n    Examples:\n        >>> # With Learnable Parameters\n        >>> m = SynchronizedBatchNorm2d(100)\n        >>> # Without Learnable Parameters\n        >>> m = SynchronizedBatchNorm2d(100, affine=False)\n        >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))\n        >>> output = m(input)\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() != 4:\n            raise ValueError('expected 4D input (got {}D input)'\n                             .format(input.dim()))\n        super(SynchronizedBatchNorm2d, self)._check_input_dim(input)\n\n\nclass SynchronizedBatchNorm3d(_SynchronizedBatchNorm):\n    r\"\"\"Applies Batch Normalization over a 5d input that is seen as a mini-batch\n    of 4d inputs\n\n    .. math::\n\n        y = \\frac{x - mean[x]}{ \\sqrt{Var[x] + \\epsilon}} * gamma + beta\n\n    This module differs from the built-in PyTorch BatchNorm3d as the mean and\n    standard-deviation are reduced across all devices during training.\n\n    For example, when one uses `nn.DataParallel` to wrap the network during\n    training, PyTorch's implementation normalize the tensor on each device using\n    the statistics only on that device, which accelerated the computation and\n    is also easy to implement, but the statistics might be inaccurate.\n    Instead, in this synchronized version, the statistics will be computed\n    over all training samples distributed on multiple devices.\n    \n    Note that, for one-GPU or CPU-only case, this module behaves exactly same\n    as the built-in PyTorch implementation.\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and gamma and beta are learnable parameter vectors\n    of size C (where C is the input size).\n\n    During training, this layer keeps a running estimate of its computed mean\n    and variance. The running sum is kept with a default momentum of 0.1.\n\n    During evaluation, this running mean/variance is used for normalization.\n\n    Because the BatchNorm is done over the `C` dimension, computing statistics\n    on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm\n    or Spatio-temporal BatchNorm\n\n    Args:\n        num_features: num_features from an expected input of\n            size batch_size x num_features x depth x height x width\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Default: 0.1\n        affine: a boolean value that when set to ``True``, gives the layer learnable\n            affine parameters. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C, D, H, W)`\n        - Output: :math:`(N, C, D, H, W)` (same shape as input)\n\n    Examples:\n        >>> # With Learnable Parameters\n        >>> m = SynchronizedBatchNorm3d(100)\n        >>> # Without Learnable Parameters\n        >>> m = SynchronizedBatchNorm3d(100, affine=False)\n        >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))\n        >>> output = m(input)\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() != 5:\n            raise ValueError('expected 5D input (got {}D input)'\n                             .format(input.dim()))\n        super(SynchronizedBatchNorm3d, self)._check_input_dim(input)\n"
  },
  {
    "path": "LFG/sync_batchnorm/comm.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : comm.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport queue\nimport collections\nimport threading\n\n__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']\n\n\nclass FutureResult(object):\n    \"\"\"A thread-safe future implementation. Used only as one-to-one pipe.\"\"\"\n\n    def __init__(self):\n        self._result = None\n        self._lock = threading.Lock()\n        self._cond = threading.Condition(self._lock)\n\n    def put(self, result):\n        with self._lock:\n            assert self._result is None, 'Previous result has\\'t been fetched.'\n            self._result = result\n            self._cond.notify()\n\n    def get(self):\n        with self._lock:\n            if self._result is None:\n                self._cond.wait()\n\n            res = self._result\n            self._result = None\n            return res\n\n\n_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])\n_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])\n\n\nclass SlavePipe(_SlavePipeBase):\n    \"\"\"Pipe for master-slave communication.\"\"\"\n\n    def run_slave(self, msg):\n        self.queue.put((self.identifier, msg))\n        ret = self.result.get()\n        self.queue.put(True)\n        return ret\n\n\nclass SyncMaster(object):\n    \"\"\"An abstract `SyncMaster` object.\n\n    - During the replication, as the data parallel will trigger an callback of each module, all slave devices should\n    call `register(id)` and obtain an `SlavePipe` to communicate with the master.\n    - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,\n    and passed to a registered callback.\n    - After receiving the messages, the master device should gather the information and determine to message passed\n    back to each slave devices.\n    \"\"\"\n\n    def __init__(self, master_callback):\n        \"\"\"\n\n        Args:\n            master_callback: a callback to be invoked after having collected messages from slave devices.\n        \"\"\"\n        self._master_callback = master_callback\n        self._queue = queue.Queue()\n        self._registry = collections.OrderedDict()\n        self._activated = False\n\n    def __getstate__(self):\n        return {'master_callback': self._master_callback}\n\n    def __setstate__(self, state):\n        self.__init__(state['master_callback'])\n\n    def register_slave(self, identifier):\n        \"\"\"\n        Register an slave device.\n\n        Args:\n            identifier: an identifier, usually is the device id.\n\n        Returns: a `SlavePipe` object which can be used to communicate with the master device.\n\n        \"\"\"\n        if self._activated:\n            assert self._queue.empty(), 'Queue is not clean before next initialization.'\n            self._activated = False\n            self._registry.clear()\n        future = FutureResult()\n        self._registry[identifier] = _MasterRegistry(future)\n        return SlavePipe(identifier, self._queue, future)\n\n    def run_master(self, master_msg):\n        \"\"\"\n        Main entry for the master device in each forward pass.\n        The messages were first collected from each devices (including the master device), and then\n        an callback will be invoked to compute the message to be sent back to each devices\n        (including the master device).\n\n        Args:\n            master_msg: the message that the master want to send to itself. This will be placed as the first\n            message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.\n\n        Returns: the message to be sent back to the master device.\n\n        \"\"\"\n        self._activated = True\n\n        intermediates = [(0, master_msg)]\n        for i in range(self.nr_slaves):\n            intermediates.append(self._queue.get())\n\n        results = self._master_callback(intermediates)\n        assert results[0][0] == 0, 'The first result should belongs to the master.'\n\n        for i, res in results:\n            if i == 0:\n                continue\n            self._registry[i].result.put(res)\n\n        for i in range(self.nr_slaves):\n            assert self._queue.get() is True\n\n        return results[0][1]\n\n    @property\n    def nr_slaves(self):\n        return len(self._registry)\n"
  },
  {
    "path": "LFG/sync_batchnorm/replicate.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : replicate.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport functools\n\nfrom torch.nn.parallel.data_parallel import DataParallel\n\n__all__ = [\n    'CallbackContext',\n    'execute_replication_callbacks',\n    'DataParallelWithCallback',\n    'patch_replication_callback'\n]\n\n\nclass CallbackContext(object):\n    pass\n\n\ndef execute_replication_callbacks(modules):\n    \"\"\"\n    Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.\n\n    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`\n\n    Note that, as all modules are isomorphism, we assign each sub-module with a context\n    (shared among multiple copies of this module on different devices).\n    Through this context, different copies can share some information.\n\n    We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback\n    of any slave copies.\n    \"\"\"\n    master_copy = modules[0]\n    nr_modules = len(list(master_copy.modules()))\n    ctxs = [CallbackContext() for _ in range(nr_modules)]\n\n    for i, module in enumerate(modules):\n        for j, m in enumerate(module.modules()):\n            if hasattr(m, '__data_parallel_replicate__'):\n                m.__data_parallel_replicate__(ctxs[j], i)\n\n\nclass DataParallelWithCallback(DataParallel):\n    \"\"\"\n    Data Parallel with a replication callback.\n\n    An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by\n    original `replicate` function.\n    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`\n\n    Examples:\n        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])\n        # sync_bn.__data_parallel_replicate__ will be invoked.\n    \"\"\"\n\n    def replicate(self, module, device_ids):\n        modules = super(DataParallelWithCallback, self).replicate(module, device_ids)\n        execute_replication_callbacks(modules)\n        return modules\n\n\ndef patch_replication_callback(data_parallel):\n    \"\"\"\n    Monkey-patch an existing `DataParallel` object. Add the replication callback.\n    Useful when you have customized `DataParallel` implementation.\n\n    Examples:\n        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n        > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])\n        > patch_replication_callback(sync_bn)\n        # this is equivalent to\n        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])\n    \"\"\"\n\n    assert isinstance(data_parallel, DataParallel)\n\n    old_replicate = data_parallel.replicate\n\n    @functools.wraps(old_replicate)\n    def new_replicate(module, device_ids):\n        modules = old_replicate(module, device_ids)\n        execute_replication_callbacks(modules)\n        return modules\n\n    data_parallel.replicate = new_replicate\n"
  },
  {
    "path": "LFG/sync_batchnorm/unittest.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : unittest.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport unittest\n\nimport numpy as np\nfrom torch.autograd import Variable\n\n\ndef as_numpy(v):\n    if isinstance(v, Variable):\n        v = v.data\n    return v.cpu().numpy()\n\n\nclass TorchTestCase(unittest.TestCase):\n    def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):\n        npa, npb = as_numpy(a), as_numpy(b)\n        self.assertTrue(\n                np.allclose(npa, npb, atol=atol),\n                'Tensor close check failed\\n{}\\n{}\\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())\n        )\n"
  },
  {
    "path": "LFG/test_flowautoenc_crema_video.py",
    "content": "# use LFG to reconstruct testing videos and measure the loss in video domain\n# using RegionMM\n\nimport argparse\nimport imageio\nimport torch\nfrom torch.utils import data\nimport numpy as np\nimport torch.backends.cudnn as cudnn\nimport os\nimport timeit\nfrom PIL import Image\nimport sys\nsys.path.append(\"your/path/DAWN-pytorch\")\nfrom misc import grid2fig\nfrom DM_2.datasets_crema_wpose_lmk_block import HDTF\nimport random\nfrom LFG.modules.flow_autoenc import FlowAE\nimport torch.nn.functional as F\nfrom LFG.modules.util import Visualizer\nimport json_tricks as json\nimport cv2\nimport tempfile\nfrom subprocess import call\nfrom pydub import AudioSegment\nfrom einops import rearrange\nfrom tqdm import tqdm\n\nstart = timeit.default_timer()\nBATCH_SIZE = 1\nINPUT_SIZE = 128\nroot_dir = 'your/path/DAWN-pytorch/AE'  # your work directory\ndata_dir = \"/train20/intern/permanent/hbcheng2/data/crema/images_25hz_128_chunk\"\npose_dir = '/train20/intern/permanent/hbcheng2/data/crema/pose_bar_chunk'\neye_blink_dir = '/train20/intern/permanent/hbcheng2/data/crema/eye_blink_bbox_bar_2_chunk'\n\nDATASAVE_DIR = '/train20/intern/permanent/hbcheng2/data'\nCKPT_DIR = os.path.join(DATASAVE_DIR, 'mraa_result_crema', str(INPUT_SIZE),'video')\nos.makedirs(CKPT_DIR, exist_ok=True)\nIMG_DIR = os.path.join(DATASAVE_DIR, 'mraa_result_crema', str(INPUT_SIZE),'img')\nos.makedirs(IMG_DIR, exist_ok=True)\n\n# GPU = \"6\"\npostfix = \"\"\n\nN_FRAMES = 40\nNUM_VIDEOS = 10\nSAVE_VIDEO = True\nNUM_ITER = NUM_VIDEOS // BATCH_SIZE\nRANDOM_SEED = 1234\nMEAN = (0.0, 0.0, 0.0)\n# the path to trained LFG model\nRESTORE_FROM ='your_path/data/log-hdtf/hdtf128_2024-02-11_15:45/snapshots/RegionMM_0100_S074360.pth'\n# RESTORE_FROM = \"/train20/intern/permanent/lmlin2/Flow/CVPR23_LFDM-main/data/log-hdtf/hdtf256_2023-11-21_16:49/snapshots/RegionMM_0020_S080000.pth\"\nconfig_pth = \"your/path/DAWN-pytorch/AE/data/log-hdtf/hdtf128_llm_2024-07-26_12:54/hdtf128_llm.yaml\"\n\njson_path = os.path.join(CKPT_DIR, \"loss%d%s.json\" % (NUM_VIDEOS, postfix))\nvisualizer = Visualizer()\nprint(root_dir)\nprint(postfix)\nprint(\"RESTORE_FROM:\", RESTORE_FROM)\nprint(\"config_path:\", config_pth)\nprint(json_path)\nprint(\"save video:\", SAVE_VIDEO)\n\n\ndef get_arguments():\n    \"\"\"Parse all the arguments provided from the CLI.\n\n    Returns:\n      A list of parsed arguments.\n    \"\"\"\n    parser = argparse.ArgumentParser(description=\"Flow Autoencoder\")\n    parser.add_argument(\"--num-workers\", default=8)\n    parser.add_argument(\"--gpu\", default=0,\n                        help=\"choose gpu device.\")\n    parser.add_argument('--print-freq', '-p', default=1, type=int,\n                        metavar='N', help='print frequency')\n    parser.add_argument(\"--batch-size\", type=int, default=BATCH_SIZE,\n                        help=\"Number of images sent to the network in one step.\")\n    parser.add_argument(\"--input-size\", type=str, default=INPUT_SIZE,\n                        help=\"Comma-separated string with height and width of images.\")\n    parser.add_argument(\"--random-seed\", type=int, default=RANDOM_SEED,\n                        help=\"Random seed to have reproducible results.\")\n    parser.add_argument(\"--restore-from\", default=RESTORE_FROM)\n    parser.add_argument(\"--fp16\", default=False)\n    return parser.parse_args()\n\n\nargs = get_arguments()\n\ndef extract_audio_by_frames(input_wav_path, start_frame_index, num_frames, frame_rate, output_wav_path):\n    # \n    audio = AudioSegment.from_wav(input_wav_path)\n\n    # \n    frame_duration = 1000 / frame_rate  # \n\n    # \n    start_time_ms = start_frame_index * frame_duration\n    end_time_ms = (start_frame_index + num_frames) * frame_duration\n\n    # \n    selected_audio = audio[start_time_ms:end_time_ms]\n\n    # \n    selected_audio.export(output_wav_path, format=\"wav\")\n\n\n\ndef sample_img(rec_img_batch):\n    rec_img = rec_img_batch.permute(1, 2, 0).data.cpu().numpy().copy()\n    rec_img += np.array(MEAN)/255.0\n    rec_img[rec_img < 0] = 0\n    rec_img[rec_img > 1] = 1\n    rec_img *= 255\n    return np.array(rec_img, np.uint8)\n\n\ndef main():\n    \"\"\"Create the model and start the training.\"\"\"\n\n    os.environ[\"CUDA_VISIBLE_DEVICES\"] = str(args.gpu)\n\n    cudnn.enabled = True\n    cudnn.benchmark = True\n    setup_seed(args.random_seed)\n\n    model = FlowAE(is_train=False, config_pth=config_pth)\n    model.cuda()\n\n    if os.path.isfile(args.restore_from):\n        print(\"=> loading checkpoint '{}'\".format(args.restore_from))\n        checkpoint = torch.load(args.restore_from)\n        model.generator.load_state_dict(checkpoint['generator'])\n        model.region_predictor.load_state_dict(checkpoint['region_predictor'])\n        model.bg_predictor.load_state_dict(checkpoint['bg_predictor'])\n        print(\"=> loaded checkpoint '{}'\".format(args.restore_from))\n    else:\n        print(\"=> no checkpoint found at '{}'\".format(args.restore_from))\n        exit(-1)\n\n    model.eval()\n\n    setup_seed(args.random_seed)\n\n    testloader = data.DataLoader(HDTF(data_dir=data_dir,\n                                       pose_dir=pose_dir,\n                                       eye_blink_dir = eye_blink_dir,\n                                       image_size=INPUT_SIZE,\n                                       mode='test',\n                                       max_num_frames=1e8,\n                                       color_jitter=True,\n                                       mean=MEAN),\n                                 batch_size=BATCH_SIZE,\n                                 shuffle=True, num_workers=8,\n                                 pin_memory=True)\n\n    batch_time = AverageMeter()\n    data_time = AverageMeter()\n\n    iter_end = timeit.default_timer()\n    cnt = 0\n\n    out_loss = 0.0\n    warp_loss = 0.0\n    num_sample = 0.0\n    l1_loss = torch.nn.L1Loss(reduction='sum')\n\n    global_iter = 0\n\n    while global_iter < NUM_ITER:\n        for i_iter, batch in enumerate(testloader):\n            # if i_iter < NUM_ITER:\n            #     break\n            # if global_iter < NUM_ITER:\n            #     break\n\n            data_time.update(timeit.default_timer() - iter_end)\n\n            real_vids, ref_hubert, real_poses, real_blink_bbox, mouth_lmk_tensor, real_names, _ = batch\n            # use first frame of each video as reference frame\n            real_vids = real_vids/255.\n            ref_imgs = real_vids[:, :, 0, :, :].clone().detach()\n            bs = real_vids.size(0)\n\n            batch_time.update(timeit.default_timer() - iter_end)\n\n            nf = real_vids.size(2)\n            out_img_list = []\n            warped_img_list = []\n            warped_grid_list = []\n            conf_map_list = []\n\n            segment_length = 80\n            b,c,f,h,w = real_vids.size()\n            real_vid_tmp = rearrange(real_vids, 'b c f h w -> (b f) c h w')# real_vid.reshape(b * f, c, h,  w) \n            ref_img_tmp = ref_imgs.repeat(segment_length,1,1,1).reshape(-1, 3, INPUT_SIZE, INPUT_SIZE)\n            for frame_idx in tqdm(range(0, nf, segment_length)):\n                \n                end_fn = min(nf, frame_idx + segment_length)\n                dri_imgs = real_vid_tmp[frame_idx : end_fn, :, :, :]\n                if end_fn == nf:\n                    ref_img_tmp = ref_imgs.repeat(dri_imgs.shape[0],1,1,1).reshape(-1, 3, INPUT_SIZE, INPUT_SIZE)\n                with torch.no_grad():\n                    model.set_train_input(ref_img=ref_img_tmp, dri_img=dri_imgs)\n                    model.forward()\n                out_img_list.append(model.generated['prediction'].clone().detach().cpu())\n                # warped_img_list.append(model.generated['deformed'].clone().detach())\n\n            out_img_list_tensor = torch.concat(out_img_list, dim = 0)\n\n            # out_loss += l1_loss(real_vids.permute(2, 0, 1, 3, 4).cpu(), out_img_list_tensor.cpu()).item()\n            # warp_loss += l1_loss(real_vids.permute(2, 0, 1, 3, 4).cpu(), warped_img_list_tensor.cpu()).item()\n            num_sample += bs\n            \n            \n            if SAVE_VIDEO:\n                for batch_idx in range(bs):\n                    msk_size = ref_imgs.shape[-1]\n                    new_im_list = []\n                    img_dir_name = \"%04d_%s\" % (i_iter, real_names[batch_idx])\n                    cur_img_dir_gt = os.path.join(IMG_DIR, img_dir_name,'gt')\n                    os.makedirs(cur_img_dir_gt, exist_ok=True)\n                    cur_img_dir_samp = os.path.join(IMG_DIR, img_dir_name,'mraa')\n                    os.makedirs(cur_img_dir_samp, exist_ok=True)\n                    \n                    fps = 25\n\n                    tmp_video_file_pred = tempfile.NamedTemporaryFile('w', suffix='.mp4', dir='your/path/DAWN-pytorch/demo')\n                    output_wav_path = tempfile.NamedTemporaryFile('w', suffix='.wav', dir='your/path/DAWN-pytorch/demo').name\n\n                    fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n                    video_writer = cv2.VideoWriter(tmp_video_file_pred.name, fourcc, fps, (INPUT_SIZE, INPUT_SIZE))\n                    SAV_DIR = os.path.join(CKPT_DIR, str(i_iter)+'_'+real_names[0] + '.mp4')\n\n\n                    wav_path = os.path.join('/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/audio', real_names[0].replace('_','/',1)+'.wav')\n\n                    extract_audio_by_frames(wav_path, 0, nf, fps, output_wav_path)\n\n                    for frame_idx in range(nf):\n                        new_im_gt = Image.new('RGB', (msk_size, msk_size))\n                        new_im_sample = Image.new('RGB', (msk_size, msk_size))\n\n                        save_tar_img = sample_img(real_vids[0, :, frame_idx])\n                        save_out_img = sample_img(out_img_list_tensor[frame_idx])\n                        # save_warped_img = sample_img(warped_img_list_tensor[frame_idx], batch_idx)\n                        # save_warped_grid = grid2fig(warped_grid_list_tensor[frame_idx, batch_idx].data.cpu().numpy(),\n                        #                             grid_size=32, img_size=msk_size)\n                        # save_conf_map = conf_map_list_tensor[frame_idx, batch_idx].unsqueeze(dim=0)\n                        # save_conf_map = save_conf_map.data.cpu()\n                        # save_conf_map = F.interpolate(save_conf_map, size=real_vids.shape[3:5]).numpy()\n                        # save_conf_map = np.transpose(save_conf_map, [0, 2, 3, 1])\n                        # save_conf_map = np.array(save_conf_map[0, :, :, 0]*255, dtype=np.uint8)\n\n                        frame_rgb = np.uint8(save_out_img)\n                        frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)\n                        video_writer.write(frame_bgr)\n\n                        # save sample and gt imgs\n                        new_im_gt.paste(Image.fromarray(save_tar_img, 'RGB'), (0, 0))\n                        new_im_sample.paste(Image.fromarray(save_out_img, 'RGB'), (0, 0))\n                        new_im_arr_gt = np.array(new_im_gt)\n                        new_im_arr_sample = np.array(new_im_sample)\n                        new_im_name = \"%03d_%s.png\" % (frame_idx, real_names[batch_idx])\n                        imageio.imsave(os.path.join(cur_img_dir_gt,new_im_name), new_im_arr_gt)\n                        imageio.imsave(os.path.join(cur_img_dir_samp,new_im_name), new_im_arr_sample)\n        \n                        # new_im = Image.new('RGB', (msk_size * 5, msk_size))\n                        # new_im.paste(Image.fromarray(save_tar_img, 'RGB'), (0, 0))\n                        # new_im.paste(Image.fromarray(save_out_img, 'RGB'), (msk_size, 0))\n                        # new_im.paste(Image.fromarray(save_warped_img, 'RGB'), (msk_size * 2, 0))\n                        # new_im.paste(Image.fromarray(save_warped_grid), (msk_size * 3, 0))\n                        # new_im.paste(Image.fromarray(save_conf_map, \"L\"), (msk_size * 4, 0))\n                        # new_im_list.append(new_im)\n                    # video_name = \"%04d_%s.gif\" % (cnt, real_names[batch_idx])\n                    # imageio.mimsave(os.path.join(CKPT_DIR, video_name), new_im_list)\n                    cnt += 1\n                    video_writer.release()\n                    cmd = ('ffmpeg -y ' + ' -i {0} -i {1} -vcodec copy -ac 2 -channel_layout stereo -pix_fmt yuv420p {2} -shortest'.format(\n                    output_wav_path, tmp_video_file_pred.name, SAV_DIR)).split()  \n                     \n                    call(cmd)  \n                    try:\n                        os.remove(tmp_video_file_pred.name)\n                        os.remove(output_wav_path)\n                    except OSError as e:\n                        print(f'Error: {e.strerror}')\n\n            iter_end = timeit.default_timer()\n\n            if global_iter % args.print_freq == 0:\n                print('Test:[{0}/{1}]\\t'\n                      'Time {batch_time.val:.3f}({batch_time.avg:.3f})'\n                      .format(global_iter, NUM_ITER, batch_time=batch_time))\n            global_iter += 1\n\n    print(\"loss for prediction: %.5f\" % (out_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3)))\n    print(\"loss for warping: %.5f\" % (warp_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3)))\n\n    res_dict = {}\n    res_dict[\"out_loss\"] = out_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3)\n    res_dict[\"warp_loss\"] = warp_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3)\n    with open(json_path, \"w\") as f:\n        json.dump(res_dict, f)\n\n    end = timeit.default_timer()\n    print(end - start, 'seconds')\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n\ndef setup_seed(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    np.random.seed(seed)\n    random.seed(seed)\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "LFG/test_flowautoenc_hdtf_video.py",
    "content": "# use LFG to reconstruct testing videos and measure the loss in video domain\n# using RegionMM\n\nimport argparse\nimport imageio\nimport torch\nfrom torch.utils import data\nimport numpy as np\nimport torch.backends.cudnn as cudnn\nimport os\nimport timeit\nfrom PIL import Image\nimport sys\nsys.path.append(\"your/path/DAWN-pytorch\")\nfrom misc import grid2fig\nfrom DM.datasets_hdtf_wpose_lmk_mo_block import HDTF \nimport random\nfrom LFG.modules.flow_autoenc import FlowAE\nimport torch.nn.functional as F\nfrom LFG.modules.util import Visualizer\nimport json_tricks as json\nimport cv2\nimport tempfile\nfrom subprocess import call\nfrom pydub import AudioSegment\nfrom einops import rearrange\nfrom tqdm import tqdm\n\nstart = timeit.default_timer()\nBATCH_SIZE = 1\nINPUT_SIZE = 128\nroot_dir = 'your/path/DAWN-pytorch/AE'  # your work directory\ndata_dir = \"/train20/intern/permanent/hbcheng2/data/HDTF/images_25hz_128_chunk\"\npose_dir = \"/train20/intern/permanent/hbcheng2/data/HDTF/pose_bar_chunk\"\neye_blink_dir = \"/train20/intern/permanent/hbcheng2/data/HDTF/eye_blink_bbox_from_xpc_bar_2_chunk\"\n\nDATASAVE_DIR = '/train20/intern/permanent/hbcheng2/data'\nCKPT_DIR = os.path.join(DATASAVE_DIR, 'mraa_result', str(INPUT_SIZE) + '_1000ep','video')\nos.makedirs(CKPT_DIR, exist_ok=True)\nIMG_DIR = os.path.join(DATASAVE_DIR, 'mraa_result', str(INPUT_SIZE) + '_1000ep','img')\nos.makedirs(IMG_DIR, exist_ok=True)\n\n# GPU = \"6\"\npostfix = \"\"\n\nN_FRAMES = 40\nNUM_VIDEOS = 10\nSAVE_VIDEO = True\nNUM_ITER = NUM_VIDEOS // BATCH_SIZE\nRANDOM_SEED = 1234\nMEAN = (0.0, 0.0, 0.0)\n# the path to trained LFG model\nRESTORE_FROM ='your/path/DAWN-pytorch/AE/data/log-hdtf-cosin/hdtf128_1000ep_2024-08-08_15:04/snapshots/RegionMM.pth'\n# RESTORE_FROM = \"/train20/intern/permanent/lmlin2/Flow/CVPR23_LFDM-main/data/log-hdtf/hdtf256_2023-11-21_16:49/snapshots/RegionMM_0020_S080000.pth\"\nconfig_pth = \"your/path/DAWN-pytorch/AE/data/log-hdtf/hdtf128_llm_2024-07-26_12:54/hdtf128_llm.yaml\"\n\njson_path = os.path.join(CKPT_DIR, \"loss%d%s.json\" % (NUM_VIDEOS, postfix))\nvisualizer = Visualizer()\nprint(root_dir)\nprint(postfix)\nprint(\"RESTORE_FROM:\", RESTORE_FROM)\nprint(\"config_path:\", config_pth)\nprint(json_path)\nprint(\"save video:\", SAVE_VIDEO)\n\n\ndef get_arguments():\n    \"\"\"Parse all the arguments provided from the CLI.\n\n    Returns:\n      A list of parsed arguments.\n    \"\"\"\n    parser = argparse.ArgumentParser(description=\"Flow Autoencoder\")\n    parser.add_argument(\"--num-workers\", default=8)\n    parser.add_argument(\"--gpu\", default=0,\n                        help=\"choose gpu device.\")\n    parser.add_argument('--print-freq', '-p', default=1, type=int,\n                        metavar='N', help='print frequency')\n    parser.add_argument(\"--batch-size\", type=int, default=BATCH_SIZE,\n                        help=\"Number of images sent to the network in one step.\")\n    parser.add_argument(\"--input-size\", type=str, default=INPUT_SIZE,\n                        help=\"Comma-separated string with height and width of images.\")\n    parser.add_argument(\"--random-seed\", type=int, default=RANDOM_SEED,\n                        help=\"Random seed to have reproducible results.\")\n    parser.add_argument(\"--restore-from\", default=RESTORE_FROM)\n    parser.add_argument(\"--fp16\", default=False)\n    return parser.parse_args()\n\n\nargs = get_arguments()\n\ndef extract_audio_by_frames(input_wav_path, start_frame_index, num_frames, frame_rate, output_wav_path):\n    # \n    audio = AudioSegment.from_wav(input_wav_path)\n\n    # \n    frame_duration = 1000 / frame_rate  # \n\n    # \n    start_time_ms = start_frame_index * frame_duration\n    end_time_ms = (start_frame_index + num_frames) * frame_duration\n\n    # \n    selected_audio = audio[start_time_ms:end_time_ms]\n\n    # \n    selected_audio.export(output_wav_path, format=\"wav\")\n\n\n\ndef sample_img(rec_img_batch):\n    rec_img = rec_img_batch.permute(1, 2, 0).data.cpu().numpy().copy()\n    rec_img += np.array(MEAN)/255.0\n    rec_img[rec_img < 0] = 0\n    rec_img[rec_img > 1] = 1\n    rec_img *= 255\n    return np.array(rec_img, np.uint8)\n\n\ndef main():\n    \"\"\"Create the model and start the training.\"\"\"\n\n    os.environ[\"CUDA_VISIBLE_DEVICES\"] = str(args.gpu)\n\n    cudnn.enabled = True\n    cudnn.benchmark = True\n    setup_seed(args.random_seed)\n\n    model = FlowAE(is_train=False, config_pth=config_pth)\n    model.cuda()\n\n    if os.path.isfile(args.restore_from):\n        print(\"=> loading checkpoint '{}'\".format(args.restore_from))\n        checkpoint = torch.load(args.restore_from)\n        model.generator.load_state_dict(checkpoint['generator'])\n        model.region_predictor.load_state_dict(checkpoint['region_predictor'])\n        model.bg_predictor.load_state_dict(checkpoint['bg_predictor'])\n        print(\"=> loaded checkpoint '{}'\".format(args.restore_from))\n    else:\n        print(\"=> no checkpoint found at '{}'\".format(args.restore_from))\n        exit(-1)\n\n    model.eval()\n\n    setup_seed(args.random_seed)\n\n    testloader = data.DataLoader(HDTF(data_dir=data_dir,\n                                       pose_dir=pose_dir,\n                                       eye_blink_dir = eye_blink_dir,\n                                       image_size=INPUT_SIZE,\n                                       mode='test',\n                                       max_num_frames=1e8,\n                                       color_jitter=True,\n                                       mean=MEAN),\n                                 batch_size=BATCH_SIZE,\n                                 shuffle=True, num_workers=8,\n                                 pin_memory=True)\n\n    batch_time = AverageMeter()\n    data_time = AverageMeter()\n\n    iter_end = timeit.default_timer()\n    cnt = 0\n\n    out_loss = 0.0\n    warp_loss = 0.0\n    num_sample = 0.0\n    l1_loss = torch.nn.L1Loss(reduction='sum')\n\n    global_iter = 0\n\n    while global_iter < NUM_ITER:\n        for i_iter, batch in enumerate(testloader):\n            # if i_iter < NUM_ITER:\n            #     break\n            # if global_iter < NUM_ITER:\n            #     break\n\n            data_time.update(timeit.default_timer() - iter_end)\n\n            real_vids, ref_hubert, real_poses, real_blink_bbox, real_mouth_ratio, real_names, start_frame_index = batch\n            # use first frame of each video as reference frame\n            real_vids = real_vids/255.\n            ref_imgs = real_vids[:, :, 0, :, :].clone().detach()\n            bs = real_vids.size(0)\n\n            batch_time.update(timeit.default_timer() - iter_end)\n\n            nf = real_vids.size(2)\n            out_img_list = []\n            warped_img_list = []\n            warped_grid_list = []\n            conf_map_list = []\n\n            segment_length = 120\n            b,c,f,h,w = real_vids.size()\n            real_vid_tmp = rearrange(real_vids, 'b c f h w -> (b f) c h w')# real_vid.reshape(b * f, c, h,  w) \n            ref_img_tmp = ref_imgs.repeat(segment_length,1,1,1).reshape(-1, 3, INPUT_SIZE, INPUT_SIZE)\n            for frame_idx in tqdm(range(0, nf, segment_length)):\n                \n                end_fn = min(nf, frame_idx + segment_length)\n                dri_imgs = real_vid_tmp[frame_idx : end_fn, :, :, :]\n                if end_fn == nf:\n                    ref_img_tmp = ref_imgs.repeat(dri_imgs.shape[0],1,1,1).reshape(-1, 3, INPUT_SIZE, INPUT_SIZE)\n                with torch.no_grad():\n                    model.set_train_input(ref_img=ref_img_tmp, dri_img=dri_imgs)\n                    model.forward()\n                out_img_list.append(model.generated['prediction'].clone().detach().cpu())\n                # warped_img_list.append(model.generated['deformed'].clone().detach())\n\n            out_img_list_tensor = torch.concat(out_img_list, dim = 0)\n\n            # out_loss += l1_loss(real_vids.permute(2, 0, 1, 3, 4).cpu(), out_img_list_tensor.cpu()).item()\n            # warp_loss += l1_loss(real_vids.permute(2, 0, 1, 3, 4).cpu(), warped_img_list_tensor.cpu()).item()\n            num_sample += bs\n            \n            \n            if SAVE_VIDEO:\n                for batch_idx in range(bs):\n                    msk_size = ref_imgs.shape[-1]\n                    new_im_list = []\n                    img_dir_name = \"%04d_%s\" % (i_iter, real_names[batch_idx])\n                    cur_img_dir_gt = os.path.join(IMG_DIR, img_dir_name,'gt')  \n                    os.makedirs(cur_img_dir_gt, exist_ok=True)\n                    cur_img_dir_samp = os.path.join(IMG_DIR, img_dir_name,'mraa')  \n                    os.makedirs(cur_img_dir_samp, exist_ok=True)\n                    \n                    fps = 25  # \n\n                    tmp_video_file_pred = tempfile.NamedTemporaryFile('w', suffix='.mp4', dir='your/path/DAWN-pytorch/demo')\n                    output_wav_path = tempfile.NamedTemporaryFile('w', suffix='.wav', dir='your/path/DAWN-pytorch/demo').name\n\n                    fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n                    video_writer = cv2.VideoWriter(tmp_video_file_pred.name, fourcc, fps, (INPUT_SIZE, INPUT_SIZE))\n                    SAV_DIR = os.path.join(CKPT_DIR, str(i_iter)+'_'+real_names[0] + '.mp4')\n\n\n                    wav_path = os.path.join(\"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\".replace('/images_25hz','/image_audio'), real_names[0]+'.wav')\n\n                    extract_audio_by_frames(wav_path, 0, nf, fps, output_wav_path)\n\n                    for frame_idx in range(nf):\n                        new_im_gt = Image.new('RGB', (msk_size, msk_size))\n                        new_im_sample = Image.new('RGB', (msk_size, msk_size))\n\n                        save_tar_img = sample_img(real_vids[0, :, frame_idx])\n                        save_out_img = sample_img(out_img_list_tensor[frame_idx])\n                        # save_warped_img = sample_img(warped_img_list_tensor[frame_idx], batch_idx)\n                        # save_warped_grid = grid2fig(warped_grid_list_tensor[frame_idx, batch_idx].data.cpu().numpy(),\n                        #                             grid_size=32, img_size=msk_size)\n                        # save_conf_map = conf_map_list_tensor[frame_idx, batch_idx].unsqueeze(dim=0)\n                        # save_conf_map = save_conf_map.data.cpu()\n                        # save_conf_map = F.interpolate(save_conf_map, size=real_vids.shape[3:5]).numpy()\n                        # save_conf_map = np.transpose(save_conf_map, [0, 2, 3, 1])\n                        # save_conf_map = np.array(save_conf_map[0, :, :, 0]*255, dtype=np.uint8)\n\n                        frame_rgb = np.uint8(save_out_img)\n                        frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)\n                        video_writer.write(frame_bgr)\n\n                        # save sample and gt imgs\n                        new_im_gt.paste(Image.fromarray(save_tar_img, 'RGB'), (0, 0))\n                        new_im_sample.paste(Image.fromarray(save_out_img, 'RGB'), (0, 0))\n                        new_im_arr_gt = np.array(new_im_gt)\n                        new_im_arr_sample = np.array(new_im_sample)\n                        new_im_name = \"%03d_%s.png\" % (frame_idx, real_names[batch_idx])\n                        imageio.imsave(os.path.join(cur_img_dir_gt,new_im_name), new_im_arr_gt)\n                        imageio.imsave(os.path.join(cur_img_dir_samp,new_im_name), new_im_arr_sample)\n        \n                        # new_im = Image.new('RGB', (msk_size * 5, msk_size))\n                        # new_im.paste(Image.fromarray(save_tar_img, 'RGB'), (0, 0))\n                        # new_im.paste(Image.fromarray(save_out_img, 'RGB'), (msk_size, 0))\n                        # new_im.paste(Image.fromarray(save_warped_img, 'RGB'), (msk_size * 2, 0))\n                        # new_im.paste(Image.fromarray(save_warped_grid), (msk_size * 3, 0))\n                        # new_im.paste(Image.fromarray(save_conf_map, \"L\"), (msk_size * 4, 0))\n                        # new_im_list.append(new_im)\n                    # video_name = \"%04d_%s.gif\" % (cnt, real_names[batch_idx])\n                    # imageio.mimsave(os.path.join(CKPT_DIR, video_name), new_im_list)\n                    cnt += 1\n                    video_writer.release()\n                    cmd = ('ffmpeg -y ' + ' -i {0} -i {1} -vcodec copy -ac 2 -channel_layout stereo -pix_fmt yuv420p {2} -shortest'.format(\n                    output_wav_path, tmp_video_file_pred.name, SAV_DIR)).split()  \n                     \n                    call(cmd)  \n                    try:\n                        os.remove(tmp_video_file_pred.name)\n                        os.remove(output_wav_path)\n                    except OSError as e:\n                        print(f'Error: {e.strerror}')\n\n            iter_end = timeit.default_timer()\n\n            if global_iter % args.print_freq == 0:\n                print('Test:[{0}/{1}]\\t'\n                      'Time {batch_time.val:.3f}({batch_time.avg:.3f})'\n                      .format(global_iter, NUM_ITER, batch_time=batch_time))\n            global_iter += 1\n\n    print(\"loss for prediction: %.5f\" % (out_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3)))\n    print(\"loss for warping: %.5f\" % (warp_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3)))\n\n    res_dict = {}\n    res_dict[\"out_loss\"] = out_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3)\n    res_dict[\"warp_loss\"] = warp_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3)\n    with open(json_path, \"w\") as f:\n        json.dump(res_dict, f)\n\n    end = timeit.default_timer()\n    print(end - start, 'seconds')\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n\ndef setup_seed(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    np.random.seed(seed)\n    random.seed(seed)\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "LFG/test_flowautoenc_hdtf_video_256.py",
    "content": "# use LFG to reconstruct testing videos and measure the loss in video domain\n# using RegionMM\n\nimport argparse\nimport imageio\nimport torch\nfrom torch.utils import data\nimport numpy as np\nimport torch.backends.cudnn as cudnn\nimport os\nimport timeit\nfrom PIL import Image\nimport sys\nsys.path.append(\"your/path/DAWN-pytorch\")\nfrom misc import grid2fig\nfrom DM.datasets_hdtf_wpose_lmk_mo_block_mraa import HDTF_test as HDTF\nimport random\nfrom LFG.modules.flow_autoenc import FlowAE\nimport torch.nn.functional as F\nfrom LFG.modules.util import Visualizer\nimport json_tricks as json\nimport cv2\nimport tempfile\nfrom subprocess import call\nfrom pydub import AudioSegment\nfrom einops import rearrange\nfrom tqdm import tqdm\n\nstart = timeit.default_timer()\nBATCH_SIZE = 1\nINPUT_SIZE = 256\nroot_dir = 'your/path/DAWN-pytorch/AE'  # your work directory\ndata_dir = \"/train20/intern/permanent/hbcheng2/data/HDTF/images_25hz_256_chunk\"\npose_dir = \"/train20/intern/permanent/hbcheng2/data/HDTF/pose_bar_chunk\"\neye_blink_dir = \"/train20/intern/permanent/hbcheng2/data/HDTF/eye_blink_bbox_from_xpc_bar_2_chunk\"\n\nDATASAVE_DIR = '/train20/intern/permanent/hbcheng2/data'\nCKPT_DIR = os.path.join(DATASAVE_DIR, 'mraa_result', str(INPUT_SIZE) + '_400ep','video')\nos.makedirs(CKPT_DIR, exist_ok=True)\nIMG_DIR = os.path.join(DATASAVE_DIR, 'mraa_result', str(INPUT_SIZE) + '_400ep','img')\nos.makedirs(IMG_DIR, exist_ok=True)\n\n# GPU = \"6\"\npostfix = \"\"\n\nN_FRAMES = 40\nNUM_VIDEOS = 10\nSAVE_VIDEO = True\nNUM_ITER = NUM_VIDEOS // BATCH_SIZE\nRANDOM_SEED = 1234\nMEAN = (0.0, 0.0, 0.0)\n# the path to trained LFG model\nRESTORE_FROM ='your/path/DAWN-pytorch/AE/data/log-hdtf-256-cosin/hdtf256_400ep_2024-08-08_00:15/snapshots/RegionMM.pth'\n# RESTORE_FROM = \"/train20/intern/permanent/lmlin2/Flow/CVPR23_LFDM-main/data/log-hdtf/hdtf256_2023-11-21_16:49/snapshots/RegionMM_0020_S080000.pth\"\nconfig_pth = \"your/path/DAWN-pytorch/config/hdtf256.yaml\"\n\njson_path = os.path.join(CKPT_DIR, \"loss%d%s.json\" % (NUM_VIDEOS, postfix))\nvisualizer = Visualizer()\nprint(root_dir)\nprint(postfix)\nprint(\"RESTORE_FROM:\", RESTORE_FROM)\nprint(\"config_path:\", config_pth)\nprint(json_path)\nprint(\"save video:\", SAVE_VIDEO)\n\n\ndef get_arguments():\n    \"\"\"Parse all the arguments provided from the CLI.\n\n    Returns:\n      A list of parsed arguments.\n    \"\"\"\n    parser = argparse.ArgumentParser(description=\"Flow Autoencoder\")\n    parser.add_argument(\"--num-workers\", default=8)\n    parser.add_argument(\"--gpu\", default=0,\n                        help=\"choose gpu device.\")\n    parser.add_argument('--print-freq', '-p', default=1, type=int,\n                        metavar='N', help='print frequency')\n    parser.add_argument(\"--batch-size\", type=int, default=BATCH_SIZE,\n                        help=\"Number of images sent to the network in one step.\")\n    parser.add_argument(\"--input-size\", type=str, default=INPUT_SIZE,\n                        help=\"Comma-separated string with height and width of images.\")\n    parser.add_argument(\"--random-seed\", type=int, default=RANDOM_SEED,\n                        help=\"Random seed to have reproducible results.\")\n    parser.add_argument(\"--restore-from\", default=RESTORE_FROM)\n    parser.add_argument(\"--fp16\", default=False)\n    return parser.parse_args()\n\n\nargs = get_arguments()\n\ndef extract_audio_by_frames(input_wav_path, start_frame_index, num_frames, frame_rate, output_wav_path):\n    # \n    audio = AudioSegment.from_wav(input_wav_path)\n\n    # \n    frame_duration = 1000 / frame_rate  # \n\n    # \n    start_time_ms = start_frame_index * frame_duration\n    end_time_ms = (start_frame_index + num_frames) * frame_duration\n\n    # \n    selected_audio = audio[start_time_ms:end_time_ms]\n\n    # \n    selected_audio.export(output_wav_path, format=\"wav\")\n\n\n\ndef sample_img(rec_img_batch):\n    rec_img = rec_img_batch.permute(1, 2, 0).data.cpu().numpy().copy()\n    rec_img += np.array(MEAN)/255.0\n    rec_img[rec_img < 0] = 0\n    rec_img[rec_img > 1] = 1\n    rec_img *= 255\n    return np.array(rec_img, np.uint8)\n\n\ndef main():\n    \"\"\"Create the model and start the training.\"\"\"\n\n    os.environ[\"CUDA_VISIBLE_DEVICES\"] = str(args.gpu)\n\n    cudnn.enabled = True\n    cudnn.benchmark = True\n    setup_seed(args.random_seed)\n\n    model = FlowAE(is_train=False, config_pth=config_pth)\n    model.cuda()\n\n    if os.path.isfile(args.restore_from):\n        print(\"=> loading checkpoint '{}'\".format(args.restore_from))\n        checkpoint = torch.load(args.restore_from)\n        model.generator.load_state_dict(checkpoint['generator'])\n        model.region_predictor.load_state_dict(checkpoint['region_predictor'])\n        model.bg_predictor.load_state_dict(checkpoint['bg_predictor'])\n        print(\"=> loaded checkpoint '{}'\".format(args.restore_from))\n    else:\n        print(\"=> no checkpoint found at '{}'\".format(args.restore_from))\n        exit(-1)\n\n    model.eval()\n\n    setup_seed(args.random_seed)\n\n    testloader = data.DataLoader(HDTF(data_dir=data_dir,\n                                       pose_dir=pose_dir,\n                                       eye_blink_dir = eye_blink_dir,\n                                       image_size=INPUT_SIZE,\n                                       mode='test',\n                                       max_num_frames=1e8,\n                                       color_jitter=True,\n                                       mean=MEAN),\n                                 batch_size=BATCH_SIZE,\n                                 shuffle=True, num_workers=8,\n                                 pin_memory=True)\n\n    batch_time = AverageMeter()\n    data_time = AverageMeter()\n\n    iter_end = timeit.default_timer()\n    cnt = 0\n\n    out_loss = 0.0\n    warp_loss = 0.0\n    num_sample = 0.0\n    l1_loss = torch.nn.L1Loss(reduction='sum')\n\n    global_iter = 0\n\n\n    while global_iter < NUM_ITER:\n        for i_iter, batch in enumerate(testloader):\n            # if i_iter < NUM_ITER:\n            #     break\n            # if global_iter < NUM_ITER:\n            #     break\n\n            data_time.update(timeit.default_timer() - iter_end)\n\n            block_path_list, real_names, total_num_frames = batch\n\n            out_img_list = []\n            # use first frame of each video as reference frame\n\n            ref_path = block_path_list[0][0]\n            ref_imgs = np.load(ref_path) # 25, 256, 256, 3\n            ref_imgs = torch.tensor(ref_imgs).permute(0, 3, 1, 2)\n            ref_imgs = ref_imgs[0].clone().detach().to(torch.float32)/255.\n            ref_img_tmp = ref_imgs.repeat(25,1,1,1).reshape(-1, 3, INPUT_SIZE, INPUT_SIZE)\n            \n            msk_size = ref_imgs.shape[-1]\n            new_im_list = []\n            img_dir_name = \"%04d_%s\" % (i_iter, real_names[0])\n            cur_img_dir_gt = os.path.join(IMG_DIR, img_dir_name,'gt')\n            os.makedirs(cur_img_dir_gt, exist_ok=True)\n            cur_img_dir_samp = os.path.join(IMG_DIR, img_dir_name,'mraa') \n            os.makedirs(cur_img_dir_samp, exist_ok=True)\n            \n            fps = 25  # \n\n            tmp_video_file_pred = tempfile.NamedTemporaryFile('w', suffix='.mp4', dir='your/path/DAWN-pytorch/demo')\n            output_wav_path = tempfile.NamedTemporaryFile('w', suffix='.wav', dir='your/path/DAWN-pytorch/demo').name\n\n            fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n            video_writer = cv2.VideoWriter(tmp_video_file_pred.name, fourcc, fps, (INPUT_SIZE, INPUT_SIZE))\n            SAV_DIR = os.path.join(CKPT_DIR, str(i_iter)+'_'+real_names[0] + '.mp4')\n\n\n            wav_path = os.path.join(\"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\".replace('/images_25hz','/image_audio'), real_names[0]+'.wav')\n\n            extract_audio_by_frames(wav_path, 0, total_num_frames, fps, output_wav_path)\n\n            batch_time.update(timeit.default_timer() - iter_end)\n\n            new_im_gt = Image.new('RGB', (msk_size, msk_size))\n            new_im_sample = Image.new('RGB', (msk_size, msk_size))\n\n            frame_cnt = 0\n            for id in range(len(block_path_list)):\n                block_path = block_path_list[id][0]\n                real_vids = np.load(block_path)\n                if real_vids.shape[0] !=ref_img_tmp.shape[0]:\n                    ref_img_tmp = ref_imgs.repeat(real_vids.shape[0],1,1,1).reshape(-1, 3, INPUT_SIZE, INPUT_SIZE)\n                real_vids = torch.tensor(real_vids).permute(0, 3, 1, 2)  # 25, 256, 256, 3 - > 25, 3, 256, 256\n\n                \n\n                dri_imgs = real_vids.to(torch.float32)/255.\n                with torch.no_grad():\n                    model.set_train_input(ref_img=ref_img_tmp, dri_img=dri_imgs)\n                    model.forward()\n                out_img_tensor = (model.generated['prediction'] * 255.).to(torch.uint8).clone().detach().cpu()\n\n                # save real_vids\n                for i in range(real_vids.shape[0]):\n                    save_tar_img = np.array(real_vids[i].permute(1, 2, 0), np.uint8)\n                    save_out_img = np.array(out_img_tensor[i].permute(1, 2, 0), np.uint8)\n                    frame_rgb = np.uint8(save_out_img)\n                    frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)\n                    video_writer.write(frame_bgr)\n\n                    new_im_gt.paste(Image.fromarray(save_tar_img, 'RGB'), (0, 0))\n                    new_im_sample.paste(Image.fromarray(save_out_img, 'RGB'), (0, 0))\n                    new_im_arr_gt = np.array(new_im_gt)\n                    new_im_arr_sample = np.array(new_im_sample)\n                    new_im_name = \"%03d_%s.png\" % (frame_cnt, real_names[0])\n                    imageio.imsave(os.path.join(cur_img_dir_gt,new_im_name), new_im_arr_gt)\n                    imageio.imsave(os.path.join(cur_img_dir_samp,new_im_name), new_im_arr_sample)\n\n                    frame_cnt += 1\n\n\n            # out_loss += l1_loss(real_vids.permute(2, 0, 1, 3, 4).cpu(), out_img_list_tensor.cpu()).item()\n            # warp_loss += l1_loss(real_vids.permute(2, 0, 1, 3, 4).cpu(), warped_img_list_tensor.cpu()).item()\n            \n            \n            \n            video_writer.release()\n            cmd = ('ffmpeg -y ' + ' -i {0} -i {1} -vcodec copy -ac 2 -channel_layout stereo -pix_fmt yuv420p {2} -shortest'.format(\n            output_wav_path, tmp_video_file_pred.name, SAV_DIR)).split()  \n                \n            call(cmd)  \n            try:\n                os.remove(tmp_video_file_pred.name)\n                os.remove(output_wav_path)\n            except OSError as e:\n                print(f'Error: {e.strerror}')\n\n            iter_end = timeit.default_timer()\n\n            if global_iter % args.print_freq == 0:\n                print('Test:[{0}/{1}]\\t'\n                      'Time {batch_time.val:.3f}({batch_time.avg:.3f})'\n                      .format(global_iter, NUM_ITER, batch_time=batch_time))\n            global_iter += 1\n\n    print(\"loss for prediction: %.5f\" % (out_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3)))\n    print(\"loss for warping: %.5f\" % (warp_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3)))\n\n    res_dict = {}\n    res_dict[\"out_loss\"] = out_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3)\n    res_dict[\"warp_loss\"] = warp_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3)\n    with open(json_path, \"w\") as f:\n        json.dump(res_dict, f)\n\n    end = timeit.default_timer()\n    print(end - start, 'seconds')\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n\ndef setup_seed(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    np.random.seed(seed)\n    random.seed(seed)\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "LFG/train.py",
    "content": "# train a LFAE\n# this code is based on RegionMM (MRAA): https://github.com/snap-research/articulated-animation\nimport os.path\nimport torch\nfrom torch.utils.data import DataLoader\nfrom modules.model import ReconstructionModel\nfrom torch.optim.lr_scheduler import MultiStepLR\nfrom sync_batchnorm import DataParallelWithCallback\nfrom frames_dataset import DatasetRepeater\nimport timeit\nfrom modules.util import Visualizer\nimport imageio\nimport math\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n\ndef train(config, generator, region_predictor, bg_predictor, checkpoint, log_dir, dataset, device_ids):\n    train_params = config['train_params']\n\n    optimizer = torch.optim.Adam(list(generator.parameters()) +\n                                 list(region_predictor.parameters()) +\n                                 list(bg_predictor.parameters()), lr=train_params['lr'], betas=(0.5, 0.999))\n\n    start_epoch = 0\n    start_step = 0\n    if checkpoint is not None:\n        ckpt = torch.load(checkpoint)\n        if config[\"set_start\"]:\n            start_step = int(math.ceil(ckpt['example'] / config['train_params']['batch_size']))\n            start_epoch = ckpt['epoch']\n        generator.load_state_dict(ckpt['generator'])\n        region_predictor.load_state_dict(ckpt['region_predictor'])\n        bg_predictor.load_state_dict(ckpt['bg_predictor'])\n        if 'optimizer' in list(ckpt.keys()):\n            try:\n                optimizer.load_state_dict(ckpt['optimizer'])\n            except:\n                optimizer.load_state_dict(ckpt['optimizer'].state_dict())\n\n    # scheduler = MultiStepLR(optimizer, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1)\n    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=train_params[\"max_epochs\"], eta_min=2e-6)\n    if 'num_repeats' in train_params or train_params['num_repeats'] != 1:\n        dataset = DatasetRepeater(dataset, train_params['num_repeats'])\n\n    dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True,\n                            num_workers=8, drop_last=True)\n\n    model = ReconstructionModel(region_predictor, bg_predictor, generator, train_params)\n\n    visualizer = Visualizer(**config['visualizer_params'])\n\n    if torch.cuda.is_available():\n        if ('use_sync_bn' in train_params) and train_params['use_sync_bn']:\n            model = DataParallelWithCallback(model, device_ids=device_ids)\n        else:\n            model = torch.nn.DataParallel(model, device_ids=device_ids)\n\n    # rewritten by nhm\n    batch_time = AverageMeter()\n    data_time = AverageMeter()\n\n    total_losses = AverageMeter()\n    losses_perc = AverageMeter()\n    losses_equiv_shift = AverageMeter()\n    losses_equiv_affine = AverageMeter()\n\n    cnt = 0\n    epoch_cnt = start_epoch\n    actual_step = start_step\n    final_step = config[\"num_step_per_epoch\"] * train_params[\"max_epochs\"]\n\n    while actual_step < final_step:\n        iter_end = timeit.default_timer()\n\n        for i_iter, x in enumerate(dataloader):\n            actual_step = int(start_step + cnt)\n            data_time.update(timeit.default_timer() - iter_end)\n            optimizer.zero_grad()\n            losses, generated = model(x)\n            loss_values = [val.mean() for val in losses.values()]\n            loss = sum(loss_values)\n            loss.backward()\n            optimizer.step()\n\n            batch_time.update(timeit.default_timer() - iter_end)\n            iter_end = timeit.default_timer()\n\n            bs = x['source'].size(0)\n            total_losses.update(loss.item(), bs)\n            losses_perc.update(loss_values[0].item(), bs)\n            losses_equiv_shift.update(loss_values[1].item(), bs)\n            losses_equiv_affine.update(loss_values[2].item(), bs)\n\n            if actual_step % train_params[\"print_freq\"] == 0:\n                print('iter: [{0}]{1}/{2}\\t'\n                      'loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n                      'loss_perc {loss_perc.val:.4f} ({loss_perc.avg:.4f})\\n'\n                      'loss_shift {loss_shift.val:.4f} ({loss_shift.avg:.4f})\\t'\n                      'loss_affine {loss_affine.val:.4f} ({loss_affine.avg:.4f})'\n                    .format(\n                    cnt, actual_step, final_step,\n                    loss=total_losses,\n                    loss_perc=losses_perc,\n                    loss_shift=losses_equiv_shift,\n                    loss_affine=losses_equiv_affine\n                ))\n\n            if actual_step % train_params['save_img_freq'] == 0:\n                save_image = visualizer.visualize(x['driving'], x['source'], generated, index=0)\n                save_name = 'B' + format(train_params[\"batch_size\"], \"04d\") + '_S' + format(actual_step, \"06d\") \\\n                            + '_' + x[\"frame\"][0][0][:-4] + '_to_' + x[\"frame\"][1][0][-7:]\n                save_file = os.path.join(config[\"imgshots\"], save_name)\n                imageio.imsave(save_file, save_image)\n\n            if actual_step % config[\"save_ckpt_freq\"] == 0 and cnt != 0:\n                print('taking snapshot...')\n                torch.save({'example': actual_step * train_params[\"batch_size\"],\n                            'epoch': epoch_cnt,\n                            'generator': generator.state_dict(),\n                            'bg_predictor': bg_predictor.state_dict(),\n                            'region_predictor': region_predictor.state_dict(),\n                            'optimizer': optimizer.state_dict()},\n                           os.path.join(config[\"snapshots\"],\n                                        'RegionMM_' + format(train_params[\"batch_size\"], \"04d\") +\n                                        '_S' + format(actual_step, \"06d\") + '.pth'))\n\n            if actual_step % train_params[\"update_ckpt_freq\"] == 0 and cnt != 0:\n                print('updating snapshot...')\n                torch.save({'example': actual_step * train_params[\"batch_size\"],\n                            'epoch': epoch_cnt,\n                            'generator': generator.state_dict(),\n                            'bg_predictor': bg_predictor.state_dict(),\n                            'region_predictor': region_predictor.state_dict(),\n                            'optimizer': optimizer.state_dict()},\n                           os.path.join(config[\"snapshots\"],'RegionMM.pth'))\n\n            if actual_step >= final_step:\n                break\n\n            cnt += 1\n\n        scheduler.step()\n        epoch_cnt += 1\n        # print lr\n        print(\"epoch %d, lr= %.7f\" % (epoch_cnt, optimizer.param_groups[0][\"lr\"]))\n\n    print('save the final model...')\n    torch.save({'example': actual_step * train_params[\"batch_size\"],\n                'epoch': epoch_cnt,\n                'generator': generator.state_dict(),\n                'bg_predictor': bg_predictor.state_dict(),\n                'region_predictor': region_predictor.state_dict(),\n                'optimizer': optimizer.state_dict()},\n               os.path.join(config[\"snapshots\"],\n                            'RegionMM_' + format(train_params[\"batch_size\"], \"04d\") +\n                            '_S' + format(actual_step, \"06d\") + '.pth'))\n\n\n"
  },
  {
    "path": "LFG/vis_flow.py",
    "content": "import torch\nimport numpy as np\nimport matplotlib.pyplot as plt\n\ndef visualize_dense_optical_flow(flow_tensor, save_path):\n    flow_np = flow_tensor.cpu().numpy()\n    flow_tensor = flow_tensor + 1e-7\n\n    magnitude = np.sqrt(flow_np[0]**2 + flow_np[1]**2)\n\n    # mask = magnitude > 1/64\n\n\n    magnitude = magnitude # * mask\n    angle = np.arctan2(flow_np[1], flow_np[0])\n\n    angle = angle # * mask\n\n    plt.figure()\n    plt.imshow(magnitude, cmap='BuPu', alpha=0.8)\n    plt.imshow(angle, cmap='hsv', alpha=0.2)\n    plt.title('Dense Optical Flow')\n    plt.axis('off')\n    plt.savefig(save_path)\n    plt.close()\n\ndef grid2flow(warped_grid, grid_size=64, img_size=256):\n    dpi = 1000\n    # plt.ioff()\n    h_range = torch.linspace(-1, 1, grid_size)\n    w_range = torch.linspace(-1, 1, grid_size)\n    grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).flip(2)\n    \n    out = warped_grid - grid\n    return out\n\nif __name__ == '__main__':\n    dense_flow_tensor = torch.zeros(2, 100, 100)\n    visualize_dense_optical_flow(dense_flow_tensor, 'test.jpg')"
  },
  {
    "path": "PBnet/run_cvae_h_ann_reemb_rope_eye_3.sh",
    "content": "source /home4/intern/lmlin2/.bashrc\nconda activate actor\n# crema rc delta pose\nexport CUDA_VISIBLE_DEVICES=\"0\"\n# python /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/train/train_cvae.py\\\n#      --num_frames 40\\\n#      --lambda_kl 1\\\n#      --lambda_ssim 1\\\n#      --lambda_freq 1\\\n#      --modelname cvae_transformer_ssim_kl_freq\\\n#      --dataset hdtf\\\n#      --num_epochs 10000\\\n#      --folder exps_delta_pose/HDTF_nf40_kl1_ssim1_freq_128_w5_1w_6\n\npython /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/train/train_cvae_ganloss_ann_eye.py\\\n     --num_frames 200\\\n     --eye True\\\n     --lr 0.0004 \\\n     --batch_size 40\\\n     --lambda_kl 0.004\\\n     --lambda_reg 0.0005\\\n     --lambda_rc 1\\\n     --ff_size 128\\\n     --max_distance 128\\\n     --num_buckets 128\\\n     --num_layers 2\\\n     --audio_latent_dim 256\\\n     --snapshot 10000\\\n     --modelname cvae_transformerreemb8_rc_kl_reg\\\n     --dataset hdtf\\\n     --num_epochs 100000\\\n     --folder exps_delta_pose_rope_eye/HDTF_b40_200_eye_kl4e3_lr4e-4_reg5e-4_rope16_3 #  > output.log &\n\n# nohup python /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/train/train_cvae_ganloss_first3.py\\\n#      --num_frames 40\\\n#      --batch_size 20\\\n#      --lambda_kl 1\\\n#      --lambda_rc 1\\\n#      --num_layers 4\\\n#      --modelname cvae_transformerold_kl_ssim\\\n#      --dataset hdtf\\\n#      --num_epochs 30000\\\n#      --folder exps_delta_pose_f3/HDTF_l2_nf40_kl1_ssim_norm_w5_1w_b20_first_3 > output.log &\n"
  },
  {
    "path": "PBnet/src/__init__.py",
    "content": ""
  },
  {
    "path": "PBnet/src/config.py",
    "content": "import os\n\nSMPL_DATA_PATH = \"models/smpl/\"\n\nSMPL_KINTREE_PATH = os.path.join(SMPL_DATA_PATH, \"kintree_table.pkl\")\nSMPL_MODEL_PATH = os.path.join(SMPL_DATA_PATH, \"SMPL_NEUTRAL.pkl\")\n\nJOINT_REGRESSOR_TRAIN_EXTRA = os.path.join(SMPL_DATA_PATH, 'J_regressor_extra.npy')\n"
  },
  {
    "path": "PBnet/src/datasets/__init__.py",
    "content": ""
  },
  {
    "path": "PBnet/src/datasets/datasets_hdtf_pos_chunk_norm_2_fast.py",
    "content": "from os import name\nimport sys\nsys.path.append('your_path')\n\nimport os\nimport random\nimport torch\n\nimport numpy as np\nimport torch.utils.data as data\nimport torch.nn.functional as Ft\nimport imageio.v2 as imageio\n\nimport cv2\nimport torchvision.transforms.functional as F\nimport matplotlib.pyplot as plt\nfrom PIL import Image\nfrom scipy.interpolate import interp1d\n# import decord\nfrom torchvision.transforms.functional import to_pil_image\nfrom torchvision import transforms\nimport time\nimport pickle as pkl\n\n# decord.bridge.set_bridge('torch')\n\n\ndef resize(im, desired_size, interpolation):\n    old_size = im.shape[:2]\n    ratio = float(desired_size)/max(old_size)\n    new_size = tuple(int(x*ratio) for x in old_size)\n\n    im = cv2.resize(im, (new_size[1], new_size[0]), interpolation=interpolation)\n    delta_w = desired_size - new_size[1]\n    delta_h = desired_size - new_size[0]\n    top, bottom = delta_h//2, delta_h-(delta_h//2)\n    left, right = delta_w//2, delta_w-(delta_w//2)\n\n    color = [0, 0, 0]\n    new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)\n\n    return new_im\n\nclass HDTF(data.Dataset):\n    def __init__(self, data_dir, max_num_frames=80, mode='train'):\n\n        super(HDTF, self).__init__()\n        self.data_dir = data_dir\n        self.max_num_frames = max_num_frames\n        self.mode = mode\n            \n        self.hubert_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate_chunk'  \n        self.pose_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/pose_bar_chunk'\n\n        # self.max_vals = torch.tensor([20, 10, 10,  7e-4,\n        #     7e+1,  9e+1]).to(torch.float32)\n        # self.min_vals = torch.tensor([-20, -10, -10,  4e-4,\n        #     5e+1,  6e+1]).to(torch.float32)\n\n        self.max_vals = torch.tensor([90, 90, 90,  1,\n            720,  1080]).to(torch.float32)\n        self.min_vals = torch.tensor([-90, -90, -90,  0,\n            0,  0]).to(torch.float32)\n\n        vid_list = []\n\n        # hdtf  \n        vid_id_name_list = ['RD_Radio14_000','RD_Radio30_000','RD_Radio47_000','RD_Radio56_000','WDA_AmyKlobuchar1_001',\\\n                            'WDA_BarbaraLee0_000','WDA_BobCasey0_000','WDA_CatherineCortezMasto_000','WDA_DebbieDingell1_000','WDA_DonaldMcEachin_000',\\\n                            'WDA_EricSwalwell_000','WDA_HenryWaxman_000','WDA_JanSchakowsky1_000','WDA_JoeDonnelly_000','WDA_JohnSarbanes1_000',\\\n                            'WDA_JoeNeguse_001','WDA_KatieHill_000','WDA_LucyMcBath_000','WDA_MazieHirono0_000','WDA_NancyPelosi1_000',\\\n                            'WDA_PattyMurray0_000','WDA_RaulRuiz_000','WDA_SeanPatrickMaloney_000','WDA_TammyBaldwin0_000','WDA_TerriSewell0_000',\\\n                            'WDA_TomCarper_000','WDA_WhipJimClyburn_000','WRA_AdamKinzinger0_000','WRA_AnnWagner_000','WRA_BobCorker_000',\\\n                            'WRA_CandiceMiller0_000','WRA_CathyMcMorrisRodgers2_000','WRA_CoryGardner1_000','WRA_DebFischer1_000','WRA_DianeBlack1_000',\\\n                            'WRA_ErikPaulsen_000','WRA_GeorgeLeMieux_000','WRA_JebHensarling0_001','WRA_JoeHeck1_000','WRA_JohnKasich1_001',\\\n                            'WRA_MarcoRubio_000']\n\n        bad_id_name = ['WDA_DanKildee_000', 'WDA_PatrickLeahy1_000', 'WRA_KristiNoem2_000']\n\n        # vid_id_name_list = [item + '.mp4' for item in vid_id_name_list]\n        # bad_id_name = [item + '.mp4' for item in bad_id_name]\n\n        with open('/train20/intern/permanent/hbcheng2/data/HDTF/length_dict.pkl', 'rb') as f:\n            self.len_dict = pkl.load(f)\n\n        if mode == 'train':\n            for id_name in os.listdir(data_dir):\n                # id_name = id_name[:-4]\n                if id_name in vid_id_name_list or id_name in bad_id_name:\n                    continue\n                vid_list.append(id_name)\n            self.videos = vid_list\n        if mode == 'test':\n            self.videos = vid_id_name_list\n\n    def check_head(self, frame_list, video_name, start, end):\n\n        start_path = self.get_pose_path(frame_list, video_name, start)\n        end_path = self.get_pose_path(frame_list, video_name, end)\n\n        if os.path.exists(start_path) and os.path.exists(end_path):\n            return True\n        else:\n            return False\n\n\n    def get_block_data_for_two(self, path, start, end):\n        # TODO： id function\n        '''\n        input: \n            start: start id\n            end:  end id\n        output:\n            the data from block\n        '''\n\n        block_st = start//25\n        block_ed = end//25\n\n        st_pos = block_st % 25\n        ed_pos = block_ed % 25\n\n        block_st_name = 'chunk_%04d.npy' % (block_st)\n        block_ed_name = 'chunk_%04d.npy' % (block_ed)\n\n        if block_st != block_ed:\n            block_st_path = os.path.join(path, block_st_name)\n            block_ed_path = os.path.join(path, block_ed_name)\n            block_st = np.load(block_st_path)\n            block_ed = np.load(block_ed_path)\n\n            return np.concatenate((block_st[st_pos:], block_ed[:ed_pos]))\n        else:\n            block_st_path = os.path.join(path, block_st_name)\n            block_st = np.load(block_st_path)\n            return block_st[st_pos, ed_pos]\n\n    def get_block_data(self, path, start, end):\n        # TODO： id function\n        '''\n        input: \n            start: start id\n            end:  end id\n        output:\n            the data from block\n        '''\n\n        block_st = start//25\n        block_ed = end//25\n\n        st_pos = start % 25\n        ed_pos = end % 25\n\n        block_list = [os.path.join(path,'chunk_%04d.npy' % (i)) for i in range(block_st, block_ed+1)]\n\n        if block_st != block_ed:\n            arr_list = []\n            block_st = np.load(block_list[0])\n            arr_list.append(block_st[st_pos:])\n            for path in block_list[1:-1]:\n                arr_list.append(np.load(path))\n\n            block_ed = np.load(block_list[-1])\n            arr_list.append(block_ed[:ed_pos])\n\n            return np.concatenate(arr_list)\n        else:\n            block_st_path = os.path.join(path, block_list[0])\n            block_st = np.load(block_st_path)\n            return block_st[st_pos: ed_pos]\n            \n\n    def check_len(self, name):\n        return self.len_dict[name]\n\n\n    def __len__(self):\n        return len(self.videos)\n\n    def __getitem__(self, idx):\n        video_name = self.videos[idx]\n        path = os.path.join(self.data_dir, video_name)\n        hubert_path = os.path.join(self.hubert_dir, video_name)\n        pose_path = os.path.join(self.pose_dir, video_name)\n        # eye_blink_path = os.path.join(self.eye_blink_dir, video_name)\n\n        total_num_frames = self.check_len(video_name)\n        \n\n        if total_num_frames <= self.max_num_frames:\n            sample_frames = total_num_frames\n            start = 0\n        else:\n            sample_frames = self.max_num_frames\n            start = np.random.randint(total_num_frames-self.max_num_frames)\n        start=start\n        stop=sample_frames+start\n\n        sample_hubert_feature_npy = self.get_block_data(path = hubert_path, start = start, end = stop).astype(np.float32)\n        sample_pose_list_npy = self.get_block_data(path = pose_path, start = start, end = stop).astype(np.float32)\n\n        sample_hubert_feature_tensor = torch.tensor(sample_hubert_feature_npy)\n        sample_pos_feature_tensor = torch.tensor(sample_pose_list_npy)[:,:-1]\n        sample_pos_feature_tensor = (sample_pos_feature_tensor - self.min_vals)/ (self.max_vals - self.min_vals) \n        video_name = video_name.replace('/','_')\n\n        # sample_pose_list_npy = sample_pose_list_npy.transpose(1,0)  # for compatibility\n\n        return sample_hubert_feature_tensor, sample_pos_feature_tensor, video_name, start\n\n    def update_parameters(self, parameters):\n        _, self.pos_dim = self[0][1].shape\n        _, self.audio_dim = self[0][0].shape\n        parameters[\"audio_dim\"] = self.audio_dim\n        parameters[\"pos_dim\"] = self.pos_dim\n        # parameters[\"njoints\"] = self.njoints\n\n\n\n\nif __name__ == \"__main__\":\n    # hdtf\n    data_dir = \"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\"\n    # pose_dir = \"/train20/intern/permanent/hbcheng2/data/HDTF/pose\"\n    # crema\n    # data_dir='/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images'\n    dataset = HDTF(data_dir=data_dir, mode='train')\n    for i in range(10):\n        dataset.__getitem__(i)\n        print('------')    \n\n    test_dataset = data.DataLoader(dataset=dataset,\n                                    batch_size=10,\n                                    num_workers=8,\n                                    shuffle=False)\n    for i, batch in enumerate(test_dataset):\n        print(i)\n"
  },
  {
    "path": "PBnet/src/datasets/datasets_hdtf_pos_chunk_norm_eye_fast.py",
    "content": "from os import name\nimport sys\n# sys.path.append('your_path')\n\nimport os\nimport random\nimport torch\n\nimport numpy as np\nimport torch.utils.data as data\nimport torch.nn.functional as Ft\nimport imageio.v2 as imageio\n\nimport cv2\nimport torchvision.transforms.functional as F\nimport matplotlib.pyplot as plt\nfrom PIL import Image\nfrom scipy.interpolate import interp1d\n# import decord\nfrom torchvision.transforms.functional import to_pil_image\nfrom torchvision import transforms\nimport time\nimport pickle as pkl\n\n# decord.bridge.set_bridge('torch')\n\n\ndef resize(im, desired_size, interpolation):\n    old_size = im.shape[:2]\n    ratio = float(desired_size)/max(old_size)\n    new_size = tuple(int(x*ratio) for x in old_size)\n\n    im = cv2.resize(im, (new_size[1], new_size[0]), interpolation=interpolation)\n    delta_w = desired_size - new_size[1]\n    delta_h = desired_size - new_size[0]\n    top, bottom = delta_h//2, delta_h-(delta_h//2)\n    left, right = delta_w//2, delta_w-(delta_w//2)\n\n    color = [0, 0, 0]\n    new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)\n\n    return new_im\n\nclass HDTF(data.Dataset):\n    def __init__(self, data_dir, max_num_frames=80, mode='train'):\n\n        super(HDTF, self).__init__()\n        self.data_dir = data_dir\n        self.max_num_frames = max_num_frames\n        self.mode = mode\n            \n        self.hubert_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate'  #hdtf hubert \n        # self.hubert_dir =  '/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wavlm_interpolate_chunk'   \n        self.pose_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/pose_bar'\n        self.eye_blink_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/eye_blink_bbox_from_xpc_bar'\n\n        # self.max_vals = torch.tensor([20, 10, 10,  7e-4,\n        #     7e+1,  9e+1]).to(torch.float32)\n        # self.min_vals = torch.tensor([-20, -10, -10,  4e-4,\n        #     5e+1,  6e+1]).to(torch.float32)\n\n        self.max_vals = torch.tensor([90, 90, 90,  1,\n            720,  1080]).to(torch.float32)\n        self.min_vals = torch.tensor([-90, -90, -90,  0,\n            0,  0]).to(torch.float32)\n\n        vid_list = []\n\n        # hdtf  \n        vid_id_name_list = ['RD_Radio14_000','RD_Radio30_000','RD_Radio47_000','RD_Radio56_000','WDA_AmyKlobuchar1_001',\\\n                            'WDA_BarbaraLee0_000','WDA_BobCasey0_000','WDA_CatherineCortezMasto_000','WDA_DebbieDingell1_000','WDA_DonaldMcEachin_000',\\\n                            'WDA_EricSwalwell_000','WDA_HenryWaxman_000','WDA_JanSchakowsky1_000','WDA_JoeDonnelly_000','WDA_JohnSarbanes1_000',\\\n                            'WDA_JoeNeguse_001','WDA_KatieHill_000','WDA_LucyMcBath_000','WDA_MazieHirono0_000','WDA_NancyPelosi1_000',\\\n                            'WDA_PattyMurray0_000','WDA_RaulRuiz_000','WDA_SeanPatrickMaloney_000','WDA_TammyBaldwin0_000','WDA_TerriSewell0_000',\\\n                            'WDA_TomCarper_000','WDA_WhipJimClyburn_000','WRA_AdamKinzinger0_000','WRA_AnnWagner_000','WRA_BobCorker_000',\\\n                            'WRA_CandiceMiller0_000','WRA_CathyMcMorrisRodgers2_000','WRA_CoryGardner1_000','WRA_DebFischer1_000','WRA_DianeBlack1_000',\\\n                            'WRA_ErikPaulsen_000','WRA_GeorgeLeMieux_000','WRA_JebHensarling0_001','WRA_JoeHeck1_000','WRA_JohnKasich1_001',\\\n                            'WRA_MarcoRubio_000']\n\n        bad_id_name = ['WDA_DanKildee_000', 'WDA_PatrickLeahy1_000', 'WRA_KristiNoem2_000']\n\n        \n\n        # vid_id_name_list = [item + '.mp4' for item in vid_id_name_list]\n        # bad_id_name = [item + '.mp4' for item in bad_id_name]\n\n        with open('/train20/intern/permanent/hbcheng2/data/HDTF/length_dict.pkl', 'rb') as f:\n            self.len_dict = pkl.load(f)\n\n        if mode == 'train':\n            for id_name in os.listdir(data_dir):\n                # id_name = id_name[:-4]\n                if id_name in vid_id_name_list or id_name in bad_id_name:\n                    continue\n                vid_list.append(id_name)\n            self.videos = vid_list\n        if mode == 'test':\n            self.videos = vid_id_name_list\n\n        self.cache_audio = {}\n        self.cache_eye = {}\n        self.cache_pose = {}\n        \n        for video in self.videos:\n            hubert_path = os.path.join(self.hubert_dir, video) + '.npy'\n            pose_path = os.path.join(self.pose_dir, video) + '.npy'\n            eye_blink_path = os.path.join(self.eye_blink_dir, video) + '.npy'\n\n            hubert_fea = np.load(hubert_path)\n            pose_fea = np.load(pose_path)\n            blink_fea = np.load(eye_blink_path)\n\n            self.cache_audio[video] = hubert_fea\n            self.cache_pose[video] = pose_fea\n            self.cache_eye[video] = blink_fea\n\n    def check_head(self, frame_list, video_name, start, end):\n        start_path = self.get_pose_path(frame_list, video_name, start)\n        end_path = self.get_pose_path(frame_list, video_name, end)\n\n        if os.path.exists(start_path) and os.path.exists(end_path):\n            return True\n        else:\n            return False\n\n\n    def get_block_data_for_two(self, path, start, end):\n        # TODO： id function\n        '''\n        input: \n            start: start id\n            end:  end id\n        output:\n            the data from block\n        '''\n\n        block_st = start//25\n        block_ed = end//25\n\n        st_pos = block_st % 25\n        ed_pos = block_ed % 25\n\n        block_st_name = 'chunk_%04d.npy' % (block_st)\n        block_ed_name = 'chunk_%04d.npy' % (block_ed)\n\n        if block_st != block_ed:\n            block_st_path = os.path.join(path, block_st_name)\n            block_ed_path = os.path.join(path, block_ed_name)\n            block_st = np.load(block_st_path)\n            block_ed = np.load(block_ed_path)\n\n            return np.concatenate((block_st[st_pos:], block_ed[:ed_pos]))\n        else:\n            block_st_path = os.path.join(path, block_st_name)\n            block_st = np.load(block_st_path)\n            return block_st[st_pos, ed_pos]\n\n    def get_block_data(self, path, start, end):\n        # TODO： id function\n        '''\n        input: \n            start: start id\n            end:  end id\n        output:\n            the data from block\n        '''\n\n        block_st = start//25\n        block_ed = end//25\n\n        st_pos = start % 25\n        ed_pos = end % 25\n\n        block_list = [os.path.join(path,'chunk_%04d.npy' % (i)) for i in range(block_st, block_ed+1)]\n\n        if block_st != block_ed:\n            arr_list = []\n            block_st = np.load(block_list[0])\n            arr_list.append(block_st[st_pos:])\n            for path in block_list[1:-1]:\n                arr_list.append(np.load(path))\n\n            block_ed = np.load(block_list[-1])\n            arr_list.append(block_ed[:ed_pos])\n\n            return np.concatenate(arr_list)\n        else:\n            block_st_path = os.path.join(path, block_list[0])\n            block_st = np.load(block_st_path)\n            return block_st[st_pos: ed_pos]\n            \n\n    def check_len(self, name):\n        return self.len_dict[name]\n\n\n    def __len__(self):\n        return len(self.videos)\n\n    def __getitem__(self, idx):\n        video_name = self.videos[idx]\n        # path = os.path.join(self.data_dir, video_name)\n        # hubert_path = os.path.join(self.hubert_dir, video_name)\n        # pose_path = os.path.join(self.pose_dir, video_name)\n        # eye_blink_path = os.path.join(self.eye_blink_dir, video_name)\n\n        total_num_frames = self.check_len(video_name)\n        \n\n        if total_num_frames <= self.max_num_frames:\n            sample_frames = total_num_frames\n            start = 0\n        else:\n            sample_frames = self.max_num_frames\n            start = np.random.randint(total_num_frames-self.max_num_frames)\n        start=start\n        stop=sample_frames+start\n\n        \n        start_time = time.time()\n        # sample_hubert_feature_npy = self.get_block_data(path = hubert_path, start = start, end = stop).astype(np.float32)\n        # sample_pose_list_npy = self.get_block_data(path = pose_path, start = start, end = stop).astype(np.float32)\n\n        sample_hubert_feature_npy = self.cache_audio[video_name][start:stop].astype(np.float32)\n        sample_pose_list_npy = self.cache_pose[video_name][start:stop].astype(np.float32)\n        sample_eye_blink_list_npy = self.cache_eye[video_name][start:stop].astype(np.float32)\n\n        # end_time = time.time()\n        # print(\"dataset_audiopose_cost: \", - start_time + end_time)\n        # start_time = time.time()\n\n        # sample_eye_blink_list_npy = self.get_block_data(path = eye_blink_path, start = start, end = stop).astype(np.float32)\n\n        # end_time = time.time()\n        # print(\"dataset_eye_cost: \", - start_time + end_time)\n        # start_time = time.time()\n\n        sample_hubert_feature_tensor = torch.tensor(sample_hubert_feature_npy)\n        sample_pos_feature_tensor = torch.tensor(sample_pose_list_npy)[:,:-1]\n        sample_pos_feature_tensor = (sample_pos_feature_tensor - self.min_vals)/ (self.max_vals - self.min_vals) \n\n        # end_time = time.time()\n        # print(\"dataset_audiopose_cost2: \", - start_time + end_time)\n        # start_time = time.time()\n\n        sample_eye_feature_tensor = torch.tensor(sample_eye_blink_list_npy)[:,:2]\n\n        # end_time = time.time()\n        # print(\"dataset_eye_cost2: \", - start_time + end_time)\n        # start_time = time.time()\n\n        sample_pos_eye_cat_tensor  = torch.cat((sample_pos_feature_tensor,sample_eye_feature_tensor),dim=1)\n\n        # end_time = time.time()\n        # print(\"dataset_eye_cost3: \", - start_time + end_time)\n        # start_time = time.time()\n\n        \n        video_name = video_name.replace('/','_')\n\n        # sample_pose_list_npy = sample_pose_list_npy.transpose(1,0)  # for compatibility\n\n        return sample_hubert_feature_tensor, sample_pos_feature_tensor, sample_eye_feature_tensor, video_name, start, sample_pos_eye_cat_tensor \n\n    def update_parameters(self, parameters):\n        _, self.pos_dim = self[0][1].shape\n        _, self.eye_dim = self[0][2].shape\n        _, self.audio_dim = self[0][0].shape\n        parameters[\"audio_dim\"] = self.audio_dim\n        parameters[\"pos_dim\"] = self.pos_dim\n        parameters[\"eye_dim\"] = self.eye_dim\n        # parameters[\"njoints\"] = self.njoints\n\n\n\n\nif __name__ == \"__main__\":\n    # hdtf\n    data_dir = \"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\"\n    # pose_dir = \"/train20/intern/permanent/hbcheng2/data/HDTF/pose\"\n    # crema\n    # data_dir='/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images'\n    dataset = HDTF(data_dir=data_dir, mode='train')\n    for i in range(10):\n        dataset.__getitem__(i)\n        print('------')    \n\n    test_dataset = data.DataLoader(dataset=dataset,\n                                    batch_size=10,\n                                    num_workers=8,\n                                    shuffle=False)\n    for i, batch in enumerate(test_dataset):\n        print(i)\n"
  },
  {
    "path": "PBnet/src/datasets/datasets_hdtf_pos_df.py",
    "content": "from os import name\nfrom src.datasets.datasets_hdtf_pos import HDTF\nimport sys\nsys.path.append('your_path')\n\nimport os\nimport random\nimport torch\n\nimport numpy as np\nimport torch.utils.data as data\nimport imageio.v2 as imageio\n\nimport cv2\nimport torchvision.transforms.functional as F\nimport matplotlib.pyplot as plt\nfrom PIL import Image\nfrom scipy.interpolate import interp1d\n\n# from ..utils.tensors import collate\n\ndef resize(im, desired_size, interpolation):\n    old_size = im.shape[:2]\n    ratio = float(desired_size)/max(old_size)\n    new_size = tuple(int(x*ratio) for x in old_size)\n\n    im = cv2.resize(im, (new_size[1], new_size[0]), interpolation=interpolation)\n    delta_w = desired_size - new_size[1]\n    delta_h = desired_size - new_size[0]\n    top, bottom = delta_h//2, delta_h-(delta_h//2)\n    left, right = delta_w//2, delta_w-(delta_w//2)\n\n    color = [0, 0, 0]\n    new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)\n\n    return new_im\n\nclass HDTF(data.Dataset):\n    def __init__(self, data_dir, max_num_frames=80, min_num_frames=40, mode='train'):\n\n        super(HDTF, self).__init__()\n        \n        self.data_dir = data_dir\n        self.max_num_frames = max_num_frames\n        self.min_num_frames = min_num_frames\n        self.mode = mode\n        # self.hubert_dir = '/train20/intern/permanent/lmlin2/data/crema_wav_hubert'\n        # self.pose_dir = '/train20/intern/permanent/hbcheng2/data/crema/pose'\n\n        self.pose_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/pose_bar'\n        self.hubert_dir = '/train20/intern/permanent/lmlin2/data/hdtf_wav_hubert'\n        \n        vid_list = []\n        # # crema\n        # if mode == 'train':\n        #     for id_name in os.listdir(data_dir):\n        #         if id_name in ['s15','s20','s21','s30','s33','s52','s62','s81','s82','s89']: #['s64','s76','s88','s90','s91']\n        #             continue\n        #         vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{data_dir}/{id_name}') ])\n\n        # # crema\n        # if mode == 'test':\n        #     for id_name in ['s15','s20','s21','s30','s33','s52','s62','s81','s82','s89']:\n        #         vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{data_dir}/{id_name}') ])\n\n\n\n        # hdtf  \n        vid_id_name_list = ['RD_Radio14_000','RD_Radio30_000','RD_Radio47_000','RD_Radio56_000','WDA_AmyKlobuchar1_001',\\\n                            'WDA_BarbaraLee0_000','WDA_BobCasey0_000','WDA_CatherineCortezMasto_000','WDA_DebbieDingell1_000','WDA_DonaldMcEachin_000',\\\n                            'WDA_EricSwalwell_000','WDA_HenryWaxman_000','WDA_JanSchakowsky1_000','WDA_JoeDonnelly_000','WDA_JohnSarbanes1_000',\\\n                            'WDA_JoeNeguse_001','WDA_KatieHill_000','WDA_LucyMcBath_000','WDA_MazieHirono0_000','WDA_NancyPelosi1_000',\\\n                            'WDA_PattyMurray0_000','WDA_RaulRuiz_000','WDA_SeanPatrickMaloney_000','WDA_TammyBaldwin0_000','WDA_TerriSewell0_000',\\\n                            'WDA_TomCarper_000','WDA_WhipJimClyburn_000','WRA_AdamKinzinger0_000','WRA_AnnWagner_000','WRA_BobCorker_000',\\\n                            'WRA_CandiceMiller0_000','WRA_CathyMcMorrisRodgers2_000','WRA_CoryGardner1_000','WRA_DebFischer1_000','WRA_DianeBlack1_000',\\\n                            'WRA_ErikPaulsen_000','WRA_GeorgeLeMieux_000','WRA_JebHensarling0_001','WRA_JoeHeck1_000','WRA_JohnKasich1_001',\\\n                            'WRA_MarcoRubio_000']\n        bad_id_name_list = ['WDA_DanKildee_000', 'WDA_PatrickLeahy1_000', 'WRA_KristiNoem2_000', 'RD_Radio39_000']\n        if mode == 'train':\n            for id_name in os.listdir(data_dir):\n                if id_name in vid_id_name_list or id_name in bad_id_name_list:\n                    continue\n                vid_list.append(id_name)\n            self.videos = vid_list\n        if mode == 'test':\n            self.videos = vid_id_name_list\n\n\n    # def __len__(self):\n    #     return len(self.videos)\n\n    def __len__(self):\n        num_seq_max = getattr(self, \"num_seq_max\", -1)\n        if num_seq_max == -1:\n            from math import inf\n            num_seq_max = inf\n\n        return min(len(self.videos), num_seq_max)\n        \n    def __getitem__(self, idx):\n        \n        video_name = self.videos[idx]\n        path = os.path.join(self.data_dir, video_name)\n        # path_pose = os.path.join(self.pose_dir, video_name)\n\n        frame_path_list = os.listdir(path)\n        frame_path_list.sort()\n        total_num_frames = len(frame_path_list)\n\n        # pose_path_list = os.listdir(path_pose)\n        # pose_path_list.sort()\n\n        hubert_path = os.path.join(self.hubert_dir, video_name+'.npy')\n        hubert_feature = np.load(hubert_path)\n        Nframes_hubert = hubert_feature.shape[0]\n        interp_func = interp1d(np.arange(Nframes_hubert), hubert_feature, kind='linear', axis=0)\n        hubert_feature = interp_func(np.linspace(0, Nframes_hubert - 1, total_num_frames)).astype(np.float32)\n        \n        pose_path = os.path.join(self.pose_dir, video_name+'.npy')\n        pose_seq = np.load(pose_path).astype(np.float32)\n\n        cur_num_frames = np.random.randint(self.min_num_frames, self.max_num_frames+1)\n        if total_num_frames <= cur_num_frames:\n            sample_frames = total_num_frames\n            start = 0\n        else:\n            sample_frames = cur_num_frames\n            start = np.random.randint(total_num_frames-cur_num_frames)\n        sample_idx_list = np.linspace(start=start, stop=sample_frames+start-1, num=sample_frames, dtype=int)\n        # sample_frame_path_list = [frame_path_list[x] for x in sample_idx_list]\n        # sample_pose_path_list = [pose_path_list[x] for x in sample_idx_list]\n\n        sample_hubert_feature_list = [hubert_feature[x,:] for x in sample_idx_list]  # nf,1024\n        sample_hubert_feature_tensor = [torch.from_numpy(arr) for arr in sample_hubert_feature_list]\n        sample_hubert_feature_tensor = torch.stack(sample_hubert_feature_tensor)\n        # sample_hubert_feature_list = np.stack(sample_hubert_feature_list).reshape(-1)  # (nf*1024)\n\n        # load pose\n        try:\n            # sample_pose_list = [np.load(os.path.join(path_pose, x))[0][:-1].astype(np.float32) for x in sample_pose_path_list]\n            sample_pose_list = [pose_seq[x,:] for x in sample_idx_list]\n            sample_pos_feature_tensor = [torch.from_numpy(arr) for arr in sample_pose_list]\n            sample_pos_feature_tensor = torch.stack(sample_pos_feature_tensor) # nf, 6\n        except Exception:\n            # print(os.path.join(path_pose, x))\n            print(\"load fail !! \")\n            print(pose_path)\n            print(sample_idx_list)\n            sample_pose_list = [pose_seq[x,:] for x in sample_idx_list]\n            sample_pos_feature_tensor = [torch.from_numpy(arr) for arr in sample_pose_list]\n            sample_pos_feature_tensor = torch.stack(sample_pos_feature_tensor) # nf, 6\n        \n        # added to change the video_name of crema\n        video_name = video_name.replace('/','_')\n        # sample_class_tensor = torch.tensor(0)\n\n        return sample_hubert_feature_tensor, sample_pos_feature_tensor, video_name\n\n    def update_parameters(self, parameters):\n        _, self.pos_dim = self[0][1].shape\n        _, self.audio_dim = self[0][0].shape\n        parameters[\"audio_dim\"] = self.audio_dim\n        parameters[\"pos_dim\"] = self.pos_dim\n        # parameters[\"njoints\"] = self.njoints\n\n\nif __name__ == \"__main__\":\n    data_dir = \"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\"\n    # data_dir='/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images'\n    dataset = HDTF(data_dir=data_dir,mode='test')\n    for i in range(100):\n        dataset.__getitem__(i)\n        print('------')    \n\n    test_dataset = data.DataLoader(dataset=dataset,\n                                    batch_size=10,\n                                    num_workers=8,\n                                    shuffle=False)\n    for i, batch in enumerate(test_dataset):\n        print(i)\n"
  },
  {
    "path": "PBnet/src/datasets/datasets_hdtf_pos_dict_norm_2.py",
    "content": "from os import name\nimport sys\nsys.path.append('your_path')\n\nimport os\nimport random\nimport torch\n\nimport numpy as np\nimport torch.utils.data as data\nimport torch.nn.functional as Ft\nimport imageio.v2 as imageio\n\nimport cv2\nimport torchvision.transforms.functional as F\nimport matplotlib.pyplot as plt\nfrom PIL import Image\nfrom scipy.interpolate import interp1d\n# import decord\nfrom torchvision.transforms.functional import to_pil_image\nfrom torchvision import transforms\nimport time\nimport pickle as pkl\n\n# decord.bridge.set_bridge('torch')\n\n\ndef resize(im, desired_size, interpolation):\n    old_size = im.shape[:2]\n    ratio = float(desired_size)/max(old_size)\n    new_size = tuple(int(x*ratio) for x in old_size)\n\n    im = cv2.resize(im, (new_size[1], new_size[0]), interpolation=interpolation)\n    delta_w = desired_size - new_size[1]\n    delta_h = desired_size - new_size[0]\n    top, bottom = delta_h//2, delta_h-(delta_h//2)\n    left, right = delta_w//2, delta_w-(delta_w//2)\n\n    color = [0, 0, 0]\n    new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)\n\n    return new_im\n\nclass HDTF(data.Dataset):\n    def __init__(self, data_dir, max_num_frames=80, mode='train'):\n\n        super(HDTF, self).__init__()\n        self.data_dir = data_dir\n        self.max_num_frames = max_num_frames\n        self.mode = mode\n            \n        self.hubert_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate'  \n        self.pose_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/pose_bar'\n\n        # self.max_vals = torch.tensor([20, 10, 10,  7e-4,\n        #     7e+1,  9e+1]).to(torch.float32)\n        # self.min_vals = torch.tensor([-20, -10, -10,  4e-4,\n        #     5e+1,  6e+1]).to(torch.float32)\n\n        self.max_vals = torch.tensor([90, 90, 90,  1,\n            720,  1080]).to(torch.float32).cuda()\n        self.min_vals = torch.tensor([-90, -90, -90,  0,\n            0,  0]).to(torch.float32).cuda()\n\n        vid_list = []\n\n        # hdtf  \n        vid_id_name_list = ['RD_Radio14_000','RD_Radio30_000','RD_Radio47_000','RD_Radio56_000','WDA_AmyKlobuchar1_001',\\\n                            'WDA_BarbaraLee0_000','WDA_BobCasey0_000','WDA_CatherineCortezMasto_000','WDA_DebbieDingell1_000','WDA_DonaldMcEachin_000',\\\n                            'WDA_EricSwalwell_000','WDA_HenryWaxman_000','WDA_JanSchakowsky1_000','WDA_JoeDonnelly_000','WDA_JohnSarbanes1_000',\\\n                            'WDA_JoeNeguse_001','WDA_KatieHill_000','WDA_LucyMcBath_000','WDA_MazieHirono0_000','WDA_NancyPelosi1_000',\\\n                            'WDA_PattyMurray0_000','WDA_RaulRuiz_000','WDA_SeanPatrickMaloney_000','WDA_TammyBaldwin0_000','WDA_TerriSewell0_000',\\\n                            'WDA_TomCarper_000','WDA_WhipJimClyburn_000','WRA_AdamKinzinger0_000','WRA_AnnWagner_000','WRA_BobCorker_000',\\\n                            'WRA_CandiceMiller0_000','WRA_CathyMcMorrisRodgers2_000','WRA_CoryGardner1_000','WRA_DebFischer1_000','WRA_DianeBlack1_000',\\\n                            'WRA_ErikPaulsen_000','WRA_GeorgeLeMieux_000','WRA_JebHensarling0_001','WRA_JoeHeck1_000','WRA_JohnKasich1_001',\\\n                            'WRA_MarcoRubio_000']\n\n        bad_id_name = ['WDA_DanKildee_000', 'WDA_PatrickLeahy1_000', 'WRA_KristiNoem2_000']\n\n        # vid_id_name_list = [item + '.mp4' for item in vid_id_name_list]\n        # bad_id_name = [item + '.mp4' for item in bad_id_name]\n\n        with open('/train20/intern/permanent/hbcheng2/data/HDTF/length_dict.pkl', 'rb') as f:\n            self.len_dict = pkl.load(f)\n\n        if mode == 'train':\n            for id_name in os.listdir(data_dir):\n                # id_name = id_name[:-4]\n                if id_name in vid_id_name_list or id_name in bad_id_name:\n                    continue\n                vid_list.append(id_name)\n            self.videos = vid_list\n        if mode == 'test':\n            self.videos = vid_id_name_list\n\n        self.audio_dict = {}\n        self.pose_dict = {}\n\n        for video_name in self.videos:\n            hubert_path = os.path.join(self.hubert_dir, video_name + '.npy')\n            pose_path = os.path.join(self.pose_dir, video_name + '.npy')\n            \n            hubert_npy = torch.tensor(np.load(hubert_path).astype(np.float32))\n            pose_npy = torch.tensor(np.load(pose_path).astype(np.float32))\n\n            self.audio_dict[video_name] = hubert_npy\n            self.pose_dict[video_name] = pose_npy\n\n\n    def check_head(self, frame_list, video_name, start, end):\n        start_path = self.get_pose_path(frame_list, video_name, start)\n        end_path = self.get_pose_path(frame_list, video_name, end)\n\n        if os.path.exists(start_path) and os.path.exists(end_path):\n            return True\n        else:\n            return False\n\n\n    def get_block_data_for_two(self, path, start, end):\n        # TODO： id function\n        '''\n        input: \n            start: start id\n            end:  end id\n        output:\n            the data from block\n        '''\n\n        block_st = start//25\n        block_ed = end//25\n\n        st_pos = block_st % 25\n        ed_pos = block_ed % 25\n\n        block_st_name = 'chunk_%04d.npy' % (block_st)\n        block_ed_name = 'chunk_%04d.npy' % (block_ed)\n\n        if block_st != block_ed:\n            block_st_path = os.path.join(path, block_st_name)\n            block_ed_path = os.path.join(path, block_ed_name)\n            block_st = np.load(block_st_path)\n            block_ed = np.load(block_ed_path)\n\n            return np.concatenate((block_st[st_pos:], block_ed[:ed_pos]))\n        else:\n            block_st_path = os.path.join(path, block_st_name)\n            block_st = np.load(block_st_path)\n            return block_st[st_pos, ed_pos]\n\n    def get_block_data(self, path, start, end):\n        # TODO： id function\n        '''\n        input: \n            start: start id\n            end:  end id\n        output:\n            the data from block\n        '''\n\n        block_st = start//25\n        block_ed = end//25\n\n        st_pos = start % 25\n        ed_pos = end % 25\n\n        block_list = [os.path.join(path,'chunk_%04d.npy' % (i)) for i in range(block_st, block_ed+1)]\n\n        if block_st != block_ed:\n            arr_list = []\n            block_st = np.load(block_list[0])\n            arr_list.append(block_st[st_pos:])\n            for path in block_list[1:-1]:\n                arr_list.append(np.load(path))\n\n            block_ed = np.load(block_list[-1])\n            arr_list.append(block_ed[:ed_pos])\n\n            return np.concatenate(arr_list)\n        else:\n            block_st_path = os.path.join(path, block_list[0])\n            block_st = np.load(block_st_path)\n            return block_st[st_pos: ed_pos]\n            \n\n    def check_len(self, name):\n        return self.len_dict[name]\n\n\n    def __len__(self):\n        return len(self.videos)\n\n    def __getitem__(self, idx):\n        # if __debug__:\n        video_name = self.videos[idx]\n        # path = os.path.join(self.data_dir, video_name)\n        # hubert_path = os.path.join(self.hubert_dir, video_name)\n        # pose_path = os.path.join(self.pose_dir, video_name)\n        # eye_blink_path = os.path.join(self.eye_blink_dir, video_name)\n\n        total_num_frames = self.check_len(video_name)\n        \n\n        if total_num_frames <= self.max_num_frames:\n            sample_frames = total_num_frames\n            start = 0\n        else:\n            sample_frames = self.max_num_frames\n            start = np.random.randint(total_num_frames-self.max_num_frames)\n        start=start\n        stop=sample_frames+start\n\n        # end_time = time.time()\n        # print(\"indexing: \", - start_time + end_time)\n        # start_time = time.time()\n            \n        sample_hubert_feature_tensor = self.audio_dict[video_name][start:stop].cuda() # self.get_block_data(path = hubert_path, start = start, end = stop).astype(np.float32)\n        sample_pos_feature_tensor = self.pose_dict[video_name][start:stop][:,:-1].cuda() # self.get_block_data(path = pose_path, start = start, end = stop).astype(np.float32)\n\n        # end_time = time.time()\n        # print(\"loading: \", - start_time + end_time)\n        # start_time = time.time()\n\n        # sample_hubert_feature_tensor = torch.tensor(sample_hubert_feature_npy)\n        # sample_pos_feature_tensor = torch.tensor(sample_pose_list_npy)[:,:-1]\n        # end_time = time.time()\n        # print(\"converting: \", - start_time + end_time)\n        # start_time = time.time()\n        sample_pos_feature_tensor = (sample_pos_feature_tensor - self.min_vals)/ (self.max_vals - self.min_vals) \n        video_name = video_name.replace('/','_')\n\n        # end_time = time.time()\n        # print(\"processing: \", - start_time + end_time)\n        # start_time = time.time()\n\n        # sample_pose_list_npy = sample_pose_list_npy.transpose(1,0)  # for compatibility\n\n        return sample_hubert_feature_tensor, sample_pos_feature_tensor, video_name, start\n\n    def update_parameters(self, parameters):\n        _, self.pos_dim = self[0][1].shape\n        _, self.audio_dim = self[0][0].shape\n        parameters[\"audio_dim\"] = self.audio_dim\n        parameters[\"pos_dim\"] = self.pos_dim\n        # parameters[\"njoints\"] = self.njoints\n\n\n\n\nif __name__ == \"__main__\":\n    # hdtf\n    data_dir = \"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\"\n    # pose_dir = \"/train20/intern/permanent/hbcheng2/data/HDTF/pose\"\n    # crema\n    # data_dir='/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images'\n    dataset = HDTF(data_dir=data_dir, mode='train')\n    for i in range(10):\n        dataset.__getitem__(i)\n        print('------')    \n\n    test_dataset = data.DataLoader(dataset=dataset,\n                                    batch_size=10,\n                                    num_workers=8,\n                                    shuffle=False)\n    for i, batch in enumerate(test_dataset):\n        print(i)\n"
  },
  {
    "path": "PBnet/src/datasets/datasets_hdtf_wpose_lmk_block.py",
    "content": "from os import name\nimport sys\nsys.path.append('your_path')\n\nimport os\nimport random\nimport torch\n\nimport numpy as np\nimport torch.utils.data as data\nimport torch.nn.functional as Ft\nimport imageio.v2 as imageio\n\nimport cv2\nimport torchvision.transforms.functional as F\nimport matplotlib.pyplot as plt\nfrom PIL import Image\nfrom scipy.interpolate import interp1d\nimport decord\nfrom torchvision.transforms.functional import to_pil_image\nfrom torchvision import transforms\nimport time\nimport pickle as pkl\n\ndecord.bridge.set_bridge('torch')\n\n\ndef resize(im, desired_size, interpolation):\n    old_size = im.shape[:2]\n    ratio = float(desired_size)/max(old_size)\n    new_size = tuple(int(x*ratio) for x in old_size)\n\n    im = cv2.resize(im, (new_size[1], new_size[0]), interpolation=interpolation)\n    delta_w = desired_size - new_size[1]\n    delta_h = desired_size - new_size[0]\n    top, bottom = delta_h//2, delta_h-(delta_h//2)\n    left, right = delta_w//2, delta_w-(delta_w//2)\n\n    color = [0, 0, 0]\n    new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)\n\n    return new_im\n\nclass HDTF(data.Dataset):\n    def __init__(self, data_dir, pose_dir, eye_blink_dir, max_num_frames=80, image_size=128,mode='train',\n                 mean=(128, 128, 128), color_jitter=True):\n\n        super(HDTF, self).__init__()\n        self.mean = torch.tensor(mean)[None,:,None,None]\n        self.data_dir = data_dir\n        self.pose_dir = pose_dir\n        self.eye_blink_dir = eye_blink_dir\n        self.is_jitter = color_jitter\n        self.max_num_frames = max_num_frames\n        self.image_size = image_size\n        self.mode = mode\n\n        vid_list = []\n        # # crema\n        # self.hubert_dir = '/train20/intern/permanent/lmlin2/data/crema_wav_hubert'\n        # if mode == 'train':\n        #     for id_name in os.listdir(data_dir):\n        #         if id_name in ['s64','s76','s88','s90','s91']:\n        #             continue\n        #         vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{data_dir}/{id_name}') ])\n        # if mode == 'test':\n        #     for id_name in ['s64','s76','s88','s90','s91']:\n        #         vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{data_dir}/{id_name}') ])\n        # self.videos = vid_list\n\n        # hdtf  \n        vid_id_name_list = ['RD_Radio14_000','RD_Radio30_000','RD_Radio47_000','RD_Radio56_000','WDA_AmyKlobuchar1_001',\\\n                            'WDA_BarbaraLee0_000','WDA_BobCasey0_000','WDA_CatherineCortezMasto_000','WDA_DebbieDingell1_000','WDA_DonaldMcEachin_000',\\\n                            'WDA_EricSwalwell_000','WDA_HenryWaxman_000','WDA_JanSchakowsky1_000','WDA_JoeDonnelly_000','WDA_JohnSarbanes1_000',\\\n                            'WDA_JoeNeguse_001','WDA_KatieHill_000','WDA_LucyMcBath_000','WDA_MazieHirono0_000','WDA_NancyPelosi1_000',\\\n                            'WDA_PattyMurray0_000','WDA_RaulRuiz_000','WDA_SeanPatrickMaloney_000','WDA_TammyBaldwin0_000','WDA_TerriSewell0_000',\\\n                            'WDA_TomCarper_000','WDA_WhipJimClyburn_000','WRA_AdamKinzinger0_000','WRA_AnnWagner_000','WRA_BobCorker_000',\\\n                            'WRA_CandiceMiller0_000','WRA_CathyMcMorrisRodgers2_000','WRA_CoryGardner1_000','WRA_DebFischer1_000','WRA_DianeBlack1_000',\\\n                            'WRA_ErikPaulsen_000','WRA_GeorgeLeMieux_000','WRA_JebHensarling0_001','WRA_JoeHeck1_000','WRA_JohnKasich1_001',\\\n                            'WRA_MarcoRubio_000']\n\n        bad_id_name = ['WDA_DanKildee_000', 'WDA_PatrickLeahy1_000', 'WRA_KristiNoem2_000']\n\n        # vid_id_name_list = [item + '.mp4' for item in vid_id_name_list]\n        # bad_id_name = [item + '.mp4' for item in bad_id_name]\n        # hdtf  \n        self.hubert_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate_chunk'  \n        with open('/train20/intern/permanent/hbcheng2/data/HDTF/length_dict.pkl', 'rb') as f:\n            self.len_dict = pkl.load(f)\n        # vid_id_name_list = ['RD_Radio47_000','WDA_CatherineCortezMasto_000','WDA_JoeNeguse_001','WDA_MichelleLujanGrisham_000','WRA_ErikPaulsen_002', \\\n        #                     'WDA_ZoeLofgren_000','WRA_JebHensarling2_003','WRA_MichaelSteele_000', 'WRA_ToddYoung_000', 'WRA_VickyHartzler_000']\n        if mode == 'train':\n            for id_name in os.listdir(data_dir):\n                # id_name = id_name[:-4]\n                if id_name in vid_id_name_list or id_name in bad_id_name:\n                    continue\n                vid_list.append(id_name)\n            self.videos = vid_list\n        if mode == 'test':\n            self.videos = vid_id_name_list\n\n    def check_head(self, frame_list, video_name, start, end):\n        start_path = self.get_pose_path(frame_list, video_name, start)\n        end_path = self.get_pose_path(frame_list, video_name, end)\n\n        if os.path.exists(start_path) and os.path.exists(end_path):\n            return True\n        else:\n            return False\n\n\n    def get_block_data_for_two(self, path, start, end):\n        # TODO： id function\n        '''\n        input: \n            start: start id\n            end:  end id\n        output:\n            the data from block\n        '''\n\n        block_st = start//25\n        block_ed = end//25\n\n        st_pos = block_st % 25\n        ed_pos = block_ed % 25\n\n        block_st_name = 'chunk_%04d.npy' % (block_st)\n        block_ed_name = 'chunk_%04d.npy' % (block_ed)\n\n        if block_st != block_ed:\n            block_st_path = os.path.join(path, block_st_name)\n            block_ed_path = os.path.join(path, block_ed_name)\n            block_st = np.load(block_st_path)\n            block_ed = np.load(block_ed_path)\n\n            return np.concatenate((block_st[st_pos:], block_ed[:ed_pos]))\n        else:\n            block_st_path = os.path.join(path, block_st_name)\n            block_st = np.load(block_st_path)\n            return block_st[st_pos, ed_pos]\n\n    def get_block_data(self, path, start, end):\n        # TODO： id function\n        '''\n        input: \n            start: start id\n            end:  end id\n        output:\n            the data from block\n        '''\n\n        block_st = start//25\n        block_ed = end//25\n\n        st_pos = start % 25\n        ed_pos = end % 25\n\n        block_list = [os.path.join(path,'chunk_%04d.npy' % (i)) for i in range(block_st, block_ed+1)]\n\n        if block_st != block_ed:\n            arr_list = []\n            block_st = np.load(block_list[0])\n            arr_list.append(block_st[st_pos:])\n            for path in block_list[1:-1]:\n                arr_list.append(np.load(path))\n\n            block_ed = np.load(block_list[-1])\n            arr_list.append(block_ed[:ed_pos])\n\n            return np.concatenate(arr_list)\n        else:\n            block_st_path = os.path.join(path, block_list[0])\n            block_st = np.load(block_st_path)\n            return block_st[st_pos: ed_pos]\n            \n\n    def check_len(self, name):\n        \n        return self.len_dict[name]\n\n\n    def __len__(self):\n        return len(self.videos)\n\n    def __getitem__(self, idx):\n        video_name = self.videos[idx]\n        path = os.path.join(self.data_dir, video_name)\n        hubert_path = os.path.join(self.hubert_dir, video_name)\n        pose_path = os.path.join(self.pose_dir, video_name)\n        eye_blink_path = os.path.join(self.eye_blink_dir, video_name)\n\n        total_num_frames = self.check_len(video_name)\n\n        \n\n        \n\n        if total_num_frames <= self.max_num_frames:\n            sample_frames = total_num_frames\n            start = 0\n        else:\n            sample_frames = self.max_num_frames\n            start = np.random.randint(total_num_frames-self.max_num_frames)\n        start=start\n        stop=sample_frames+start\n\n        sample_frame_npy = self.get_block_data(path = path, start = start, end = stop)\n        sample_hubert_feature_npy = self.get_block_data(path = hubert_path, start = start, end = stop).astype(np.float32)\n        sample_pose_list_npy = self.get_block_data(path = pose_path, start = start, end = stop).astype(np.float32)\n        sample_eye_blink_list_npy = self.get_block_data(path = eye_blink_path, start = start, end = stop).astype(np.float32)\n\n        sample_frame_list = torch.tensor(sample_frame_npy).permute(0,3,1,2)\n        sample_hubert_feature_tensor = torch.tensor(sample_hubert_feature_npy)\n        sample_frame_list = sample_frame_list - self.mean # 20, 3, 128, 128\n        # sample_frame_list = [np.transpose(x, (2, 0, 1)) for x in sample_frame_list]\n        # sample_frame_list_npy = np.stack(sample_frame_list, axis=1) \n        # sample_pose_list_npy = np.stack(sample_pose_list, axis = 1)\n        # sample_eye_blink_list_npy = np.stack(sample_eye_blink_list, axis = 1)\n        # change to float32\n        sample_frame_list = sample_frame_list.permute(1, 0, 2, 3)\n        # sample_frame_list = np.array(sample_frame_list/255.0, dtype=np.float32)  #3, 40, 128, 128\n        # sample_frame_list = sample_frame_list/255.  # put to mode l forward\n        # added to change the video_name of crema\n        video_name = video_name.replace('/','_')\n\n        sample_pose_list_npy = sample_pose_list_npy.transpose(1,0)  # for compatibility\n        sample_eye_blink_list_npy = sample_eye_blink_list_npy.transpose(1,0)\n        \n        if self.mode == 'test':\n            return sample_frame_list, sample_hubert_feature_tensor, sample_pose_list_npy, sample_eye_blink_list_npy, video_name, start\n        return sample_frame_list, sample_hubert_feature_tensor, sample_pose_list_npy, sample_eye_blink_list_npy, video_name, total_num_frames\n\n\n\n\nif __name__ == \"__main__\":\n    # hdtf\n    data_dir = \"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\"\n    pose_dir = \"/train20/intern/permanent/hbcheng2/data/HDTF/pose\"\n    # crema\n    # data_dir='/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images'\n    dataset = HDTF(data_dir=data_dir, pose_dir=pose_dir ,mode='train')\n    for i in range(10):\n        dataset.__getitem__(i)\n        print('------')    \n\n    test_dataset = data.DataLoader(dataset=dataset,\n                                    batch_size=10,\n                                    num_workers=8,\n                                    shuffle=False)\n    for i, batch in enumerate(test_dataset):\n        print(i)\n"
  },
  {
    "path": "PBnet/src/datasets/get_dataset.py",
    "content": "def get_dataset(name=\"ntu13\"):\n    if name == \"ntu13\":\n        from .ntu13 import NTU13\n        return NTU13\n    elif name == \"uestc\":\n        from .uestc import UESTC\n        return UESTC\n    elif name == \"humanact12\":\n        from .humanact12poses import HumanAct12Poses\n        return HumanAct12Poses\n\n\ndef get_datasets(parameters):\n    name = parameters[\"dataset\"]\n\n    DATA = get_dataset(name)\n    dataset = DATA(split=\"train\", **parameters)\n\n    train = dataset\n\n    # test: shallow copy (share the memory) but set the other indices\n    from copy import copy\n    test = copy(train)\n    test.split = test\n\n    datasets = {\"train\": train,\n                \"test\": test}\n\n    # add specific parameters from the dataset loading\n    dataset.update_parameters(parameters)\n\n    return datasets\n"
  },
  {
    "path": "PBnet/src/datasets/tools.py",
    "content": "import os\nimport string\n\n\ndef parse_info_name(path):\n    name = os.path.splitext(os.path.split(path)[-1])[0]\n    info = {}\n    current_letter = None\n    for letter in name:\n        if letter in string.ascii_letters:\n            info[letter] = []\n            current_letter = letter\n        else:\n            info[current_letter].append(letter)\n    for key in info.keys():\n        info[key] = \"\".join(info[key])\n    return info\n\n\n"
  },
  {
    "path": "PBnet/src/evaluate/__init__.py",
    "content": ""
  },
  {
    "path": "PBnet/src/evaluate/action2motion/accuracy.py",
    "content": "import torch\n\n\ndef calculate_accuracy(model, motion_loader, num_labels, classifier, device):\n    confusion = torch.zeros(num_labels, num_labels, dtype=torch.long)\n    with torch.no_grad():\n        for batch in motion_loader:\n            batch_prob = classifier(batch[\"output_xyz\"], lengths=batch[\"lengths\"])\n            batch_pred = batch_prob.max(dim=1).indices\n            for label, pred in zip(batch[\"y\"], batch_pred):\n                confusion[label][pred] += 1\n\n    accuracy = torch.trace(confusion)/torch.sum(confusion)\n    return accuracy.item(), confusion\n"
  },
  {
    "path": "PBnet/src/evaluate/action2motion/diversity.py",
    "content": "import torch\nimport numpy as np\n\n\n# from action2motion\ndef calculate_diversity_multimodality(activations, labels, num_labels):\n    diversity_times = 200\n    multimodality_times = 20\n    labels = labels.long()\n    num_motions = len(labels)\n\n    diversity = 0\n        \n    first_indices = np.random.randint(0, num_motions, diversity_times)\n    second_indices = np.random.randint(0, num_motions, diversity_times)\n    for first_idx, second_idx in zip(first_indices, second_indices):\n        diversity += torch.dist(activations[first_idx, :],\n                                activations[second_idx, :])\n    diversity /= diversity_times\n\n    multimodality = 0\n    label_quotas = np.repeat(multimodality_times, num_labels)\n    while np.any(label_quotas > 0):\n        # print(label_quotas)\n        first_idx = np.random.randint(0, num_motions)\n        first_label = labels[first_idx]\n        if not label_quotas[first_label]:\n            continue\n\n        second_idx = np.random.randint(0, num_motions)\n        second_label = labels[second_idx]\n        while first_label != second_label:\n            second_idx = np.random.randint(0, num_motions)\n            second_label = labels[second_idx]\n\n        label_quotas[first_label] -= 1\n\n        first_activation = activations[first_idx, :]\n        second_activation = activations[second_idx, :]\n        multimodality += torch.dist(first_activation,\n                                    second_activation)\n\n    multimodality /= (multimodality_times * num_labels)\n\n    return diversity.item(), multimodality.item()\n\n"
  },
  {
    "path": "PBnet/src/evaluate/action2motion/evaluate.py",
    "content": "import torch\nimport numpy as np\nfrom .models import load_classifier, load_classifier_for_fid\nfrom .accuracy import calculate_accuracy\nfrom .fid import calculate_fid\nfrom .diversity import calculate_diversity_multimodality\n\n\nclass A2MEvaluation:\n    def __init__(self, dataname, device):\n        dataset_opt = {\"ntu13\": {\"joints_num\": 18,\n                                 \"input_size_raw\": 54,\n                                 \"num_classes\": 13},\n                       'humanact12': {\"input_size_raw\": 72,\n                                      \"joints_num\": 24,\n                                      \"num_classes\": 12}}\n        \n        if dataname != dataset_opt.keys():\n            assert NotImplementedError(f\"{dataname} is not supported.\")\n        \n        self.dataname = dataname\n        self.input_size_raw = dataset_opt[dataname][\"input_size_raw\"]\n        self.num_classes = dataset_opt[dataname][\"num_classes\"]\n        self.device = device\n        \n        self.gru_classifier_for_fid = load_classifier_for_fid(dataname, self.input_size_raw,\n                                                              self.num_classes, device).eval()\n        self.gru_classifier = load_classifier(dataname, self.input_size_raw,\n                                              self.num_classes, device).eval()\n        \n    def compute_features(self, model, motionloader):\n        # calculate_activations_labels function from action2motion\n        activations = []\n        labels = []\n        with torch.no_grad():\n            for idx, batch in enumerate(motionloader):\n                activations.append(self.gru_classifier_for_fid(batch[\"output_xyz\"], lengths=batch[\"lengths\"]))\n                labels.append(batch[\"y\"])\n            activations = torch.cat(activations, dim=0)\n            labels = torch.cat(labels, dim=0)\n        return activations, labels\n\n    @staticmethod\n    def calculate_activation_statistics(activations):\n        activations = activations.cpu().numpy()\n        mu = np.mean(activations, axis=0)\n        sigma = np.cov(activations, rowvar=False)\n        return mu, sigma\n\n    def evaluate(self, model, loaders):\n        \n        def print_logs(metric, key):\n            print(f\"Computing action2motion {metric} on the {key} loader ...\")\n            \n        metrics = {}\n        \n        computedfeats = {}\n        for key, loader in loaders.items():\n            metric = \"accuracy\"\n            print_logs(metric, key)\n            mkey = f\"{metric}_{key}\"\n            metrics[mkey], _ = calculate_accuracy(model, loader,\n                                                  self.num_classes,\n                                                  self.gru_classifier, self.device)\n\n            # features for diversity\n            print_logs(\"features\", key)\n            feats, labels = self.compute_features(model, loader)\n            print_logs(\"stats\", key)\n            stats = self.calculate_activation_statistics(feats)\n            \n            computedfeats[key] = {\"feats\": feats,\n                                  \"labels\": labels,\n                                  \"stats\": stats}\n\n            print_logs(\"diversity\", key)\n            ret = calculate_diversity_multimodality(feats, labels, self.num_classes)\n            metrics[f\"diversity_{key}\"], metrics[f\"multimodality_{key}\"] = ret\n            \n        # taking the stats of the ground truth and remove it from the computed feats\n        gtstats = computedfeats[\"gt\"][\"stats\"]\n        # computing fid\n        for key, loader in computedfeats.items():\n            metric = \"fid\"\n            mkey = f\"{metric}_{key}\"\n            \n            stats = computedfeats[key][\"stats\"]\n            metrics[mkey] = float(calculate_fid(gtstats, stats))\n            \n        return metrics\n"
  },
  {
    "path": "PBnet/src/evaluate/action2motion/fid.py",
    "content": "import numpy as np\nfrom scipy import linalg\n\n\n# from action2motion\ndef calculate_fid(statistics_1, statistics_2):\n    return calculate_frechet_distance(statistics_1[0], statistics_1[1],\n                                      statistics_2[0], statistics_2[1])\n\n\ndef calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):\n    \"\"\"Numpy implementation of the Frechet Distance.\n    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)\n    and X_2 ~ N(mu_2, C_2) is\n            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).\n    Stable version by Dougal J. Sutherland.\n    Params:\n    -- mu1   : Numpy array containing the activations of a layer of the\n               inception net (like returned by the function 'get_predictions')\n               for generated samples.\n    -- mu2   : The sample mean over activations, precalculated on an\n               representative data set.\n    -- sigma1: The covariance matrix over activations for generated samples.\n    -- sigma2: The covariance matrix over activations, precalculated on an\n               representative data set.\n    Returns:\n    --   : The Frechet Distance.\n    \"\"\"\n\n    mu1 = np.atleast_1d(mu1)\n    mu2 = np.atleast_1d(mu2)\n\n    sigma1 = np.atleast_2d(sigma1)\n    sigma2 = np.atleast_2d(sigma2)\n\n    assert mu1.shape == mu2.shape, \\\n        'Training and test mean vectors have different lengths'\n    assert sigma1.shape == sigma2.shape, \\\n        'Training and test covariances have different dimensions'\n\n    diff = mu1 - mu2\n\n    # Product might be almost singular\n    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)\n    if not np.isfinite(covmean).all():\n        msg = ('fid calculation produces singular product; '\n               'adding %s to diagonal of cov estimates') % eps\n        print(msg)\n        offset = np.eye(sigma1.shape[0]) * eps\n        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))\n\n    # Numerical error might give slight imaginary component\n    if np.iscomplexobj(covmean):\n        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):\n            m = np.max(np.abs(covmean.imag))\n            raise ValueError('Imaginary component {}'.format(m))\n        covmean = covmean.real\n\n    tr_covmean = np.trace(covmean)\n\n    return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)\n"
  },
  {
    "path": "PBnet/src/evaluate/action2motion/models.py",
    "content": "import torch\nimport torch.nn as nn\n\n\n# adapted from action2motion to take inputs of different lengths\nclass MotionDiscriminator(nn.Module):\n    def __init__(self, input_size, hidden_size, hidden_layer, device, output_size=12, use_noise=None):\n        super(MotionDiscriminator, self).__init__()\n        self.device = device\n\n        self.input_size = input_size\n        self.hidden_size = hidden_size\n        self.hidden_layer = hidden_layer\n        self.use_noise = use_noise\n\n        self.recurrent = nn.GRU(input_size, hidden_size, hidden_layer)\n        self.linear1 = nn.Linear(hidden_size, 30)\n        self.linear2 = nn.Linear(30, output_size)\n\n    def forward(self, motion_sequence, lengths=None, hidden_unit=None):\n        # dim (motion_length, num_samples, hidden_size)\n        bs, njoints, nfeats, num_frames = motion_sequence.shape\n        motion_sequence = motion_sequence.reshape(bs, njoints*nfeats, num_frames)\n        motion_sequence = motion_sequence.permute(2, 0, 1)\n        if hidden_unit is None:\n            # motion_sequence = motion_sequence.permute(1, 0, 2)\n            hidden_unit = self.initHidden(motion_sequence.size(1), self.hidden_layer)\n        gru_o, _ = self.recurrent(motion_sequence.float(), hidden_unit)\n\n        # select the last valid, instead of: gru_o[-1, :, :]\n        out = gru_o[tuple(torch.stack((lengths-1, torch.arange(bs, device=self.device))))]\n\n        # dim (num_samples, 30)\n        lin1 = self.linear1(out)\n        lin1 = torch.tanh(lin1)\n        # dim (num_samples, output_size)\n        lin2 = self.linear2(lin1)\n        return lin2\n\n    def initHidden(self, num_samples, layer):\n        return torch.randn(layer, num_samples, self.hidden_size, device=self.device, requires_grad=False)\n\n\nclass MotionDiscriminatorForFID(MotionDiscriminator):\n    def forward(self, motion_sequence, lengths=None, hidden_unit=None):\n        # dim (motion_length, num_samples, hidden_size)\n        bs, njoints, nfeats, num_frames = motion_sequence.shape\n        motion_sequence = motion_sequence.reshape(bs, njoints*nfeats, num_frames)\n        motion_sequence = motion_sequence.permute(2, 0, 1)\n        if hidden_unit is None:\n            # motion_sequence = motion_sequence.permute(1, 0, 2)\n            hidden_unit = self.initHidden(motion_sequence.size(1), self.hidden_layer)\n        gru_o, _ = self.recurrent(motion_sequence.float(), hidden_unit)\n\n        # select the last valid, instead of: gru_o[-1, :, :]\n        out = gru_o[tuple(torch.stack((lengths-1, torch.arange(bs, device=self.device))))]\n\n        # dim (num_samples, 30)\n        lin1 = self.linear1(out)\n        lin1 = torch.tanh(lin1)\n        return lin1\n\n\nclassifier_model_files = {\n    \"ntu13\": \"models/actionrecognition/ntu13_gru.tar\",\n    \"humanact12\": \"models/actionrecognition/humanact12_gru.tar\",\n}\n\n\ndef load_classifier(dataset_type, input_size_raw, num_classes, device):\n    model = torch.load(classifier_model_files[dataset_type], map_location=device)\n    classifier = MotionDiscriminator(input_size_raw, 128, 2, device=device, output_size=num_classes).to(device)\n    classifier.load_state_dict(model[\"model\"])\n    classifier.eval()\n    return classifier\n\n\ndef load_classifier_for_fid(dataset_type, input_size_raw, num_classes, device):\n    model = torch.load(classifier_model_files[dataset_type], map_location=device)\n    classifier = MotionDiscriminatorForFID(input_size_raw, 128, 2, device=device, output_size=num_classes).to(device)\n    classifier.load_state_dict(model[\"model\"])\n    classifier.eval()\n    return classifier\n\n\ndef test():\n    from src.datasets.ntu13 import NTU13\n    import src.utils.fixseed  # noqa\n\n    classifier = load_classifier(\"ntu13\", input_size_raw=54, num_classes=13, device=\"cuda\").eval()\n    params = {\"pose_rep\": \"rot6d\",\n              \"translation\": True,\n              \"glob\": True,\n              \"jointstype\": \"a2m\",\n              \"vertstrans\": True,\n              \"num_frames\": 60,\n              \"sampling\": \"conseq\",\n              \"sampling_step\": 1}\n    dataset = NTU13(**params)\n\n    from src.models.rotation2xyz import Rotation2xyz\n    rot2xyz = Rotation2xyz(device=\"cuda\")\n    confusion_xyz = torch.zeros(13, 13, dtype=torch.long)\n    confusion = torch.zeros(13, 13, dtype=torch.long)\n\n    for i in range(1000):\n        dataset.pose_rep = \"xyz\"\n        data = dataset[i][0].to(\"cuda\")\n        data = data[None]\n\n        dataset.pose_rep = params[\"pose_rep\"]\n        x = dataset[i][0].to(\"cuda\")[None]\n        mask = torch.ones(1, x.shape[-1], dtype=bool, device=\"cuda\")\n        lengths = mask.sum(1)\n\n        xyz_t = rot2xyz(x, mask, **params)\n\n        predicted_cls_xyz = classifier(data, lengths=lengths).argmax().item()\n        predicted_cls = classifier(xyz_t, lengths=lengths).argmax().item()\n\n        gt_cls = dataset[i][1]\n\n        confusion_xyz[gt_cls][predicted_cls_xyz] += 1\n        confusion[gt_cls][predicted_cls] += 1\n\n    accuracy_xyz = torch.trace(confusion_xyz)/torch.sum(confusion_xyz).item()\n    accuracy = torch.trace(confusion)/torch.sum(confusion).item()\n\n    print(f\"accuracy: {accuracy:.1%}, accuracy_xyz: {accuracy_xyz:.1%}\")\n\n\nif __name__ == \"__main__\":\n    test()\n"
  },
  {
    "path": "PBnet/src/evaluate/evaluate_cvae.py",
    "content": "import sys\nsys.path.append('/train20/intern/permanent/lmlin2/ReferenceCode/ACTOR-master')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema_pos import CREMA\nfrom src.datasets.datasets_hdtf_pos_chunk import HDTF\n# from src.datasets.datasets_hdtf_pos_chunk_mel_3 import HDTF\nfrom src.evaluate.tvae_eval import evaluate\n\n\ndef main():\n    parameters, folder, checkpointname, epoch, niter = parser()\n    \n    # data path\n\n    dataset_name = parameters[\"dataset\"]\n    if dataset_name == 'crema':\n        # data path\n        data_dir = \"/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images\"\n        # model and dataset\n        dataset = CREMA(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n    elif dataset_name == 'hdtf':\n        data_dir = \"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\"\n        dataset = HDTF(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n    else:\n        dataset = None\n        print('Dataset can not be found!!')\n\n    evaluate(parameters, dataset, folder, checkpointname, epoch, niter)\n\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "PBnet/src/evaluate/evaluate_cvae_debug.py",
    "content": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema_pos import CREMA\nfrom src.datasets.datasets_hdtf_pos_chunk_norm_2 import HDTF\nfrom src.evaluate.tvae_eval_train_norm import evaluate\n\n\ndef main():\n    parameters, folder, checkpointname, epoch, niter = parser()\n    \n    # data path\n\n    dataset_name = parameters[\"dataset\"]\n    if dataset_name == 'crema':\n        # data path\n        data_dir = \"/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images\"\n        # model and dataset\n        dataset = CREMA(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n    elif dataset_name == 'hdtf':\n        data_dir = \"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\"\n        dataset = HDTF(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n    else:\n        dataset = None\n        print('Dataset can not be found!!')\n\n    evaluate(parameters, dataset, folder, checkpointname, epoch, niter)\n\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "PBnet/src/evaluate/evaluate_cvae_f3.py",
    "content": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema_pos import CREMA\nfrom src.datasets.datasets_hdtf_pos_chunk_3 import HDTF\nfrom src.evaluate.tvae_eval_std import evaluate\n\n\ndef main():\n    parameters, folder, checkpointname, epoch, niter = parser()\n    \n    # data path\n\n    dataset_name = parameters[\"dataset\"]\n    if dataset_name == 'crema':\n        # data path\n        data_dir = \"/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images\"\n        # model and dataset\n        dataset = CREMA(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n    elif dataset_name == 'hdtf':\n        data_dir = \"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\"\n        dataset = HDTF(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n    else:\n        dataset = None\n        print('Dataset can not be found!!')\n\n    evaluate(parameters, dataset, folder, checkpointname, epoch, niter)\n\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "PBnet/src/evaluate/evaluate_cvae_f3_debug.py",
    "content": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema_pos import CREMA\nfrom src.datasets.datasets_hdtf_pos_chunk_3 import HDTF\nfrom src.evaluate.tvae_eval_train_std import evaluate\n\n\ndef main():\n    parameters, folder, checkpointname, epoch, niter = parser()\n    \n    # data path\n\n    dataset_name = parameters[\"dataset\"]\n    if dataset_name == 'crema':\n        # data path\n        data_dir = \"/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images\"\n        # model and dataset\n        dataset = CREMA(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n    elif dataset_name == 'hdtf':\n        data_dir = \"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\"\n        dataset = HDTF(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n    else:\n        dataset = None\n        print('Dataset can not be found!!')\n\n    evaluate(parameters, dataset, folder, checkpointname, epoch, niter)\n\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "PBnet/src/evaluate/evaluate_cvae_f3_mel.py",
    "content": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema_pos import CREMA\nfrom src.datasets.datasets_hdtf_pos_chunk_mel_f3 import HDTF\nfrom src.evaluate.tvae_eval_std import evaluate\n\n\ndef main():\n    parameters, folder, checkpointname, epoch, niter = parser()\n    \n    # data path\n\n    dataset_name = parameters[\"dataset\"]\n    if dataset_name == 'crema':\n        # data path\n        data_dir = \"/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images\"\n        # model and dataset\n        dataset = CREMA(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n    elif dataset_name == 'hdtf':\n        data_dir = \"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\"\n        dataset = HDTF(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n    else:\n        dataset = None\n        print('Dataset can not be found!!')\n\n    evaluate(parameters, dataset, folder, checkpointname, epoch, niter)\n\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "PBnet/src/evaluate/evaluate_cvae_norm.py",
    "content": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema_pos import CREMA\nfrom src.datasets.datasets_hdtf_pos_chunk_norm_2 import HDTF\nfrom src.evaluate.tvae_eval_norm import evaluate\n\n\ndef main():\n    parameters, folder, checkpointname, epoch, niter = parser()\n    \n    # data path\n\n    dataset_name = parameters[\"dataset\"]\n    parameters[\"eye\"] = False\n    if dataset_name == 'crema':\n        # data path\n        data_dir = \"/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images\"\n        # model and dataset\n        dataset = CREMA(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n    elif dataset_name == 'hdtf':\n        data_dir = \"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\"\n        dataset = HDTF(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n    else:\n        dataset = None\n        print('Dataset can not be found!!')\n\n    evaluate(parameters, dataset, folder, checkpointname, epoch, niter)\n\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "PBnet/src/evaluate/evaluate_cvae_norm_all.py",
    "content": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema_pos import CREMA\nfrom src.datasets.datasets_hdtf_pos_chunk_norm_2_all import HDTF\nfrom src.evaluate.tvae_eval_norm_all import evaluate\n\n\ndef main():\n    parameters, folder, checkpointname, epoch, niter = parser()\n    \n    # data path\n\n    dataset_name = parameters[\"dataset\"]\n    if dataset_name == 'crema':\n        # data path\n        data_dir = \"/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images\"\n        # model and dataset\n        dataset = CREMA(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n    elif dataset_name == 'hdtf':\n        data_dir = \"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\"\n        dataset = HDTF(data_dir=data_dir,\n                        max_num_frames=1e8,\n                        mode = 'test')\n        # dataset.update_parameters(parameters)\n        parameters[\"audio_dim\"] = 1024\n        parameters[\"pos_dim\"] = 6\n    else:\n        dataset = None\n        print('Dataset can not be found!!')\n\n    evaluate(parameters, dataset, folder, checkpointname, epoch, niter)\n\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "PBnet/src/evaluate/evaluate_cvae_norm_all_seg.py",
    "content": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema_pos import CREMA\nfrom src.datasets.datasets_hdtf_pos_chunk_norm_2_all import HDTF\nfrom src.evaluate.tvae_eval_norm_seg import evaluate\n\n\ndef main():\n    parameters, folder, checkpointname, epoch, niter = parser()\n    \n    # data path\n\n    dataset_name = parameters[\"dataset\"]\n    if dataset_name == 'crema':\n        # data path\n        data_dir = \"/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images\"\n        # model and dataset\n        dataset = CREMA(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n    elif dataset_name == 'hdtf':\n        data_dir = \"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\"\n        dataset = HDTF(data_dir=data_dir,\n                        max_num_frames=1e8,\n                        mode = 'test')\n        # dataset.update_parameters(parameters)\n        parameters[\"audio_dim\"] = 1024\n        parameters[\"pos_dim\"] = 6\n    else:\n        dataset = None\n        print('Dataset can not be found!!')\n\n    evaluate(parameters, dataset, folder, checkpointname, epoch, niter)\n\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "PBnet/src/evaluate/evaluate_cvae_norm_all_seg_weye.py",
    "content": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema_pos_eye_fast_all import CREMA\nfrom src.datasets.datasets_hdtf_pos_chunk_norm_2_all import HDTF\nfrom src.evaluate.tvae_eval_norm_seg import evaluate\n\n\n\ndef main():\n    parameters, folder, checkpointname, epoch, niter = parser()\n    \n    # data path\n\n    dataset_name = parameters[\"dataset\"]\n    if dataset_name == 'crema':\n        # data path\n        data_dir = \"/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images\"\n        # model and dataset\n        dataset = CREMA(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n    elif dataset_name == 'hdtf':\n        data_dir = \"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\"\n        dataset = HDTF(data_dir=data_dir,\n                        max_num_frames=1e8,\n                        mode = 'test')\n        # dataset.update_parameters(parameters)\n        parameters[\"audio_dim\"] = 1024\n        parameters[\"pos_dim\"] = 6\n    else:\n        dataset = None\n        print('Dataset can not be found!!')\n\n    evaluate(parameters, dataset, folder, checkpointname, epoch, niter)\n\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "PBnet/src/evaluate/evaluate_cvae_norm_all_seg_weye2.py",
    "content": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema_pos_eye_fast_all import CREMA\nfrom src.datasets.datasets_hdtf_pos_chunk_norm_eye_fast import HDTF\nfrom src.evaluate.tvae_eval_norm_eye_pose_seg import evaluate\n\n\ndef main():\n    parameters, folder, checkpointname, epoch, niter = parser()\n    \n    # data path\n\n    dataset_name = parameters[\"dataset\"]\n    if dataset_name == 'crema':\n        # data path\n        data_dir = \"/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images\"\n        # model and dataset\n        dataset = CREMA(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n    elif dataset_name == 'hdtf':\n        data_dir = \"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\"\n        dataset = HDTF(data_dir=data_dir,\n                        max_num_frames=1e8,\n                        mode = 'test')\n        # dataset.update_parameters(parameters)\n        parameters[\"audio_dim\"] = 1024\n        parameters[\"pos_dim\"] = 3\n        parameters['latent_dim'] = 128\n    else:\n        dataset = None\n        print('Dataset can not be found!!')\n\n    evaluate(parameters, dataset, folder, checkpointname, epoch, niter)\n\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "PBnet/src/evaluate/evaluate_cvae_norm_eye_pose.py",
    "content": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema_pos_eye_fast import CREMA\nfrom src.datasets.datasets_hdtf_pos_chunk_norm_eye_fast import HDTF\nfrom src.evaluate.tvae_eval_norm_eye_pose import evaluate\n\n\ndef main():\n    parameters, folder, checkpointname, epoch, niter = parser()\n    \n    # data path\n\n    dataset_name = parameters[\"dataset\"]\n    if dataset_name == 'crema':\n        # data path\n        data_dir = \"/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images\"\n        # model and dataset\n        dataset = CREMA(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n    elif dataset_name == 'hdtf':\n        data_dir = \"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\"\n        dataset = HDTF(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n    else:\n        dataset = None\n        print('Dataset can not be found!!')\n\n    evaluate(parameters, dataset, folder, checkpointname, epoch, niter)\n\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "PBnet/src/evaluate/evaluate_cvae_norm_eye_pose_test.py",
    "content": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema_pos import CREMA\nfrom src.datasets.datasets_hdtf_pos_chunk_norm_eye_fast import HDTF\nfrom src.evaluate.tvae_eval_norm_eye_pose import evaluate\n\n\ndef main():\n    parameters, folder, checkpointname, epoch, niter = parser()\n    \n    # data path\n\n    dataset_name = parameters[\"dataset\"]\n    if dataset_name == 'crema':\n        # data path\n        data_dir = \"/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images\"\n        # model and dataset\n        dataset = CREMA(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n    elif dataset_name == 'hdtf':\n        data_dir = \"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\"\n        dataset = HDTF(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n    else:\n        dataset = None\n        print('Dataset can not be found!!')\n\n    evaluate(parameters, dataset, folder, checkpointname, epoch, niter)\n\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "PBnet/src/evaluate/evaluate_cvae_onlyeye_all_seg.py",
    "content": "import sys\nsys.path.append('/train20/intern/permanent/lmlin2/ReferenceCode/ACTOR-master')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema_pos import CREMA\nfrom src.datasets.datasets_hdtf_onlyeye_fast import HDTF\nfrom src.evaluate.tvae_eval_onlyeye_all_seg import evaluate\n\n\ndef main():\n    parameters, folder, checkpointname, epoch, niter = parser()\n    \n    # data path\n\n    dataset_name = parameters[\"dataset\"]\n    if dataset_name == 'crema':\n        # data path\n        data_dir = \"/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images\"\n        # model and dataset\n        dataset = CREMA(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n    elif dataset_name == 'hdtf':\n        data_dir = \"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\"\n        dataset = HDTF(data_dir=data_dir,\n                        max_num_frames=1e8,\n                        mode = 'test')\n        dataset.update_parameters(parameters)\n        # parameters[\"audio_dim\"] = 1024\n        # parameters[\"pos_dim\"] = 6\n    else:\n        dataset = None\n        print('Dataset can not be found!!')\n\n    evaluate(parameters, dataset, folder, checkpointname, epoch, niter)\n\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "PBnet/src/evaluate/othermetrics/acceleration.py",
    "content": "import torch\nimport numpy as np\n\nfrom src.utils.tensors import lengths_to_mask\n\n\ndef calculate_acceletation(motionloader, device, xyz):\n    # for now even if it is not xyz, the acceleration is one the euclidian/pose\n    outfeat = \"output_xyz\" if xyz else \"output\"\n\n    sum_acc = 0\n    num_acc = 0\n    for batch in motionloader:\n        motion = batch[outfeat].permute(0, 3, 1, 2)\n        bs, num_frames, njoints, nfeats = motion.shape\n        \n        velocity = motion[:, 1:] - motion[:, :-1]\n        acceleration = velocity[:, 1:] - velocity[:, :-1]\n        acceleration_normed = torch.linalg.norm(acceleration, axis=3)\n\n        lengths = batch[\"lengths\"]\n        mask = lengths_to_mask(lengths - 2)  # because acceleration\n\n        usefull_accs_n = acceleration_normed[mask]\n        sum_acc += usefull_accs_n.sum().item()\n        num_acc += np.prod(usefull_accs_n.shape)\n\n    acceleration = sum_acc/num_acc\n    return acceleration\n"
  },
  {
    "path": "PBnet/src/evaluate/othermetrics/evaluation.py",
    "content": "import torch\nimport numpy as np\n\nfrom ..action2motion.diversity import calculate_diversity_multimodality\nfrom .acceleration import calculate_acceletation\n\n\nclass OtherMetricsEvaluation:\n    \"\"\" Evaluation of some metrics in output space (not feature space):\n    - Acceleration metrics\n    - Reconstruction loss\n    - Diversity\n    - Multimodality\n    (Not used in the paper)\n\"\"\"\n    def __init__(self, device):\n        self.device = device\n        \n    def compute_features(self, model, motionloader, xyz=True):\n        feat = \"output_xyz\" if xyz else \"output\"\n        activations = []\n        labels = []\n        for idx, batch in enumerate(motionloader):\n            batch_motion = batch[feat]\n            batch_label = batch[\"y\"]\n            activations.append(batch_motion)\n            labels.append(batch_label)\n        activations = torch.cat(activations, dim=0)\n        activations = activations.reshape(activations.shape[0], -1)\n        labels = torch.cat(labels, dim=0)\n        return activations, labels\n\n    def reconstructionloss(self, motionloader, xyz=True):\n        infeat = \"x_xyz\" if xyz else \"x\"\n        outfeat = \"output_xyz\" if xyz else \"output\"\n\n        sum_loss = 0\n        num_loss = 0\n        for batch in motionloader:\n            motion_in = batch[infeat].permute(0, 3, 1, 2)\n            motion_out = batch[outfeat].permute(0, 3, 1, 2)\n            mask = batch[\"mask\"]\n\n            square_diff = (motion_in[mask] - motion_out[mask])**2\n            sum_loss += square_diff.sum().item()\n            num_loss += np.prod(square_diff.shape)\n\n        rcloss = sum_loss / num_loss\n\n        return rcloss\n    \n    def evaluate(self, model, num_classes, loaders, xyz=True):\n        # get the xyz as well\n        model.outputxyz = True\n        metrics = {}\n        repname = \"xyz\" if xyz else \"pose\"\n        \n        def print_logs(metric, key):\n            print(f\"Computing {metric} on the {key} loader ({repname})...\")\n            \n        for key, loader in loaders.items():\n            # acceleration\n            metric = \"acceleration\"\n            print_logs(metric, key)\n            mkey = f\"{metric}_{key}\"\n            metrics[mkey] = calculate_acceletation(loader, device=self.device, xyz=xyz)\n            \n            # features for diversity\n            print_logs(\"features\", key)\n            feats, labels = self.compute_features(model, loader, xyz=xyz)\n\n            # diversity and multimodality\n            metric = \"diversity\"\n            print_logs(metric, key)\n            ret = calculate_diversity_multimodality(feats, labels, num_classes)\n            metrics[f\"diversity_{key}\"], metrics[f\"multimodality_{key}\"] = ret\n\n        metric = \"rc_recons\"\n        print(f\"Computing reconstruction loss ({repname})..\")\n        rcloss = self.reconstructionloss(loaders[\"recons\"], xyz=xyz)\n        metrics[metric] = rcloss\n        return metrics\n"
  },
  {
    "path": "PBnet/src/evaluate/stgcn/accuracy.py",
    "content": "import torch\n\n\ndef calculate_accuracy(model, motion_loader, num_labels, classifier, device):\n    confusion = torch.zeros(num_labels, num_labels, dtype=torch.long)\n    with torch.no_grad():\n        for batch in motion_loader:\n            batch_prob = classifier(batch)[\"yhat\"]\n            batch_pred = batch_prob.max(dim=1).indices\n            for label, pred in zip(batch[\"y\"], batch_pred):\n                confusion[label][pred] += 1\n\n    accuracy = torch.trace(confusion)/torch.sum(confusion)\n    return accuracy.item(), confusion\n"
  },
  {
    "path": "PBnet/src/evaluate/stgcn/diversity.py",
    "content": "import torch\nimport numpy as np\n\n\n# from action2motion\ndef calculate_diversity_multimodality(activations, labels, num_labels, seed=None):\n    diversity_times = 200\n    multimodality_times = 20\n    labels = labels.long()\n    num_motions = len(labels)\n\n    diversity = 0\n\n    if seed is not None:\n        np.random.seed(seed)\n        \n    first_indices = np.random.randint(0, num_motions, diversity_times)\n    second_indices = np.random.randint(0, num_motions, diversity_times)\n    for first_idx, second_idx in zip(first_indices, second_indices):\n        diversity += torch.dist(activations[first_idx, :],\n                                activations[second_idx, :])\n    diversity /= diversity_times\n\n    multimodality = 0\n    label_quotas = np.repeat(multimodality_times, num_labels)\n    while np.any(label_quotas > 0):\n        # print(label_quotas)\n        first_idx = np.random.randint(0, num_motions)\n        first_label = labels[first_idx]\n        if not label_quotas[first_label]:\n            continue\n\n        second_idx = np.random.randint(0, num_motions)\n        second_label = labels[second_idx]\n        while first_label != second_label:\n            second_idx = np.random.randint(0, num_motions)\n            second_label = labels[second_idx]\n\n        label_quotas[first_label] -= 1\n\n        first_activation = activations[first_idx, :]\n        second_activation = activations[second_idx, :]\n        multimodality += torch.dist(first_activation,\n                                    second_activation)\n\n    multimodality /= (multimodality_times * num_labels)\n\n    return diversity.item(), multimodality.item()\n\n"
  },
  {
    "path": "PBnet/src/evaluate/stgcn/evaluate.py",
    "content": "import torch\nimport numpy as np\nfrom .accuracy import calculate_accuracy\nfrom .fid import calculate_fid\nfrom .diversity import calculate_diversity_multimodality\n\nfrom src.recognition.models.stgcn import STGCN\n\n\nclass Evaluation:\n    def __init__(self, dataname, parameters, device, seed=None):\n        layout = \"smpl\" if parameters[\"glob\"] else \"smpl_noglobal\"\n        model = STGCN(in_channels=parameters[\"nfeats\"],\n                      num_class=parameters[\"num_classes\"],\n                      graph_args={\"layout\": layout, \"strategy\": \"spatial\"},\n                      edge_importance_weighting=True,\n                      device=parameters[\"device\"])\n\n        model = model.to(parameters[\"device\"])\n\n        modelpath = \"models/actionrecognition/uestc_rot6d_stgcn.tar\"\n\n        state_dict = torch.load(modelpath, map_location=parameters[\"device\"])\n        model.load_state_dict(state_dict)\n        model.eval()\n\n        self.num_classes = parameters[\"num_classes\"]\n        self.model = model\n\n        self.dataname = dataname\n        self.device = device\n\n        self.seed = seed\n\n    def compute_features(self, model, motionloader):\n        # calculate_activations_labels function from action2motion\n        activations = []\n        labels = []\n        with torch.no_grad():\n            for idx, batch in enumerate(motionloader):\n                activations.append(self.model(batch)[\"features\"])\n                labels.append(batch[\"y\"])\n            activations = torch.cat(activations, dim=0)\n            labels = torch.cat(labels, dim=0)\n        return activations, labels\n\n    @staticmethod\n    def calculate_activation_statistics(activations):\n        activations = activations.cpu().numpy()\n        mu = np.mean(activations, axis=0)\n        sigma = np.cov(activations, rowvar=False)\n        return mu, sigma\n\n    def evaluate(self, model, loaders):\n        def print_logs(metric, key):\n            print(f\"Computing stgcn {metric} on the {key} loader ...\")\n\n        metrics_all = {}\n        for sets in [\"train\", \"test\"]:\n            computedfeats = {}\n            metrics = {}\n            for key, loaderSets in loaders.items():\n                loader = loaderSets[sets]\n\n                metric = \"accuracy\"\n                print_logs(metric, key)\n                mkey = f\"{metric}_{key}\"\n                metrics[mkey], _ = calculate_accuracy(model, loader,\n                                                      self.num_classes,\n                                                      self.model, self.device)\n                # features for diversity\n                print_logs(\"features\", key)\n                feats, labels = self.compute_features(model, loader)\n                print_logs(\"stats\", key)\n                stats = self.calculate_activation_statistics(feats)\n\n                computedfeats[key] = {\"feats\": feats,\n                                      \"labels\": labels,\n                                      \"stats\": stats}\n\n                print_logs(\"diversity\", key)\n                ret = calculate_diversity_multimodality(feats, labels, self.num_classes,\n                                                        seed=self.seed)\n                metrics[f\"diversity_{key}\"], metrics[f\"multimodality_{key}\"] = ret\n\n            # taking the stats of the ground truth and remove it from the computed feats\n            gtstats = computedfeats[\"gt\"][\"stats\"]\n            # computing fid\n            for key, loader in computedfeats.items():\n                metric = \"fid\"\n                mkey = f\"{metric}_{key}\"\n\n                stats = computedfeats[key][\"stats\"]\n                metrics[mkey] = float(calculate_fid(gtstats, stats))\n\n            metrics_all[sets] = metrics\n\n        metrics = {}\n        for sets in [\"train\", \"test\"]:\n            for key in metrics_all[sets]:\n                metrics[f\"{key}_{sets}\"] = metrics_all[sets][key]\n        return metrics\n"
  },
  {
    "path": "PBnet/src/evaluate/stgcn/fid.py",
    "content": "import numpy as np\nfrom scipy import linalg\n\n\n# from action2motion\ndef calculate_fid(statistics_1, statistics_2):\n    return calculate_frechet_distance(statistics_1[0], statistics_1[1],\n                                      statistics_2[0], statistics_2[1])\n\n\ndef calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):\n    \"\"\"Numpy implementation of the Frechet Distance.\n    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)\n    and X_2 ~ N(mu_2, C_2) is\n            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).\n    Stable version by Dougal J. Sutherland.\n    Params:\n    -- mu1   : Numpy array containing the activations of a layer of the\n               inception net (like returned by the function 'get_predictions')\n               for generated samples.\n    -- mu2   : The sample mean over activations, precalculated on an\n               representative data set.\n    -- sigma1: The covariance matrix over activations for generated samples.\n    -- sigma2: The covariance matrix over activations, precalculated on an\n               representative data set.\n    Returns:\n    --   : The Frechet Distance.\n    \"\"\"\n\n    mu1 = np.atleast_1d(mu1)\n    mu2 = np.atleast_1d(mu2)\n\n    sigma1 = np.atleast_2d(sigma1)\n    sigma2 = np.atleast_2d(sigma2)\n\n    assert mu1.shape == mu2.shape, \\\n        'Training and test mean vectors have different lengths'\n    assert sigma1.shape == sigma2.shape, \\\n        'Training and test covariances have different dimensions'\n\n    diff = mu1 - mu2\n\n    # Product might be almost singular\n    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)\n    if not np.isfinite(covmean).all():\n        msg = ('fid calculation produces singular product; '\n               'adding %s to diagonal of cov estimates') % eps\n        print(msg)\n        offset = np.eye(sigma1.shape[0]) * eps\n        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))\n\n    # Numerical error might give slight imaginary component\n    if np.iscomplexobj(covmean):\n        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):\n            m = np.max(np.abs(covmean.imag))\n            raise ValueError('Imaginary component {}'.format(m))\n        covmean = covmean.real\n\n    tr_covmean = np.trace(covmean)\n\n    return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)\n"
  },
  {
    "path": "PBnet/src/evaluate/tables/archtable.py",
    "content": "import os\nimport glob\nimport math\nimport re\nimport numpy as np\n\nfrom .tools import load_metrics\n\n\ndef valformat(val, power=3):\n    p = float(pow(10, power))\n    # \"{:<04}\".format(np.round(p*val).astype(int)/p)\n    return str(np.round(p*val).astype(int)/p).ljust(4, \"0\")\n\n\ndef format_values(values, key):\n    mean = np.mean(values)\n\n    if key == \"accuracy\":\n        mean = 100*mean\n        values = 100*values\n        smean = valformat(mean, 1)\n    else:\n        smean = valformat(mean, 2)\n\n    interval = valformat(1.96 * np.var(values), 2)  # [1:]\n    # string = rf\"${mean:.4}^{{\\pm{interval:.3}}}$\"\n    # string = rf\"${smean}$\"  # ^{{\\pm{interval}}}$\"\n    string = rf\"${smean}^{{\\pm{interval}}}$\"\n    return string\n\n\ndef construct_table(folder):\n    exppath = folder\n    paths = glob.glob(f\"{exppath}/**/evaluation*_all.yaml\")\n\n    keys = [\"fid\", \"accuracy\", \"diversity\", \"multimodality\"]\n    \n    model_metrics_dataset = {\"ntu13\": {},\n                             \"uestc\": {}}\n\n    epoch_dataset = {\"ntu13\": 1000,\n                     \"uestc\": 500}\n    \n    model_naming = {\"fc\": \"Fully connected\",\n                    \"gru\": \"GRU\",\n                    # \"transformer\": \"Old transformer\",\n                    \"gtransformer\": \"Transformer\"}\n\n    ablation_naming = {\"average_encoder\": r\"No $\\mu_{a}^{token},\\Sigma_{a}^{token}$\",\n                       \"time_encoding\": r\"No Decoder-PE\",\n                       \"zandtime\": r\"No $b_{a}^{token}$\"}\n    \n    for i, path in enumerate(paths):\n        epoch = int(path.split(\"evaluation_metrics_\")[1].split(\".\")[0].split(\"_\")[0])\n        \n        modelinfo = os.path.split(os.path.split(path)[0])[1]\n        modelname = modelinfo.split(\"_\")[1]\n        dataset = modelinfo.split(\"_kl_\")[1].split(\"_\")[0]\n\n        # Take the right epoch\n        if epoch_dataset[dataset] != epoch:\n            continue\n\n        # Ablation study\n        if \"abl\" in modelinfo:\n            ablation = modelinfo.split(\"_abl_\")[1].split(\"_sampling\")[0]\n\n            if ablation not in ablation_naming:\n                continue\n\n            name = ablation_naming[ablation]\n        else:\n            if modelname not in model_naming:\n                continue\n            name = model_naming[modelname]\n            \n        metrics = load_metrics(path)\n\n        model_metrics = model_metrics_dataset[dataset]\n        if dataset == \"ntu13\":\n            a2m = metrics[\"action2motion\"]\n\n            if \"GT\" not in model_metrics:\n                a2m[\"fid_gt\"] = a2m[\"fid_gt2\"]\n                \n                row = []\n                for key in keys:\n                    ckey = f\"{key}_gt\"\n                    values = np.array([float(x) for x in a2m[ckey]])\n                    string = format_values(values, key)\n                    row.append(string)\n                model_metrics[\"GT\"] = row\n                \n            row = []\n            for key in keys:\n                ckey = f\"{key}_gen\"\n                values = np.array([float(x) for x in a2m[ckey]])\n                string = format_values(values, key)\n                row.append(string)\n\n            model_metrics[name] = row\n        elif dataset == \"uestc\":\n            stgcn = metrics[\"stgcn\"]\n\n            if \"GT\" not in model_metrics:\n                for sets in [\"train\", \"test\"]:\n                    stgcn[f\"fid_gt_{sets}\"] = stgcn[f\"fid_gt2_{sets}\"]\n                stgcnkeys = [\"fid_gt_train\", \"fid_gt_test\", \"accuracy_gt_train\", \"diversity_gt_train\", \"multimodality_gt_train\"]\n                row = []\n                for ckey in stgcnkeys:\n                    values = np.array([float(x) for x in stgcn[ckey]])\n                    string = format_values(values, ckey.split(\"_\")[0])\n                    row.append(string)\n                model_metrics[\"GT\"] = row\n\n            stgcnkeys = [\"fid_gen_train\", \"fid_gen_test\", \"accuracy_gen_train\", \"diversity_gen_train\", \"multimodality_gen_train\"]\n            row = []\n            for ckey in stgcnkeys:\n                values = np.array([float(x) for x in stgcn[ckey]])\n                string = format_values(values, ckey.split(\"_\")[0])\n                row.append(string)\n\n            model_metrics[name] = row\n\n    archmodels = list(model_naming.values())\n    ablationmodels = list(ablation_naming.values())\n    \n    gtvalues = [\"GT\"]\n    for dataset in [\"uestc\", \"ntu13\"]:\n        model_metrics = model_metrics_dataset[dataset]\n        gtvalues.extend(model_metrics[\"GT\"])\n    gtrow = \" & \".join(gtvalues) + r\"\\\\\"\n    \n    groupedrows = []\n    for lst in [archmodels, ablationmodels]:\n        rows = []\n        for model in lst:\n            if model == \"GT\":\n                continue\n            values = [model]\n            for dataset in [\"uestc\", \"ntu13\"]:\n                model_metrics = model_metrics_dataset[dataset]\n                if model in model_metrics:\n                    values.extend(model_metrics[model])\n                else:\n                    dummy = [\"\" for _ in range(len(model_metrics[\"GT\"]))]\n                    values.extend(dummy)\n            row = \" & \".join(values) + r\"\\\\\"\n            rows.append(row)\n        groupedrows.append(\"\\n\".join(rows) + \"\\n\")\n        \n    template = r\"\"\"\\documentclass{{standalone}}\n\\usepackage{{booktabs}}\n\\usepackage[dvipsnames]{{xcolor}}\n\\begin{{document}}\n    \\begin{{tabular}}{{lccccc|cccc}}\n        \\toprule\n        Architecture &  FID$_{{tr}}$$\\downarrow$ & Acc.$\\uparrow$ & Div.$\\uparrow$ & Multimod.$\\uparrow$ & FID$_{{tr}}$$\\downarrow$ & FID$_{{test}}$$\\downarrow$ & Acc.$\\uparrow$ & Div.$\\uparrow$ & Multimod.$\\uparrow$\\\\\n        \\midrule\n        {gtrow}\n        \\midrule\n        {archrow}\n        \\midrule\n        {ablationrow}\n        \\bottomrule\n    \\end{{tabular}}\n\\end{{document}}\n\"\"\".format(gtrow=gtrow, archrow=groupedrows[0], ablationrow=groupedrows[1])\n    return template\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    def parse_opts():\n        parser = argparse.ArgumentParser()\n        parser.add_argument(\"exppath\", help=\"name of the exp\")\n        return parser.parse_args()\n\n    opt = parse_opts()\n    exppath = opt.exppath\n\n    folder = exppath\n    \n    tex = construct_table(folder)\n    texpath = os.path.join(folder, \"table_arch.tex\")\n\n    with open(texpath, \"w\") as ftex:\n        ftex.write(tex)\n        \n    print(f\"Table saved at {texpath}\")\n"
  },
  {
    "path": "PBnet/src/evaluate/tables/bstable.py",
    "content": "import os\nimport glob\nimport math\nimport re\nimport numpy as np\n\nfrom .tools import load_metrics\n\n\ndef valformat(val, power=3):\n    p = float(pow(10, power))\n    # \"{:<04}\".format(np.round(p*val).astype(int)/p)\n    return str(np.round(p*val).astype(int)/p).ljust(4, \"0\")\n\n\ndef format_values(values, key):\n    mean = np.mean(values)\n\n    if key == \"accuracy\":\n        mean = 100*mean\n        values = 100*values\n        smean = valformat(mean, 1)\n    else:\n        smean = valformat(mean, 2)\n\n    interval = valformat(1.96 * np.var(values), 2)  # [1:]\n    # string = rf\"${mean:.4}^{{\\pm{interval:.3}}}$\"\n    # string = rf\"${smean}$\"  # ^{{\\pm{interval}}}$\"\n    string = rf\"${smean}^{{\\pm{interval}}}$\"\n    return string\n                    \n\ndef construct_table(folder):\n    exppath = folder\n    paths = glob.glob(f\"{exppath}/**/evaluation*_all*.yaml\")\n\n    keys = [\"fid\", \"accuracy\", \"diversity\", \"multimodality\"]\n    \n    model_metrics_dataset = {\"ntu13\": {},\n                             \"uestc\": {}}\n\n    epoch_dataset = {\"ntu13\": 1000,\n                     \"uestc\": 500}\n        \n    for i, path in enumerate(paths):\n        epoch = int(path.split(\"evaluation_metrics_\")[1].split(\".\")[0].split(\"_\")[0])\n        \n        modelinfo = os.path.split(os.path.split(path)[0])[1]\n        dataset = modelinfo.split(\"_kl_\")[1].split(\"_\")[0]\n\n        # Take the right epoch\n        if epoch_dataset[dataset] != epoch:\n            continue\n\n        name = \"Batch size \" + modelinfo.split(\"bs_\")[1]\n        metrics = load_metrics(path)\n\n        model_metrics = model_metrics_dataset[dataset]\n        if dataset == \"ntu13\":\n            a2m = metrics[\"action2motion\"]\n\n            if \"GT\" not in model_metrics:\n                a2m[\"fid_gt\"] = a2m[\"fid_gt2\"]\n                \n                row = []\n                for key in keys:\n                    ckey = f\"{key}_gt\"\n                    values = np.array([float(x) for x in a2m[ckey]])\n                    string = format_values(values, key)\n                    row.append(string)\n                model_metrics[\"GT\"] = row\n                \n            row = []\n            for key in keys:\n                ckey = f\"{key}_gen\"\n                values = np.array([float(x) for x in a2m[ckey]])\n                string = format_values(values, key)\n                row.append(string)\n\n            model_metrics[name] = row\n        elif dataset == \"uestc\":\n            stgcn = metrics[\"stgcn\"]\n\n            if \"GT\" not in model_metrics:\n                for sets in [\"train\", \"test\"]:\n                    stgcn[f\"fid_gt_{sets}\"] = stgcn[f\"fid_gt2_{sets}\"]\n                stgcnkeys = [\"fid_gt_train\", \"fid_gt_test\", \"accuracy_gt_train\", \"diversity_gt_train\", \"multimodality_gt_train\"]\n                row = []\n                for ckey in stgcnkeys:\n                    values = np.array([float(x) for x in stgcn[ckey]])\n                    string = format_values(values, ckey.split(\"_\")[0])\n                    row.append(string)\n                model_metrics[\"GT\"] = row\n\n            stgcnkeys = [\"fid_gen_train\", \"fid_gen_test\", \"accuracy_gen_train\", \"diversity_gen_train\", \"multimodality_gen_train\"]\n            row = []\n            for ckey in stgcnkeys:\n                values = np.array([float(x) for x in stgcn[ckey]])\n                string = format_values(values, ckey.split(\"_\")[0])\n                row.append(string)\n\n            model_metrics[name] = row\n    \n    gtvalues = [\"GT\"]\n    for dataset in [\"uestc\", \"ntu13\"]:\n        model_metrics = model_metrics_dataset[dataset]\n        gtvalues.extend(model_metrics[\"GT\"])\n    gtrow = \" & \".join(gtvalues) + r\"\\\\\"\n\n    rows = []\n    modelnames = sorted(list(model_metrics.keys()))\n    for model in modelnames:\n        if model == \"GT\":\n            continue\n        values = [model]\n        for dataset in [\"uestc\", \"ntu13\"]:\n            model_metrics = model_metrics_dataset[dataset]\n            if model in model_metrics:\n                values.extend(model_metrics[model])\n            else:\n                dummy = [\"\" for _ in range(len(model_metrics[\"GT\"]))]\n                values.extend(dummy)\n        row = \" & \".join(values) + r\"\\\\\"\n        rows.append(row)\n        \n    rows = \"\\n\".join(rows)\n        \n    template = r\"\"\"\\documentclass{{standalone}}\n\\usepackage{{booktabs}}\n\\usepackage[dvipsnames]{{xcolor}}\n\\begin{{document}}\n    \\begin{{tabular}}{{lccccc|cccc}}\n        \\toprule\n        & \\multicolumn{{5}}{{c}}{{UESTC}} & \\multicolumn{{4}}{{|c}}{{NTU-13}} \\\\\n    Loss & FID$_{{tr}}$$\\downarrow$ & FID$_{{test}}$$\\downarrow$ & Acc.$\\uparrow$ & Div.$\\rightarrow$ & Multimod.$\\rightarrow$ & FID$_{{tr}}$$\\downarrow$ & Acc.$\\uparrow$ & Div.$\\rightarrow$ & Multimod.$\\rightarrow$ \\\\\n        \\midrule\n        {gtrow}\n        \\midrule\n        {rows}\n        \\bottomrule\n    \\end{{tabular}}\n\\end{{document}}\n\"\"\".format(rows=rows, gtrow=gtrow)\n    return template\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    def parse_opts():\n        parser = argparse.ArgumentParser()\n        parser.add_argument(\"exppath\", help=\"name of the exp\")\n        return parser.parse_args()\n\n    opt = parse_opts()\n    exppath = opt.exppath\n\n    folder = exppath\n    \n    tex = construct_table(folder)\n    texpath = os.path.join(folder, \"table_loss.tex\")\n\n    with open(texpath, \"w\") as ftex:\n        ftex.write(tex)\n        \n    print(f\"Table saved at {texpath}\")\n\n\n\n"
  },
  {
    "path": "PBnet/src/evaluate/tables/easy_table.py",
    "content": "import os\nimport glob\nimport math\nimport numpy as np\n\nfrom ..tools import load_metrics\n\n\ndef get_gtname(mname):\n    return mname + \"_gt\"\n\n\ndef get_genname(mname):\n    return mname + \"_gen\"\n\n\ndef get_reconsname(mname):\n    return mname + \"_recons\"\n\n\ndef valformat(val, power=3):\n    p = float(pow(10, power))\n    # \"{:<04}\".format(np.round(p*val).astype(int)/p)\n    return str(np.round(p*val).astype(int)/p).ljust(4, \"0\")\n\n\ndef format_values(values, key, latex=True):\n    mean = np.mean(values)\n\n    if \"accuracy\" in key:\n        mean = 100*mean\n        values = 100*values\n        smean = valformat(mean, 1)\n    else:\n        smean = valformat(mean, 2)\n\n    interval = valformat(1.96 * np.var(values), 2)  # [1:]\n\n    if latex:\n        string = rf\"${smean}^{{\\pm{interval}}}$\"\n    else:\n        string = rf\"{smean} +/- {interval}\"\n    return string\n\n\ndef print_results(folder, evaluation):\n    evalpath = os.path.join(folder, evaluation)\n    metrics = load_metrics(evalpath)\n\n    a2m = metrics[\"feats\"]\n\n    if \"fid_gen_test\" in a2m:\n        keys = [\"fid_{}_train\", \"fid_{}_test\", \"accuracy_{}_train\", \"diversity_{}_train\", \"multimodality_{}_train\"]\n    else:\n        keys = [\"fid_{}\", \"accuracy_{}\", \"diversity_{}\", \"multimodality_{}\"]\n\n    lines = [\"gen\", \"recons\"]\n    # print the GT, only if it is computed with respect to \"another\" GT\n    if \"fid_gt2\" in a2m:\n        a2m[\"fid_gt\"] = a2m[\"fid_gt2\"]\n        lines = [\"gt\"] + lines\n\n    rows = []\n    rows_latex = []\n\n    for model in lines:\n        row = [\"{:6}\".format(model)]\n        row_latex = [\"{:6}\".format(model)]\n        try:\n            for key in keys:\n                ckey = key.format(model)\n                values = np.array([float(x) for x in a2m[ckey]])\n                string_latex = format_values(values, key, latex=True)\n                string = format_values(values, key, latex=False)\n                row.append(string)\n                row_latex.append(string_latex)\n            rows.append(\" | \".join(row))\n            rows_latex.append(\" & \".join(row_latex) + r\"\\\\\")\n        except KeyError:\n            continue\n\n    table = \"\\n\".join(rows)\n    table_latex = \"\\n\".join(rows_latex)\n    print(\"Results\")\n    print(table)\n    print()\n    print(\"Latex table\")\n    print(table_latex)\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    def parse_opts():\n        parser = argparse.ArgumentParser()\n        parser.add_argument(\"evalpath\", help=\"name of the evaluation\")\n        return parser.parse_args()\n\n    opt = parse_opts()\n    evalpath = opt.evalpath\n\n    folder, evaluation = os.path.split(evalpath)\n    print_results(folder, evaluation)\n"
  },
  {
    "path": "PBnet/src/evaluate/tables/easy_table_A2M.py",
    "content": "import os\nimport glob\nimport math\nimport numpy as np\n\nfrom ..tools import load_metrics\n\n\ndef valformat(val, power=3):\n    p = float(pow(10, power))\n    # \"{:<04}\".format(np.round(p*val).astype(int)/p)\n    return str(np.round(p*val).astype(int)/p).ljust(4, \"0\")\n\n\ndef construct_table(folder, evaluation):\n    evalpath = os.path.join(folder, evaluation)\n    metrics = load_metrics(evalpath)\n\n    a2m = metrics[\"feats\"]\n    keys = [\"fid\", \"accuracy\", \"diversity\", \"multimodality\"]\n\n    a2m[\"fid_gt\"] = a2m[\"fid_gt2\"]\n\n    values = []\n    rows = []\n    for model in [\"gt\", \"gen\", \"genden\"]:\n        row = [\"{:6}\".format(model)]\n        for key in keys:\n            ckey = f\"{key}_{model}\"\n            values = np.array([float(x) for x in a2m[ckey]])\n            mean = np.mean(values)\n            if key == \"accuracy\":\n                mean = 100*mean\n                values = 100*values\n                smean = valformat(mean, 1)\n            else:\n                smean = valformat(mean, 2)\n                mean = np.mean(values)\n            interval = valformat(1.96 * np.var(values), 2)  # [1:]\n            string = rf\"${smean}^{{\\pm{interval}}}$\"\n            # string = rf\"{mean:.4}\"  #^{{\\pm{interval:.1}}}\"\n            row.append(string)\n        rows.append(\" & \".join(row) + r\"\\\\\")\n\n    test = \"\\n\".join(rows)\n    print(test)\n    import ipdb; ipdb.set_trace()\n    bodylist.append(r\"\\bottomrule\")\n    body = \"\\n\".join(bodylist)\n    ncols = 5\n    title = f\"Evaluation TODO name\"\n    template = r\"\"\"\\documentclass{{standalone}}\n\\usepackage{{booktabs}}\n\\usepackage[dvipsnames]{{xcolor}}\n\\begin{{document}}\n\\begin{{tabular}}{{{ncolsl}}}\n\\multicolumn{{{ncols}}}{{c}}{{{title}}} \\\\\n\\multicolumn{{{ncols}}}{{c}}{{}} \\\\\n& \\multicolumn{{{nbcolsxyz}}}{{c}}{{xyz}} & & \\multicolumn{{{nbcolspose}}}{{c}}{{{pose_rep}}} & & \\multicolumn{{{nbcolsa2m}}}{{c}}{{action2motion}} \\\\\n{firstrow}\n\\midrule\n{body}\n\\end{{tabular}}\n\\end{{document}}\n\"\"\".format(ncolsl=\"l\"+\"c\"*(ncols-1), ncols=ncols,\n           pose_rep=pose_rep, title=title, firstrow=firstrow,\n           nbcolsxyz=len(METRICS[\"joints\"]),\n           nbcolspose=len(METRICS[pose_rep]),\n           nbcolsa2m=len(METRICS[\"action2motion\"]),\n           body=body)\n    return template\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    def parse_opts():\n        parser = argparse.ArgumentParser()\n        parser.add_argument(\"evalpath\", help=\"name of the evaluation\")\n        return parser.parse_args()\n\n    opt = parse_opts()\n    evalpath = opt.evalpath\n    \n    folder, evaluation = os.path.split(evalpath)\n    tex = construct_table(folder, evaluation)\n    texpath = os.path.join(folder, os.path.splitext(evaluation)[0] + \".tex\")\n\n    with open(texpath, \"w\") as ftex:\n        ftex.write(tex)\n        \n    print(f\"Table saved at {texpath}\")\n"
  },
  {
    "path": "PBnet/src/evaluate/tables/kltable.py",
    "content": "import os\nimport glob\nimport math\nimport re\nimport numpy as np\n\nfrom .tools import load_metrics\n\n\ndef valformat(val, power=3):\n    p = float(pow(10, power))\n    # \"{:<04}\".format(np.round(p*val).astype(int)/p)\n    return str(np.round(p*val).astype(int)/p).ljust(4, \"0\")\n\n\ndef format_values(values, key):\n    mean = np.mean(values)\n\n    if key == \"accuracy\":\n        mean = 100*mean\n        values = 100*values\n        smean = valformat(mean, 1)\n    else:\n        smean = valformat(mean, 2)\n\n    interval = valformat(1.96 * np.var(values), 2)  # [1:]\n    # string = rf\"${mean:.4}^{{\\pm{interval:.3}}}$\"\n    # string = rf\"${smean}$\"  # ^{{\\pm{interval}}}$\"\n    string = rf\"${smean}^{{\\pm{interval}}}$\"\n    return string\n                    \n\ndef construct_table(folder):\n    exppath = folder\n    paths = glob.glob(f\"{exppath}/**/evaluation*_all*.yaml\")\n\n    keys = [\"fid\", \"accuracy\", \"diversity\", \"multimodality\"]\n    \n    model_metrics_dataset = {\"ntu13\": {},\n                             \"uestc\": {}}\n\n    epoch_dataset = {\"ntu13\": 1000,\n                     \"uestc\": 500}\n        \n    for i, path in enumerate(paths):\n        epoch = int(path.split(\"evaluation_metrics_\")[1].split(\".\")[0].split(\"_\")[0])\n        \n        modelinfo = os.path.split(os.path.split(path)[0])[1]\n        dataset = modelinfo.split(\"_kl_\")[1].split(\"_\")[0]\n\n        # Take the right epoch\n        if epoch_dataset[dataset] != epoch:\n            continue\n\n        name = modelinfo.split(\"samplingstep_1_\")[1].split(\"_gelu\")[0].replace(\"_\", \" \")\n        metrics = load_metrics(path)\n\n        model_metrics = model_metrics_dataset[dataset]\n        if dataset == \"ntu13\":\n            a2m = metrics[\"action2motion\"]\n\n            if \"GT\" not in model_metrics:\n                a2m[\"fid_gt\"] = a2m[\"fid_gt2\"]\n                \n                row = []\n                for key in keys:\n                    ckey = f\"{key}_gt\"\n                    values = np.array([float(x) for x in a2m[ckey]])\n                    string = format_values(values, key)\n                    row.append(string)\n                model_metrics[\"GT\"] = row\n                \n            row = []\n            for key in keys:\n                ckey = f\"{key}_gen\"\n                values = np.array([float(x) for x in a2m[ckey]])\n                string = format_values(values, key)\n                row.append(string)\n\n            model_metrics[name] = row\n        elif dataset == \"uestc\":\n            stgcn = metrics[\"stgcn\"]\n\n            if \"GT\" not in model_metrics:\n                for sets in [\"train\", \"test\"]:\n                    stgcn[f\"fid_gt_{sets}\"] = stgcn[f\"fid_gt2_{sets}\"]\n                stgcnkeys = [\"fid_gt_train\", \"fid_gt_test\", \"accuracy_gt_train\", \"diversity_gt_train\", \"multimodality_gt_train\"]\n                row = []\n                for ckey in stgcnkeys:\n                    values = np.array([float(x) for x in stgcn[ckey]])\n                    string = format_values(values, ckey.split(\"_\")[0])\n                    row.append(string)\n                model_metrics[\"GT\"] = row\n\n            stgcnkeys = [\"fid_gen_train\", \"fid_gen_test\", \"accuracy_gen_train\", \"diversity_gen_train\", \"multimodality_gen_train\"]\n            row = []\n            for ckey in stgcnkeys:\n                values = np.array([float(x) for x in stgcn[ckey]])\n                string = format_values(values, ckey.split(\"_\")[0])\n                row.append(string)\n\n            model_metrics[name] = row\n    \n    gtvalues = [\"GT\"]\n    for dataset in [\"uestc\", \"ntu13\"]:\n        model_metrics = model_metrics_dataset[dataset]\n        gtvalues.extend(model_metrics[\"GT\"])\n    gtrow = \" & \".join(gtvalues) + r\"\\\\\"\n\n    rows = []\n    for model in model_metrics:\n        if model == \"GT\":\n            continue\n        values = [model]\n        for dataset in [\"uestc\", \"ntu13\"]:\n            model_metrics = model_metrics_dataset[dataset]\n            if model in model_metrics:\n                values.extend(model_metrics[model])\n            else:\n                dummy = [\"\" for _ in range(len(model_metrics[\"GT\"]))]\n                values.extend(dummy)\n        row = \" & \".join(values) + r\"\\\\\"\n        rows.append(row)\n        \n    rows = \"\\n\".join(rows)\n        \n    template = r\"\"\"\\documentclass{{standalone}}\n\\usepackage{{booktabs}}\n\\usepackage[dvipsnames]{{xcolor}}\n\\begin{{document}}\n    \\begin{{tabular}}{{lccccc|cccc}}\n        \\toprule\n        & \\multicolumn{{5}}{{c}}{{UESTC}} & \\multicolumn{{4}}{{|c}}{{NTU-13}} \\\\\n    Loss & FID$_{{tr}}$$\\downarrow$ & FID$_{{test}}$$\\downarrow$ & Acc.$\\uparrow$ & Div.$\\rightarrow$ & Multimod.$\\rightarrow$ & FID$_{{tr}}$$\\downarrow$ & Acc.$\\uparrow$ & Div.$\\rightarrow$ & Multimod.$\\rightarrow$ \\\\\n        \\midrule\n        {gtrow}\n        \\midrule\n        {rows}\n        \\bottomrule\n    \\end{{tabular}}\n\\end{{document}}\n\"\"\".format(rows=rows, gtrow=gtrow)\n    return template\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    def parse_opts():\n        parser = argparse.ArgumentParser()\n        parser.add_argument(\"exppath\", help=\"name of the exp\")\n        return parser.parse_args()\n\n    opt = parse_opts()\n    exppath = opt.exppath\n\n    folder = exppath\n    \n    tex = construct_table(folder)\n    texpath = os.path.join(folder, \"table_loss.tex\")\n\n    with open(texpath, \"w\") as ftex:\n        ftex.write(tex)\n        \n    print(f\"Table saved at {texpath}\")\n\n\n\n"
  },
  {
    "path": "PBnet/src/evaluate/tables/latexmodela2m.py",
    "content": "import os\nimport glob\nimport math\nimport numpy as np\n\nfrom .tools import load_metrics\n\n\ndef get_gtname(mname):\n    return mname + \"_gt\"\n\n\ndef get_genname(mname):\n    return mname + \"_gen\"\n\n\ndef get_reconsname(mname):\n    return mname + \"_recons\"\n\n\ndef construct_table(folder, evaluation):\n    evalpath = os.path.join(folder, evaluation)\n    metrics = load_metrics(evalpath)\n\n    a2m = metrics[\"action2motion\"]\n    keys = [\"fid\", \"accuracy\", \"diversity\", \"multimodality\"]\n\n    a2m[\"fid_gt\"] = a2m[\"fid_gt2\"]\n    modelname = os.path.split(folder)[1]\n\n    modelname = modelname.replace(\"_ntu13_vibe_rot6d_glob_translation_numlayers_8_numframes_60_sampling_conseq_samplingstep_1_kl_1e-05_gelu\", \"\")\n    modelname = modelname.replace(\"_\", \" \")\n\n    def valformat(val, power=3):\n        p = float(pow(10, power))\n        # \"{:<04}\".format(np.round(p*val).astype(int)/p)\n        return str(np.round(p*val).astype(int)/p).ljust(5, \"0\")\n    \n    values = []\n    rows = []\n    for model in [\"gt\", \"gen\", \"recons\"]:\n        row = [\"{} {}\".format(modelname, model)]\n        for key in keys:\n            ckey = f\"{key}_{model}\"\n            values = np.array([float(x) for x in a2m[ckey]])\n            mean = valformat(np.mean(values))\n            interval = valformat(1.96 * np.var(values))[1:]\n            # string = rf\"${mean:.4}^{{\\pm{interval:.3}}}$\"\n            string = rf\"${mean}^{{\\pm{interval}}}$\"\n            row.append(string)\n        row = \" & \".join(row) + r\"\\\\\"\n        rows.append(row)\n        \n    MODELS = \"\\n        \".join(rows)\n    \n    template = r\"\"\"\\documentclass{{standalone}}\n\\usepackage{{booktabs}}\n\\usepackage[dvipsnames]{{xcolor}}\n\\begin{{document}}\n    \\begin{{tabular}}{{lcccc}}\n        \\toprule\n        Architecture &  FID$\\downarrow$ & Acc.$\\uparrow$ & Div.$\\uparrow$ & Multimod.$\\uparrow$\\\\\n        \\midrule\n        action2motion ground truth   & $0.031^{{\\pm.004}}$  & $0.999^{{\\pm.001}}$  & $7.108^{{\\pm.048}}$ & $2.194^{{\\pm.025}}$ \\\\\n        action2motion lie model & $0.330^{{\\pm.008}}$  & $0.949^{{\\pm.001}}$  & $7.065^{{\\pm.043}}$ & $2.052^{{\\pm.030}}$ \\\\\n        \\midrule\n        {MODELS}\n        \\bottomrule\n    \\end{{tabular}}\n\\end{{document}}\n\"\"\".format(MODELS=MODELS)\n    return template\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    def parse_opts():\n        parser = argparse.ArgumentParser()\n        parser.add_argument(\"evalpath\", help=\"name of the evaluation\")\n        return parser.parse_args()\n\n    opt = parse_opts()\n    evalpath = opt.evalpath\n    \n    folder, evaluation = os.path.split(evalpath)\n    tex = construct_table(folder, evaluation)\n    texpath = os.path.join(folder, os.path.splitext(evaluation)[0] + \".tex\")\n\n    with open(texpath, \"w\") as ftex:\n        ftex.write(tex)\n        \n    print(f\"Table saved at {texpath}\")\n"
  },
  {
    "path": "PBnet/src/evaluate/tables/latexmodelsa2m.py",
    "content": "import os\nimport glob\nimport math\nimport re\nimport numpy as np\n\nfrom .tools import load_metrics\n\n\ndef valformat(val, power=3):\n    p = float(pow(10, power))\n    # \"{:<04}\".format(np.round(p*val).astype(int)/p)\n    return str(np.round(p*val).astype(int)/p).ljust(5, \"0\")\n\n\ndef construct_table(folder):\n    exppath = folder\n    paths = glob.glob(f\"{exppath}/**/evaluation*_all*.yaml\")\n\n    keys = [\"fid\", \"accuracy\", \"diversity\", \"multimodality\"]\n\n    models_results = []\n    for i, path in enumerate(paths):\n        metrics = load_metrics(path)\n        a2m = metrics[\"action2motion\"]\n        a2m[\"fid_gt\"] = a2m[\"fid_gt2\"]\n\n        modelname = os.path.split(os.path.split(path)[0])[1]\n\n        for info in [\"vibe\", \"rot6d\", \"glob\", \"translation\", \"numlayers_8\",\n                     \"numframes_60\", \"sampling_conseq\", \"samplingstep_1\", \"jointstype\",\n                     \"gelu\", \"kl_1e-05\", \"cvae\", \"ntu13\"]:\n            modelname = modelname.replace(info, \"\")\n            \n        modelname = re.sub(\"_{1,}\", \" \", modelname)\n\n        # takin GT only for the first one\n        if i == 0:\n            gtrow = [\"Our GT\"]\n            for key in keys:\n                ckey = f\"{key}_gt\"\n                values = np.array([float(x) for x in a2m[ckey]])\n                mean = valformat(np.mean(values))\n                interval = valformat(1.96 * np.var(values))[1:]\n                # string = rf\"${mean:.4}^{{\\pm{interval:.3}}}$\"\n                string = rf\"${mean}$\"  # ^{{\\pm{interval}}}$\"\n                gtrow.append(string)\n            gtrow = \" & \".join(gtrow) + r\"\\\\\"\n                \n        rows = []\n        for model in [\"gen\"]:  # [\"gt\", \"gen\", \"recons\"]:\n            # row = [\"{} {}\".format(modelname, model)]\n            row = [modelname]\n            for key in keys:\n                ckey = f\"{key}_{model}\"\n                values = np.array([float(x) for x in a2m[ckey]])\n                mean = valformat(np.mean(values))\n                interval = valformat(1.96 * np.var(values))[1:]\n                # string = rf\"${mean:.4}^{{\\pm{interval:.3}}}$\"\n                string = rf\"${mean}$\"  # ^{{\\pm{interval}}}$\"\n                row.append(string)\n            row = \" & \".join(row) + r\"\\\\\"\n            rows.append(row)\n        models_result = \"\\n        \".join(rows)\n        models_results.append(models_result)\n\n    sorting = [\"former rc kl\", \"former rcxyz kl\", \"former rc rcxyz kl\",\n               \"former rc rcxyz vel kl\", \"former rc rcxyz velxyz kl\",\n               \"former rc rcxyz vel velxyz kl\"]\n\n    changing = {\"rc\": r\"$\\mathcal{L}_{R}$\",\n                \"rcxyz\": r\"$\\mathcal{L}_{O}$\",\n                \"vel\": r\"$\\mathcal{L}_{\\Delta R}$\",\n                \"velxyz\": r\"$\\mathcal{L}_{\\Delta O}$\"}\n    \n    changing_jointstype = {\"smpl\": \"J\",\n                           \"vertices\": \"V\"}\n    \n    sorted_models = [gtrow, \"        \\\\midrule\\n\"]\n    for sortkey in sorting:\n        for models_result in models_results:\n            if sortkey in models_result:\n                modelsname = models_result.split(\"&\")[0].rstrip()\n                losses = sortkey.split(\" \")[1:-1]  # remove former and kl\n                wlosses = []\n                for loss in losses:\n                    renaming = changing[loss]\n                    jtype = modelsname.split(\" \")[-1]\n                    if jtype in changing_jointstype:\n                        renaming = renaming.replace(\"O\", changing_jointstype[jtype])\n                    wlosses.append(renaming)\n                \n                models_result = models_result.replace(modelsname, \" + \".join(wlosses))\n                sorted_models.append(models_result)\n                \n    # MODELS = \"\\n        \\\\midrule\\n\".join(sorted_models)\n    MODELS = \"\\n\".join(sorted_models) + \"\\n\"\n    \n    template = r\"\"\"\\documentclass{{standalone}}\n\\usepackage{{booktabs}}\n\\usepackage[dvipsnames]{{xcolor}}\n\\begin{{document}}\n    \\begin{{tabular}}{{lcccc}}\n        \\toprule\n        Architecture &  FID$\\downarrow$ & Acc.$\\uparrow$ & Div.$\\uparrow$ & Multimod.$\\uparrow$\\\\\n        \\midrule\n        action2motion ground truth   & $0.031^{{\\pm.004}}$  & $0.999^{{\\pm.001}}$  & $7.108^{{\\pm.048}}$ & $2.194^{{\\pm.025}}$ \\\\\n        action2motion lie model & $0.330^{{\\pm.008}}$  & $0.949^{{\\pm.001}}$  & $7.065^{{\\pm.043}}$ & $2.052^{{\\pm.030}}$ \\\\\n        \\midrule\n        {MODELS}\n        \\bottomrule\n    \\end{{tabular}}\n\\end{{document}}\n\"\"\".format(MODELS=MODELS)\n    return template\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    def parse_opts():\n        parser = argparse.ArgumentParser()\n        parser.add_argument(\"exppath\", help=\"name of the exp\")\n        return parser.parse_args()\n\n    opt = parse_opts()\n    exppath = opt.exppath\n\n    folder = exppath\n    \n    tex = construct_table(folder)\n    texpath = os.path.join(folder, \"table.tex\")\n\n    with open(texpath, \"w\") as ftex:\n        ftex.write(tex)\n        \n    print(f\"Table saved at {texpath}\")\n"
  },
  {
    "path": "PBnet/src/evaluate/tables/latexmodelsstgcn.py",
    "content": "import os\nimport glob\nimport math\nimport re\nimport numpy as np\n\nfrom .tools import load_metrics\n\n\ndef get_gtname(mname):\n    return mname + \"_gt\"\n\n\ndef get_genname(mname):\n    return mname + \"_gen\"\n\n\ndef get_reconsname(mname):\n    return mname + \"_recons\"\n\n\ndef valformat(val, power=3):\n    p = float(pow(10, power))\n    # \"{:<04}\".format(np.round(p*val).astype(int)/p)\n    return str(np.round(p*val).astype(int)/p).ljust(5, \"0\")\n\n\ndef construct_table(folder):\n    exppath = folder\n    paths = glob.glob(f\"{exppath}/**/evaluation*0500_all.yaml\")\n\n    keys = [\"fid\", \"accuracy\", \"diversity\", \"multimodality\"]\n\n    models_results = []\n    for i, path in enumerate(paths):\n        metrics = load_metrics(path)\n        stgcn = metrics[\"stgcn\"]\n\n        # easy fid gt\n        for sets in [\"train\", \"test\"]:\n            stgcn[f\"fid_gt_{sets}\"] = stgcn[f\"fid_gt2_{sets}\"]\n        \n        modelname = os.path.split(os.path.split(path)[0])[1]\n\n        for info in [\"vibe\", \"rot6d\", \"glob\", \"translation\", \"numlayers_8\",\n                     \"numframes_60\", \"sampling_conseq\", \"samplingstep_1\", \"jointstype\",\n                     \"gelu\", \"kl_1e-05\", \"cvae\", \"uestc\"]:\n            modelname = modelname.replace(info, \"\")\n\n        modelname = re.sub(\"_{1,}\", \" \", modelname)\n\n        # takin GT only for the first one\n        if i == 0:\n            gtrow = [\"Our GT\"]\n            for sets in [\"train\", \"test\"]:\n                for key in keys:\n                    ckey = f\"{key}_gt_{sets}\"\n                    values = np.array([float(x) for x in stgcn[ckey]])\n                    mean = valformat(np.mean(values))\n                    interval = valformat(1.96 * np.var(values))[1:]\n                    # string = rf\"${mean:.4}^{{\\pm{interval:.3}}}$\"\n                    string = rf\"${mean}$\"  # ^{{\\pm{interval}}}$\"\n                    gtrow.append(string)\n                gtrow.append(\"\")\n            gtrow = \" & \".join(gtrow[:-1]) + r\"\\\\\"\n\n        rows = []\n        for model in [\"gen\"]:  # [\"gt\", \"gen\", \"recons\"]:\n            # row = [\"{} {}\".format(modelname, model)]\n            row = [modelname]\n            for sets in [\"train\", \"test\"]:\n                for key in keys:\n                    ckey = f\"{key}_{model}_{sets}\"\n                    values = np.array([float(x) for x in stgcn[ckey]])\n                    mean = valformat(np.mean(values))\n                    interval = valformat(1.96 * np.var(values))[1:]\n                    # string = rf\"${mean:.4}^{{\\pm{interval:.3}}}$\"\n                    string = rf\"${mean}$\"  # ^{{\\pm{interval}}}$\"\n                    row.append(string)\n                row.append(\"\")\n            row = \" & \".join(row[:-1]) + r\"\\\\\"\n            rows.append(row)\n        models_result = \"\\n        \".join(rows)\n        models_results.append(models_result)\n\n    sorting = [\"former rc kl\", \"former rcxyz kl\", \"former rc rcxyz kl\",\n               \"former rc rcxyz vel kl\", \"former rc rcxyz velxyz kl\",\n               \"former rc rcxyz vel velxyz kl\"]\n    \n    changing = {\"rc\": r\"$\\mathcal{L}_{R}$\",\n                \"rcxyz\": r\"$\\mathcal{L}_{O}$\",\n                \"vel\": r\"$\\mathcal{L}_{\\Delta R}$\",\n                \"velxyz\": r\"$\\mathcal{L}_{\\Delta O}$\"}\n    \n    changing_jointstype = {\"smpl\": \"J\",\n                           \"vertices\": \"V\"}\n    \n    sorted_models = [gtrow, \"        \\\\midrule\\n\"]\n    for sortkey in sorting:\n        for models_result in models_results:\n            if sortkey in models_result:\n                modelsname = models_result.split(\"&\")[0].rstrip()\n                losses = sortkey.split(\" \")[1:-1]  # remove former and kl\n                wlosses = []\n                for loss in losses:\n                    renaming = changing[loss]\n                    jtype = modelsname.split(\" \")[-1]\n                    if jtype in changing_jointstype:\n                        renaming = renaming.replace(\"O\", changing_jointstype[jtype])\n                    wlosses.append(renaming)\n                \n                models_result = models_result.replace(modelsname, \" + \".join(wlosses))\n                sorted_models.append(models_result)\n                \n    # MODELS = \"\\n        \\\\midrule\\n\".join(sorted_models)\n    MODELS = \"\\n\".join(sorted_models) + \"\\n\"\n    \n    template = r\"\"\"\\documentclass{{standalone}}\n\\usepackage{{booktabs}}\n\\usepackage[dvipsnames]{{xcolor}}\n\\begin{{document}}\n    \\begin{{tabular}}{{lccccccccc}}\n        Architecture &  FID$\\downarrow$ & Acc.$\\uparrow$ & Div.$\\uparrow$ & Multimod.$\\uparrow$ & & FID$\\downarrow$ & Acc.$\\uparrow$ & Div.$\\uparrow$ & Multimod.$\\uparrow$\\\\\n        \\midrule\n        \\toprule\n        {MODELS}\n        \\bottomrule\n    \\end{{tabular}}\n\\end{{document}}\n\"\"\".format(MODELS=MODELS)\n    return template\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    def parse_opts():\n        parser = argparse.ArgumentParser()\n        parser.add_argument(\"exppath\", help=\"name of the exp\")\n        return parser.parse_args()\n\n    opt = parse_opts()\n    exppath = opt.exppath\n\n    folder = exppath\n    \n    tex = construct_table(folder)\n    texpath = os.path.join(folder, \"table.tex\")\n\n    with open(texpath, \"w\") as ftex:\n        ftex.write(tex)\n        \n    print(f\"Table saved at {texpath}\")\n"
  },
  {
    "path": "PBnet/src/evaluate/tables/losstable.py",
    "content": "import os\nimport glob\nimport math\nimport re\nimport numpy as np\n\nfrom .tools import load_metrics\n\n\ndef valformat(val, power=3):\n    p = float(pow(10, power))\n    # \"{:<04}\".format(np.round(p*val).astype(int)/p)\n    return str(np.round(p*val).astype(int)/p).ljust(4, \"0\")\n\n\ndef format_values(values, key):\n    mean = np.mean(values)\n\n    if key == \"accuracy\":\n        mean = 100*mean\n        values = 100*values\n        smean = valformat(mean, 1)\n    else:\n        smean = valformat(mean, 2)\n\n    interval = valformat(1.96 * np.var(values), 2)  # [1:]\n    # string = rf\"${mean:.4}^{{\\pm{interval:.3}}}$\"\n    # string = rf\"${smean}$\"  # ^{{\\pm{interval}}}$\"\n    string = rf\"${smean}^{{\\pm{interval}}}$\"\n    return string\n                    \n\ndef construct_table(folder):\n    exppath = folder\n    paths = glob.glob(f\"{exppath}/**/evaluation*_all*.yaml\")\n\n    keys = [\"fid\", \"accuracy\", \"diversity\", \"multimodality\"]\n    \n    model_metrics_dataset = {\"ntu13\": {},\n                             \"uestc\": {}}\n\n    epoch_dataset = {\"ntu13\": 1000,\n                     \"uestc\": 500}\n        \n    for i, path in enumerate(paths):\n        epoch = int(path.split(\"evaluation_metrics_\")[1].split(\".\")[0].split(\"_\")[0])\n        \n        modelinfo = os.path.split(os.path.split(path)[0])[1]\n        dataset = modelinfo.split(\"_kl_\")[1].split(\"_\")[0]\n\n        # Take the right epoch\n        if epoch_dataset[dataset] != epoch:\n            continue\n\n        if \"vel\" in modelinfo:\n            continue\n\n        if \"rc_rcxyz_kl\" in modelinfo:\n            if \"vertices\" in modelinfo:\n                name = r\"$\\mathcal{L}_{P}$ + $\\mathcal{L}_{V}$\"\n            else:\n                # name = r\"$\\mathcal{L}_{P}$ + $\\mathcal{L}_{J}$\"\n                continue\n        elif \"rc_kl\" in modelinfo:\n            name = r\"$\\mathcal{L}_{P}$\"\n        elif \"rcxyz_kl\" in modelinfo:\n            if \"vertices\" in modelinfo:\n                name = r\"$\\mathcal{L}_{V}$\"\n            else:\n                name = r\"$\\mathcal{L}_{J}$\"\n        else:\n            print(f\"weird: {modelinfo}\")\n            \n        metrics = load_metrics(path)\n\n        model_metrics = model_metrics_dataset[dataset]\n        if dataset == \"ntu13\":\n            a2m = metrics[\"action2motion\"]\n\n            if \"GT\" not in model_metrics:\n                a2m[\"fid_gt\"] = a2m[\"fid_gt2\"]\n                \n                row = []\n                for key in keys:\n                    ckey = f\"{key}_gt\"\n                    values = np.array([float(x) for x in a2m[ckey]])\n                    string = format_values(values, key)\n                    row.append(string)\n                model_metrics[\"GT\"] = row\n                \n            row = []\n            for key in keys:\n                ckey = f\"{key}_gen\"\n                values = np.array([float(x) for x in a2m[ckey]])\n                string = format_values(values, key)\n                row.append(string)\n\n            model_metrics[name] = row\n        elif dataset == \"uestc\":\n            stgcn = metrics[\"stgcn\"]\n\n            if \"GT\" not in model_metrics:\n                for sets in [\"train\", \"test\"]:\n                    stgcn[f\"fid_gt_{sets}\"] = stgcn[f\"fid_gt2_{sets}\"]\n                stgcnkeys = [\"fid_gt_train\", \"fid_gt_test\", \"accuracy_gt_train\", \"diversity_gt_train\", \"multimodality_gt_train\"]\n                row = []\n                for ckey in stgcnkeys:\n                    values = np.array([float(x) for x in stgcn[ckey]])\n                    string = format_values(values, ckey.split(\"_\")[0])\n                    row.append(string)\n                model_metrics[\"GT\"] = row\n\n            stgcnkeys = [\"fid_gen_train\", \"fid_gen_test\", \"accuracy_gen_train\", \"diversity_gen_train\", \"multimodality_gen_train\"]\n            row = []\n            for ckey in stgcnkeys:\n                values = np.array([float(x) for x in stgcn[ckey]])\n                string = format_values(values, ckey.split(\"_\")[0])\n                row.append(string)\n\n            model_metrics[name] = row\n\n    lossmodels = [r\"$\\mathcal{L}_{J}$\", r\"$\\mathcal{L}_{P}$\", r\"$\\mathcal{L}_{V}$\",\n                  # r\"$\\mathcal{L}_{P}$ + $\\mathcal{L}_{J}$\",\n                  r\"$\\mathcal{L}_{P}$ + $\\mathcal{L}_{V}$\"]\n    \n    gtvalues = [\"GT\"]\n    for dataset in [\"uestc\", \"ntu13\"]:\n        model_metrics = model_metrics_dataset[dataset]\n        gtvalues.extend(model_metrics[\"GT\"])\n    gtrow = \" & \".join(gtvalues) + r\"\\\\\"\n    \n    rows = []\n    for model in lossmodels:\n        if model == \"GT\":\n            continue\n        values = [model]\n        for dataset in [\"uestc\", \"ntu13\"]:\n            model_metrics = model_metrics_dataset[dataset]\n            if model in model_metrics:\n                values.extend(model_metrics[model])\n            else:\n                dummy = [\"\" for _ in range(len(model_metrics[\"GT\"]))]\n                values.extend(dummy)\n        row = \" & \".join(values) + r\"\\\\\"\n        rows.append(row)\n\n    rows = \"\\n\".join(rows)\n        \n    template = r\"\"\"\\documentclass{{standalone}}\n\\usepackage{{booktabs}}\n\\usepackage[dvipsnames]{{xcolor}}\n\\begin{{document}}\n    \\begin{{tabular}}{{lccccc|cccc}}\n        \\toprule\n        & \\multicolumn{{5}}{{c}}{{UESTC}} & \\multicolumn{{4}}{{|c}}{{NTU-13}} \\\\\n    Loss & FID$_{{tr}}$$\\downarrow$ & FID$_{{test}}$$\\downarrow$ & Acc.$\\uparrow$ & Div.$\\rightarrow$ & Multimod.$\\rightarrow$ & FID$_{{tr}}$$\\downarrow$ & Acc.$\\uparrow$ & Div.$\\rightarrow$ & Multimod.$\\rightarrow$ \\\\\n        \\midrule\n        {gtrow}\n        \\midrule\n        {rows}\n        \\bottomrule\n    \\end{{tabular}}\n\\end{{document}}\n\"\"\".format(rows=rows, gtrow=gtrow)\n    return template\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    def parse_opts():\n        parser = argparse.ArgumentParser()\n        parser.add_argument(\"exppath\", help=\"name of the exp\")\n        return parser.parse_args()\n\n    opt = parse_opts()\n    exppath = opt.exppath\n\n    folder = exppath\n    \n    tex = construct_table(folder)\n    texpath = os.path.join(folder, \"table_loss.tex\")\n\n    with open(texpath, \"w\") as ftex:\n        ftex.write(tex)\n        \n    print(f\"Table saved at {texpath}\")\n\n\n\n"
  },
  {
    "path": "PBnet/src/evaluate/tables/maketable.py",
    "content": "import os\nimport glob\nimport math\n\nfrom .tools import load_metrics\n\nMETRICS = {\"joints\": [\"acceleration\", \"rc\", \"diversity\", \"multimodality\"],\n           \"action2motion\": [\"accuracy\", \"fid\", \"diversity\", \"multimodality\"]}\n\nUP = r\"$\\uparrow$\"\nDOWN = r\"$\\downarrow$\"\nRIGHT = r\"$\\rightarrow$\"\n\nARROWS = {\"accuracy\": UP,\n          \"acceleration\": RIGHT,\n          \"rc\": DOWN,\n          \"fid\": DOWN,\n          \"diversity\": RIGHT,\n          \"multimodality\": RIGHT}\n\nPOSE_ORDER = [\"xyz\", \"rotvec\", \"rotquat\", \"rotmat\", \"rot6d\"]\nfor pose in POSE_ORDER:\n    METRICS[pose] = METRICS[\"joints\"]\n\nGROUPORDER = POSE_ORDER + [\"action2motion\"]\n\nGREEN = \"Green\"\nRED = \"Mahogany\"\n\n\ndef bold(string):\n    return r\"\\textbf{{\" + string + r\"}}\"\n\n\ndef colorize_template(string, color):\n    return r\"\\textcolor{{\" + color + r\"}}{{\" + string + r\"}}\"\n\n\ndef colorize_bold_template(string, color):\n    return bold(colorize_template(string, color))\n\n\ndef format_table(val, gtval, mname):\n    value = float(val)\n    \n    try:\n        exp = math.floor(math.log10(value))\n    except ValueError:\n        exp = 0\n        value = 0\n    \n    if mname == \"rc\":\n        formatter = \"{:.1e}\"\n        if value >= 1:\n            formatter = colorize_bold_template(formatter, RED)\n            \n    elif mname in [\"diversity\", \"multimodality\"]:\n        if exp < -1:\n            formatter = \"{:.1e}\"\n        else:\n            formatter = \"{:.3g}\"\n        if gtval is not None:\n            gtval = float(gtval)\n            if value > 0.8*gtval:\n                formatter = colorize_bold_template(formatter, GREEN)\n            elif value < 0.3*gtval:\n                formatter = colorize_bold_template(formatter, RED)\n                    \n    elif mname == \"accuracy\":\n        formatter = \"{:.1%}\"\n        if value > 0.65:\n            formatter = colorize_bold_template(formatter, GREEN)\n        elif value < 0.35:\n            formatter = colorize_bold_template(formatter, RED)\n        \n    elif mname == \"acceleration\":\n        formatter = \"{:.1e}\"\n        if gtval is not None:\n            gtval = float(gtval)\n            diff = math.log10(value/gtval)\n            # below acceleration\n            if diff < 0.05:\n                formatter = colorize_bold_template(formatter, GREEN)\n            elif diff > 0.3:\n                formatter = colorize_bold_template(formatter, RED)\n                \n    else:\n        formatter = \"{:.2f}\"\n\n    formatter = bold(formatter)\n    return formatter.format(value).replace(\"%\", r\"\\%\")\n\n\ndef get_gtname(mname):\n    return mname + \"_gt\"\n\n\ndef get_genname(mname):\n    return mname + \"_gen\"\n\n\ndef get_reconsname(mname):\n    return mname + \"_recons\"\n\n\ndef collect_tables(folder, expname, lastepoch=False, norecons=False):\n    exppath = os.path.join(folder, expname)\n    paths = glob.glob(f\"{exppath}/**/evaluation*\")\n\n    if len(paths) == 0:\n        raise ValueError(\"No evaluation founds.\")\n\n    pose_rep, *losses = expname.split(\"_\")\n    expname = expname.replace(\"_\", \"\\\\_\")\n\n    models_kl = {}\n    allkls = set()\n    models_epochs = {}\n    for path in paths:\n        metrics = load_metrics(path)\n        fname = os.path.split(path)[0]\n        modelname = fname.split(\"cvae_\")[1].split(\"_rc\")[0]\n        kl_loss = float(fname.split(\"_kl_\")[2].split(\"_\")[0])\n        epoch = os.path.split(path)[1].split(\"evaluation_metrics_\")[1].split(\".\")[0]\n            \n        if lastepoch:\n            if modelname not in models_epochs:\n                models_epochs[modelname] = epoch\n            else:\n                if models_epochs[modelname] > epoch:\n                    continue\n                else:\n                    models_epochs[modelname] = epoch\n            modelname = rf\"{modelname}\"\n        else:\n            modelname = rf\"{modelname}\\_{epoch}\"\n\n        if \"numlayers\" in fname:\n            nlayers = int(fname.split(\"numlayers\")[1].split(\"_\")[1])\n            modelname += rf\"\\_nlayer\\_{nlayers}\"\n\n        if \"relu\" in fname:\n            activation = \"relu\"\n        elif \"gelu\" in fname:\n            activation = \"gelu\"\n        else:\n            activation = \"\"\n\n        modelname += rf\"\\_{activation}\"\n\n        try:\n            ablation = fname.split(\"abl_\")[1].split(\"_sampling\")[0]\n            ablation = ablation.replace(\"_\", r\"\\_\")\n            modelname += rf\"\\_{ablation}\"\n        except IndexError:\n            modelname += r\"\\_noablation\"\n        \n        if modelname not in models_kl:\n            models_kl[modelname] = {}\n        models_kl[modelname][kl_loss] = metrics\n        allkls.add(kl_loss)\n\n    lambdas_sorted = sorted(list(allkls), reverse=True)\n    \n    gtrowl = [\"ground truth\"]\n    for group in GROUPORDER:\n        if group in metrics:\n            for mname in METRICS[group]:\n                gtname = get_gtname(mname)\n                if gtname in metrics[group]:\n                    val = format_table(metrics[group][gtname], None, mname)\n                    gtrowl.append(val)\n                else:\n                    gtrowl.append(\"\")\n            gtrowl.append(\"\")\n    gtrowl.pop()\n    gtrow = \" & \".join(gtrowl) + r\"\\\\\"\n\n    bodylist = [gtrow]\n    bodylist.append(r\"\\midrule\")\n\n    modelnames = sorted(list(models_kl.keys()))\n\n    # compute first rows\n    # to add a first col\n    firstrow = [\"\"]\n    for group in GROUPORDER:\n        if group in metrics:\n            for mname in METRICS[group]:\n                mname = f\"{mname} {ARROWS[mname]}\"\n                firstrow.append(mname)\n            firstrow.append(\"\")\n    firstrow.pop()\n    firstrow = \" & \".join(firstrow) + r\"\\\\\"\n       \n    for lam in lambdas_sorted:\n        for modelname in modelnames:\n            if lam in models_kl[modelname]:\n                metrics = models_kl[modelname][lam]\n                row = [f\"{modelname} {lam}\"]\n                for group in GROUPORDER:\n                    if group in metrics:\n                        for mname in METRICS[group]:\n                            gtname = get_gtname(mname)\n                            gtval = metrics[group][gtname] if gtname in metrics[group] else None\n                            genname = get_genname(mname)\n                            reconsname = get_reconsname(mname)\n                            if not norecons and genname in metrics[group] and reconsname in metrics[group]:\n                                genval = format_table(metrics[group][genname], gtval, mname)\n                                reconsval = format_table(metrics[group][reconsname], gtval, mname)\n                                row.append(f\"{genval}/{reconsval}\")\n                            elif genname in metrics[group]:\n                                genval = format_table(metrics[group][genname], gtval, mname)\n                                row.append(f\"{genval}\")\n                            elif reconsname in metrics[group]:\n                                reconsval = format_table(metrics[group][reconsname], gtval, mname)\n                                row.append(f\"{reconsval}\")\n                            else:\n                                print(f\"{mname} is not present in this evaluation\")\n                        row.append(\"\")\n                row.pop()\n                row = \" & \".join(row) + r\"\\\\\"\n                bodylist.append(row)\n        # bodylist.append(emptyrow)\n        bodylist.append(r\"\\midrule\")\n\n    bodylist.append(r\"\\bottomrule\")\n    body = \"\\n\".join(bodylist)\n    ncols = len(gtrowl)\n    title = f\"Evaluation of {expname} experiment\"\n    template = r\"\"\"\\documentclass{{standalone}}\n\\usepackage{{booktabs}}\n\\usepackage[dvipsnames]{{xcolor}}\n\\begin{{document}}\n\\begin{{tabular}}{{{ncolsl}}}\n\\multicolumn{{{ncols}}}{{c}}{{{title}}} \\\\\n\\multicolumn{{{ncols}}}{{c}}{{}} \\\\\n& \\multicolumn{{{nbcolsxyz}}}{{c}}{{xyz}} & & \\multicolumn{{{nbcolspose}}}{{c}}{{{pose_rep}}} & & \\multicolumn{{{nbcolsa2m}}}{{c}}{{action2motion}} \\\\\n{firstrow}\n\\midrule\n{body}\n\\end{{tabular}}\n\\end{{document}}\n\"\"\".format(ncolsl=\"l\"+\"c\"*(ncols-1), ncols=ncols,\n           pose_rep=pose_rep, title=title, firstrow=firstrow,\n           nbcolsxyz=len(METRICS[\"joints\"]),\n           nbcolspose=len(METRICS[pose_rep]),\n           nbcolsa2m=len(METRICS[\"action2motion\"]),\n           body=body)\n    return template\n\n\nif __name__ == \"__main__\":\n    import argparse\n    \n    def parse_opts():\n        parser = argparse.ArgumentParser()\n        parser.add_argument(\"exppath\", help=\"name of the exp\")\n        parser.add_argument(\"--outpath\", default=\"tex\", help=\"name of the exp\")\n        parser.add_argument(\"--norecons\", dest='norecons', action='store_true')\n        parser.set_defaults(norecons=False)\n        parser.add_argument(\"--lastepoch\", dest='lastepoch', action='store_true')\n        parser.set_defaults(lastepoch=False)\n        return parser.parse_args()\n\n    opt = parse_opts()\n    exppath = opt.exppath\n    norecons = opt.norecons\n    lastepoch = opt.lastepoch\n    \n    folder, expname = os.path.split(exppath)\n\n    template = collect_tables(folder, expname, lastepoch=lastepoch, norecons=norecons)\n\n    # os.makedirs(opt.outpath, exist_ok=True)\n    \n    name = expname\n    if norecons:\n        name += \"_norecons\"\n    texpath = os.path.join(exppath, name + \".tex\")\n\n    with open(texpath, \"w\") as ftex:\n        ftex.write(template)\n    print(f\"Table saved at {texpath}\")\n"
  },
  {
    "path": "PBnet/src/evaluate/tables/numlayertable.py",
    "content": "import os\nimport glob\nimport math\nimport re\nimport numpy as np\n\nfrom .tools import load_metrics\n\n\ndef valformat(val, power=3):\n    p = float(pow(10, power))\n    # \"{:<04}\".format(np.round(p*val).astype(int)/p)\n    return str(np.round(p*val).astype(int)/p).ljust(4, \"0\")\n\n\ndef format_values(values, key):\n    mean = np.mean(values)\n\n    if key == \"accuracy\":\n        mean = 100*mean\n        values = 100*values\n        smean = valformat(mean, 1)\n    else:\n        smean = valformat(mean, 2)\n\n    interval = valformat(1.96 * np.var(values), 2)  # [1:]\n    # string = rf\"${mean:.4}^{{\\pm{interval:.3}}}$\"\n    # string = rf\"${smean}$\"  # ^{{\\pm{interval}}}$\"\n    string = rf\"${smean}^{{\\pm{interval}}}$\"\n    return string\n                    \n\ndef construct_table(folder):\n    exppath = folder\n    paths = glob.glob(f\"{exppath}/**/evaluation*_all*.yaml\")\n\n    keys = [\"fid\", \"accuracy\", \"diversity\", \"multimodality\"]\n    \n    model_metrics_dataset = {\"ntu13\": {},\n                             \"uestc\": {}}\n\n    epoch_dataset = {\"ntu13\": 1000,\n                     \"uestc\": 500}\n        \n    for i, path in enumerate(paths):\n        epoch = int(path.split(\"evaluation_metrics_\")[1].split(\".\")[0].split(\"_\")[0])\n        \n        modelinfo = os.path.split(os.path.split(path)[0])[1]\n        dataset = modelinfo.split(\"_kl_\")[1].split(\"_\")[0]\n\n        # Take the right epoch\n        if epoch_dataset[dataset] != epoch:\n            continue\n\n        name = \"numlayers \" + modelinfo.split(\"numlayers_\")[1].split(\"_\")[0]\n        metrics = load_metrics(path)\n\n        model_metrics = model_metrics_dataset[dataset]\n        if dataset == \"ntu13\":\n            a2m = metrics[\"action2motion\"]\n\n            if \"GT\" not in model_metrics:\n                a2m[\"fid_gt\"] = a2m[\"fid_gt2\"]\n                \n                row = []\n                for key in keys:\n                    ckey = f\"{key}_gt\"\n                    values = np.array([float(x) for x in a2m[ckey]])\n                    string = format_values(values, key)\n                    row.append(string)\n                model_metrics[\"GT\"] = row\n                \n            row = []\n            for key in keys:\n                ckey = f\"{key}_gen\"\n                values = np.array([float(x) for x in a2m[ckey]])\n                string = format_values(values, key)\n                row.append(string)\n\n            model_metrics[name] = row\n        elif dataset == \"uestc\":\n            stgcn = metrics[\"stgcn\"]\n\n            if \"GT\" not in model_metrics:\n                for sets in [\"train\", \"test\"]:\n                    stgcn[f\"fid_gt_{sets}\"] = stgcn[f\"fid_gt2_{sets}\"]\n                stgcnkeys = [\"fid_gt_train\", \"fid_gt_test\", \"accuracy_gt_train\", \"diversity_gt_train\", \"multimodality_gt_train\"]\n                row = []\n                for ckey in stgcnkeys:\n                    values = np.array([float(x) for x in stgcn[ckey]])\n                    string = format_values(values, ckey.split(\"_\")[0])\n                    row.append(string)\n                model_metrics[\"GT\"] = row\n\n            stgcnkeys = [\"fid_gen_train\", \"fid_gen_test\", \"accuracy_gen_train\", \"diversity_gen_train\", \"multimodality_gen_train\"]\n            row = []\n            for ckey in stgcnkeys:\n                values = np.array([float(x) for x in stgcn[ckey]])\n                string = format_values(values, ckey.split(\"_\")[0])\n                row.append(string)\n\n            model_metrics[name] = row\n    \n    gtvalues = [\"GT\"]\n    for dataset in [\"uestc\", \"ntu13\"]:\n        model_metrics = model_metrics_dataset[dataset]\n        gtvalues.extend(model_metrics[\"GT\"])\n    gtrow = \" & \".join(gtvalues) + r\"\\\\\"\n\n    rows = []\n    modelnames = sorted(list(model_metrics.keys()))\n    for model in modelnames:\n        if model == \"GT\":\n            continue\n        values = [model]\n        for dataset in [\"uestc\", \"ntu13\"]:\n            model_metrics = model_metrics_dataset[dataset]\n            if model in model_metrics:\n                values.extend(model_metrics[model])\n            else:\n                dummy = [\"\" for _ in range(len(model_metrics[\"GT\"]))]\n                values.extend(dummy)\n        row = \" & \".join(values) + r\"\\\\\"\n        rows.append(row)\n        \n    rows = \"\\n\".join(rows)\n        \n    template = r\"\"\"\\documentclass{{standalone}}\n\\usepackage{{booktabs}}\n\\usepackage[dvipsnames]{{xcolor}}\n\\begin{{document}}\n    \\begin{{tabular}}{{lccccc|cccc}}\n        \\toprule\n        & \\multicolumn{{5}}{{c}}{{UESTC}} & \\multicolumn{{4}}{{|c}}{{NTU-13}} \\\\\n    Loss & FID$_{{tr}}$$\\downarrow$ & FID$_{{test}}$$\\downarrow$ & Acc.$\\uparrow$ & Div.$\\rightarrow$ & Multimod.$\\rightarrow$ & FID$_{{tr}}$$\\downarrow$ & Acc.$\\uparrow$ & Div.$\\rightarrow$ & Multimod.$\\rightarrow$ \\\\\n        \\midrule\n        {gtrow}\n        \\midrule\n        {rows}\n        \\bottomrule\n    \\end{{tabular}}\n\\end{{document}}\n\"\"\".format(rows=rows, gtrow=gtrow)\n    return template\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    def parse_opts():\n        parser = argparse.ArgumentParser()\n        parser.add_argument(\"exppath\", help=\"name of the exp\")\n        return parser.parse_args()\n\n    opt = parse_opts()\n    exppath = opt.exppath\n\n    folder = exppath\n    \n    tex = construct_table(folder)\n    texpath = os.path.join(folder, \"table_loss.tex\")\n\n    with open(texpath, \"w\") as ftex:\n        ftex.write(tex)\n        \n    print(f\"Table saved at {texpath}\")\n\n\n\n"
  },
  {
    "path": "PBnet/src/evaluate/tables/posereptable.py",
    "content": "import os\nimport glob\nimport math\nimport re\nimport numpy as np\n\nfrom .tools import load_metrics\n\n\ndef valformat(val, power=3):\n    p = float(pow(10, power))\n    # \"{:<04}\".format(np.round(p*val).astype(int)/p)\n    return str(np.round(p*val).astype(int)/p).ljust(4, \"0\")\n\n\ndef format_values(values, key):\n    mean = np.mean(values)\n\n    if key == \"accuracy\":\n        mean = 100*mean\n        values = 100*values\n        smean = valformat(mean, 1)\n    else:\n        smean = valformat(mean, 2)\n\n    interval = valformat(1.96 * np.var(values), 2)  # [1:]\n    # string = rf\"${mean:.4}^{{\\pm{interval:.3}}}$\"\n    # string = rf\"${smean}$\"  # ^{{\\pm{interval}}}$\"\n    string = rf\"${smean}^{{\\pm{interval}}}$\"\n    return string\n                    \n\ndef construct_table(folder):\n    exppath = folder\n    paths = glob.glob(f\"{exppath}/**/evaluation*_all*.yaml\")\n\n    keys = [\"fid\", \"accuracy\", \"diversity\", \"multimodality\"]\n    \n    model_metrics_dataset = {\"ntu13\": {},\n                             \"uestc\": {}}\n\n    epoch_dataset = {\"ntu13\": 1000,\n                     \"uestc\": 500}\n        \n    for i, path in enumerate(paths):\n        epoch = int(path.split(\"evaluation_metrics_\")[1].split(\".\")[0].split(\"_\")[0])\n        \n        modelinfo = os.path.split(os.path.split(path)[0])[1]\n        dataset = modelinfo.split(\"_kl_\")[1].split(\"_\")[0]\n\n        # Take the right epoch\n        if epoch_dataset[dataset] != epoch:\n            continue\n\n        name = \"Pose rep \" + modelinfo.split(\"_vibe_\")[1].split(\"_\")[0]\n        if \"xyz\" in name:\n            continue\n        \n        metrics = load_metrics(path)\n\n        model_metrics = model_metrics_dataset[dataset]\n        if dataset == \"ntu13\":\n            a2m = metrics[\"action2motion\"]\n\n            if \"GT\" not in model_metrics:\n                a2m[\"fid_gt\"] = a2m[\"fid_gt2\"]\n                \n                row = []\n                for key in keys:\n                    ckey = f\"{key}_gt\"\n                    values = np.array([float(x) for x in a2m[ckey]])\n                    string = format_values(values, key)\n                    row.append(string)\n                model_metrics[\"GT\"] = row\n                \n            row = []\n            for key in keys:\n                ckey = f\"{key}_gen\"\n                values = np.array([float(x) for x in a2m[ckey]])\n                string = format_values(values, key)\n                row.append(string)\n\n            model_metrics[name] = row\n        elif dataset == \"uestc\":\n            stgcn = metrics[\"stgcn\"]\n\n            if \"GT\" not in model_metrics:\n                for sets in [\"train\", \"test\"]:\n                    stgcn[f\"fid_gt_{sets}\"] = stgcn[f\"fid_gt2_{sets}\"]\n                stgcnkeys = [\"fid_gt_train\", \"fid_gt_test\", \"accuracy_gt_train\", \"diversity_gt_train\", \"multimodality_gt_train\"]\n                row = []\n                for ckey in stgcnkeys:\n                    values = np.array([float(x) for x in stgcn[ckey]])\n                    string = format_values(values, ckey.split(\"_\")[0])\n                    row.append(string)\n                model_metrics[\"GT\"] = row\n\n            stgcnkeys = [\"fid_gen_train\", \"fid_gen_test\", \"accuracy_gen_train\", \"diversity_gen_train\", \"multimodality_gen_train\"]\n            row = []\n            for ckey in stgcnkeys:\n                values = np.array([float(x) for x in stgcn[ckey]])\n                string = format_values(values, ckey.split(\"_\")[0])\n                row.append(string)\n\n            model_metrics[name] = row\n    \n    gtvalues = [\"GT\"]\n    for dataset in [\"uestc\", \"ntu13\"]:\n        model_metrics = model_metrics_dataset[dataset]\n        if \"GT\" not in model_metrics and dataset == \"uestc\":\n            gtvalues.extend([\" \"] * (5 if dataset == \"uestc\" else 4))\n        else:\n            gtvalues.extend(model_metrics[\"GT\"])\n    gtrow = \" & \".join(gtvalues) + r\"\\\\\"\n\n    rows = []\n    modelnames = sorted(list(model_metrics_dataset[\"ntu13\"].keys()))\n    import ipdb; ipdb.set_trace()\n    for model in modelnames:\n        if model == \"GT\":\n            continue\n        values = [model]\n        for dataset in [\"uestc\", \"ntu13\"]:\n            model_metrics = model_metrics_dataset[dataset]\n            if model in model_metrics:\n                values.extend(model_metrics[model])\n            else:\n                dummy = [\"\" for _ in range(5 if dataset == \"uestc\" else 4)]\n                values.extend(dummy)\n        row = \" & \".join(values) + r\"\\\\\"\n        rows.append(row)\n        \n    rows = \"\\n\".join(rows)\n        \n    template = r\"\"\"\\documentclass{{standalone}}\n\\usepackage{{booktabs}}\n\\usepackage[dvipsnames]{{xcolor}}\n\\begin{{document}}\n    \\begin{{tabular}}{{lccccc|cccc}}\n        \\toprule\n        & \\multicolumn{{5}}{{c}}{{UESTC}} & \\multicolumn{{4}}{{|c}}{{NTU-13}} \\\\\n    Loss & FID$_{{tr}}$$\\downarrow$ & FID$_{{test}}$$\\downarrow$ & Acc.$\\uparrow$ & Div.$\\rightarrow$ & Multimod.$\\rightarrow$ & FID$_{{tr}}$$\\downarrow$ & Acc.$\\uparrow$ & Div.$\\rightarrow$ & Multimod.$\\rightarrow$ \\\\\n        \\midrule\n        {gtrow}\n        \\midrule\n        {rows}\n        \\bottomrule\n    \\end{{tabular}}\n\\end{{document}}\n\"\"\".format(rows=rows, gtrow=gtrow)\n    return template\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    def parse_opts():\n        parser = argparse.ArgumentParser()\n        parser.add_argument(\"exppath\", help=\"name of the exp\")\n        return parser.parse_args()\n\n    opt = parse_opts()\n    exppath = opt.exppath\n\n    folder = exppath\n    \n    tex = construct_table(folder)\n    texpath = os.path.join(folder, \"table_loss.tex\")\n\n    with open(texpath, \"w\") as ftex:\n        ftex.write(tex)\n        \n    print(f\"Table saved at {texpath}\")\n"
  },
  {
    "path": "PBnet/src/evaluate/tools.py",
    "content": "import yaml\n\n\ndef format_metrics(metrics, formatter=\"{:.6}\"):\n    newmetrics = {}\n    for key, val in metrics.items():\n        newmetrics[key] = formatter.format(val)\n    return newmetrics\n\n\ndef save_metrics(path, metrics):\n    with open(path, \"w\") as yfile:\n        yaml.dump(metrics, yfile)\n\n        \ndef load_metrics(path):\n    with open(path, \"r\") as yfile:\n        string = yfile.read()\n        return yaml.load(string, yaml.loader.BaseLoader)\n"
  },
  {
    "path": "PBnet/src/evaluate/tvae_eval.py",
    "content": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom torch.utils.data import DataLoader\nfrom src.utils.tensors_hdtf import collate\n\nimport os\nimport numpy as np\nimport torch.nn.functional as F\n# from .tools import save_metrics, format_metrics\nfrom src.models.get_model import get_model as get_gen_model\n\n\n\ndef evaluate(parameters, dataset, folder, checkpointname, epoch, niter):\n    # num_frames = 60\n\n    device = parameters[\"device\"]\n\n    # dummy => update parameters info\n    model = get_gen_model(parameters)\n\n    print(\"Restore weights..\")\n    checkpointpath = os.path.join(folder, checkpointname)\n    state_dict = torch.load(checkpointpath, map_location=device)\n    model.load_state_dict(state_dict)\n    model.eval()\n\n    if checkpointname.split(\"_\")[0] == 'retraincheckpoint':\n        save_folder = os.path.join(folder, 'fintune', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0])\n    else:\n        save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0])\n    os.makedirs(save_folder, exist_ok=True)\n\n    allseeds = list(range(niter))\n\n    try:\n        for index, seed in enumerate(allseeds):\n            print(f\"Evaluation number: {index+1}/{niter}\")\n            fixseed(seed)\n            save_pred_path = os.path.join(save_folder, 'eval_pred', str(seed))\n            save_gt_path = os.path.join(save_folder, 'eval_gt', str(seed))\n            os.makedirs(save_pred_path, exist_ok=True)\n            os.makedirs(save_gt_path, exist_ok=True)\n            \n\n            dataiterator = DataLoader(dataset, batch_size=parameters[\"batch_size\"],\n                                shuffle=False, num_workers=8, collate_fn=collate)\n\n            with torch.no_grad():\n                for databatch in tqdm(dataiterator, desc=f\"Construct dataloader: generating..\"):\n                    # batch = {key: val.to(device) for key, val in databatch.items()}\n                    pose = databatch[\"x\"]\n                    audio = databatch[\"y\"]\n                    gendurations = databatch[\"lengths\"]\n                    # start = databatch[\"start\"]\n                    batch = model.generate(pose, audio, gendurations)\n                    batch = {key: val.to(device) for key, val in batch.items()}\n                    \n\n                    for pose_pre, pose_gt, mask, filename, start_num in zip(batch['output'], databatch['x'], databatch['mask'], databatch['videoname'], databatch['start']):\n                        # \n                        # x_ref = pose_gt[0,:].unsqueeze(dim=0).cpu()\n                        # pose_pre = (pose_pre.cpu()+x_ref - 0.5) * 180\n                        # gtmasked = (pose_gt[mask].cpu() -0.5 ) * 180\n\n                        # \n                        x_ref = pose_gt[0,:].unsqueeze(dim=0)\n                        pose_pre = pose_pre.cpu()+x_ref\n                        gtmasked = pose_gt[mask].cpu()\n                        outmasked = pose_pre[mask].cpu()\n                        pred_path = os.path.join(save_pred_path, filename+'_'+str(start_num))\n                        gt_path = os.path.join(save_gt_path, filename+'_'+str(start_num)+'_gt')\n                        # np.save(pred_path, pose_pre.cpu())\n                        # np.save(gt_path, pose_gt.cpu())\n                        np.savetxt(pred_path, outmasked)\n                        np.savetxt(gt_path, gtmasked)\n                        loss = F.mse_loss(gtmasked, outmasked, reduction='mean')\n                        print(loss)\n                        \n\n\n    except KeyboardInterrupt:\n        string = \"Saving the evaluation before exiting..\"\n        print(string)\n\n\n    epoch = checkpointname.split(\"_\")[1].split(\".\")[0]\n    metricname = \"evaluation_metrics_{}_all.yaml\".format(epoch)\n\n    evalpath = os.path.join(folder, metricname)\n    print(f\"Saving evaluation: {evalpath}\")\n    # save_metrics(evalpath, metrics)\n"
  },
  {
    "path": "PBnet/src/evaluate/tvae_eval_norm.py",
    "content": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom src.utils.utils import MultiEpochsDataLoader as DataLoader\nfrom src.utils.tensors_hdtf import collate_old\n\nimport os\nimport numpy as np\nimport torch.nn.functional as F\n# from .tools import save_metrics, format_metrics\nfrom src.models.get_model import get_model as get_gen_model\n\n\ndef transform(x, min_val, max_val):\n    out = x * (max_val - min_val) + min_val\n    return out\n\ndef evaluate(parameters, dataset, folder, checkpointname, epoch, niter):\n    # num_frames = 60\n    min_val = dataset.min_vals\n    max_val = dataset.max_vals\n    device = parameters[\"device\"]\n\n    # dummy => update parameters info\n    model = get_gen_model(parameters)\n\n    print(\"Restore weights..\")\n    checkpointpath = os.path.join(folder, checkpointname)\n    state_dict = torch.load(checkpointpath, map_location=device)\n    model.load_state_dict(state_dict)\n    model.eval()\n\n    if checkpointname.split(\"_\")[0] == 'retraincheckpoint':\n        save_folder = os.path.join(folder, 'fintune', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0])\n    else:\n        save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0])\n    os.makedirs(save_folder, exist_ok=True)\n\n    allseeds = list(range(niter))\n\n    try:\n        for index, seed in enumerate(allseeds):\n            print(f\"Evaluation number: {index+1}/{niter}\")\n            fixseed(seed)\n            save_pred_path = os.path.join(save_folder, 'eval_pred', str(seed))\n            save_gt_path = os.path.join(save_folder, 'eval_gt', str(seed))\n            os.makedirs(save_pred_path, exist_ok=True)\n            os.makedirs(save_gt_path, exist_ok=True)\n            \n\n            dataiterator = DataLoader(dataset, batch_size=parameters[\"batch_size\"],\n                                shuffle=False, num_workers=8, collate_fn=collate_old)\n\n            with torch.no_grad():\n                for databatch in tqdm(dataiterator, desc=f\"Construct dataloader: generating..\"):\n                    # batch = {key: val.to(device) for key, val in databatch.items()}\n                    pose = databatch[\"x\"]\n                    audio = databatch[\"y\"]\n                    gendurations = databatch[\"lengths\"]\n                    # start = databatch[\"start\"]\n                    batch = model.generate(pose, audio, gendurations, fact = 1)\n                    batch = {key: val.to(device) for key, val in batch.items()}\n                    \n\n                    for pose_pre, pose_gt, mask, filename, start_num in zip(batch['output'], databatch['x'], databatch['mask'], databatch['videoname'], databatch['start']):\n                        x_ref = pose_gt[0,:].unsqueeze(dim=0)\n                        pose_pre = pose_pre.cpu()+x_ref\n                        gtmasked = pose_gt[mask].cpu()\n                        outmasked = pose_pre[mask].cpu()\n                        gtmasked = transform(gtmasked, min_val, max_val)\n                        outmasked = transform(outmasked, min_val, max_val)\n                        pred_path = os.path.join(save_pred_path, filename+'_'+str(start_num))\n                        gt_path = os.path.join(save_gt_path, filename+'_'+str(start_num)+'_gt')\n                        # np.save(pred_path, pose_pre.cpu())\n                        # np.save(gt_path, pose_gt.cpu())\n                        np.savetxt(pred_path, outmasked)\n                        np.savetxt(gt_path, gtmasked)\n                        loss = F.mse_loss(gtmasked, outmasked, reduction='mean')\n                        print(loss)\n                        \n\n\n    except KeyboardInterrupt:\n        string = \"Saving the evaluation before exiting..\"\n        print(string)\n\n\n    epoch = checkpointname.split(\"_\")[1].split(\".\")[0]\n    metricname = \"evaluation_metrics_{}_all.yaml\".format(epoch)\n\n    evalpath = os.path.join(folder, metricname)\n    print(f\"Saving evaluation: {evalpath}\")\n    # save_metrics(evalpath, metrics)\n"
  },
  {
    "path": "PBnet/src/evaluate/tvae_eval_norm_all.py",
    "content": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom src.utils.utils import MultiEpochsDataLoader as DataLoader\nfrom src.utils.tensors_hdtf import collate_old\n\nimport os\nimport numpy as np\nimport torch.nn.functional as F\n# from .tools import save_metrics, format_metrics\nfrom src.models.get_model import get_model as get_gen_model\n\n\ndef save_images_as_npy(input_data, output_file):\n    # save_npy = np.zeros(input_data.shape[0], 7)\n    # save_npy[:,:, :-1] = input_data\n    # save_npy[:, -1] = ref[:,:, -1]\n    # images_array = np.array(images)\n    np.save(output_file, input_data)\n\n\ndef save_as_chunk(dir, data): \n    if not os.path.exists(dir):\n        os.makedirs(dir)\n    chunks = [data[i:min(i + 25, data.shape[0])] for i in range(0, data.shape[0], 25)]\n\n    for i, chunk in enumerate(chunks):\n        output_file = os.path.join(dir, f'chunk_%04d.npy' % (i))\n        # chunk = np.stack(chunk, axis = 0)\n        save_images_as_npy(chunk, output_file)\n\n\n\ndef transform(x, min_val, max_val):\n    out = x * (max_val - min_val) + min_val\n    return out\n\ndef evaluate(parameters, dataset, folder, checkpointname, epoch, niter):\n    # num_frames = 60\n    min_val = dataset.min_vals\n    max_val = dataset.max_vals\n    device = parameters[\"device\"]\n\n    # dummy => update parameters info\n    model = get_gen_model(parameters)\n\n    print(\"Restore weights..\")\n    checkpointpath = os.path.join(folder, checkpointname)\n    model_ckpt = model.state_dict()\n    state_dict = torch.load(checkpointpath, map_location=device)\n    for name, _ in model_ckpt.items():\n        if  model_ckpt[name].shape == state_dict[name].shape:\n            model_ckpt[name].copy_(state_dict[name])\n        model.load_state_dict(model_ckpt)\n    # model.load_state_dict(state_dict)\n    model.eval()\n\n    if checkpointname.split(\"_\")[0] == 'retraincheckpoint':\n        save_folder = os.path.join(folder, 'fintune', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0])\n    else:\n        save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0])\n    os.makedirs(save_folder, exist_ok=True)\n\n    allseeds = list(range(niter))\n\n    try:\n        for index, seed in enumerate(allseeds):\n            print(f\"Evaluation number: {index+1}/{niter}\")\n            fixseed(seed)\n            save_pred_path = os.path.join(save_folder, 'eval_pred', str(seed))\n            save_gt_path = os.path.join(save_folder, 'eval_gt', str(seed))\n            os.makedirs(save_pred_path, exist_ok=True)\n            os.makedirs(save_gt_path, exist_ok=True)\n            \n\n            dataiterator = DataLoader(dataset, batch_size=parameters[\"batch_size\"],\n                                shuffle=False, num_workers=8, collate_fn=collate_old)\n\n            with torch.no_grad():\n                for databatch in tqdm(dataiterator, desc=f\"Construct dataloader: generating..\"):\n                    # batch = {key: val.to(device) for key, val in databatch.items()}\n                    pose = databatch[\"x\"][:,:,:-1]   # b, len, c\n                    ref = databatch['x'][:,:, -1]\n                    audio = databatch[\"y\"]  # b, len, c\n                    gendurations = databatch[\"lengths\"]\n                    # start = databatch[\"start\"]\n                    batch = model.generate(pose, audio, gendurations, fact = 1)\n                    batch = {key: val.to(device) for key, val in batch.items()}\n                    \n\n                    for pose_pre, pose_gt, mask, filename, start_num in zip(batch['output'], databatch['x'], databatch['mask'], databatch['videoname'], databatch['start']):\n                        x_ref = pose_gt[0,:].unsqueeze(dim=0)\n                        pose_pre = pose_pre.cpu()\n                        padding_vec = torch.zeros(pose_pre.shape[0], 1)\n                        pose_pre = torch.concat([pose_pre, padding_vec], dim = -1)\n                        pose_pre = pose_pre.cpu()+x_ref\n                        gtmasked = pose_gt[mask].cpu()\n                        outmasked = pose_pre[mask].cpu()\n                        gtmasked = transform(gtmasked, min_val, max_val)\n                        outmasked = transform(outmasked, min_val, max_val)\n                        pred_dir = os.path.join(save_pred_path, filename)\n                        save_as_chunk(pred_dir, outmasked)\n                        # np.save(pred_path, pose_pre.cpu())\n                        # np.save(gt_path, pose_gt.cpu())\n                        # np.savetxt(pred_path, outmasked)\n                        # np.savetxt(gt_path, gtmasked)\n                        loss = F.mse_loss(gtmasked, outmasked, reduction='mean')\n                        print(loss)\n                        \n\n\n    except KeyboardInterrupt:\n        string = \"Saving the evaluation before exiting..\"\n        print(string)\n\n\n    epoch = checkpointname.split(\"_\")[1].split(\".\")[0]\n    metricname = \"evaluation_metrics_{}_all.yaml\".format(epoch)\n\n    evalpath = os.path.join(folder, metricname)\n    print(f\"Saving evaluation: {evalpath}\")\n    # save_metrics(evalpath, metrics)\n"
  },
  {
    "path": "PBnet/src/evaluate/tvae_eval_norm_eye_pose.py",
    "content": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom src.utils.utils import MultiEpochsDataLoader as DataLoader\nfrom src.utils.tensors_eye_eval import collate\n\nimport os\nimport numpy as np\nimport torch.nn.functional as F\nimport time\n# from .tools import save_metrics, format_metrics\nfrom src.models.get_model import get_model as get_gen_model\n\n\ndef transform(x, min_val, max_val):\n    out = x * (max_val - min_val) + min_val\n    return out\n\ndef evaluate(parameters, dataset, folder, checkpointname, epoch, niter):\n    # num_frames = 60\n    min_val = dataset.min_vals\n    max_val = dataset.max_vals\n    device = parameters[\"device\"]\n\n    # dummy => update parameters info\n    model = get_gen_model(parameters)\n\n    print(\"Restore weights..\")\n    checkpointpath = os.path.join(folder, checkpointname)\n    state_dict = torch.load(checkpointpath, map_location=device)\n    model.load_state_dict(state_dict)\n    model.eval()\n\n    if checkpointname.split(\"_\")[0] == 'retraincheckpoint':\n        save_folder = os.path.join(folder, 'fintune', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0])\n    else:\n        save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0])\n    os.makedirs(save_folder, exist_ok=True)\n\n    allseeds = list(range(niter))\n\n    try:\n        for index, seed in enumerate(allseeds):\n            print(f\"Evaluation number: {index+1}/{niter}\")\n            fixseed(seed)\n            save_pred_path = os.path.join(save_folder, 'eval_pred', str(seed))\n            save_gt_path = os.path.join(save_folder, 'eval_gt', str(seed))\n            os.makedirs(save_pred_path, exist_ok=True)\n            os.makedirs(save_gt_path, exist_ok=True)\n            \n\n            dataiterator = DataLoader(dataset, batch_size=parameters[\"batch_size\"],\n                                shuffle=False, num_workers=8, collate_fn=collate)\n\n            with torch.no_grad():\n                for databatch in tqdm(dataiterator, desc=f\"Construct dataloader: generating..\"):\n                    # batch = {key: val.to(device) for key, val in databatch.items()}\n                    pose_eye = databatch[\"x\"]\n                    audio = databatch[\"y\"]\n                    gendurations = databatch[\"lengths\"]\n                    # start = databatch[\"start\"]\n                    start_time = time.time()\n                    batch = model.generate(pose_eye, audio, gendurations, fact = 1)\n                    end_time = time.time()\n                    print(f'generate audio time {end_time- start_time}')\n                    start_time = end_time\n                    # exit()\n                    batch = {key: val.to(device) for key, val in batch.items()}\n                    \n\n                    for pose_eye_pre, pose_eye_gt, mask, filename, start_num in zip(batch['output'], databatch['x'], databatch['mask'], databatch['videoname'], databatch['start']):\n                        x_ref = pose_eye_gt[0,:].unsqueeze(dim=0)\n                        pose_eye_pre = pose_eye_pre.cpu()+x_ref\n                        gtmasked = pose_eye_gt[mask].cpu()\n                        outmasked = pose_eye_pre[mask].cpu()\n                        gtmasked[:,:-2] = transform(gtmasked[:,:-2], min_val, max_val)\n                        outmasked[:,:-2] = transform(outmasked[:,:-2], min_val, max_val)\n                        pred_path = os.path.join(save_pred_path, filename+'_'+str(start_num))\n                        gt_path = os.path.join(save_gt_path, filename+'_'+str(start_num)+'_gt')\n                        # np.save(pred_path, pose_pre.cpu())\n                        # np.save(gt_path, pose_gt.cpu())\n                        np.savetxt(pred_path, outmasked)\n                        np.savetxt(gt_path, gtmasked)\n                        loss = F.mse_loss(gtmasked[:,:3], outmasked[:,:3], reduction='mean')\n                        print(loss)\n                        \n\n\n    except KeyboardInterrupt:\n        string = \"Saving the evaluation before exiting..\"\n        print(string)\n\n\n    epoch = checkpointname.split(\"_\")[1].split(\".\")[0]\n    metricname = \"evaluation_metrics_{}_all.yaml\".format(epoch)\n\n    evalpath = os.path.join(folder, metricname)\n    print(f\"Saving evaluation: {evalpath}\")\n    # save_metrics(evalpath, metrics)\n"
  },
  {
    "path": "PBnet/src/evaluate/tvae_eval_norm_eye_pose_seg.py",
    "content": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom src.utils.utils import MultiEpochsDataLoader as DataLoader\nfrom src.utils.tensors_eye_eval import collate\n\nimport os\nimport numpy as np\nimport torch.nn.functional as F\nimport time\n# from .tools import save_metrics, format_metrics\nfrom src.models.get_model import get_model as get_gen_model\n\nINF_LENGTH = 200\n\ndef transform(x, min_val, max_val):\n    out = x * (max_val - min_val) + min_val\n    return out\n\ndef save_images_as_npy(input_data, output_file):\n    # save_npy = np.zeros(input_data.shape[0], 7)\n    # save_npy[:,:, :-1] = input_data\n    # save_npy[:, -1] = ref[:,:, -1]\n    # images_array = np.array(images)\n    np.save(output_file, input_data)\n\ndef save_as_chunk(dir, data): \n    if not os.path.exists(dir):\n        os.makedirs(dir)\n    chunks = [data[i:min(i + 25, data.shape[0])] for i in range(0, data.shape[0], 25)]\n\n    for i, chunk in enumerate(chunks):\n        output_file = os.path.join(dir, f'chunk_%04d.npy' % (i))\n        # chunk = np.stack(chunk, axis = 0)\n        save_images_as_npy(chunk, output_file)\n\ndef evaluate(parameters, dataset, folder, checkpointname, epoch, niter):\n    # num_frames = 60\n    min_val = dataset.min_vals\n    max_val = dataset.max_vals\n    device = parameters[\"device\"]\n\n    # dummy => update parameters info\n    model = get_gen_model(parameters)\n\n    print(\"Restore weights..\")\n    checkpointpath = os.path.join(folder, checkpointname)\n    state_dict = torch.load(checkpointpath, map_location=device)\n    model.load_state_dict(state_dict)\n    model.eval()\n\n    if checkpointname.split(\"_\")[0] == 'retraincheckpoint':\n        save_folder = os.path.join(folder, 'fintune', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0])\n    else:\n        save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0])\n    os.makedirs(save_folder, exist_ok=True)\n\n    allseeds = list(range(niter))\n\n    try:\n        for index, seed in enumerate(allseeds):\n            print(f\"Evaluation number: {index+1}/{niter}\")\n            fixseed(seed)\n            save_pred_path_pose = os.path.join(save_folder, 'eval_pred', str(seed),'pose')\n            save_gt_path_pose = os.path.join(save_folder, 'eval_gt', str(seed),'pose')\n            save_pred_path_eye = os.path.join(save_folder, 'eval_pred', str(seed),'eye')\n            save_gt_path_eye = os.path.join(save_folder, 'eval_gt', str(seed),'eye')\n            os.makedirs(save_pred_path_pose, exist_ok=True)\n            os.makedirs(save_gt_path_pose, exist_ok=True)\n            os.makedirs(save_pred_path_eye, exist_ok=True)\n            os.makedirs(save_gt_path_eye, exist_ok=True)\n            \n\n            dataiterator = DataLoader(dataset, batch_size=parameters[\"batch_size\"],\n                                shuffle=False, num_workers=8, collate_fn=collate)\n\n            with torch.no_grad():\n                for databatch in tqdm(dataiterator, desc=f\"Construct dataloader: generating..\"):\n                    # batch = {key: val.to(device) for key, val in databatch.items()}\n                    pose_eye = databatch[\"x\"]\n                    audio = databatch[\"y\"]\n                    gendurations = databatch[\"lengths\"]\n                    # start = databatch[\"start\"]\n                    # start_time = time.time()\n                    # batch = model.generate(pose_eye, audio, gendurations, fact = 1)\n                    # end_time = time.time()\n                    # print(f'generate audio time {end_time- start_time}')\n                    # start_time = end_time\n                    # # exit()\n                    # batch = {key: val.to(device) for key, val in batch.items()}\n\n                    output = None\n                    for i in range(0, pose_eye.shape[1], INF_LENGTH):\n                        # step 1: seg\n                        start = i\n                        end = min(pose_eye.shape[1], i + INF_LENGTH)\n                        pose_seg = pose_eye[:, start:end]\n                        audio_seg = audio[:, start:end]\n                        gendurations_seg = torch.tensor([end - start])\n                        # step 2: predict\n                        batch = model.generate(pose_seg, audio_seg, gendurations_seg, fact = 1)\n                        # step 3: merge\n                        if output == None:\n                            output = batch['output'].detach().cpu()\n                        else:\n                            output = torch.concat([output, batch['output'].detach().cpu()], dim= 1)\n                    \n\n                    for pose_pre, pose_gt, mask, filename, start_num in zip(output, databatch['x'], databatch['mask'], databatch['videoname'], databatch['start']):\n                        \n                        pose_pre = pose_pre.cpu()\n                        # padding_vec = torch.zeros(pose_pre.shape[0], 1)\n                        # pose_pre = torch.concat([pose_pre], dim = -1)\n                        for i in range(0, pose_gt.shape[0], INF_LENGTH):\n                            start = i\n                            end = min(pose_gt.shape[0], i + INF_LENGTH)\n                            x_ref = pose_gt[i,:].unsqueeze(dim=0)\n                            pose_pre[start:end] = pose_pre[start:end]+x_ref\n                        gtmasked = pose_gt[mask].cpu()\n                        outmasked = pose_pre[mask].cpu()\n                        gtmasked[:,:-2] = transform(gtmasked[:,:-2], min_val, max_val)\n                        outmasked[:,:-2] = transform(outmasked[:,:-2], min_val, max_val)\n                        pred_dir_pose = os.path.join(save_pred_path_pose, filename)\n                        pred_dir_eye = os.path.join(save_pred_path_eye, filename)\n                        out_eye = outmasked[:, 6:]\n                        out_pose = outmasked[:, :6]\n                        save_as_chunk(pred_dir_pose, out_pose)\n                        save_as_chunk(pred_dir_eye, out_eye)\n                        # save_as_chunk(pred_dir, outmasked)\n                        # np.save(pred_path, pose_pre.cpu())\n                        # np.save(gt_path, pose_gt.cpu())\n                        # np.savetxt(pred_path, outmasked)\n                        # np.savetxt(gt_path, gtmasked)\n                        loss = F.mse_loss(gtmasked, outmasked, reduction='mean')\n                        print(loss)\n                        \n\n\n    except KeyboardInterrupt:\n        string = \"Saving the evaluation before exiting..\"\n        print(string)\n\n\n    epoch = checkpointname.split(\"_\")[1].split(\".\")[0]\n    metricname = \"evaluation_metrics_{}_all.yaml\".format(epoch)\n\n    evalpath = os.path.join(folder, metricname)\n    print(f\"Saving evaluation: {evalpath}\")\n    # save_metrics(evalpath, metrics)\n"
  },
  {
    "path": "PBnet/src/evaluate/tvae_eval_norm_seg.py",
    "content": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom src.utils.utils import MultiEpochsDataLoader as DataLoader\nfrom src.utils.tensors_hdtf import collate, collate_old\n\nimport os\nimport numpy as np\nimport torch.nn.functional as F\n# from .tools import save_metrics, format_metrics\nfrom src.models.get_model import get_model as get_gen_model\n\nINF_LENGTH = 600\n\ndef save_images_as_npy(input_data, output_file):\n    # save_npy = np.zeros(input_data.shape[0], 7)\n    # save_npy[:,:, :-1] = input_data\n    # save_npy[:, -1] = ref[:,:, -1]\n    # images_array = np.array(images)\n    np.save(output_file, input_data)\n\n\ndef save_as_chunk(dir, data): \n    if not os.path.exists(dir):\n        os.makedirs(dir)\n    chunks = [data[i:min(i + 25, data.shape[0])] for i in range(0, data.shape[0], 25)]\n\n    for i, chunk in enumerate(chunks):\n        output_file = os.path.join(dir, f'chunk_%04d.npy' % (i))\n        # chunk = np.stack(chunk, axis = 0)\n        save_images_as_npy(chunk, output_file)\n\n\n\ndef transform(x, min_val, max_val):\n    out = x * (max_val - min_val) + min_val\n    return out\n\ndef evaluate(parameters, dataset, folder, checkpointname, epoch, niter):\n    # num_frames = 60\n    min_val = dataset.min_vals\n    max_val = dataset.max_vals\n    device = parameters[\"device\"]\n\n    # dummy => update parameters info\n    model = get_gen_model(parameters)\n\n    print(\"Restore weights..\")\n    checkpointpath = os.path.join(folder, checkpointname)\n    model_ckpt = model.state_dict()\n    state_dict = torch.load(checkpointpath, map_location=device)\n    for name, _ in model_ckpt.items():\n        if  model_ckpt[name].shape == state_dict[name].shape:\n            model_ckpt[name].copy_(state_dict[name])\n        model.load_state_dict(model_ckpt)\n    # model.load_state_dict(state_dict)\n    model.eval()\n\n    if checkpointname.split(\"_\")[0] == 'retraincheckpoint':\n        save_folder = os.path.join(folder, 'fintune', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0])\n    else:\n        save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0])\n    os.makedirs(save_folder, exist_ok=True)\n\n    allseeds = list(range(niter))\n\n    try:\n        for index, seed in enumerate(allseeds):\n            print(f\"Evaluation number: {index+1}/{niter}\")\n            fixseed(seed)\n            save_pred_path = os.path.join(save_folder, 'eval_pred', str(seed))\n            save_gt_path = os.path.join(save_folder, 'eval_gt', str(seed))\n            os.makedirs(save_pred_path, exist_ok=True)\n            os.makedirs(save_gt_path, exist_ok=True)\n            \n\n            dataiterator = DataLoader(dataset, batch_size=parameters[\"batch_size\"],\n                                shuffle=False, num_workers=8, collate_fn=collate_old)\n\n            with torch.no_grad():\n                for databatch in tqdm(dataiterator, desc=f\"Construct dataloader: generating..\"):\n                    # batch = {key: val.to(device) for key, val in databatch.items()}\n                    pose = databatch[\"x\"][:,:,:-1]   # b, len, c\n                    ref = databatch['x'][:,:, -1]\n                    audio = databatch[\"y\"]  # b, len, c\n                    gendurations = databatch[\"lengths\"]\n                    # start = databatch[\"start\"]\n\n                    output = None\n                    for i in range(0, pose.shape[1], INF_LENGTH):\n                        # step 1: seg\n                        start = i\n                        end = min(pose.shape[1], i + INF_LENGTH)\n                        pose_seg = pose[:, start:end]\n                        audio_seg = audio[:, start:end]\n                        gendurations_seg = torch.tensor([end - start])\n                        # step 2: predict\n                        batch = model.generate(pose_seg, audio_seg, gendurations_seg, fact = 1)\n                        # step 3: merge\n                        if output == None:\n                            output = batch['output'].detach().cpu()\n                        else:\n                            output = torch.concat([output, batch['output'].detach().cpu()], dim= 1)\n                        \n\n\n\n                    \n                    # batch = model.generate(pose, audio, gendurations, fact = 1)\n                    # batch = {key: val.to(device) for key, val in batch.items()}\n                    \n\n                    for pose_pre, pose_gt, mask, filename, start_num in zip(output, databatch['x'], databatch['mask'], databatch['videoname'], databatch['start']):\n                        \n                        pose_pre = pose_pre.cpu()\n                        padding_vec = torch.zeros(pose_pre.shape[0], 1)\n                        pose_pre = torch.concat([pose_pre, padding_vec], dim = -1)\n                        for i in range(0, pose_gt.shape[0], INF_LENGTH):\n                            start = i\n                            end = min(pose_gt.shape[0], i + INF_LENGTH)\n                            x_ref = pose_gt[i,:].unsqueeze(dim=0)\n                            pose_pre[start:end] = pose_pre[start:end]+x_ref\n                        gtmasked = pose_gt[mask].cpu()\n                        outmasked = pose_pre[mask].cpu()\n                        gtmasked = transform(gtmasked, min_val, max_val)\n                        outmasked = transform(outmasked, min_val, max_val)\n                        pred_dir = os.path.join(save_pred_path, filename)\n                        save_as_chunk(pred_dir, outmasked)\n                        # np.save(pred_path, pose_pre.cpu())\n                        # np.save(gt_path, pose_gt.cpu())\n                        # np.savetxt(pred_path, outmasked)\n                        # np.savetxt(gt_path, gtmasked)\n                        loss = F.mse_loss(gtmasked, outmasked, reduction='mean')\n                        print(loss)\n                        \n\n\n    except KeyboardInterrupt:\n        string = \"Saving the evaluation before exiting..\"\n        print(string)\n\n\n    epoch = checkpointname.split(\"_\")[1].split(\".\")[0]\n    metricname = \"evaluation_metrics_{}_all.yaml\".format(epoch)\n\n    evalpath = os.path.join(folder, metricname)\n    print(f\"Saving evaluation: {evalpath}\")\n    # save_metrics(evalpath, metrics)\n"
  },
  {
    "path": "PBnet/src/evaluate/tvae_eval_onlyeye_all_seg.py",
    "content": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom src.utils.utils import MultiEpochsDataLoader as DataLoader\nfrom src.utils.tensors_onlyeye import collate_eval\n\nimport os\nimport numpy as np\nimport torch.nn.functional as F\n# from .tools import save_metrics, format_metrics\nfrom src.models.get_model import get_model as get_gen_model\n\n\ndef save_images_as_npy(input_data, output_file):\n    # save_npy = np.zeros(input_data.shape[0], 7)\n    # save_npy[:,:, :-1] = input_data\n    # save_npy[:, -1] = ref[:,:, -1]\n    # images_array = np.array(images)\n    np.save(output_file, input_data)\n\n\ndef save_as_chunk(dir, data): \n    if not os.path.exists(dir):\n        os.makedirs(dir)\n    chunks = [data[i:min(i + 25, data.shape[0])] for i in range(0, data.shape[0], 25)]\n\n    for i, chunk in enumerate(chunks):\n        output_file = os.path.join(dir, f'chunk_%04d.npy' % (i))\n        # chunk = np.stack(chunk, axis = 0)\n        save_images_as_npy(chunk, output_file)\n\n\n\n# def transform(x, min_val, max_val):\n#     out = x * (max_val - min_val) + min_val\n#     return out\n\ndef evaluate(parameters, dataset, folder, checkpointname, epoch, niter):\n    # num_frames = 60\n    # min_val = dataset.min_vals\n    # max_val = dataset.max_vals\n    device = parameters[\"device\"]\n\n    # dummy => update parameters info\n    model = get_gen_model(parameters)\n\n    print(\"Restore weights..\")\n    checkpointpath = os.path.join(folder, checkpointname)\n    model_ckpt = model.state_dict()\n    state_dict = torch.load(checkpointpath, map_location=device)\n    for name, _ in model_ckpt.items():\n        if  model_ckpt[name].shape == state_dict[name].shape:\n            model_ckpt[name].copy_(state_dict[name])\n        model.load_state_dict(model_ckpt)\n    # model.load_state_dict(state_dict)\n    model.eval()\n\n    if checkpointname.split(\"_\")[0] == 'retraincheckpoint':\n        save_folder = os.path.join(folder, 'fintune', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0])\n    else:\n        save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0])\n    os.makedirs(save_folder, exist_ok=True)\n\n    allseeds = list(range(niter))\n\n    try:\n        for index, seed in enumerate(allseeds):\n            print(f\"Evaluation number: {index+1}/{niter}\")\n            fixseed(seed)\n            save_pred_path = os.path.join(save_folder, 'seg_all', 'eval_pred', str(seed))\n            save_gt_path = os.path.join(save_folder, 'seg_all', 'eval_gt', str(seed))\n            os.makedirs(save_pred_path, exist_ok=True)\n            os.makedirs(save_gt_path, exist_ok=True)\n            \n\n            dataiterator = DataLoader(dataset, batch_size=parameters[\"batch_size\"],\n                                shuffle=False, num_workers=8, collate_fn=collate_eval)\n\n            with torch.no_grad():\n                for databatch in tqdm(dataiterator, desc=f\"Construct dataloader: generating..\"):\n                    # batch = {key: val.to(device) for key, val in databatch.items()}\n                    pose = databatch[\"x\"]  # b, len, c\n                    audio = databatch[\"y\"]  # b, len, c\n                    gendurations = databatch[\"lengths\"]\n                    # start = databatch[\"start\"]\n\n                    output = None\n                    for i in range(0, pose.shape[1], 200):\n                        # step 1: seg\n                        start = i\n                        end = min(pose.shape[1], i + 200)\n                        pose_seg = pose[:, start:end]\n                        audio_seg = audio[:, start:end]\n                        # gendurations_seg = gendurations[:, start:end]\n                        gendurations_seg = torch.tensor([end - start])\n                        # step 2: predict\n                        batch = model.generate(pose_seg, audio_seg, gendurations_seg, fact = 1)\n                        # step 3: merge\n                        if output == None:\n                            output = batch['output'].detach().cpu()\n                        else:\n                            output = torch.concat([output, batch['output'].detach().cpu()], dim= 1)\n                        \n\n\n\n                    \n                    # batch = model.generate(pose, audio, gendurations, fact = 1)\n                    # batch = {key: val.to(device) for key, val in batch.items()}\n                    \n\n                    for pose_pre, pose_gt, mask, filename, start_num in zip(output, databatch['x'], databatch['mask'], databatch['videoname'], databatch['start']):\n                        \n                        pose_pre = pose_pre.cpu()\n                        # padding_vec = torch.zeros(pose_pre.shape[0], 1)\n                        # pose_pre = torch.concat([pose_pre, padding_vec], dim = -1)\n                        for i in range(0, pose_gt.shape[0], 200):\n                            start = i\n                            end = min(pose_gt.shape[0], i + 200)\n                            x_ref = pose_gt[i,:].unsqueeze(dim=0)\n                            pose_pre[start:end] = pose_pre[start:end]+x_ref\n                        gtmasked = pose_gt[mask].cpu()\n                        outmasked = pose_pre[mask].cpu()\n                        # gtmasked = transform(gtmasked, min_val, max_val)\n                        # outmasked = transform(outmasked, min_val, max_val)\n                        pred_dir = os.path.join(save_pred_path, filename)\n                        save_as_chunk(pred_dir, outmasked)\n                        # np.save(pred_path, pose_pre.cpu())\n                        # np.save(gt_path, pose_gt.cpu())\n                        # np.savetxt(pred_path, outmasked)\n                        # np.savetxt(gt_path, gtmasked)\n                        loss = F.mse_loss(gtmasked, outmasked, reduction='mean')\n                        print(loss)\n                        \n\n\n    except KeyboardInterrupt:\n        string = \"Saving the evaluation before exiting..\"\n        print(string)\n\n\n    epoch = checkpointname.split(\"_\")[1].split(\".\")[0]\n    metricname = \"evaluation_metrics_{}_all.yaml\".format(epoch)\n\n    evalpath = os.path.join(folder, metricname)\n    print(f\"Saving evaluation: {evalpath}\")\n    # save_metrics(evalpath, metrics)\n"
  },
  {
    "path": "PBnet/src/evaluate/tvae_eval_single.py",
    "content": "import torch\nfrom tqdm import tqdm\nimport sys\nimport os\ncurrent_dir = os.path.dirname(os.path.abspath(__file__))\nparent_dir = os.path.dirname(os.path.dirname(current_dir))\nif parent_dir not in sys.path:\n    sys.path.append(parent_dir)\n    print(parent_dir)\n\nfrom src.utils.fixseed import fixseed\nfrom src.parser.tools import load_args\nimport os\nimport numpy as np\nimport torch.nn.functional as F\n# from .tools import save_metrics, format_metrics\nfrom src.models.get_model import get_model as get_gen_model\nimport argparse\n\nmax_vals = torch.tensor([90, 90, 90,  1,\n            720,  1080]).to(torch.float32).reshape(1, 1, 6)\nmin_vals = torch.tensor([-90, -90, -90,  0,\n            0,  0]).to(torch.float32).reshape(1, 1, 6)\n\ndef inv_transform(x, min_val, max_val):\n    out = x * (max_val - min_val) + min_val\n    return out\n\ndef save_images_as_npy(input_data, output_file):\n    # save_npy = np.zeros(input_data.shape[0], 7)\n    # save_npy[:,:, :-1] = input_data\n    # save_npy[:, -1] = ref[:,:, -1]\n    # images_array = np.array(images)\n    np.save(output_file, input_data)\n\n\n\n\n# def transform(x, min_val, max_val):\n#     out = x * (max_val - min_val) + min_val\n#     return out\n\ndef evaluate(parameters_pose, parameters_blink, audio_path, init_pose_path, init_blink_path, checkpoint_p_path, checkpoint_b_path, output_path):\n    # num_frames = 60\n    # min_val = dataset.min_vals\n    # max_val = dataset.max_vals\n    device = \"cuda:0\"\n    pose_dim = parameters_pose['pos_dim']\n    eye_dim = parameters_blink['eye_dim']\n    # dummy => update parameters info\n    model_p = get_gen_model(parameters_pose)\n    model_b = get_gen_model(parameters_blink)\n\n    print(\"Restore weights..\")\n    # checkpointpath = os.path.join(folder, checkpointname)\n    # model_p_ckpt = model_p.state_dict()\n    # model_b_ckpt = model_b.state_dict()\n    state_dict_p = torch.load(checkpoint_p_path, map_location=device)\n    state_dict_b = torch.load(checkpoint_b_path, map_location=device)\n    # for name, _ in model_ckpt.items():\n    #     if  model_ckpt[name].shape == state_dict[name].shape:\n    #         model_ckpt[name].copy_(state_dict[name])\n    #     model.load_state_dict(model_ckpt)\n    model_p.load_state_dict(state_dict_p)\n    model_b.load_state_dict(state_dict_b)\n    model_p.eval()\n    model_b.eval()\n\n    os.makedirs(output_path, exist_ok=True)\n\n    try:\n        init_pose = torch.from_numpy(np.load(init_pose_path))[:,:pose_dim].unsqueeze(0).to(torch.float32)\n        init_blink = torch.from_numpy(np.load(init_blink_path))[:,:eye_dim].unsqueeze(0).to(torch.float32)\n        audio = torch.from_numpy(np.load(audio_path)).unsqueeze(0).to(torch.float32)\n    except Exception:\n        # the 3ddfa fail to extract valid pose, using typical value instead\n        init_pose = torch.from_numpy(np.array([[0, 0, 0, 4.79e-04, 5.65e+01, 6.49e+01,]]))[:,:pose_dim].unsqueeze(0).to(torch.float32)\n        init_blink = torch.from_numpy(np.array([[0.3,0.3]]))[:,:eye_dim].unsqueeze(0).to(torch.float32)\n        audio = torch.from_numpy(np.load(audio_path)).unsqueeze(0).to(torch.float32)\n\n    init_pose = (init_pose - min_vals)/ (max_vals - min_vals)\n    fixseed(1234)\n        \n\n    with torch.no_grad():\n\n        # batch = {key: val.to(device) for key, val in databatch.items()}\n\n\n        # step 1: seg\n        pose_seg = init_pose\n        blink_seg = init_blink\n        audio_seg = audio\n        # gendurations_seg = gendurations[:, start:end]\n        gendurations_seg = torch.tensor([audio.shape[1] - 0])\n        # step 2: predict\n        batch_p = model_p.generate(pose_seg, audio_seg, gendurations_seg, fact = 1)\n        batch_b = model_b.generate(blink_seg, audio_seg, gendurations_seg, fact = 1)\n        # step 3: merge\n\n        output_p = batch_p['output'].detach().cpu()\n        output_b = batch_b['output'].detach().cpu()\n\n        output_p = output_p + pose_seg\n        output_p = inv_transform(output_p, min_vals, max_vals)\n        output_b = output_b + blink_seg\n\n        output_pose_path = os.path.join(output_path, 'dri_pose.npy')\n        output_blink_path = os.path.join(output_path, 'dri_blink.npy')\n\n        np.save(output_pose_path , output_p[0])\n        np.save(output_blink_path, output_b[0])\n\ndef get_arguments():\n    \"\"\"Parse all the arguments provided from the CLI.\n\n    Returns:\n      A list of parsed arguments.\n    \"\"\"\n    parser = argparse.ArgumentParser(description=\"PBnet\")\n    parser.add_argument(\"--audio_path\", default='/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate/RD_Radio54_000.npy')\n    parser.add_argument(\"--ckpt_pose\", default='your_path/pretrain_models/pbnet_seperate/pose/checkpoint_40000.pth.tar',\n                        help=\"ckpt of PoseNet\")\n    parser.add_argument(\"--ckpt_blink\", default='your_path/pretrain_models/pbnet_seperate/blink/checkpoint_95000.pth.tar',\n                        help=\"ckpt of BlinkNet\")\n\n    parser.add_argument(\"--init_pose_blink\", default='your/path/DAWN-pytorch/ood_data/ood_test_material/cache_2',\n                        help=\"dir of init pose/blink\")\n    \n    parser.add_argument(\"--output\", default='/train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/demo_output',\n                        help=\"output_dir\")\n\n    return parser.parse_args()\n\nif __name__ == '__main__':\n    args = get_arguments()\n    audio_path = args.audio_path\n    ckpt_pose = args.ckpt_pose\n    ckpt_blink = args.ckpt_blink\n    output_dir = args.output\n    init_blink = os.path.join(args.init_pose_blink, 'init_eye_bbox.npy') # init_eye_bbox.npy\n    init_pose = os.path.join(args.init_pose_blink, 'init_pose.npy')\n\n    folder_p, _ = os.path.split(ckpt_pose)\n    parameters_p = load_args(os.path.join(folder_p, \"opt.yaml\"))\n    parameters_p['device'] = 'cuda:0'\n    parameters_p[\"audio_dim\"] = 1024\n    parameters_p[\"pos_dim\"] = 6\n    parameters_p[\"eye_dim\"] = 0\n\n    folder_b, _ = os.path.split(ckpt_blink)\n    parameters_b = load_args(os.path.join(folder_b, \"opt.yaml\"))\n    parameters_b['device'] = 'cuda:0'\n    parameters_b[\"audio_dim\"] = 1024\n    parameters_b[\"pos_dim\"] = 0\n    parameters_b[\"eye_dim\"] = 2\n\n    evaluate(parameters_pose = parameters_p,\n            parameters_blink = parameters_b,\n            audio_path = audio_path,\n            init_pose_path = init_pose,\n            init_blink_path = init_blink,\n            checkpoint_p_path = ckpt_pose,\n            checkpoint_b_path = ckpt_blink,\n            output_path = output_dir)\n\n\n"
  },
  {
    "path": "PBnet/src/evaluate/tvae_eval_single_both_eye_pose.py",
    "content": "import torch\nfrom tqdm import tqdm\nimport os\nimport sys\n# adding path of PBnet\ncurrent_dir = os.path.dirname(os.path.abspath(__file__))\nparent_dir = os.path.dirname(os.path.dirname(current_dir))\nif parent_dir not in sys.path:\n    sys.path.append(parent_dir)\n    print(parent_dir)\n\nfrom src.utils.fixseed import fixseed\nfrom src.parser.tools import load_args\nimport numpy as np\nimport torch.nn.functional as F\n# from .tools import save_metrics, format_metrics\nfrom src.models.get_model import get_model as get_gen_model\nimport argparse\n\nmax_vals = torch.tensor([90, 90, 90,  1,\n            720,  1080, 1, 1]).to(torch.float32).reshape(1, 1, 8)\nmin_vals = torch.tensor([-90, -90, -90,  0,\n            0,  0, 0, 0]).to(torch.float32).reshape(1, 1, 8)\n\ndef inv_transform(x, min_val, max_val):\n    out = x * (max_val - min_val) + min_val\n    return out\n\ndef save_images_as_npy(input_data, output_file):\n    # save_npy = np.zeros(input_data.shape[0], 7)\n    # save_npy[:,:, :-1] = input_data\n    # save_npy[:, -1] = ref[:,:, -1]\n    # images_array = np.array(images)\n    np.save(output_file, input_data)\n\n\n\n\n# def transform(x, min_val, max_val):\n#     out = x * (max_val - min_val) + min_val\n#     return out\n\ndef evaluate(parameters, audio_path, init_pose_path, init_blink_path, checkpoint_path, output_path):\n    # num_frames = 60\n    # min_val = dataset.min_vals\n    # max_val = dataset.max_vals\n    device = parameters[\"device\"]\n    pose_dim = parameters['pos_dim']\n    eye_dim = parameters['eye_dim']\n    # dummy => update parameters info\n    model = get_gen_model(parameters)\n\n    print(\"Restore weights..\")\n    # checkpointpath = os.path.join(folder, checkpointname)\n    # model_p_ckpt = model_p.state_dict()\n    # model_b_ckpt = model_b.state_dict()\n    state_dict_p = torch.load(checkpoint_path, map_location=device)\n    # for name, _ in model_ckpt.items():\n    #     if  model_ckpt[name].shape == state_dict[name].shape:\n    #         model_ckpt[name].copy_(state_dict[name])\n    #     model.load_state_dict(model_ckpt)\n    model.load_state_dict(state_dict_p)\n\n    model.eval()\n\n\n    os.makedirs(output_path, exist_ok=True)\n\n    try:\n        init_pose = torch.from_numpy(np.load(init_pose_path))[:,:pose_dim].unsqueeze(0).to(torch.float32)\n        init_blink = torch.from_numpy(np.load(init_blink_path))[:,:eye_dim].unsqueeze(0).to(torch.float32)\n        audio = torch.from_numpy(np.load(audio_path)).unsqueeze(0).to(torch.float32)\n    except Exception:\n        # the 3ddfa fail to extract valid pose, using typical value instead\n        init_pose = torch.from_numpy(np.array([[0, 0, 0, 4.79e-04, 5.65e+01, 6.49e+01,]]))[:,:pose_dim].unsqueeze(0).to(torch.float32)\n        init_blink = torch.from_numpy(np.array([[0.3,0.3]]))[:,:eye_dim].unsqueeze(0).to(torch.float32)\n        audio = torch.from_numpy(np.load(audio_path)).unsqueeze(0).to(torch.float32)\n\n    pose_seg = init_pose\n    blink_seg = init_blink\n    init_pose = torch.concat([pose_seg, blink_seg], dim = -1)\n\n    init_pose = (init_pose - min_vals)/ (max_vals - min_vals)\n    fixseed(1234)\n        \n\n    with torch.no_grad():\n\n        # batch = {key: val.to(device) for key, val in databatch.items()}\n\n\n        # step 1: seg\n        \n        audio_seg = audio\n        # gendurations_seg = gendurations[:, start:end]\n        gendurations_seg = torch.tensor([audio.shape[1] - 0])\n        # step 2: predict\n        batch = model.generate(init_pose, audio_seg, gendurations_seg, fact = 1)\n        # step 3: merge\n\n        output = batch['output'].detach().cpu()\n\n        output = output + init_pose\n        output = inv_transform(output, min_vals, max_vals)\n\n        output_pose_path = os.path.join(output_path, 'dri_pose.npy')\n        output_blink_path = os.path.join(output_path, 'dri_blink.npy')\n\n        np.save(output_pose_path , output[0,:,:pose_dim])\n        np.save(output_blink_path, output[0,:,pose_dim:])\n\ndef get_arguments():\n    \"\"\"Parse all the arguments provided from the CLI.\n\n    Returns:\n      A list of parsed arguments.\n    \"\"\"\n    parser = argparse.ArgumentParser(description=\"PBnet\")\n    parser.add_argument(\"--audio_path\", default='/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate/RD_Radio54_000.npy')\n    parser.add_argument(\"--ckpt\", default='../pretrain_models/pbnet_both/checkpoint_100000.pth.tar',\n                        help=\"ckpt of PoseNet\")\n\n    parser.add_argument(\"--init_pose_blink\", default='your/path/DAWN-pytorch/ood_data/ood_test_material/cache_2',\n                        help=\"dir of init pose/blink\")\n    \n    parser.add_argument(\"--output\", default='/train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/demo_output',\n                        help=\"output_dir\")\n\n    return parser.parse_args()\n\nif __name__ == '__main__':\n    args = get_arguments()\n    audio_path = args.audio_path\n    ckpt_pose = args.ckpt\n    output_dir = args.output\n    init_blink = os.path.join(args.init_pose_blink, 'init_eye_bbox.npy') # init_eye_bbox.npy\n    init_pose = os.path.join(args.init_pose_blink, 'init_pose.npy')\n\n    folder_p, _ = os.path.split(ckpt_pose)\n    parameters = load_args(os.path.join(folder_p, \"opt.yaml\"))\n    parameters['device'] = 'cuda:0'\n    parameters[\"audio_dim\"] = 1024\n    parameters[\"pos_dim\"] = 6\n    parameters[\"eye_dim\"] = 2\n\n\n    evaluate(parameters = parameters,\n            audio_path = audio_path,\n            init_pose_path = init_pose,\n            init_blink_path = init_blink,\n            checkpoint_path = ckpt_pose,\n            output_path = output_dir)\n\n\n"
  },
  {
    "path": "PBnet/src/evaluate/tvae_eval_std.py",
    "content": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom src.utils.utils import MultiEpochsDataLoader as DataLoader\nfrom src.utils.tensors_hdtf import collate\n\nimport os\nimport numpy as np\nimport torch.nn.functional as F\n# from .tools import save_metrics, format_metrics\nfrom src.models.get_model import get_model as get_gen_model\n\n\n\ndef evaluate(parameters, dataset, folder, checkpointname, epoch, niter):\n    # num_frames = 60\n\n    device = parameters[\"device\"]\n\n    # dummy => update parameters info\n    model = get_gen_model(parameters)\n\n    print(\"Restore weights..\")\n    checkpointpath = os.path.join(folder, checkpointname)\n    model_ckpt = model.state_dict()\n    state_dict = torch.load(checkpointpath, map_location=device)\n    for name, _ in model_ckpt.items():\n        if  model_ckpt[name].shape == state_dict[name].shape:\n            model_ckpt[name].copy_(state_dict[name])\n        model.load_state_dict(model_ckpt)\n    model.eval()\n\n    if checkpointname.split(\"_\")[0] == 'retraincheckpoint':\n        save_folder = os.path.join(folder, 'fintune', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0])\n    else:\n        save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0])\n    os.makedirs(save_folder, exist_ok=True)\n\n    allseeds = list(range(niter))\n\n    try:\n        for index, seed in enumerate(allseeds):\n            print(f\"Evaluation number: {index+1}/{niter}\")\n            fixseed(seed)\n            save_pred_path = os.path.join(save_folder, 'eval_pred', str(seed))\n            save_gt_path = os.path.join(save_folder, 'eval_gt', str(seed))\n            os.makedirs(save_pred_path, exist_ok=True)\n            os.makedirs(save_gt_path, exist_ok=True)\n            \n\n            dataiterator = DataLoader(dataset, batch_size=parameters[\"batch_size\"],\n                                shuffle=False, num_workers=8, collate_fn=collate)\n\n            with torch.no_grad():\n                for databatch in tqdm(dataiterator, desc=f\"Construct dataloader: generating..\"):\n                    # batch = {key: val.to(device) for key, val in databatch.items()}\n                    pose = databatch[\"x\"]\n                    audio = databatch[\"y\"]\n                    gendurations = databatch[\"lengths\"]\n                    # start = databatch[\"start\"]\n                    batch = model.generate(pose, audio, gendurations, fact = 1)\n                    batch = {key: val.to(device) for key, val in batch.items()}\n                    \n\n                    for pose_pre, pose_gt, mask, filename, start_num in zip(batch['output'], databatch['x'], databatch['mask'], databatch['videoname'], databatch['start']):\n                        x_ref = pose_gt[0,:].unsqueeze(dim=0).cpu()\n                        pose_pre = (pose_pre.cpu()+x_ref - 0.5) * 180\n                        gtmasked = (pose_gt[mask].cpu() -0.5 ) * 180\n                        outmasked = pose_pre[mask].cpu()\n                        pred_path = os.path.join(save_pred_path, filename+'_'+str(start_num))\n                        gt_path = os.path.join(save_gt_path, filename+'_'+str(start_num)+'_gt')\n                        # np.save(pred_path, pose_pre.cpu())\n                        # np.save(gt_path, pose_gt.cpu())\n                        np.savetxt(pred_path, outmasked)\n                        np.savetxt(gt_path, gtmasked)\n                        loss = F.mse_loss(gtmasked, outmasked, reduction='mean')\n                        print(loss)\n                        \n\n\n    except KeyboardInterrupt:\n        string = \"Saving the evaluation before exiting..\"\n        print(string)\n\n\n    epoch = checkpointname.split(\"_\")[1].split(\".\")[0]\n    metricname = \"evaluation_metrics_{}_all.yaml\".format(epoch)\n\n    evalpath = os.path.join(folder, metricname)\n    print(f\"Saving evaluation: {evalpath}\")\n    # save_metrics(evalpath, metrics)\n"
  },
  {
    "path": "PBnet/src/evaluate/tvae_eval_train.py",
    "content": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom src.utils.utils import MultiEpochsDataLoader as DataLoader\nfrom src.utils.tensors_hdtf import collate\n\nimport os\nimport numpy as np\nimport torch.nn.functional as F\n# from .tools import save_metrics, format_metrics\nfrom src.models.get_model import get_model as get_gen_model\n\n\n\ndef evaluate(parameters, dataset, folder, checkpointname, epoch, niter):\n    # num_frames = 60\n\n    device = parameters[\"device\"]\n\n    # dummy => update parameters info\n    model = get_gen_model(parameters)\n\n    print(\"Restore weights..\")\n    checkpointpath = os.path.join(folder, checkpointname)\n    state_dict = torch.load(checkpointpath, map_location=device)\n    model.load_state_dict(state_dict)\n    model.eval()\n\n    if checkpointname.split(\"_\")[0] == 'retraincheckpoint':\n        save_folder = os.path.join(folder, 'fintune', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0])\n    else:\n        save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0])\n    os.makedirs(save_folder, exist_ok=True)\n\n    allseeds = list(range(niter))\n\n    try:\n        for index, seed in enumerate(allseeds):\n            print(f\"Evaluation number: {index+1}/{niter}\")\n            fixseed(seed)\n            save_pred_path = os.path.join(save_folder, 'eval_pred', str(seed))\n            save_gt_path = os.path.join(save_folder, 'eval_gt', str(seed))\n            os.makedirs(save_pred_path, exist_ok=True)\n            os.makedirs(save_gt_path, exist_ok=True)\n            \n\n            dataiterator = DataLoader(dataset, batch_size=parameters[\"batch_size\"],\n                                shuffle=False, num_workers=8, collate_fn=collate)\n\n            with torch.no_grad():\n                for databatch in tqdm(dataiterator, desc=f\"Construct dataloader: generating..\"):\n                    name_list = databatch['videoname']\n                    start_list = databatch['start']\n                    databatch = {key: val.to(device) for key, val in databatch.items() if key!='videoname' and key!='start'}\n                    \n                    pose = databatch[\"x\"]\n                    audio = databatch[\"y\"]\n                    gendurations = databatch[\"lengths\"]\n                    # start = databatch[\"start\"]\n                    batch = model.forward(databatch)\n                    batch = {key: val.to(device) for key, val in batch.items()}\n                    \n\n                    for pose_pre, pose_gt, mask, filename, start_num in zip(batch['output'], databatch['x'], databatch['mask'], name_list, start_list):\n                        x_ref = pose_gt[0,:].unsqueeze(dim=0)\n                        pose_pre = pose_pre.cpu()+x_ref.cpu()\n                        gtmasked = pose_gt[mask].cpu()\n                        outmasked = pose_pre[mask].cpu()\n                        pred_path = os.path.join(save_pred_path, filename+'_'+str(start_num))\n                        gt_path = os.path.join(save_gt_path, filename+'_'+str(start_num)+'_gt')\n                        # np.save(pred_path, pose_pre.cpu())\n                        # np.save(gt_path, pose_gt.cpu())\n                        np.savetxt(pred_path, outmasked)\n                        np.savetxt(gt_path, gtmasked)\n                        loss = F.mse_loss(gtmasked, outmasked, reduction='mean')\n                        print(loss)\n                        \n\n\n    except KeyboardInterrupt:\n        string = \"Saving the evaluation before exiting..\"\n        print(string)\n\n\n    epoch = checkpointname.split(\"_\")[1].split(\".\")[0]\n    metricname = \"evaluation_metrics_{}_all.yaml\".format(epoch)\n\n    evalpath = os.path.join(folder, metricname)\n    print(f\"Saving evaluation: {evalpath}\")\n    # save_metrics(evalpath, metrics)\n"
  },
  {
    "path": "PBnet/src/evaluate/tvae_eval_train_norm.py",
    "content": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom src.utils.utils import MultiEpochsDataLoader as DataLoader\nfrom src.utils.tensors_hdtf import collate\n\nimport os\nimport numpy as np\nimport torch.nn.functional as F\n# from .tools import save_metrics, format_metrics\nfrom src.models.get_model import get_model as get_gen_model\n\n\ndef transform(x, min_val, max_val):\n    out = x * (max_val - min_val) + min_val\n    return out\n\ndef evaluate(parameters, dataset, folder, checkpointname, epoch, niter):\n    # num_frames = 60\n    min_val = dataset.min_vals\n    max_val = dataset.max_vals\n    device = parameters[\"device\"]\n\n    # dummy => update parameters info\n    model = get_gen_model(parameters)\n\n    print(\"Restore weights..\")\n    checkpointpath = os.path.join(folder, checkpointname)\n    state_dict = torch.load(checkpointpath, map_location=device)\n    model.load_state_dict(state_dict)\n    model.eval()\n\n    if checkpointname.split(\"_\")[0] == 'retraincheckpoint':\n        save_folder = os.path.join(folder, 'fintune', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0])\n    else:\n        save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0])\n    os.makedirs(save_folder, exist_ok=True)\n\n    allseeds = list(range(niter))\n\n    try:\n        for index, seed in enumerate(allseeds):\n            print(f\"Evaluation number: {index+1}/{niter}\")\n            fixseed(seed)\n            save_pred_path = os.path.join(save_folder, 'eval_pred', str(seed))\n            save_gt_path = os.path.join(save_folder, 'eval_gt', str(seed))\n            os.makedirs(save_pred_path, exist_ok=True)\n            os.makedirs(save_gt_path, exist_ok=True)\n            \n\n            dataiterator = DataLoader(dataset, batch_size=parameters[\"batch_size\"],\n                                shuffle=False, num_workers=8, collate_fn=collate)\n\n            with torch.no_grad():\n                for databatch in tqdm(dataiterator, desc=f\"Construct dataloader: generating..\"):\n                    # batch = {key: val.to(device) for key, val in databatch.items()}\n                    name_list = databatch['videoname']\n                    start_list = databatch['start']\n                    databatch = {key: val.to(device) for key, val in databatch.items() if key!='videoname' and key!='start'}\n                    \n                    pose = databatch[\"x\"]\n                    audio = databatch[\"y\"]\n                    gendurations = databatch[\"lengths\"]\n                    # start = databatch[\"start\"]\n                    batch = model.forward(databatch)\n                    batch = {key: val.to(device) for key, val in batch.items()}\n                    \n\n                    for pose_pre, pose_gt, mask, filename, start_num in zip(batch['output'], databatch['x'], databatch['mask'], name_list, start_list):\n                        x_ref = pose_gt[0,:].unsqueeze(dim=0).cpu()\n                        pose_pre = pose_pre.cpu()+x_ref\n                        gtmasked = pose_gt[mask].cpu()\n                        outmasked = pose_pre[mask].cpu()\n                        gtmasked = transform(gtmasked, min_val, max_val)\n                        outmasked = transform(outmasked, min_val, max_val)\n                        pred_path = os.path.join(save_pred_path, filename+'_'+str(start_num))\n                        gt_path = os.path.join(save_gt_path, filename+'_'+str(start_num)+'_gt')\n                        # np.save(pred_path, pose_pre.cpu())\n                        # np.save(gt_path, pose_gt.cpu())\n                        np.savetxt(pred_path, outmasked)\n                        np.savetxt(gt_path, gtmasked)\n                        loss = F.mse_loss(gtmasked, outmasked, reduction='mean')\n                        print('all loss: ',loss)\n                        loss_f3 = F.mse_loss(gtmasked[:, :3], outmasked[:, :3], reduction='mean')\n                        print('f3 loss: ',loss_f3)\n                        loss_ls = F.mse_loss(gtmasked[:, 3:], outmasked[:, 3:], reduction='mean')\n                        print('ls loss: ',loss_ls)\n                        \n\n\n    except KeyboardInterrupt:\n        string = \"Saving the evaluation before exiting..\"\n        print(string)\n\n\n    epoch = checkpointname.split(\"_\")[1].split(\".\")[0]\n    metricname = \"evaluation_metrics_{}_all.yaml\".format(epoch)\n\n    evalpath = os.path.join(folder, metricname)\n    print(f\"Saving evaluation: {evalpath}\")\n    # save_metrics(evalpath, metrics)\n"
  },
  {
    "path": "PBnet/src/evaluate/tvae_eval_train_std.py",
    "content": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom src.utils.utils import MultiEpochsDataLoader as DataLoader\nfrom src.utils.tensors_hdtf import collate\n\nimport os\nimport numpy as np\nimport torch.nn.functional as F\n# from .tools import save_metrics, format_metrics\nfrom src.models.get_model import get_model as get_gen_model\n\n\n\ndef evaluate(parameters, dataset, folder, checkpointname, epoch, niter):\n    # num_frames = 60\n\n    device = parameters[\"device\"]\n\n    # dummy => update parameters info\n    model = get_gen_model(parameters)\n\n    print(\"Restore weights..\")\n    checkpointpath = os.path.join(folder, checkpointname)\n    state_dict = torch.load(checkpointpath, map_location=device)\n    model.load_state_dict(state_dict)\n    model.eval()\n\n    if checkpointname.split(\"_\")[0] == 'retraincheckpoint':\n        save_folder = os.path.join(folder, 'fintune_train', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0])\n    else:\n        save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0])\n    os.makedirs(save_folder, exist_ok=True)\n\n    allseeds = list(range(niter))\n\n    try:\n        for index, seed in enumerate(allseeds):\n            print(f\"Evaluation number: {index+1}/{niter}\")\n            fixseed(seed)\n            save_pred_path = os.path.join(save_folder, 'eval_pred', str(seed))\n            save_gt_path = os.path.join(save_folder, 'eval_gt', str(seed))\n            os.makedirs(save_pred_path, exist_ok=True)\n            os.makedirs(save_gt_path, exist_ok=True)\n            \n\n            dataiterator = DataLoader(dataset, batch_size=parameters[\"batch_size\"],\n                                shuffle=False, num_workers=8, collate_fn=collate)\n\n            with torch.no_grad():\n                for databatch in tqdm(dataiterator, desc=f\"Construct dataloader: generating..\"):\n                    name_list = databatch['videoname']\n                    start_list = databatch['start']\n                    databatch = {key: val.to(device) for key, val in databatch.items() if key!='videoname' and key!='start'}\n                    \n                    pose = databatch[\"x\"]\n                    audio = databatch[\"y\"]\n                    gendurations = databatch[\"lengths\"]\n                    # start = databatch[\"start\"]\n                    batch = model.forward(databatch)\n                    batch = {key: val.to(device) for key, val in batch.items()}\n                    \n\n                    for pose_pre, pose_gt, mask, filename, start_num in zip(batch['output'], databatch['x'], databatch['mask'], name_list, start_list):\n                        x_ref = pose_gt[0,:].unsqueeze(dim=0).cpu()\n                        pose_pre = (pose_pre.cpu()+x_ref - 0.5) * 180\n                        gtmasked = (pose_gt[mask].cpu() -0.5 ) * 180\n                        outmasked = pose_pre[mask].cpu()\n                        pred_path = os.path.join(save_pred_path, filename+'_'+str(start_num))\n                        gt_path = os.path.join(save_gt_path, filename+'_'+str(start_num)+'_gt')\n                        # np.save(pred_path, pose_pre.cpu())\n                        # np.save(gt_path, pose_gt.cpu())\n                        np.savetxt(pred_path, outmasked)\n                        np.savetxt(gt_path, gtmasked)\n                        loss = F.mse_loss(gtmasked, outmasked, reduction='mean')\n                        print(loss)\n                        \n\n\n    except KeyboardInterrupt:\n        string = \"Saving the evaluation before exiting..\"\n        print(string)\n\n\n    epoch = checkpointname.split(\"_\")[1].split(\".\")[0]\n    metricname = \"evaluation_metrics_{}_all.yaml\".format(epoch)\n\n    evalpath = os.path.join(folder, metricname)\n    print(f\"Saving evaluation: {evalpath}\")\n    # save_metrics(evalpath, metrics)\n"
  },
  {
    "path": "PBnet/src/generate/generate_sequences.py",
    "content": "import os\n\nimport matplotlib.pyplot as plt\nimport torch\nimport numpy as np\n\nfrom src.utils.get_model_and_data import get_model_and_data\nfrom src.models.get_model import get_model\n\nfrom src.parser.generate import parser\nimport src.utils.fixseed  # noqa\n\nplt.switch_backend('agg')\n\n\ndef generate_actions(beta, model, dataset, epoch, params, folder, num_frames=60,\n                     durationexp=False, vertstrans=True, onlygen=False, nspa=10, inter=False, writer=None):\n    \"\"\" Generate & viz samples \"\"\"\n\n    # visualize with joints3D\n    model.outputxyz = True\n    # print(\"remove smpl\")\n    model.param2xyz[\"jointstype\"] = \"vertices\"\n\n    print(f\"Visualization of the epoch {epoch}\")\n\n    fact = params[\"fact_latent\"]\n    num_classes = dataset.num_classes\n    classes = torch.arange(num_classes)\n\n    if not onlygen:\n        nspa = 1\n\n    nats = num_classes\n\n    if durationexp:\n        nspa = 4\n        durations = [40, 60, 80, 100]\n        gendurations = torch.tensor([[dur for cl in classes] for dur in durations], dtype=int)\n    else:\n        gendurations = torch.tensor([num_frames for cl in classes], dtype=int)\n\n    if not onlygen:\n        # extract the real samples\n        real_samples, mask_real, real_lengths = dataset.get_label_sample_batch(classes.numpy())\n        # to visualize directly\n\n        # Visualizaion of real samples\n        visualization = {\"x\": real_samples.to(model.device),\n                         \"y\": classes.to(model.device),\n                         \"mask\": mask_real.to(model.device),\n                         \"lengths\": real_lengths.to(model.device),\n                         \"output\": real_samples.to(model.device)}\n\n        reconstruction = {\"x\": real_samples.to(model.device),\n                          \"y\": classes.to(model.device),\n                          \"lengths\": real_lengths.to(model.device),\n                          \"mask\": mask_real.to(model.device)}\n\n    print(\"Computing the samples poses..\")\n\n    # generate the repr (joints3D/pose etc)\n    model.eval()\n    with torch.no_grad():\n        if not onlygen:\n            # Get xyz for the real ones\n            visualization[\"output_xyz\"] = model.rot2xyz(visualization[\"output\"],\n                                                        visualization[\"mask\"],\n                                                        vertstrans=vertstrans,\n                                                        beta=beta)\n\n            # Reconstruction of the real data\n            reconstruction = model(reconstruction)  # update reconstruction dicts\n\n            noise_same_action = \"random\"\n            noise_diff_action = \"random\"\n\n            # Generate the new data\n            generation = model.generate(classes, gendurations, nspa=nspa,\n                                        noise_same_action=noise_same_action,\n                                        noise_diff_action=noise_diff_action,\n                                        fact=fact)\n\n            generation[\"output_xyz\"] = model.rot2xyz(generation[\"output\"],\n                                                     generation[\"mask\"], vertstrans=vertstrans,\n                                                     beta=beta)\n\n            outxyz = model.rot2xyz(reconstruction[\"output\"],\n                                   reconstruction[\"mask\"], vertstrans=vertstrans,\n                                   beta=beta)\n            reconstruction[\"output_xyz\"] = outxyz\n        else:\n            if inter:\n                noise_same_action = \"interpolate\"\n            else:\n                noise_same_action = \"random\"\n\n            noise_diff_action = \"random\"\n\n            # Generate the new data\n            generation = model.generate(classes, gendurations, nspa=nspa,\n                                        noise_same_action=noise_same_action,\n                                        noise_diff_action=noise_diff_action,\n                                        fact=fact)\n\n            generation[\"output_xyz\"] = model.rot2xyz(generation[\"output\"],\n                                                     generation[\"mask\"], vertstrans=vertstrans,\n                                                     beta=beta)\n            output = generation[\"output_xyz\"].reshape(nspa, nats, *generation[\"output_xyz\"].shape[1:]).cpu().numpy()\n\n    if not onlygen:\n        output = np.stack([visualization[\"output_xyz\"].cpu().numpy(),\n                           generation[\"output_xyz\"].cpu().numpy(),\n                           reconstruction[\"output_xyz\"].cpu().numpy()])\n\n    return output\n\n\ndef main():\n    parameters, folder, checkpointname, epoch = parser()\n    nspa = parameters[\"num_samples_per_action\"]\n\n    # no dataset needed\n    if parameters[\"mode\"] in []:   # [\"gen\", \"duration\", \"interpolate\"]:\n        model = get_model(parameters)\n    else:\n        model, datasets = get_model_and_data(parameters)\n        dataset = datasets[\"train\"]  # same for ntu\n\n    print(\"Restore weights..\")\n    checkpointpath = os.path.join(folder, checkpointname)\n    state_dict = torch.load(checkpointpath, map_location=parameters[\"device\"])\n    model.load_state_dict(state_dict)\n\n    from src.utils.fixseed import fixseed  # noqa\n    for seed in [1]:  # [0, 1, 2]:\n        fixseed(seed)\n        # visualize_params\n        onlygen = True\n        vertstrans = False\n        inter = True and onlygen\n        varying_beta = False\n        if varying_beta:\n            betas = [-2, -1, 0, 1, 2]\n        else:\n            betas = [0]\n        for beta in betas:\n            output = generate_actions(beta, model, dataset, epoch, parameters,\n                                      folder, inter=inter, vertstrans=vertstrans,\n                                      nspa=nspa, onlygen=onlygen)\n            if varying_beta:\n                filename = \"generation_beta_{}.npy\".format(beta)\n            else:\n                filename = \"generation.npy\"\n\n            filename = os.path.join(folder, filename)\n            np.save(filename, output)\n            print(\"Saved at: \" + filename)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "PBnet/src/models/__init__.py",
    "content": ""
  },
  {
    "path": "PBnet/src/models/architectures/__init__.py",
    "content": ""
  },
  {
    "path": "PBnet/src/models/architectures/autotrans.py",
    "content": "from .transformer import Encoder_TRANSFORMER as Encoder_AUTOTRANS  # noqa\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nfrom .tools.transformer_layers import PositionalEncoding\nfrom .tools.transformer_layers import TransformerDecoderLayer\n\n\n# taken from joeynmt repo\ndef subsequent_mask(size: int):\n    \"\"\"\n    Mask out subsequent positions (to prevent attending to future positions)\n    Transformer helper function.\n\n    :param size: size of mask (2nd and 3rd dim)\n    :return: Tensor with 0s and 1s of shape (1, size, size)\n    \"\"\"\n    mask = np.triu(np.ones((1, size, size)), k=1).astype('uint8')\n    return torch.from_numpy(mask) == 0\n\n\ndef augment_x(x, y, mask, lengths, num_classes, concatenate_time):\n    bs, nframes, njoints, nfeats = x.size()\n    x = x.reshape(bs, nframes, njoints*nfeats)\n    if len(y.shape) == 1:  # can give on hot encoded as input\n        y = F.one_hot(y, num_classes)\n    y = y.to(dtype=x.dtype)\n    y = y[:, None, :].repeat((1, nframes, 1))\n\n    if concatenate_time:\n        # Time embedding\n        time = mask * 1/(lengths[..., None]-1)\n        time = (time[:, None] * torch.arange(time.shape[1], device=x.device)[None, :])[:, 0]\n        time = time[..., None]\n        x_augmented = torch.cat((x, y, time), 2)\n    else:\n        x_augmented = torch.cat((x, y), 2)\n    return x_augmented\n\n\ndef augment_z(z, y, mask, lengths, num_classes, concatenate_time):\n    if len(y.shape) == 1:  # can give on hot encoded as input\n        y = F.one_hot(y, num_classes)\n    y = y.to(dtype=z.dtype)\n    # concatenete z and y and repeat the input\n    z_augmented = torch.cat((z, y), 1)[:, None].repeat((1, mask.shape[1], 1))\n\n    # Time embedding\n    if concatenate_time:\n        time = mask * 1/(lengths[..., None]-1)\n        time = (time[:, None] * torch.arange(time.shape[1], device=z.device)[None, :])[:, 0]\n        z_augmented = torch.cat((z_augmented, time[..., None]), 2)\n        \n    return z_augmented\n\n\nclass Decoder_AUTOTRANS(nn.Module):\n    def __init__(self, modeltype, njoints, nfeats, num_frames, num_classes, translation, pose_rep, glob, glob_rot,\n                 concatenate_time=True, positional_encoding=True, latent_dim=256, ff_size=1024, num_layers=4, num_heads=4,\n                 dropout=0.1, emb_dropout=0.1, teacher_forcing=True, **kargs):\n        super().__init__()\n\n        self.modeltype = modeltype\n        self.njoints = njoints\n        self.nfeats = nfeats\n        self.num_frames = num_frames\n        self.num_classes = num_classes\n        \n        self.pose_rep = pose_rep\n        self.glob = glob\n        self.glob_rot = glob_rot\n        self.translation = translation\n\n        self.concatenate_time = concatenate_time\n        self.positional_encoding = positional_encoding\n        self.latent_dim = latent_dim\n\n        self.ff_size = ff_size\n        self.num_layers = num_layers\n        self.num_heads = num_heads\n        self.dropout = dropout\n        self.emb_dropout = emb_dropout\n        self.teacher_forcing = teacher_forcing\n        \n        self.input_feats = self.latent_dim + self.num_classes\n        self.input_feats_x = self.njoints*self.nfeats + self.num_classes\n        if self.concatenate_time:\n            self.input_feats += 1\n            self.input_feats_x += 1\n\n        self.embedding = nn.Linear(self.input_feats, self.latent_dim)\n        self.embedding_x = nn.Linear(self.input_feats_x, self.latent_dim)\n            \n        self.output_feats = self.njoints*self.nfeats\n        \n        # create num_layers decoder layers and put them in a list\n        self.layers = nn.ModuleList([TransformerDecoderLayer(size=self.latent_dim,\n                                                             ff_size=self.ff_size,\n                                                             num_heads=self.num_heads,\n                                                             dropout=self.dropout)\n                                     for _ in range(self.num_layers)])\n\n        self.pe = PositionalEncoding(self.latent_dim)\n        self.layer_norm = nn.LayerNorm(self.latent_dim, eps=1e-6)\n\n        self.emb_dropout = nn.Dropout(p=self.emb_dropout)\n        self.output_layer = nn.Linear(self.latent_dim, self.output_feats, bias=False)\n        \n    def forward(self, batch):\n        z, y, mask = batch[\"z\"], batch[\"y\"], batch[\"mask\"]\n        lengths = mask.sum(1)\n        \n        lenseqmax = mask.shape[1]\n        bs, njoints, nfeats = len(z), self.njoints, self.nfeats\n        \n        z_augmented = augment_z(z, y, mask, lengths, self.num_classes, self.concatenate_time)\n        src = self.embedding(z_augmented)\n        \n        src_mask = mask.unsqueeze(1)\n        \n        # Check if using teacher forcing or not\n        # if it is allowed and possible\n        teacher_forcing = self.teacher_forcing and \"x\" in batch\n        # in eval mode, by default it it not unless it is \"forced\"\n        teacher_forcing = teacher_forcing and (self.training or batch.get(\"teacher_force\", False))\n            \n        if teacher_forcing:\n            x = batch[\"x\"].permute((0, 3, 1, 2))\n            # shift the input\n            x = torch.cat((x.new_zeros((x.shape[0], 1, *x.shape[2:])), x[:, :-1]), axis=1)\n            # Embedding of the input\n            x_augmented = augment_x(x, y, mask, lengths, self.num_classes, self.concatenate_time)\n            trg = self.embedding_x(x_augmented)\n            trg_mask = (mask[:, None] * subsequent_mask(lenseqmax).type_as(mask))\n            # shape: torch.Size([48, 183, 183])\n            \n            if self.positional_encoding:\n                trg = self.pe(trg)\n            trg = self.emb_dropout(trg)\n\n            val = trg\n            for layer in self.layers:\n                val = layer(val, src, src_mask=src_mask, trg_mask=trg_mask)\n                \n            val = self.layer_norm(val)\n            val = self.output_layer(val)\n\n            # pad the output\n            val[~mask] = 0\n            \n            val = val.reshape((bs, lenseqmax, njoints, nfeats))\n            batch[\"output\"] = val.permute(0, 2, 3, 1)\n        else:\n            # Create the first input x/src_mask\n            x = torch.Tensor.new_zeros(z, (bs, 1, njoints, nfeats))\n            for index in range(lenseqmax):\n                # change it to speed up\n                current_mask = mask[:, :index+1]\n                x_augmented = augment_x(x, y, current_mask, lengths,\n                                        self.num_classes, self.concatenate_time)\n                trg = self.embedding_x(x_augmented)\n                trg_mask = (current_mask[:, None] * subsequent_mask(index+1).type_as(mask))\n\n                if self.positional_encoding:\n                    trg = self.pe(trg)\n                trg = self.emb_dropout(trg)\n\n                val = trg\n                for layer in self.layers:\n                    val = layer(val, src, src_mask=src_mask, trg_mask=trg_mask)\n\n                val = self.layer_norm(val)\n                val = self.output_layer(val)\n\n                # pad the output\n                val[~current_mask] = 0\n                val = val.reshape((bs, index+1, njoints, nfeats))\n\n                # extract the last output\n                last_out = val[:, -1]\n                # concatenate it to input x\n                x = torch.cat((x, last_out[:, None]), 1)\n            # remove the dummy first input (BOS)\n            batch[\"output\"] = x[:, 1:].permute(0, 2, 3, 1)\n        return batch\n"
  },
  {
    "path": "PBnet/src/models/architectures/fc.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass Encoder_FC(nn.Module):\n    def __init__(self, modeltype, njoints, nfeats, num_frames, num_classes, translation, pose_rep, glob, glob_rot,\n                 latent_dim=256, **kargs):\n        super().__init__()\n\n        self.modeltype = modeltype\n        self.njoints = njoints\n        self.nfeats = nfeats\n        self.num_frames = num_frames\n        self.num_classes = num_classes\n        self.translation = translation\n        self.pose_rep = pose_rep\n        self.glob = glob\n        self.glob_rot = glob_rot\n\n        self.latent_dim = latent_dim\n\n        self.activation = nn.GELU()\n\n        self.input_dim = self.njoints*self.nfeats*self.num_frames+self.num_classes\n\n        self.fully_connected = nn.Sequential(nn.Linear(self.input_dim, 512),\n                                             nn.GELU(),\n                                             nn.Linear(512, 256),\n                                             nn.GELU())\n        if self.modeltype == \"cvae\":\n            self.mu = nn.Linear(256, self.latent_dim)\n            self.var = nn.Linear(256, self.latent_dim)\n        else:\n            self.final = nn.Linear(256, self.latent_dim)\n\n    def forward(self, batch):\n        x, y = batch[\"x\"], batch[\"y\"]\n        bs, njoints, feats, nframes = x.size()\n        if (njoints * feats * nframes) != self.njoints*self.nfeats*self.num_frames:\n            raise ValueError(\"This model is not adapted with this input\")\n        \n        if len(y.shape) == 1:  # can give on hot encoded as input\n            y = F.one_hot(y, self.num_classes)\n        y = y.to(dtype=x.dtype)\n        x = x.reshape(bs, njoints*feats*nframes)\n        x = torch.cat((x, y), 1)\n\n        x = self.fully_connected(x)\n\n        if self.modeltype == \"cvae\":\n            return {\"mu\": self.mu(x), \"logvar\": self.var(x)}\n        else:\n            return {\"z\": self.final(x)}\n\n\nclass Decoder_FC(nn.Module):\n    def __init__(self, modeltype, njoints, nfeats, num_frames, num_classes, translation, pose_rep, glob, glob_rot,\n                 latent_dim=256, **kargs):\n        super().__init__()\n\n        self.modeltype = modeltype\n        self.njoints = njoints\n        self.nfeats = nfeats\n        self.num_frames = num_frames\n        self.num_classes = num_classes\n        self.translation = translation\n        self.pose_rep = pose_rep\n        self.glob = glob\n        self.glob_rot = glob_rot\n\n        self.latent_dim = latent_dim\n\n        self.input_dim = self.latent_dim + self.num_classes\n        self.output_dim = self.njoints*self.nfeats*self.num_frames\n\n        self.fully_connected = nn.Sequential(nn.Linear(self.input_dim, 256),\n                                             nn.GELU(),\n                                             nn.Linear(256, 512),\n                                             nn.GELU(),\n                                             nn.Linear(512, self.output_dim),\n                                             nn.GELU())\n        \n    def forward(self, batch):\n        z, y = batch[\"z\"], batch[\"y\"]\n        # z: [batch_size, latent_dim]\n        # y: [batch_size]\n        if len(y.shape) == 1:  # can give on hot encoded as input\n            y = F.one_hot(y, self.num_classes)\n        y = y.to(dtype=z.dtype)  # y: [batch_size, num_classes]\n        # z: [batch_size, latent_dim+num_classes]\n        z = torch.cat((z, y), dim=1)\n        \n        z = self.fully_connected(z)\n\n        bs, _ = z.size()\n\n        z = z.reshape(bs, self.njoints, self.nfeats, self.num_frames)\n        batch[\"output\"] = z\n        return batch\n"
  },
  {
    "path": "PBnet/src/models/architectures/gru.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef augment_x(x, y, mask, lengths, num_classes, concatenate_time):\n    bs, nframes, njoints, nfeats = x.size()\n    x = x.reshape(bs, nframes, njoints*nfeats)\n    if len(y.shape) == 1:  # can give on hot encoded as input\n        y = F.one_hot(y, num_classes)\n    y = y.to(dtype=x.dtype)\n    y = y[:, None, :].repeat((1, nframes, 1))\n\n    if concatenate_time:\n        # Time embedding\n        time = mask * 1/(lengths[..., None]-1)\n        time = (time[:, None] * torch.arange(time.shape[1], device=x.device)[None, :])[:, 0]\n        time = time[..., None]\n        x_augmented = torch.cat((x, y, time), 2)\n    else:\n        x_augmented = torch.cat((x, y), 2)\n    return x_augmented\n\n\ndef augment_z(z, y, mask, lengths, num_classes, concatenate_time):\n    if len(y.shape) == 1:  # can give on hot encoded as input\n        y = F.one_hot(y, num_classes)\n    y = y.to(dtype=z.dtype)\n    # concatenete z and y and repeat the input\n    z_augmented = torch.cat((z, y), 1)[:, None].repeat((1, mask.shape[1], 1))\n\n    # Time embedding\n    if concatenate_time:\n        time = mask * 1/(lengths[..., None]-1)\n        time = (time[:, None] * torch.arange(time.shape[1], device=z.device)[None, :])[:, 0]\n        z_augmented = torch.cat((z_augmented, time[..., None]), 2)\n        \n    return z_augmented\n\n\nclass Encoder_GRU(nn.Module):\n    def __init__(self, modeltype, njoints, nfeats, num_frames,\n                 num_classes, translation, pose_rep, glob, glob_rot,\n                 concatenate_time=True, latent_dim=256, num_layers=4, **kargs):\n        super().__init__()\n\n        self.modeltype = modeltype\n        self.njoints = njoints\n        self.nfeats = nfeats\n        self.num_frames = num_frames\n        self.num_classes = num_classes\n        \n        self.pose_rep = pose_rep\n        self.glob = glob\n        self.glob_rot = glob_rot\n        self.translation = translation\n        \n        self.concatenate_time = concatenate_time\n        self.latent_dim = latent_dim\n        self.num_layers = num_layers\n\n        # Layers\n        self.input_feats = self.njoints*self.nfeats + self.num_classes\n        if self.concatenate_time:\n            self.input_feats += 1\n            \n        self.feats_embedding = nn.Linear(self.input_feats, self.latent_dim)\n        self.gru = nn.GRU(self.latent_dim, self.latent_dim, num_layers=self.num_layers, batch_first=True)\n\n        if self.modeltype == \"cvae\":\n            self.mu = nn.Linear(self.latent_dim, self.latent_dim)\n            self.var = nn.Linear(self.latent_dim, self.latent_dim)\n        else:\n            self.final = nn.Linear(self.latent_dim, self.latent_dim)\n\n    def forward(self, batch):\n        x, y, mask, lengths = batch[\"x\"], batch[\"y\"], batch[\"mask\"], batch[\"lengths\"]\n        bs = len(y)\n        x = x.permute((0, 3, 1, 2))\n        x = augment_x(x, y, mask, lengths, self.num_classes, self.concatenate_time)\n\n        # Model\n        x = self.feats_embedding(x)\n        x = self.gru(x)[0]\n        \n        # Get last valid input\n        x = x[tuple(torch.stack((torch.arange(bs, device=x.device), lengths-1)))]\n        \n        if self.modeltype == \"cvae\":\n            return {\"mu\": self.mu(x), \"logvar\": self.var(x)}\n        else:\n            return {\"z\": self.final(x)}\n\n\nclass Decoder_GRU(nn.Module):\n    def __init__(self, modeltype, njoints, nfeats, num_frames,\n                 num_classes, translation, pose_rep, glob, glob_rot,\n                 concatenate_time=True, latent_dim=256, num_layers=4, **kargs):\n        super().__init__()\n\n        self.modeltype = modeltype\n        self.njoints = njoints\n        self.nfeats = nfeats\n        self.num_frames = num_frames\n        self.num_classes = num_classes\n        \n        self.pose_rep = pose_rep\n        self.glob = glob\n        self.glob_rot = glob_rot\n        self.translation = translation\n        \n        self.concatenate_time = concatenate_time\n        self.latent_dim = latent_dim\n        self.num_layers = num_layers\n\n        # Layers\n        self.input_feats = self.latent_dim + self.num_classes\n        if self.concatenate_time:\n            self.input_feats += 1\n            \n        self.feats_embedding = nn.Linear(self.input_feats, self.latent_dim)\n        self.gru = nn.GRU(self.latent_dim, self.latent_dim, num_layers=self.num_layers, batch_first=True)\n\n        self.output_feats = self.njoints*self.nfeats\n        self.final_layer = nn.Linear(self.latent_dim, self.output_feats)\n        \n    def forward(self, batch):\n        z, y, mask, lengths = batch[\"z\"], batch[\"y\"], batch[\"mask\"], batch[\"lengths\"]\n        bs, nframes = mask.shape\n\n        z = augment_z(z, y, mask, lengths, self.num_classes, self.concatenate_time)\n        # Model\n        z = self.feats_embedding(z)\n        z = self.gru(z)[0]\n        z = self.final_layer(z)\n\n        # Post process\n        z = z.reshape(bs, nframes, self.njoints, self.nfeats)\n        # 0 for padded sequences\n        z[~mask] = 0\n        z = z.permute(0, 2, 3, 1)\n\n        batch[\"output\"] = z\n        return batch\n"
  },
  {
    "path": "PBnet/src/models/architectures/grutrans.py",
    "content": "from .gru import Encoder_GRU as Encoder_GRUTRANS  # noqa\nfrom .transformer import Decoder_TRANSFORMER as Decoder_GRUTRANS  # noqa\n"
  },
  {
    "path": "PBnet/src/models/architectures/mlp.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Upsample(nn.Module):\n    def __init__(self, input_dim, output_dim, kernel, stride):\n        super(Upsample, self).__init__()\n\n        self.upsample = nn.ConvTranspose2d(\n            input_dim, output_dim, kernel_size=kernel, stride=stride\n        )\n\n    def forward(self, x):\n        return self.upsample(x)\n\nclass ResidualConv(nn.Module):\n    def __init__(self, input_dim, output_dim, stride, padding):\n        super(ResidualConv, self).__init__()\n\n        self.conv_block = nn.Sequential(\n            nn.BatchNorm2d(input_dim),\n            nn.ReLU(),\n            nn.Conv2d(\n                input_dim, output_dim, kernel_size=3, stride=stride, padding=padding\n            ),\n            nn.BatchNorm2d(output_dim),\n            nn.ReLU(),\n            nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),\n        )\n        self.conv_skip = nn.Sequential(\n            nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),\n            nn.BatchNorm2d(output_dim),\n        )\n\n    def forward(self, x):\n\n        return self.conv_block(x) + self.conv_skip(x)\n\nclass PositionalEncoding(nn.Module):\n    def __init__(self, d_model, dropout=0.1, max_len=5000):\n        super(PositionalEncoding, self).__init__()\n        self.dropout = nn.Dropout(p=dropout)\n\n        pe = torch.zeros(max_len, d_model)\n        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))\n        pe[:, 0::2] = torch.sin(position * div_term)\n        pe[:, 1::2] = torch.cos(position * div_term)\n        pe = pe.unsqueeze(0).transpose(0, 1)\n        \n        self.register_buffer('pe', pe)\n\n    def forward(self, x):\n        # not used in the final model\n        x = x + self.pe[:x.shape[0], :]\n        return self.dropout(x)\n\n\nclass RelativePositionBias(nn.Module):\n    def __init__(\n            self,\n            heads=8,\n            num_buckets=32,\n            max_distance=128\n    ):\n        super().__init__()\n        self.num_buckets = num_buckets\n        self.max_distance = max_distance\n        self.relative_attention_bias = nn.Embedding(num_buckets, heads)\n\n    @staticmethod\n    def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):\n        ret = 0\n        n = -relative_position\n\n        num_buckets //= 2\n        ret += (n < 0).long() * num_buckets\n        n = torch.abs(n)\n\n        max_exact = num_buckets // 2\n        is_small = n < max_exact\n\n        val_if_large = max_exact + (\n                torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)\n        ).long()\n        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))\n\n        ret += torch.where(is_small, n, val_if_large)\n        return ret\n\n    def forward(self, n, device):\n        q_pos = torch.arange(n, dtype=torch.long, device=device)\n        k_pos = torch.arange(n, dtype=torch.long, device=device)\n        rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')\n        rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets,\n                                                   max_distance=self.max_distance)\n        values = self.relative_attention_bias(rp_bucket)\n        return rearrange(values, 'i j h -> h i j')\n        \n# only for ablation / not used in the final model\nclass TimeEncoding(nn.Module):\n    def __init__(self, d_model, dropout=0.1, max_len=5000):\n        super(TimeEncoding, self).__init__()\n        self.dropout = nn.Dropout(p=dropout)\n\n    def forward(self, x, mask, lengths):\n        time = mask * 1/(lengths[..., None]-1)\n        time = time[:, None] * torch.arange(time.shape[1], device=x.device)[None, :]\n        time = time[:, 0].T\n        # add the time encoding\n        x = x + time[..., None]\n        return self.dropout(x)\n    \nclass ResUnet(nn.Module):\n    def __init__(self, channel=1, filters=[32, 64, 128, 256]):\n        super(ResUnet, self).__init__()\n\n        self.input_layer = nn.Sequential(\n            nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),\n            nn.BatchNorm2d(filters[0]),\n            nn.ReLU(),\n            nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),\n        )\n        self.input_skip = nn.Sequential(\n            nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)\n        )\n\n        self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)\n        self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)\n\n        self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)\n\n        self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))\n        self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)\n\n        self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))\n        self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)\n\n        self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))\n        self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)\n\n        self.output_layer = nn.Sequential(\n            nn.Conv2d(filters[0], 1, 1, 1),\n            nn.Sigmoid(),\n        )\n\n    def forward(self, x):\n        # Encode\n        x1 = self.input_layer(x) + self.input_skip(x)\n        x2 = self.residual_conv_1(x1)\n        x3 = self.residual_conv_2(x2)\n        # Bridge\n        x4 = self.bridge(x3)\n\n        # Decode\n        x4 = self.upsample_1(x4)\n        x5 = torch.cat([x4, x3], dim=1)\n\n        x6 = self.up_residual_conv1(x5)\n\n        x6 = self.upsample_2(x6)\n        x7 = torch.cat([x6, x2], dim=1)\n\n        x8 = self.up_residual_conv2(x7)\n\n        x8 = self.upsample_3(x8)\n        x9 = torch.cat([x8, x1], dim=1)\n\n        x10 = self.up_residual_conv3(x9)\n\n        output = self.output_layer(x10)\n\n        return output\n\nclass Encoder_MLP(nn.Module):\n    def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=7, pose_latent_dim=64,\n                 audio_latent_dim=256, ff_size=128, num_layers=4, num_heads=4, dropout=0.1,\n                 ablation=None, activation=\"gelu\", **kargs):\n        super().__init__()\n\n        self.modeltype = modeltype\n        self.resunet = ResUnet()\n        self.audio_latent_dim = audio_latent_dim\n        # self.num_classes = num_classes\n        self.seq_len = num_frames\n        self.pose_latent_dim = pose_latent_dim\n\n        self.MLP = nn.Sequential()\n        layer_sizes = [pos_dim + self.seq_len * pos_dim + self.seq_len * self.audio_latent_dim, ff_size]\n        # layer_sizes[0] = self.audio_latent_dim + self.pose_latent_dim*2\n        for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):\n            self.MLP.add_module(\n                name=\"L{:d}\".format(i), module=nn.Linear(in_size, out_size))\n            self.MLP.add_module(name=\"A{:d}\".format(i), module=nn.ReLU())\n\n        self.linear_means = nn.Linear(layer_sizes[-1], ff_size)\n        self.linear_logvar = nn.Linear(layer_sizes[-1], ff_size)\n        self.linear_audio = nn.Linear(audio_dim, self.audio_latent_dim)\n\n        # self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))\n\n    def forward(self, batch):\n        # class_id = batch['class']\n        pose_motion_gt_ori = batch[\"x\"]                             #bs seq_len 6\n        ref = pose_motion_gt_ori[:,0,:]                            #bs 6\n        batch['x_delta'] = pose_motion_gt_ori - ref[:,None,:]\n        pose_motion_gt = batch['x_delta']\n        bs = pose_motion_gt_ori.shape[0]\n        audio_in = batch[\"y\"]                  # bs seq_len audio_emb_in_size\n\n        #pose encode\n        pose_emb = self.resunet(pose_motion_gt.unsqueeze(1))          #bs 1 seq_len 6 \n        pose_emb = pose_emb.reshape(bs, -1)                    #bs seq_len*6\n\n        #audio mapping\n        # print(audio_in.shape)\n        audio_out = self.linear_audio(audio_in)                # bs seq_len audio_emb_out_size\n        audio_out = audio_out.reshape(bs, -1)\n\n        # class_bias = self.classbias[class_id]                  #bs latent_size\n        x_in = torch.cat([ref, pose_emb, audio_out], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size\n        x_out = self.MLP(x_in)\n\n        mu = self.linear_means(x_out)\n        logvar = self.linear_means(x_out)                      #bs latent_size \n\n        batch.update({'mu':mu, 'logvar':logvar})\n        return batch\n\n\nclass Decoder_MLP(nn.Module):\n    def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=7, pose_latent_dim=64,\n                 audio_latent_dim=256, ff_size=128, num_layers=4, num_heads=4, dropout=0.1, activation=\"gelu\",\n                 ablation=None, **kargs):\n        super().__init__()\n\n        self.resunet = ResUnet()\n        # self.num_classes = num_classes\n        self.seq_len = num_frames\n\n        self.MLP = nn.Sequential()\n        self.audio_latent_dim = audio_latent_dim\n        layer_sizes = [ff_size, self.seq_len * pos_dim]\n        input_size = ff_size + self.seq_len*audio_latent_dim + pos_dim\n        for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):\n            self.MLP.add_module(\n                name=\"L{:d}\".format(i), module=nn.Linear(in_size, out_size))\n            if i+1 < len(layer_sizes):\n                self.MLP.add_module(name=\"A{:d}\".format(i), module=nn.ReLU())\n            else:\n                self.MLP.add_module(name=\"sigmoid\", module=nn.Sigmoid())\n        \n        self.pose_linear = nn.Linear(pos_dim, pos_dim)\n        self.linear_audio = nn.Linear(audio_dim, audio_latent_dim)\n\n        # self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))\n\n    def forward(self, batch):\n\n        z = batch['z']                                          #bs latent_size\n        bs = z.shape[0]\n        pose_motion_gt = batch[\"x\"]                             #bs seq_len 6\n        ref = pose_motion_gt[:,0,:]    \n        # class_id = batch['class']                           #bs 6\n        audio_in = batch['y']                           # bs seq_len audio_emb_in_size\n        #print('audio_in: ', audio_in[:, :, :10])\n\n        audio_out = self.linear_audio(audio_in)                 # bs seq_len audio_emb_out_size\n        #print('audio_out: ', audio_out[:, :, :10])\n        audio_out = audio_out.reshape([bs, -1])                 # bs seq_len*audio_emb_out_size\n        # class_bias = self.classbias[class_id]                   #bs latent_size\n\n        z = z # + class_bias\n        x_in = torch.cat([ref, z, audio_out], dim=-1)\n        x_out = self.MLP(x_in)                                  # bs layer_sizes[-1]\n        x_out = x_out.reshape((bs, self.seq_len, -1))\n\n        #print('x_out: ', x_out)\n\n        pose_emb = self.resunet(x_out.unsqueeze(1))             #bs 1 seq_len 6\n\n        pose_motion_pred = self.pose_linear(pose_emb.squeeze(1))       #bs seq_len 6\n\n        pose_motion_pred = pose_motion_pred\n\n        batch.update({'output':pose_motion_pred})\n        return batch\n\n"
  },
  {
    "path": "PBnet/src/models/architectures/resnet34.py",
    "content": "import torch\nimport torch.nn as nn\nfrom sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d\nfrom sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d\n\ndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=dilation, groups=groups, bias=False, dilation=dilation)\n\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n                 base_width=64, dilation=1, norm_layer=None):\n        super(BasicBlock, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        if groups != 1 or base_width != 64:\n            raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n        if dilation > 1:\n            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n        # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = norm_layer(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = norm_layer(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n                 base_width=64, dilation=1, norm_layer=None):\n        super(Bottleneck, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        width = int(planes * (base_width / 64.)) * groups\n        # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv1x1(inplanes, width)\n        self.bn1 = norm_layer(width)\n        self.conv2 = conv3x3(width, width, stride, groups, dilation)\n        self.bn2 = norm_layer(width)\n        self.conv3 = conv1x1(width, planes * self.expansion)\n        self.bn3 = norm_layer(planes * self.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\nclass ResNet(nn.Module):\n\n    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,\n                 groups=1, width_per_group=64, replace_stride_with_dilation=None,\n                 norm_layer=None,input_channel = 3):\n        super(ResNet, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self._norm_layer = norm_layer\n\n        self.inplanes = 64\n        self.dilation = 1\n        if replace_stride_with_dilation is None:\n            # each element in the tuple indicates if we should replace\n            # the 2x2 stride with a dilated convolution instead\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError(\"replace_stride_with_dilation should be None \"\n                             \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation))\n        self.groups = groups\n        self.base_width = width_per_group\n        self.conv1 = nn.Conv2d(input_channel, self.inplanes, kernel_size=7, stride=2, padding=3,\n                               bias=False)\n        self.bn1 = norm_layer(self.inplanes)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,\n                                       dilate=replace_stride_with_dilation[0])\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,\n                                       dilate=replace_stride_with_dilation[1])\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,\n                                       dilate=replace_stride_with_dilation[2])\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n        # Zero-initialize the last BN in each residual branch,\n        # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck):\n                    nn.init.constant_(m.bn3.weight, 0)\n                elif isinstance(m, BasicBlock):\n                    nn.init.constant_(m.bn2.weight, 0)\n\n    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):\n        norm_layer = self._norm_layer\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                norm_layer(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,\n                            self.base_width, previous_dilation, norm_layer))\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(block(self.inplanes, planes, groups=self.groups,\n                                base_width=self.base_width, dilation=self.dilation,\n                                norm_layer=norm_layer))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = torch.flatten(x, 1)\n        x = self.fc(x)\n\n        return x\n\ndef _resnet(arch, block, layers, pretrained, progress, **kwargs):\n    model = ResNet(block, layers, **kwargs)\n    return model\n\ndef resnet34(pretrained=False, progress=True, **kwargs):\n    r\"\"\"ResNet-34 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,\n                   **kwargs)\n\n\nclass MyResNet34(nn.Module):\n    def __init__(self,embedding_dim,input_channel = 3):\n        super(MyResNet34, self).__init__()\n        self.resnet = resnet34(norm_layer = BatchNorm2d,num_classes=embedding_dim,input_channel = input_channel)\n    def forward(self, x):\n        return self.resnet(x)\n"
  },
  {
    "path": "PBnet/src/models/architectures/tools/embeddings.py",
    "content": "# This file is taken from signjoey repository\nimport math\nimport torch\n\nfrom torch import nn, Tensor\nfrom ....tools.tools import freeze_params\n\n\ndef get_activation(activation_type):\n    if activation_type == \"relu\":\n        return nn.ReLU()\n    elif activation_type == \"relu6\":\n        return nn.ReLU6()\n    elif activation_type == \"prelu\":\n        return nn.PReLU()\n    elif activation_type == \"selu\":\n        return nn.SELU()\n    elif activation_type == \"celu\":\n        return nn.CELU()\n    elif activation_type == \"gelu\":\n        return nn.GELU()\n    elif activation_type == \"sigmoid\":\n        return nn.Sigmoid()\n    elif activation_type == \"softplus\":\n        return nn.Softplus()\n    elif activation_type == \"softshrink\":\n        return nn.Softshrink()\n    elif activation_type == \"softsign\":\n        return nn.Softsign()\n    elif activation_type == \"tanh\":\n        return nn.Tanh()\n    elif activation_type == \"tanhshrink\":\n        return nn.Tanhshrink()\n    else:\n        raise ValueError(\"Unknown activation type {}\".format(activation_type))\n\n\nclass MaskedNorm(nn.Module):\n    \"\"\"\n        Original Code from:\n        https://discuss.pytorch.org/t/batchnorm-for-different-sized-samples-in-batch/44251/8\n    \"\"\"\n\n    def __init__(self, norm_type, num_groups, num_features):\n        super().__init__()\n        self.norm_type = norm_type\n        if self.norm_type == \"batch\":\n            self.norm = nn.BatchNorm1d(num_features=num_features)\n        elif self.norm_type == \"group\":\n            self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=num_features)\n        elif self.norm_type == \"layer\":\n            self.norm = nn.LayerNorm(normalized_shape=num_features)\n        else:\n            raise ValueError(\"Unsupported Normalization Layer\")\n\n        self.num_features = num_features\n\n    def forward(self, x: Tensor, mask: Tensor):\n        if self.training:\n            reshaped = x.reshape([-1, self.num_features])\n            reshaped_mask = mask.reshape([-1, 1]) > 0\n            selected = torch.masked_select(reshaped, reshaped_mask).reshape(\n                [-1, self.num_features]\n            )\n            batch_normed = self.norm(selected)\n            scattered = reshaped.masked_scatter(reshaped_mask, batch_normed)\n            return scattered.reshape([x.shape[0], -1, self.num_features])\n        else:\n            reshaped = x.reshape([-1, self.num_features])\n            batched_normed = self.norm(reshaped)\n            return batched_normed.reshape([x.shape[0], -1, self.num_features])\n\n\n# TODO (Cihan): Spatial and Word Embeddings are pretty much the same\n#       We might as well convert them into a single module class.\n#       Only difference is the lut vs linear layers.\nclass Embeddings(nn.Module):\n\n    \"\"\"\n    Simple embeddings class\n    \"\"\"\n\n    # pylint: disable=unused-argument\n    def __init__(\n        self,\n        embedding_dim: int = 64,\n        num_heads: int = 8,\n        scale: bool = False,\n        scale_factor: float = None,\n        norm_type: str = None,\n        activation_type: str = None,\n        vocab_size: int = 0,\n        padding_idx: int = 1,\n        freeze: bool = False,\n        **kwargs\n    ):\n        \"\"\"\n        Create new embeddings for the vocabulary.\n        Use scaling for the Transformer.\n\n        :param embedding_dim:\n        :param scale:\n        :param vocab_size:\n        :param padding_idx:\n        :param freeze: freeze the embeddings during training\n        \"\"\"\n        super().__init__()\n\n        self.embedding_dim = embedding_dim\n        self.vocab_size = vocab_size\n        self.lut = nn.Embedding(vocab_size, self.embedding_dim, padding_idx=padding_idx)\n\n        self.norm_type = norm_type\n        if self.norm_type:\n            self.norm = MaskedNorm(\n                norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim\n            )\n\n        self.activation_type = activation_type\n        if self.activation_type:\n            self.activation = get_activation(activation_type)\n\n        self.scale = scale\n        if self.scale:\n            if scale_factor:\n                self.scale_factor = scale_factor\n            else:\n                self.scale_factor = math.sqrt(self.embedding_dim)\n\n        if freeze:\n            freeze_params(self)\n\n    # pylint: disable=arguments-differ\n    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:\n        \"\"\"\n        Perform lookup for input `x` in the embedding table.\n\n        :param mask: token masks\n        :param x: index in the vocabulary\n        :return: embedded representation for `x`\n        \"\"\"\n\n        x = self.lut(x)\n\n        if self.norm_type:\n            x = self.norm(x, mask)\n\n        if self.activation_type:\n            x = self.activation(x)\n\n        if self.scale:\n            return x * self.scale_factor\n        else:\n            return x\n\n    def __repr__(self):\n        return \"%s(embedding_dim=%d, vocab_size=%d)\" % (\n            self.__class__.__name__,\n            self.embedding_dim,\n            self.vocab_size,\n        )\n\n\nclass SpatialEmbeddings(nn.Module):\n\n    \"\"\"\n    Simple Linear Projection Layer\n    (For encoder outputs to predict glosses)\n    \"\"\"\n\n    # pylint: disable=unused-argument\n    def __init__(\n        self,\n        embedding_dim: int,\n        input_size: int,\n        num_heads: int,\n        freeze: bool = False,\n        norm_type: str = \"batch\",\n        activation_type: str = \"softsign\",\n        scale: bool = False,\n        scale_factor: float = None,\n        **kwargs\n    ):\n        \"\"\"\n        Create new embeddings for the vocabulary.\n        Use scaling for the Transformer.\n\n        :param embedding_dim:\n        :param input_size:\n        :param freeze: freeze the embeddings during training\n        \"\"\"\n        super().__init__()\n\n        self.embedding_dim = embedding_dim\n        self.input_size = input_size\n        self.ln = nn.Linear(self.input_size, self.embedding_dim)\n\n        self.norm_type = norm_type\n        if self.norm_type:\n            self.norm = MaskedNorm(\n                norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim\n            )\n\n        self.activation_type = activation_type\n        if self.activation_type:\n            self.activation = get_activation(activation_type)\n\n        self.scale = scale\n        if self.scale:\n            if scale_factor:\n                self.scale_factor = scale_factor\n            else:\n                self.scale_factor = math.sqrt(self.embedding_dim)\n\n        if freeze:\n            freeze_params(self)\n\n    # pylint: disable=arguments-differ\n    def forward(self, x: Tensor, mask: Tensor) -> Tensor:\n        \"\"\"\n        :param mask: frame masks\n        :param x: input frame features\n        :return: embedded representation for `x`\n        \"\"\"\n\n        x = self.ln(x)\n\n        if self.norm_type:\n            x = self.norm(x, mask)\n\n        if self.activation_type:\n            x = self.activation(x)\n\n        if self.scale:\n            return x * self.scale_factor\n        else:\n            return x\n\n    def __repr__(self):\n        return \"%s(embedding_dim=%d, input_size=%d)\" % (\n            self.__class__.__name__,\n            self.embedding_dim,\n            self.input_size,\n        )\n"
  },
  {
    "path": "PBnet/src/models/architectures/tools/resnet.py",
    "content": "import torch\nimport torch.nn as nn\n\ndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=dilation, groups=groups, bias=False, dilation=dilation)\n\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n                 base_width=64, dilation=1, norm_layer=None):\n        super(BasicBlock, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        if groups != 1 or base_width != 64:\n            raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n        if dilation > 1:\n            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n        # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = norm_layer(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = norm_layer(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n                 base_width=64, dilation=1, norm_layer=None):\n        super(Bottleneck, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        width = int(planes * (base_width / 64.)) * groups\n        # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv1x1(inplanes, width)\n        self.bn1 = norm_layer(width)\n        self.conv2 = conv3x3(width, width, stride, groups, dilation)\n        self.bn2 = norm_layer(width)\n        self.conv3 = conv1x1(width, planes * self.expansion)\n        self.bn3 = norm_layer(planes * self.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\nclass ResNet(nn.Module):\n\n    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,\n                 groups=1, width_per_group=64, replace_stride_with_dilation=None,\n                 norm_layer=None,input_channel = 3):\n        super(ResNet, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self._norm_layer = norm_layer\n\n        self.inplanes = 64\n        self.dilation = 1\n        if replace_stride_with_dilation is None:\n            # each element in the tuple indicates if we should replace\n            # the 2x2 stride with a dilated convolution instead\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError(\"replace_stride_with_dilation should be None \"\n                             \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation))\n        self.groups = groups\n        self.base_width = width_per_group\n        self.conv1 = nn.Conv2d(input_channel, self.inplanes, kernel_size=7, stride=2, padding=3,\n                               bias=False)\n        self.bn1 = norm_layer(self.inplanes)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,\n                                       dilate=replace_stride_with_dilation[0])\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,\n                                       dilate=replace_stride_with_dilation[1])\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,\n                                       dilate=replace_stride_with_dilation[2])\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n        # Zero-initialize the last BN in each residual branch,\n        # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck):\n                    nn.init.constant_(m.bn3.weight, 0)\n                elif isinstance(m, BasicBlock):\n                    nn.init.constant_(m.bn2.weight, 0)\n\n    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):\n        norm_layer = self._norm_layer\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                norm_layer(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,\n                            self.base_width, previous_dilation, norm_layer))\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(block(self.inplanes, planes, groups=self.groups,\n                                base_width=self.base_width, dilation=self.dilation,\n                                norm_layer=norm_layer))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = torch.flatten(x, 1)\n        x = self.fc(x)\n\n        return x\n\ndef _resnet(arch, block, layers, pretrained, progress, **kwargs):\n    model = ResNet(block, layers, **kwargs)\n    return model\n\ndef resnet34(pretrained=False, progress=True, **kwargs):\n    r\"\"\"ResNet-34 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,\n                   **kwargs)"
  },
  {
    "path": "PBnet/src/models/architectures/tools/transformer_layers.py",
    "content": "# -*- coding: utf-8 -*-\nimport math\nimport torch\nimport torch.nn as nn\nfrom torch import Tensor\n\n# Took from https://github.com/joeynmt/joeynmt/blob/fb66afcbe1beef9acd59283bcc084c4d4c1e6343/joeynmt/transformer_layers.py\n\n\n# pylint: disable=arguments-differ\nclass MultiHeadedAttention(nn.Module):\n    \"\"\"\n    Multi-Head Attention module from \"Attention is All You Need\"\n\n    Implementation modified from OpenNMT-py.\n    https://github.com/OpenNMT/OpenNMT-py\n    \"\"\"\n\n    def __init__(self, num_heads: int, size: int, dropout: float = 0.1):\n        \"\"\"\n        Create a multi-headed attention layer.\n        :param num_heads: the number of heads\n        :param size: model size (must be divisible by num_heads)\n        :param dropout: probability of dropping a unit\n        \"\"\"\n        super().__init__()\n\n        assert size % num_heads == 0\n\n        self.head_size = head_size = size // num_heads\n        self.model_size = size\n        self.num_heads = num_heads\n\n        self.k_layer = nn.Linear(size, num_heads * head_size)\n        self.v_layer = nn.Linear(size, num_heads * head_size)\n        self.q_layer = nn.Linear(size, num_heads * head_size)\n\n        self.output_layer = nn.Linear(size, size)\n        self.softmax = nn.Softmax(dim=-1)\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, k: Tensor, v: Tensor, q: Tensor, mask: Tensor = None):\n        \"\"\"\n        Computes multi-headed attention.\n\n        :param k: keys   [B, M, D] with M being the sentence length.\n        :param v: values [B, M, D]\n        :param q: query  [B, M, D]\n        :param mask: optional mask [B, 1, M] or [B, M, M]\n        :return:\n        \"\"\"            \n        batch_size = k.size(0)\n        num_heads = self.num_heads\n\n        # project the queries (q), keys (k), and values (v)\n        k = self.k_layer(k)\n        v = self.v_layer(v)\n        q = self.q_layer(q)\n\n        # reshape q, k, v for our computation to [batch_size, num_heads, ..]\n        k = k.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2)\n        v = v.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2)\n        q = q.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2)\n\n        # compute scores\n        q = q / math.sqrt(self.head_size)\n\n        # batch x num_heads x query_len x key_len\n        scores = torch.matmul(q, k.transpose(2, 3))\n        # torch.Size([48, 8, 183, 183])\n        \n        # apply the mask (if we have one)\n        # we add a dimension for the heads to it below: [B, 1, 1, M]\n        if mask is not None:\n            scores = scores.masked_fill(~mask.unsqueeze(1), float('-inf'))\n\n        # apply attention dropout and compute context vectors.\n        attention = self.softmax(scores)\n        attention = self.dropout(attention)\n        # torch.Size([48, 8, 183, 183]) [bs, nheads, time, time] (for decoding)\n\n        # v: torch.Size([48, 8, 183, 32]) (32 is 256/8)\n        # get context vector (select values with attention) and reshape\n        # back to [B, M, D]\n        context = torch.matmul(attention, v)  # torch.Size([48, 8, 183, 32])\n        context = context.transpose(1, 2).contiguous().view(\n            batch_size, -1, num_heads * self.head_size)\n        # torch.Size([48, 183, 256]) put back to 256 (combine the heads)\n\n        output = self.output_layer(context)\n        # torch.Size([48, 183, 256]): 1 output per time step\n        \n        return output\n\n\n# pylint: disable=arguments-differ\nclass PositionwiseFeedForward(nn.Module):\n    \"\"\"\n    Position-wise Feed-forward layer\n    Projects to ff_size and then back down to input_size.\n    \"\"\"\n\n    def __init__(self, input_size, ff_size, dropout=0.1):\n        \"\"\"\n        Initializes position-wise feed-forward layer.\n        :param input_size: dimensionality of the input.\n        :param ff_size: dimensionality of intermediate representation\n        :param dropout:\n        \"\"\"\n        super().__init__()\n        self.layer_norm = nn.LayerNorm(input_size, eps=1e-6)\n        self.pwff_layer = nn.Sequential(\n            nn.Linear(input_size, ff_size),\n            nn.ReLU(),\n            nn.Dropout(dropout),\n            nn.Linear(ff_size, input_size),\n            nn.Dropout(dropout),\n        )\n\n    def forward(self, x):\n        x_norm = self.layer_norm(x)\n        return self.pwff_layer(x_norm) + x\n\n\n# pylint: disable=arguments-differ\nclass PositionalEncoding(nn.Module):\n    \"\"\"\n    Pre-compute position encodings (PE).\n    In forward pass, this adds the position-encodings to the\n    input for as many time steps as necessary.\n\n    Implementation based on OpenNMT-py.\n    https://github.com/OpenNMT/OpenNMT-py\n    \"\"\"\n\n    def __init__(self,\n                 size: int = 0,\n                 max_len: int = 5000):\n        \"\"\"\n        Positional Encoding with maximum length max_len\n        :param size:\n        :param max_len:\n        :param dropout:\n        \"\"\"\n        if size % 2 != 0:\n            raise ValueError(\"Cannot use sin/cos positional encoding with \"\n                             \"odd dim (got dim={:d})\".format(size))\n        pe = torch.zeros(max_len, size)\n        position = torch.arange(0, max_len).unsqueeze(1)\n        div_term = torch.exp((torch.arange(0, size, 2, dtype=torch.float) *\n                              -(math.log(10000.0) / size)))\n        pe[:, 0::2] = torch.sin(position.float() * div_term)\n        pe[:, 1::2] = torch.cos(position.float() * div_term)\n        pe = pe.unsqueeze(0)  # shape: [1, size, max_len]\n        super().__init__()\n        self.register_buffer('pe', pe)\n        self.dim = size\n\n    def forward(self, emb):\n        \"\"\"Embed inputs.\n        Args:\n            emb (FloatTensor): Sequence of word vectors\n                ``(seq_len, batch_size, self.dim)``\n        \"\"\"\n        # Add position encodings\n        return emb + self.pe[:, :emb.size(1)]\n\n\nclass TransformerEncoderLayer(nn.Module):\n    \"\"\"\n    One Transformer encoder layer has a Multi-head attention layer plus\n    a position-wise feed-forward layer.\n    \"\"\"\n\n    def __init__(self,\n                 size: int = 0,\n                 ff_size: int = 0,\n                 num_heads: int = 0,\n                 dropout: float = 0.1):\n        \"\"\"\n        A single Transformer layer.\n        :param size:\n        :param ff_size:\n        :param num_heads:\n        :param dropout:\n        \"\"\"\n        super().__init__()\n\n        self.layer_norm = nn.LayerNorm(size, eps=1e-6)\n        self.src_src_att = MultiHeadedAttention(num_heads, size,\n                                                dropout=dropout)\n        self.feed_forward = PositionwiseFeedForward(size, ff_size=ff_size,\n                                                    dropout=dropout)\n        self.dropout = nn.Dropout(dropout)\n        self.size = size\n\n    # pylint: disable=arguments-differ\n    def forward(self, x: Tensor, mask: Tensor) -> Tensor:\n        \"\"\"\n        Forward pass for a single transformer encoder layer.\n        First applies layer norm, then self attention,\n        then dropout with residual connection (adding the input to the result),\n        and then a position-wise feed-forward layer.\n\n        :param x: layer input\n        :param mask: input mask\n        :return: output tensor\n        \"\"\"\n        x_norm = self.layer_norm(x)\n        h = self.src_src_att(x_norm, x_norm, x_norm, mask)\n        h = self.dropout(h) + x\n        o = self.feed_forward(h)\n        return o\n\n\nclass TransformerDecoderLayer(nn.Module):\n    \"\"\"\n    Transformer decoder layer.\n\n    Consists of self-attention, source-attention, and feed-forward.\n    \"\"\"\n\n    def __init__(self,\n                 size: int = 0,\n                 ff_size: int = 0,\n                 num_heads: int = 0,\n                 dropout: float = 0.1):\n        \"\"\"\n        Represents a single Transformer decoder layer.\n\n        It attends to the source representation and the previous decoder states.\n\n        :param size: model dimensionality\n        :param ff_size: size of the feed-forward intermediate layer\n        :param num_heads: number of heads\n        :param dropout: dropout to apply to input\n        \"\"\"\n        super().__init__()\n        self.size = size\n\n        self.trg_trg_att = MultiHeadedAttention(num_heads, size,\n                                                dropout=dropout)\n        self.src_trg_att = MultiHeadedAttention(num_heads, size,\n                                                dropout=dropout)\n\n        self.feed_forward = PositionwiseFeedForward(size, ff_size=ff_size,\n                                                    dropout=dropout)\n\n        self.x_layer_norm = nn.LayerNorm(size, eps=1e-6)\n        self.dec_layer_norm = nn.LayerNorm(size, eps=1e-6)\n\n        self.dropout = nn.Dropout(dropout)\n\n    # pylint: disable=arguments-differ\n    def forward(self,\n                x: Tensor = None,\n                memory: Tensor = None,\n                src_mask: Tensor = None,\n                trg_mask: Tensor = None) -> Tensor:\n        \"\"\"\n        Forward pass of a single Transformer decoder layer.\n\n        :param x: inputs\n        :param memory: source representations\n        :param src_mask: source mask\n        :param trg_mask: target mask (so as to not condition on future steps)\n        :return: output tensor\n        \"\"\"\n        # decoder/target self-attention\n        x_norm = self.x_layer_norm(x)  # torch.Size([48, 183, 256])\n        h1 = self.trg_trg_att(x_norm, x_norm, x_norm, mask=trg_mask)\n        h1 = self.dropout(h1) + x\n\n        # source-target attention\n        h1_norm = self.dec_layer_norm(h1)  # torch.Size([48, 183, 256]) (same for memory)\n        h2 = self.src_trg_att(memory, memory, h1_norm, mask=src_mask)\n\n        # final position-wise feed-forward layer\n        o = self.feed_forward(self.dropout(h2) + h1)\n\n        return o\n"
  },
  {
    "path": "PBnet/src/models/architectures/tools/util.py",
    "content": "from torch import nn\n\nimport torch.nn.functional as F\nimport torch\n\nfrom sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d\nfrom sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d\nfrom src.models.architectures.tools.resnet import resnet34\n\nclass MyResNet34(nn.Module):\n    def __init__(self,embedding_dim,input_channel = 3):\n        super(MyResNet34, self).__init__()\n        self.resnet = resnet34(norm_layer = BatchNorm2d,num_classes=embedding_dim,input_channel = input_channel)\n    def forward(self, x):\n        return self.resnet(x)"
  },
  {
    "path": "PBnet/src/models/architectures/transformer.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass PositionalEncoding(nn.Module):\n    def __init__(self, d_model, dropout=0.1, max_len=5000):\n        super(PositionalEncoding, self).__init__()\n        self.dropout = nn.Dropout(p=dropout)\n\n        pe = torch.zeros(max_len, d_model)\n        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))\n        pe[:, 0::2] = torch.sin(position * div_term)\n        pe[:, 1::2] = torch.cos(position * div_term)\n        pe = pe.unsqueeze(0).transpose(0, 1)\n        \n        self.register_buffer('pe', pe)\n\n    def forward(self, x):\n        # not used in the final model\n        x = x + self.pe[:x.shape[0], :]\n        return self.dropout(x)\n\n\nclass RelativePositionBias(nn.Module):\n    def __init__(\n            self,\n            heads=8,\n            num_buckets=32,\n            max_distance=128\n    ):\n        super().__init__()\n        self.num_buckets = num_buckets\n        self.max_distance = max_distance\n        self.relative_attention_bias = nn.Embedding(num_buckets, heads)\n\n    @staticmethod\n    def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):\n        ret = 0\n        n = -relative_position\n\n        num_buckets //= 2\n        ret += (n < 0).long() * num_buckets\n        n = torch.abs(n)\n\n        max_exact = num_buckets // 2\n        is_small = n < max_exact\n\n        val_if_large = max_exact + (\n                torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)\n        ).long()\n        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))\n\n        ret += torch.where(is_small, n, val_if_large)\n        return ret\n\n    def forward(self, n, device):\n        q_pos = torch.arange(n, dtype=torch.long, device=device)\n        k_pos = torch.arange(n, dtype=torch.long, device=device)\n        rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')\n        rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets,\n                                                   max_distance=self.max_distance)\n        values = self.relative_attention_bias(rp_bucket)\n        return rearrange(values, 'i j h -> h i j')\n        \n# only for ablation / not used in the final model\nclass TimeEncoding(nn.Module):\n    def __init__(self, d_model, dropout=0.1, max_len=5000):\n        super(TimeEncoding, self).__init__()\n        self.dropout = nn.Dropout(p=dropout)\n\n    def forward(self, x, mask, lengths):\n        time = mask * 1/(lengths[..., None]-1)\n        time = time[:, None] * torch.arange(time.shape[1], device=x.device)[None, :]\n        time = time[:, 0].T\n        # add the time encoding\n        x = x + time[..., None]\n        return self.dropout(x)\n    \n\nclass Encoder_TRANSFORMER(nn.Module):\n    def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=7, pose_latent_dim=64,\n                 audio_latent_dim=256, ff_size=1024, num_layers=4, num_heads=4, dropout=0.1,\n                 ablation=None, activation=\"gelu\", **kargs):\n        super().__init__()\n        \n        self.modeltype = modeltype\n        self.pos_dim = pos_dim\n        self.num_frames = num_frames\n        self.audio_dim = audio_dim\n        \n        self.pose_latent_dim = pose_latent_dim\n        self.audio_latent_dim = audio_latent_dim\n        self.latent_dim = self.audio_latent_dim + self.pose_latent_dim*2\n        \n        self.ff_size = ff_size\n        self.num_layers = num_layers\n        self.num_heads = num_heads\n        self.dropout = dropout\n\n        self.ablation = ablation\n        self.activation = activation\n\n        \n        # if self.ablation == \"average_encoder\":\n        #     self.mu_layer = nn.Linear(self.latent_dim, self.latent_dim)\n        #     self.sigma_layer = nn.Linear(self.latent_dim, self.latent_dim)\n        # else:\n        #     self.muQuery = nn.Parameter(torch.randn(self.num_classes, self.latent_dim))\n        #     self.sigmaQuery = nn.Parameter(torch.randn(self.num_classes, self.latent_dim))\n        \n        # # there's no class of our dataset CREMA/HDTF, so noly  dont need to use nn.parameter\n        self.mu_layer = nn.Linear(self.latent_dim, self.audio_latent_dim)\n        self.sigma_layer = nn.Linear(self.latent_dim, self.audio_latent_dim)\n        \n        self.poseEmbedding = nn.Linear(self.pos_dim, self.pose_latent_dim) #6,64\n        self.firstposeEmbedding = nn.Linear(self.pos_dim, self.pose_latent_dim) #6,64\n        self.audioEmbedding = nn.Linear(self.audio_dim, self.audio_latent_dim) #1024, 256\n        \n        self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)\n        \n        # self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))\n        \n        seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,\n                                                          nhead=self.num_heads,\n                                                          dim_feedforward=self.ff_size,\n                                                          dropout=self.dropout,\n                                                          activation=self.activation)\n        self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,\n                                                     num_layers=self.num_layers)\n\n    def forward(self, batch):\n        '''\n            x: 6-dim pos, (bs, max_num_frames, 6)\n            y: 1024-dim audio embbeding, (bs, max_num_frames, 1024)\n        '''\n\n        x, y, mask = batch[\"x\"], batch[\"y\"], batch[\"mask\"]\n        # bs, njoints, nfeats, nframes = x.shape\n        # x = x.permute((3, 0, 1, 2)).reshape(nframes, bs, njoints*nfeats) \n        x_ref = x[:,0,:].unsqueeze(dim=1) # The pose information of the first frame(refrence img)\n        x = x-x_ref.repeat(1,x.size(1),1) # bs, nf, 6  Obtain the difference from the first frame\n        batch['x_delta'] = x\n        x_ref = x_ref.permute((1,0,2)) #1, bs, 6\n        x = x.permute((1, 0, 2)) #nf, bs, 6\n        y = y.permute((1, 0, 2)) #nf, bs, 1024\n        # embedding of the pose/audio\n        x_ref = self.firstposeEmbedding(x_ref).repeat(x.size(0),1,1) #nf, bs, 64\n        x = self.poseEmbedding(x) #nf, bs, 64\n        y = self.audioEmbedding(y) #nf, bs, 256\n        x = torch.cat([x_ref, x, y],dim=-1) # nf, bs, 64+64+256\n\n        # only use the \"average_encoder\" mode\n        # add positional encoding\n        x = self.sequence_pos_encoder(x)\n        # transformer layers\n        final = self.seqTransEncoder(x, src_key_padding_mask=~mask) #nu_frames, bs, 64+64+256\n        # get the average of the output\n        z = final# final.mean(axis=0) # nf, bs, 64+64+256\n        # extract mu and logvar\n        mu = self.mu_layer(z) # nf, bs, 256\n        logvar = self.sigma_layer(z) # nf, bs, 256\n        # logvar = - torch.ones_like(logvar) * 1e10\n\n        return {\"mu\": mu, \"logvar\": logvar}\n\n\nclass Decoder_TRANSFORMER(nn.Module):\n    def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=7, pose_latent_dim=64,\n                 audio_latent_dim=256, ff_size=1024, num_layers=4, num_heads=4, dropout=0.1, activation=\"gelu\",\n                 ablation=None, **kargs):\n        super().__init__()\n\n        self.modeltype = modeltype\n\n        self.pos_dim = pos_dim\n        self.num_frames = num_frames\n        self.audio_dim = audio_dim\n        \n        self.pose_latent_dim = pose_latent_dim\n        self.audio_latent_dim = audio_latent_dim\n        self.latent_dim = self.audio_latent_dim + self.pose_latent_dim*2\n        \n        self.ff_size = ff_size\n        self.num_layers = num_layers\n        self.num_heads = num_heads\n        self.dropout = dropout\n\n        self.ablation = ablation\n\n        self.activation = activation\n\n        self.firstposeEmbedding = nn.Linear(self.pos_dim, self.pose_latent_dim) #6,64\n        self.audioEmbedding = nn.Linear(self.audio_dim, self.audio_latent_dim) #1024, 256\n        self.ztimelinear = nn.Linear(self.audio_latent_dim*2+self.pose_latent_dim, self.pose_latent_dim) #256*2+64,64\n        # self.input_feats = self.njoints*self.nfeats\n\n        # # only for ablation / not used in the final model\n        # if self.ablation == \"zandtime\":\n        #     self.ztimelinear = nn.Linear(self.latent_dim + self.num_classes, self.latent_dim)\n        # else:\n        #     self.actionBiases = nn.Parameter(torch.randn(1024, self.latent_dim))\n            # self.actionBiases = nn.Parameter(torch.randn(self.num_classes, self.latent_dim))\n\n        # # only for ablation / not used in the final model\n        # if self.ablation == \"time_encoding\":\n        #     self.sequence_pos_encoder = TimeEncoding(self.dropout)\n        # else:\n        #     self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)\n\n        self.sequence_pos_encoder = PositionalEncoding(self.pose_latent_dim, self.dropout)\n        # self.sequence_pos_encoder = TimeEncoding(self.dropout) #time_encoding\n        \n        seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=self.pose_latent_dim,\n                                                          nhead=self.num_heads,\n                                                          dim_feedforward=self.ff_size,\n                                                          dropout=self.dropout,\n                                                          activation=activation)\n        self.seqTransDecoder = nn.TransformerDecoder(seqTransDecoderLayer,\n                                                     num_layers=self.num_layers)\n        \n        self.finallayer = nn.Linear(self.pose_latent_dim, self.pos_dim)\n        \n    def forward(self, batch):\n        '''\n            z: bs, audio_latent_dim(256)\n            y: bs, num_frames, 1024\n            mask: bs, num_frames\n            lengths: [num_frames,...]\n        '''\n        x, z, y, mask, lengths = batch[\"x\"], batch[\"z\"], batch[\"y\"], batch[\"mask\"], batch[\"lengths\"]\n        bs, nframes = mask.shape\n        # first img\n        x_ref = x[:,0,:].unsqueeze(dim=1) #bs, 1, 64\n        x_ref = self.firstposeEmbedding(x_ref.repeat(1, nframes, 1)) #bs, nf, 64\n        y = self.audioEmbedding(y) #bs, num_frames, 256\n        z = z.permute(1, 0, 2)\n        #z = z.unsqueeze(dim=1).repeat(1, nframes, 1) #bs, num_frames, 256\n        z = torch.cat([x_ref, z, y], dim=-1) # bs, num_frames, 256*2+64\n        z = self.ztimelinear(z)\n        z = z.permute((1, 0, 2)) # nf, bs, 64\n        pose_latent_dim = z.shape[2]\n        # z = z[None]  # sequence of size 1\n\n        # # only for ablation / not used in the final model\n        # if self.ablation == \"zandtime\":\n        #     yoh = F.one_hot(y, self.num_classes)\n        #     z = torch.cat((z, yoh), axis=1)\n        #     z = self.ztimelinear(z)\n        #     z = z[None]  # sequence of size 1\n        # else:\n        #     # only for ablation / not used in the final model\n        #     if self.ablation == \"concat_bias\":\n        #         # sequence of size 2\n        #         z = torch.stack((z, self.actionBiases[y]), axis=0)\n        #     else:\n        #         # shift the latent noise vector to be the action noise\n        #         z = z + self.actionBiases[y.long()] # NEED CHECK\n        #         z = z[None]  # sequence of size 1\n            \n        timequeries = torch.zeros(nframes, bs, pose_latent_dim, device=z.device)\n        timequeries = self.sequence_pos_encoder(timequeries)\n        # timequeries = self.sequence_pos_encoder(timequeries, mask, lengths) #time_encoding\n        \n        # # only for ablation / not used in the final model\n        # if self.ablation == \"time_encoding\":\n        #     timequeries = self.sequence_pos_encoder(timequeries, mask, lengths)\n        # else:\n        #     timequeries = self.sequence_pos_encoder(timequeries)\n        \n        # num_frames, bs, 64\n        output = self.seqTransDecoder(tgt=timequeries, memory=z,\n                                      tgt_key_padding_mask=~mask)\n        \n        output = self.finallayer(output).reshape(nframes, bs, self.pos_dim) # num_frames, bs, 6\n        # output = self.finallayer(output).reshape(nframes, bs, njoints, nfeats)\n        \n        # zero for padded area\n        output[~mask.T] = 0 #nf, bs, 6\n        output = output.permute(1,0,2)#bs, nf, 6\n        \n        batch[\"output\"] = output\n        return batch\n"
  },
  {
    "path": "PBnet/src/models/architectures/transformerdecoder.py",
    "content": "import copy\nfrom typing import Optional, Any, Union, Callable\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn.functional import dropout\nfrom torch.nn import functional as F\nfrom torch.nn.modules import Module\nfrom torch.nn.modules.container import ModuleList\nfrom torch.nn.modules.activation import MultiheadAttention as MultiheadAttention0\nfrom torch.nn.init import xavier_uniform_\nfrom torch.nn.modules.dropout import Dropout\nfrom torch.nn.modules.linear import Linear\nfrom torch.nn.modules.normalization import LayerNorm\nimport torch.nn as nn\n\nclass MultiheadAttention(nn.Module):\n    def __init__(self, embed_size, heads, dropout = None, batch_first = None):\n        super(MultiheadAttention, self).__init__()\n        self.embed_size = embed_size\n        self.heads = heads\n        self.head_dim = embed_size // heads\n\n        assert (\n            self.head_dim * heads == embed_size\n        ), \"Embedding size needs to be divisible by heads\"\n\n        self.values = nn.Linear(embed_size, self.head_dim * heads, bias=False)\n        self.keys = nn.Linear(embed_size, self.head_dim * heads, bias=False)\n        self.queries = nn.Linear(embed_size, self.head_dim * heads, bias=False)\n        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)\n\n        self.dropout = nn.Dropout(dropout)\n\n    def sinusoidal_position_embedding(self, batch_size, nums_head, max_len, output_dim, device):\n        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1)\n        ids = torch.arange(0, output_dim // 2, dtype=torch.float)\n        theta = torch.pow(10000, -2 * ids / output_dim)\n\n        embeddings = position * theta\n\n        embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)\n\n        embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape))))\n\n        embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim))\n        embeddings = embeddings.to(device)\n        return embeddings\n\n    def RoPE(self, q, k):\n        # q,k: (B, H, L, D)\n        batch_size = q.shape[0]\n        nums_head = q.shape[1]\n        max_len = q.shape[2]\n        output_dim = q.shape[-1]\n\n        pos_emb = self.sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device)\n\n        cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1)\n        sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1)\n\n        # q,k: (B, H, L, D)\n        q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)\n        q2 = q2.reshape(q.shape)\n        q = q * cos_pos + q2 * sin_pos\n\n        k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)\n        k2 = k2.reshape(k.shape)\n        k = k * cos_pos + k2 * sin_pos\n\n        return q, k\n\n    def forward(self, q, k, v, attn_mask = None, key_padding_mask=None, need_weights = None):\n        B = q.shape[0]\n        use_rope = True\n        len = q.shape[1]\n\n        values = self.values(v).view(B, len, self.heads, self.head_dim)\n        keys = self.keys(k).view(B, len, self.heads, self.head_dim)\n        queries = self.queries(q).view(B, len, self.heads, self.head_dim)\n\n        values = values.permute(0, 2, 1, 3)\n        keys = keys.permute(0, 2, 1, 3)\n        queries = queries.permute(0, 2, 1, 3)\n        # [B, H, L, D]\n\n        if use_rope:\n            queries, keys = self.RoPE(queries, keys)\n\n        energy = torch.matmul(queries, keys.permute(0, 1, 3, 2))\n\n        # if attn_mask is not None:\n        #     energy = energy.masked_fill(attn_mask == 1, float(\"-1e20\"))\n\n        if attn_mask is None:\n            attn_mask = 0\n\n        attention = F.softmax(energy / (self.head_dim ** (1 / 2) + attn_mask), dim=-1)\n\n        attention = self.dropout(attention)\n\n        out = torch.matmul(attention, values)\n        out = out.permute(0, 2, 1, 3).contiguous().view(B, len, self.heads * self.head_dim)\n\n        out = self.fc_out(out)\n        return out\n\nclass Transformer(Module):\n    r\"\"\"A transformer model. User is able to modify the attributes as needed. The architecture\n    is based on the paper \"Attention Is All You Need\". Ashish Vaswani, Noam Shazeer,\n    Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and\n    Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information\n    Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805)\n    model with corresponding parameters.\n\n    Args:\n        d_model: the number of expected features in the encoder/decoder inputs (default=512).\n        nhead: the number of heads in the multiheadattention models (default=8).\n        num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).\n        num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).\n        dim_feedforward: the dimension of the feedforward network model (default=2048).\n        dropout: the dropout value (default=0.1).\n        activation: the activation function of encoder/decoder intermediate layer, can be a string\n            (\"relu\" or \"gelu\") or a unary callable. Default: relu\n        custom_encoder: custom encoder (default=None).\n        custom_decoder: custom decoder (default=None).\n        layer_norm_eps: the eps value in layer normalization components (default=1e-5).\n        batch_first: If ``True``, then the input and output tensors are provided\n            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).\n        norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before\n            other attention and feedforward operations, otherwise after. Default: ``False`` (after).\n\n    Examples::\n        >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)\n        >>> src = torch.rand((10, 32, 512))\n        >>> tgt = torch.rand((20, 32, 512))\n        >>> out = transformer_model(src, tgt)\n\n    Note: A full example to apply nn.Transformer module for the word language model is available in\n    https://github.com/pytorch/examples/tree/master/word_language_model\n    \"\"\"\n\n    def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,\n                 num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,\n                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,\n                 custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,\n                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,\n                 device=None, dtype=None) -> None:\n        factory_kwargs = {'device': device, 'dtype': dtype}\n        super(Transformer, self).__init__()\n\n        if custom_encoder is not None:\n            self.encoder = custom_encoder\n        else:\n            encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,\n                                                    activation, layer_norm_eps, batch_first, norm_first,\n                                                    **factory_kwargs)\n            encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)\n            self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)\n\n        if custom_decoder is not None:\n            self.decoder = custom_decoder\n        else:\n            decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,\n                                                    activation, layer_norm_eps, batch_first, norm_first,\n                                                    **factory_kwargs)\n            decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)\n            self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)\n\n        self._reset_parameters()\n\n        self.d_model = d_model\n        self.nhead = nhead\n\n        self.batch_first = batch_first\n\n    def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,\n                memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,\n                tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:\n        r\"\"\"Take in and process masked source/target sequences.\n\n        Args:\n            src: the sequence to the encoder (required).\n            tgt: the sequence to the decoder (required).\n            src_mask: the additive mask for the src sequence (optional).\n            tgt_mask: the additive mask for the tgt sequence (optional).\n            memory_mask: the additive mask for the encoder output (optional).\n            src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).\n            tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).\n            memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).\n\n        Shape:\n            - src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or\n              `(N, S, E)` if `batch_first=True`.\n            - tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or\n              `(N, T, E)` if `batch_first=True`.\n            - src_mask: :math:`(S, S)`.\n            - tgt_mask: :math:`(T, T)`.\n            - memory_mask: :math:`(T, S)`.\n            - src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.\n            - tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`.\n            - memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.\n\n            Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked\n            positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend\n            while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``\n            are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor\n            is provided, it will be added to the attention weight.\n            [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by\n            the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero\n            positions will be unchanged. If a BoolTensor is provided, the positions with the\n            value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.\n\n            - output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or\n              `(N, T, E)` if `batch_first=True`.\n\n            Note: Due to the multi-head attention architecture in the transformer model,\n            the output sequence length of a transformer is same as the input sequence\n            (i.e. target) length of the decode.\n\n            where S is the source sequence length, T is the target sequence length, N is the\n            batch size, E is the feature number\n\n        Examples:\n            >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)\n        \"\"\"\n\n        is_batched = src.dim() == 3\n        if not self.batch_first and src.size(1) != tgt.size(1) and is_batched:\n            raise RuntimeError(\"the batch number of src and tgt must be equal\")\n        elif self.batch_first and src.size(0) != tgt.size(0) and is_batched:\n            raise RuntimeError(\"the batch number of src and tgt must be equal\")\n\n        if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:\n            raise RuntimeError(\"the feature number of src and tgt must be equal to d_model\")\n\n        memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)\n        output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,\n                              tgt_key_padding_mask=tgt_key_padding_mask,\n                              memory_key_padding_mask=memory_key_padding_mask)\n        return output\n\n    @staticmethod\n    def generate_square_subsequent_mask(sz: int) -> Tensor:\n        r\"\"\"Generate a square mask for the sequence. The masked positions are filled with float('-inf').\n            Unmasked positions are filled with float(0.0).\n        \"\"\"\n        return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)\n\n    def _reset_parameters(self):\n        r\"\"\"Initiate parameters in the transformer model.\"\"\"\n\n        for p in self.parameters():\n            if p.dim() > 1:\n                xavier_uniform_(p)\n\n\nclass TransformerEncoder(Module):\n    r\"\"\"TransformerEncoder is a stack of N encoder layers\n\n    Args:\n        encoder_layer: an instance of the TransformerEncoderLayer() class (required).\n        num_layers: the number of sub-encoder-layers in the encoder (required).\n        norm: the layer normalization component (optional).\n\n    Examples::\n        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)\n        >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)\n        >>> src = torch.rand(10, 32, 512)\n        >>> out = transformer_encoder(src)\n    \"\"\"\n    __constants__ = ['norm']\n\n    def __init__(self, encoder_layer, num_layers, norm=None):\n        super(TransformerEncoder, self).__init__()\n        self.layers = _get_clones(encoder_layer, num_layers)\n        self.num_layers = num_layers\n        self.norm = norm\n\n    def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:\n        r\"\"\"Pass the input through the encoder layers in turn.\n\n        Args:\n            src: the sequence to the encoder (required).\n            mask: the mask for the src sequence (optional).\n            src_key_padding_mask: the mask for the src keys per batch (optional).\n\n        Shape:\n            see the docs in Transformer class.\n        \"\"\"\n        output = src\n\n        for mod in self.layers:\n            output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)\n\n        if self.norm is not None:\n            output = self.norm(output)\n\n        return output\n\n\nclass TransformerDecoder(Module):\n    r\"\"\"TransformerDecoder is a stack of N decoder layers\n\n    Args:\n        decoder_layer: an instance of the TransformerDecoderLayer() class (required).\n        num_layers: the number of sub-decoder-layers in the decoder (required).\n        norm: the layer normalization component (optional).\n\n    Examples::\n        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)\n        >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)\n        >>> memory = torch.rand(10, 32, 512)\n        >>> tgt = torch.rand(20, 32, 512)\n        >>> out = transformer_decoder(tgt, memory)\n    \"\"\"\n    __constants__ = ['norm']\n\n    def __init__(self, decoder_layer, num_layers, norm=None):\n        super(TransformerDecoder, self).__init__()\n        self.layers = _get_clones(decoder_layer, num_layers)\n        self.num_layers = num_layers\n        self.norm = norm\n\n    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,\n                memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,\n                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:\n        r\"\"\"Pass the inputs (and mask) through the decoder layer in turn.\n\n        Args:\n            tgt: the sequence to the decoder (required).\n            memory: the sequence from the last layer of the encoder (required).\n            tgt_mask: the mask for the tgt sequence (optional).\n            memory_mask: the mask for the memory sequence (optional).\n            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).\n            memory_key_padding_mask: the mask for the memory keys per batch (optional).\n\n        Shape:\n            see the docs in Transformer class.\n        \"\"\"\n        output = tgt\n\n        for mod in self.layers:\n            output = mod(output, memory, tgt_mask=tgt_mask,\n                         memory_mask=memory_mask,\n                         tgt_key_padding_mask=tgt_key_padding_mask,\n                         memory_key_padding_mask=memory_key_padding_mask)\n\n        if self.norm is not None:\n            output = self.norm(output)\n\n        return output\n\nclass TransformerEncoderLayer(Module):\n    r\"\"\"TransformerEncoderLayer is made up of self-attn and feedforward network.\n    This standard encoder layer is based on the paper \"Attention Is All You Need\".\n    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,\n    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in\n    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement\n    in a different way during application.\n\n    Args:\n        d_model: the number of expected features in the input (required).\n        nhead: the number of heads in the multiheadattention models (required).\n        dim_feedforward: the dimension of the feedforward network model (default=2048).\n        dropout: the dropout value (default=0.1).\n        activation: the activation function of the intermediate layer, can be a string\n            (\"relu\" or \"gelu\") or a unary callable. Default: relu\n        layer_norm_eps: the eps value in layer normalization components (default=1e-5).\n        batch_first: If ``True``, then the input and output tensors are provided\n            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).\n        norm_first: if ``True``, layer norm is done prior to attention and feedforward\n            operations, respectivaly. Otherwise it's done after. Default: ``False`` (after).\n\n    Examples::\n        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)\n        >>> src = torch.rand(10, 32, 512)\n        >>> out = encoder_layer(src)\n\n    Alternatively, when ``batch_first`` is ``True``:\n        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)\n        >>> src = torch.rand(32, 10, 512)\n        >>> out = encoder_layer(src)\n    \"\"\"\n    __constants__ = ['batch_first', 'norm_first']\n\n    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,\n                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,\n                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,\n                 device=None, dtype=None) -> None:\n        factory_kwargs = {'device': device, 'dtype': dtype}\n        super(TransformerEncoderLayer, self).__init__()\n        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,\n                                            **factory_kwargs)\n        # Implementation of Feedforward model\n        self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)\n        self.dropout = Dropout(dropout)\n        self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)\n\n        self.norm_first = norm_first\n        self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)\n        self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)\n        self.dropout1 = Dropout(dropout)\n        self.dropout2 = Dropout(dropout)\n\n        # Legacy string support for activation function.\n        if isinstance(activation, str):\n            self.activation = _get_activation_fn(activation)\n        else:\n            self.activation = activation\n\n    def __setstate__(self, state):\n        if 'activation' not in state:\n            state['activation'] = F.relu\n        super(TransformerEncoderLayer, self).__setstate__(state)\n\n    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:\n        r\"\"\"Pass the input through the encoder layer.\n\n        Args:\n            src: the sequence to the encoder layer (required).\n            src_mask: the mask for the src sequence (optional).\n            src_key_padding_mask: the mask for the src keys per batch (optional).\n\n        Shape:\n            see the docs in Transformer class.\n        \"\"\"\n\n        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf\n\n        x = src\n        if self.norm_first:\n            x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)\n            x = x + self._ff_block(self.norm2(x))\n        else:\n            x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))\n            x = self.norm2(x + self._ff_block(x))\n\n        return x\n\n    # self-attention block\n    def _sa_block(self, x: Tensor,\n                  attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:\n        x = self.self_attn(x, x, x,\n                           attn_mask=attn_mask,\n                           key_padding_mask=key_padding_mask,\n                           need_weights=False)[0]\n        return self.dropout1(x)\n\n    # feed forward block\n    def _ff_block(self, x: Tensor) -> Tensor:\n        x = self.linear2(self.dropout(self.activation(self.linear1(x))))\n        return self.dropout2(x)\n\n\nclass TransformerDecoderLayer(Module):\n    r\"\"\"TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.\n    This standard decoder layer is based on the paper \"Attention Is All You Need\".\n    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,\n    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in\n    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement\n    in a different way during application.\n\n    Args:\n        d_model: the number of expected features in the input (required).\n        nhead: the number of heads in the multiheadattention models (required).\n        dim_feedforward: the dimension of the feedforward network model (default=2048).\n        dropout: the dropout value (default=0.1).\n        activation: the activation function of the intermediate layer, can be a string\n            (\"relu\" or \"gelu\") or a unary callable. Default: relu\n        layer_norm_eps: the eps value in layer normalization components (default=1e-5).\n        batch_first: If ``True``, then the input and output tensors are provided\n            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).\n        norm_first: if ``True``, layer norm is done prior to self attention, multihead\n            attention and feedforward operations, respectivaly. Otherwise it's done after.\n            Default: ``False`` (after).\n\n    Examples::\n        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)\n        >>> memory = torch.rand(10, 32, 512)\n        >>> tgt = torch.rand(20, 32, 512)\n        >>> out = decoder_layer(tgt, memory)\n\n    Alternatively, when ``batch_first`` is ``True``:\n        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)\n        >>> memory = torch.rand(32, 10, 512)\n        >>> tgt = torch.rand(32, 20, 512)\n        >>> out = decoder_layer(tgt, memory)\n    \"\"\"\n    __constants__ = ['batch_first', 'norm_first']\n\n    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,\n                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,\n                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,\n                 device=None, dtype=None) -> None:\n        factory_kwargs = {'device': device, 'dtype': dtype}\n        super(TransformerDecoderLayer, self).__init__()\n        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,\n                                            )\n        self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,\n                                                 )\n        # Implementation of Feedforward model\n        self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)\n        self.dropout = Dropout(dropout)\n        self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)\n\n        self.norm_first = norm_first\n        self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)\n        self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)\n        self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)\n        self.dropout1 = Dropout(dropout)\n        self.dropout2 = Dropout(dropout)\n        self.dropout3 = Dropout(dropout)\n\n        # Legacy string support for activation function.\n        if isinstance(activation, str):\n            self.activation = _get_activation_fn(activation)\n        else:\n            self.activation = activation\n\n    def __setstate__(self, state):\n        if 'activation' not in state:\n            state['activation'] = F.relu\n        super(TransformerDecoderLayer, self).__setstate__(state)\n\n    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,\n                tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:\n        r\"\"\"Pass the inputs (and mask) through the decoder layer.\n\n        Args:\n            tgt: the sequence to the decoder layer (required).\n            memory: the sequence from the last layer of the encoder (required).\n            tgt_mask: the mask for the tgt sequence (optional).\n            memory_mask: the mask for the memory sequence (optional).\n            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).\n            memory_key_padding_mask: the mask for the memory keys per batch (optional).\n\n        Shape:\n            see the docs in Transformer class.\n        \"\"\"\n        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf\n\n        x = tgt\n        if self.norm_first:\n            x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)\n            x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask)\n            x = x + self._ff_block(self.norm3(x))\n        else:\n            x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))\n            x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask))\n            x = self.norm3(x + self._ff_block(x))\n\n        return x\n\n    # self-attention block\n    def _sa_block(self, x: Tensor,\n                  attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:\n        x = self.self_attn(x, x, x,\n                           attn_mask=attn_mask,\n                           key_padding_mask=key_padding_mask,\n                           need_weights=False)[0]\n        return self.dropout1(x)\n\n    # multihead attention block\n    def _mha_block(self, x: Tensor, mem: Tensor,\n                   attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:\n        x = self.multihead_attn(x, mem, mem,\n                                attn_mask=attn_mask,\n                                key_padding_mask=key_padding_mask,\n                                need_weights=False)[0]\n        return self.dropout2(x)\n\n    # feed forward block\n    def _ff_block(self, x: Tensor) -> Tensor:\n        x = self.linear2(self.dropout(self.activation(self.linear1(x))))\n        return self.dropout3(x)\n\n\ndef _get_clones(module, N):\n    return ModuleList([copy.deepcopy(module) for i in range(N)])\n\n\ndef _get_activation_fn(activation):\n    if activation == \"relu\":\n        return F.relu\n    elif activation == \"gelu\":\n        return F.gelu\n\n    raise RuntimeError(\"activation should be relu/gelu, not {}\".format(activation))\n"
  },
  {
    "path": "PBnet/src/models/architectures/transformerdecoder4.py",
    "content": "import copy\nfrom typing import Optional, Any, Union, Callable\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn.functional import dropout\nfrom torch.nn import functional as F\nfrom torch.nn.modules import Module\nfrom torch.nn.modules.container import ModuleList\nfrom torch.nn.modules.activation import MultiheadAttention as MultiheadAttention0\nfrom torch.nn.init import xavier_uniform_\nfrom torch.nn.modules.dropout import Dropout\nfrom torch.nn.modules.linear import Linear\nfrom torch.nn.modules.normalization import LayerNorm\nimport torch.nn as nn\nfrom einops import rearrange, repeat, reduce, pack, unpack\nfrom torch import einsum\nfrom einops_exts import rearrange_many\nfrom rotary_embedding_torch import RotaryEmbedding\n\ndef exists(x):\n    return x is not None\n\nclass Attention(nn.Module):\n    def __init__(\n            self,\n            dim,\n            heads=4,\n            dim_head=32,\n            rotary_emb=None\n    ):\n        super().__init__()\n        self.scale = dim_head ** -0.5\n        self.heads = heads\n        hidden_dim = dim_head * heads\n\n        self.rotary_emb = rotary_emb\n        self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False)\n        self.to_out = nn.Linear(hidden_dim, dim, bias=False)\n\n    def forward(\n            self,\n            x,\n            pos_bias=None,\n    ):  # temperal: 'b (h w) f c'  ; spatial :  'b f (h w) c'\n        n, device = x.shape[-2], x.device\n\n        qkv = self.to_qkv(x).chunk(3, dim=-1)\n\n        # if exists(focus_present_mask) and focus_present_mask.all():\n        #     # if all batch samples are focusing on present\n        #     # it would be equivalent to passing that token's values through to the output\n        #     values = qkv[-1]\n        #     return self.to_out(values)\n\n        # split out heads\n\n        q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads)\n\n        # scale\n\n        q = q * self.scale\n\n        # rotate positions into queries and keys for time attention\n\n        if exists(self.rotary_emb):\n            q = self.rotary_emb.rotate_queries_or_keys(q)\n            k = self.rotary_emb.rotate_queries_or_keys(k)\n\n        # similarity\n\n        sim = einsum('... h i d, ... h j d -> ... h i j', q, k)\n\n        # relative positional bias\n\n        if exists(pos_bias):\n            sim = sim + pos_bias\n\n        # if exists(focus_present_mask) and not (~focus_present_mask).all():\n        #     attend_all_mask = torch.ones((n, n), device=device, dtype=torch.bool)\n        #     attend_self_mask = torch.eye(n, device=device, dtype=torch.bool)\n\n        #     mask = torch.where(\n        #         rearrange(focus_present_mask, 'b -> b 1 1 1 1'),\n        #         rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),\n        #         rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),\n        #     )\n\n        #     sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)\n\n        # numerical stability\n\n        sim = sim - sim.amax(dim=-1, keepdim=True).detach()\n        attn = sim.softmax(dim=-1)\n\n        # aggregate values\n\n        out = einsum('... h i j, ... h j d -> ... h i d', attn, v)\n        out = rearrange(out, '... h n d -> ... n (h d)')\n        return self.to_out(out)\n\nclass Attention_2(nn.Module):\n    def __init__(\n            self,\n            dim,\n            heads=4,\n            dim_head=32,\n            rotary_emb=None\n    ):\n        super().__init__()\n        self.scale = dim_head ** -0.5\n        self.heads = heads\n        hidden_dim = dim_head * heads\n\n        self.rotary_emb = rotary_emb\n        self.to_q = nn.Linear(dim, hidden_dim, bias=False)\n        self.to_k = nn.Linear(dim, hidden_dim, bias=False)\n        self.to_v = nn.Linear(dim, hidden_dim, bias=False)\n        self.to_out = nn.Linear(hidden_dim, dim, bias=False)\n\n    def forward(\n            self,\n            q,\n            k,\n            v,\n            pos_bias=None,\n            focus_present_mask=None\n    ):  # temperal: 'b (h w) f c'  ; spatial :  'b f (h w) c'\n\n        q = self.to_q(q)\n        k = self.to_k(k)\n        v = self.to_v(v)\n\n        # split out heads\n        q = rearrange(q, '... n (h d) -> ... h n d', h=self.heads) # b, head, fn, c\n        k = rearrange(k, '... n (h d) -> ... h n d', h=self.heads)\n        v = rearrange(v, '... n (h d) -> ... h n d', h=self.heads)\n        # q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads)\n\n        # scale\n\n        q = q * self.scale\n\n        # rotate positions into queries and keys for time attention\n\n        if exists(self.rotary_emb):\n            q = self.rotary_emb.rotate_queries_or_keys(q)\n            k = self.rotary_emb.rotate_queries_or_keys(k)\n\n        # similarity\n\n        sim = einsum('... h i d, ... h j d -> ... h i j', q, k)\n\n        # relative positional bias\n\n        if exists(pos_bias):\n            sim = sim + pos_bias\n\n        # numerical stability\n\n        sim = sim - sim.amax(dim=-1, keepdim=True).detach()\n        attn = sim.softmax(dim=-1)\n\n        # aggregate values\n\n        out = einsum('... h i j, ... h j d -> ... h i d', attn, v)\n        out = rearrange(out, '... h n d -> ... n (h d)')\n        return self.to_out(out)\n\nclass PositionwiseFeedforwardLayer(nn.Module):\n    def __init__(self, d_model, d_ff, dropout):\n        super(PositionwiseFeedforwardLayer, self).__init__()\n\n        self.linear1 = nn.Linear(d_model, d_ff)\n        self.linear2 = nn.Linear(d_ff, d_model)\n\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x):\n        x = F.gelu(self.linear1(x))\n        x = self.dropout(x)\n        x = self.linear2(x)\n\n        return x\n\nclass DecoderLayer(nn.Module):\n    def __init__(self, d_model, num_heads, d_ff, dropout, rotary_emb):\n        super(DecoderLayer, self).__init__()\n\n        self.self_attn = Attention(dim = d_model, heads = num_heads, rotary_emb = rotary_emb)\n        self.multihead_attn = Attention_2(dim = d_model, heads = num_heads, rotary_emb = rotary_emb)\n\n        self.ffn = PositionwiseFeedforwardLayer(d_model, d_ff, dropout)\n\n        self.layer_norm1 = nn.LayerNorm(d_model)\n        self.layer_norm2 = nn.LayerNorm(d_model)\n        self.layer_norm3 = nn.LayerNorm(d_model)\n\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):\n        tgt = self.layer_norm1(tgt + self.dropout(self.self_attn(tgt, tgt_mask)))\n        tgt = self.layer_norm2(tgt + self.dropout(self.multihead_attn(tgt, memory, memory, memory_mask)))\n        tgt = self.layer_norm3(tgt + self.dropout(self.ffn(tgt)))\n\n        return tgt\n\nclass TransformerDecoder(nn.Module):\n    def __init__(self, num_layers, d_model, num_heads, dim_feedforward, dropout):\n        super(TransformerDecoder, self).__init__()\n\n        self.num_layers = num_layers\n        rotary_emb = RotaryEmbedding(min(32, num_heads))\n        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model = d_model, num_heads = num_heads, d_ff = dim_feedforward, dropout = dropout, rotary_emb = rotary_emb) for _ in range(num_layers)])\n    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):\n        output = tgt\n\n        for layer in self.decoder_layers:\n            output = layer(output, memory, tgt_mask, memory_mask)\n\n        return output"
  },
  {
    "path": "PBnet/src/models/architectures/transformerdecoder5.py",
    "content": "import copy\nfrom typing import Optional, Any, Union, Callable\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn.functional import dropout\nfrom torch.nn import functional as F\nfrom torch.nn.modules import Module\nfrom torch.nn.modules.container import ModuleList\nfrom torch.nn.init import xavier_uniform_\nfrom torch.nn.modules.dropout import Dropout\nfrom torch.nn.modules.linear import Linear\nfrom torch.nn.modules.normalization import LayerNorm\nimport torch.nn as nn\nfrom einops import rearrange, repeat, reduce, pack, unpack\nfrom torch import einsum\nfrom einops_exts import rearrange_many\nfrom rotary_embedding_torch import RotaryEmbedding\n\ndef exists(x):\n    return x is not None\n\nclass Attention(nn.Module):\n    def __init__(\n            self,\n            dim,\n            heads=4,\n            dim_head=32,\n            rotary_emb=None\n    ):\n        super().__init__()\n        self.scale = dim_head ** -0.5\n        self.heads = heads\n        hidden_dim = dim_head * heads\n\n        self.rotary_emb = rotary_emb\n        self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False)\n        self.to_out = nn.Linear(hidden_dim, dim, bias=False)\n\n    def forward(\n            self,\n            x,\n            pos_bias=None,\n    ):  # temperal: 'b (h w) f c'  ; spatial :  'b f (h w) c'\n        n, device = x.shape[-2], x.device\n\n        qkv = self.to_qkv(x).chunk(3, dim=-1)\n\n        # if exists(focus_present_mask) and focus_present_mask.all():\n        #     # if all batch samples are focusing on present\n        #     # it would be equivalent to passing that token's values through to the output\n        #     values = qkv[-1]\n        #     return self.to_out(values)\n\n        # split out heads\n\n        q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads)\n\n        # scale\n\n        q = q * self.scale\n\n        # rotate positions into queries and keys for time attention\n\n        if exists(self.rotary_emb):\n            q = self.rotary_emb.rotate_queries_or_keys(q)\n            k = self.rotary_emb.rotate_queries_or_keys(k)\n\n        # similarity\n\n        sim = einsum('... h i d, ... h j d -> ... h i j', q, k)\n\n        # relative positional bias\n\n        if exists(pos_bias):\n            sim = sim + pos_bias\n\n        # if exists(focus_present_mask) and not (~focus_present_mask).all():\n        #     attend_all_mask = torch.ones((n, n), device=device, dtype=torch.bool)\n        #     attend_self_mask = torch.eye(n, device=device, dtype=torch.bool)\n\n        #     mask = torch.where(\n        #         rearrange(focus_present_mask, 'b -> b 1 1 1 1'),\n        #         rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),\n        #         rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),\n        #     )\n\n        #     sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)\n\n        # numerical stability\n\n        sim = sim - sim.amax(dim=-1, keepdim=True).detach()\n        attn = sim.softmax(dim=-1)\n\n        # aggregate values\n\n        out = einsum('... h i j, ... h j d -> ... h i d', attn, v)\n        out = rearrange(out, '... h n d -> ... n (h d)')\n        return self.to_out(out)\n\nclass Attention_2(nn.Module):\n    def __init__(\n            self,\n            dim,\n            heads=4,\n            dim_head=32,\n            rotary_emb=None\n    ):\n        super().__init__()\n        self.scale = dim_head ** -0.5\n        self.heads = heads\n        hidden_dim = dim_head * heads\n\n        self.rotary_emb = rotary_emb\n        self.to_q = nn.Linear(dim, hidden_dim, bias=False)\n        self.to_k = nn.Linear(dim, hidden_dim, bias=False)\n        self.to_v = nn.Linear(dim, hidden_dim, bias=False)\n        self.to_out = nn.Linear(hidden_dim, dim, bias=False)\n\n    def forward(\n            self,\n            q,\n            k,\n            v,\n            pos_bias=None,\n            focus_present_mask=None\n    ):  # temperal: 'b (h w) f c'  ; spatial :  'b f (h w) c'\n\n        q = self.to_q(q)\n        k = self.to_k(k)\n        v = self.to_v(v)\n\n        # split out heads\n        q = rearrange(q, '... n (h d) -> ... h n d', h=self.heads) # b, head, fn, c\n        k = rearrange(k, '... n (h d) -> ... h n d', h=self.heads)\n        v = rearrange(v, '... n (h d) -> ... h n d', h=self.heads)\n        # q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads)\n\n        # scale\n\n        q = q * self.scale\n\n        # rotate positions into queries and keys for time attention\n\n        if exists(self.rotary_emb):\n            q = self.rotary_emb.rotate_queries_or_keys(q)\n            k = self.rotary_emb.rotate_queries_or_keys(k)\n\n        # similarity\n\n        sim = einsum('... h i d, ... h j d -> ... h i j', q, k)\n\n        # relative positional bias\n\n        if exists(pos_bias):\n            sim = sim + pos_bias\n\n        # numerical stability\n\n        sim = sim - sim.amax(dim=-1, keepdim=True).detach()\n        attn = sim.softmax(dim=-1)\n\n        # aggregate values\n\n        out = einsum('... h i j, ... h j d -> ... h i d', attn, v)\n        out = rearrange(out, '... h n d -> ... n (h d)')\n        return self.to_out(out)\n\nclass PositionwiseFeedforwardLayer(nn.Module):\n    def __init__(self, d_model, d_ff, dropout):\n        super(PositionwiseFeedforwardLayer, self).__init__()\n\n        self.linear1 = nn.Linear(d_model, d_ff)\n        self.linear2 = nn.Linear(d_ff, d_model)\n\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x):\n        x = F.gelu(self.linear1(x))\n        x = self.dropout(x)\n        x = self.linear2(x)\n\n        return x\n\nclass DecoderLayer(nn.Module):\n    def __init__(self, d_model, num_heads, d_ff, dropout, rotary_emb):\n        super(DecoderLayer, self).__init__()\n\n        self.self_attn = Attention(dim = d_model, heads = num_heads, rotary_emb = rotary_emb) # , rotary_emb = rotary_emb)\n        self.multihead_attn = Attention_2(dim = d_model, heads = num_heads, rotary_emb = rotary_emb)\n\n        self.ffn = PositionwiseFeedforwardLayer(d_model, d_ff, dropout)\n\n        self.layer_norm1 = nn.LayerNorm(d_model)\n        self.layer_norm2 = nn.LayerNorm(d_model)\n        self.layer_norm3 = nn.LayerNorm(d_model)\n\n        self.dropout1 = nn.Dropout(dropout)\n        self.dropout2 = nn.Dropout(dropout)\n        self.dropout3 = nn.Dropout(dropout)\n\n    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):\n        tgt = self.layer_norm1(tgt + self.dropout1(self.self_attn(tgt, tgt_mask)))\n        tgt = self.layer_norm2(tgt + self.dropout2(self.multihead_attn(tgt, memory, memory, memory_mask)))\n        tgt = self.layer_norm3(tgt + self.dropout3(self.ffn(tgt)))\n\n        return tgt\n\nclass TransformerDecoder(nn.Module):\n    def __init__(self, num_layers, d_model, num_heads, dim_feedforward, dropout):\n        super(TransformerDecoder, self).__init__()\n\n        self.num_layers = num_layers\n        rotary_emb = RotaryEmbedding(min(32, num_heads))\n        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model = d_model, num_heads = num_heads, d_ff = dim_feedforward, dropout = dropout, rotary_emb = rotary_emb) for _ in range(num_layers)])\n    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):\n        output = tgt\n\n        for layer in self.decoder_layers:\n            output = layer(output, memory, tgt_mask, memory_mask)\n\n        return output"
  },
  {
    "path": "PBnet/src/models/architectures/transformerreemb.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat, reduce, pack, unpack\nfrom einops_exts import rearrange_many\nfrom torch import einsum\nfrom rotary_embedding_torch import RotaryEmbedding\nimport math \n\n\ndef exists(x):\n    return x is not None\n\nclass LayerNorm(nn.Module):\n    def __init__(self, dim, eps=1e-5):\n        super().__init__()\n        self.eps = eps\n        self.gamma = nn.Parameter(torch.ones(1, 1,  dim))\n\n    def forward(self, x):\n        var = torch.var(x, dim=-1, unbiased=False, keepdim=True)\n        mean = torch.mean(x, dim=-1, keepdim=True)\n        return (x - mean) / (var + self.eps).sqrt() * self.gamma\n\nclass PreNorm(nn.Module):\n    def __init__(self, dim, fn):\n        super().__init__()\n        self.fn = fn\n        self.norm = LayerNorm(dim)\n\n    def forward(self, x, **kwargs):\n        x = self.norm(x)\n        return self.fn(x, **kwargs)\n\nclass Residual(nn.Module):\n    def __init__(self, fn):\n        super().__init__()\n        self.fn = fn\n\n    def forward(self, x, *args, **kwargs):\n        return self.fn(x, *args, **kwargs) + x\n\nclass EinopsToAndFrom(nn.Module):\n    def __init__(self, from_einops, to_einops, fn):\n        super().__init__()\n        self.from_einops = from_einops\n        self.to_einops = to_einops\n        self.fn = fn\n\n    def forward(self, x, **kwargs):\n        shape = x.shape\n        reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape)))\n        x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')\n        x = self.fn(x, **kwargs)\n        x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)\n        return x\n\nclass Attention(nn.Module):\n    def __init__(\n            self,\n            dim,\n            heads=4,\n            dim_head=32,\n            rotary_emb=None\n    ):\n        super().__init__()\n        self.scale = dim_head ** -0.5\n        self.heads = heads\n        hidden_dim = dim_head * heads\n\n        self.rotary_emb = rotary_emb\n        self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False)\n        self.to_out = nn.Linear(hidden_dim, dim, bias=False)\n\n    def forward(\n            self,\n            x,\n            pos_bias=None,\n            focus_present_mask=None\n    ):  # temperal: 'b (h w) f c'  ; spatial :  'b f (h w) c'\n        n, device = x.shape[-2], x.device\n\n        qkv = self.to_qkv(x).chunk(3, dim=-1)\n\n        if exists(focus_present_mask) and focus_present_mask.all():\n            # if all batch samples are focusing on present\n            # it would be equivalent to passing that token's values through to the output\n            values = qkv[-1]\n            return self.to_out(values)\n\n        # split out heads\n\n        q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads)\n\n        # scale\n\n        q = q * self.scale\n\n        # rotate positions into queries and keys for time attention\n\n        if exists(self.rotary_emb):\n            q = self.rotary_emb.rotate_queries_or_keys(q)\n            k = self.rotary_emb.rotate_queries_or_keys(k)\n\n        # similarity\n\n        sim = einsum('... h i d, ... h j d -> ... h i j', q, k)\n\n        # relative positional bias\n\n        if exists(pos_bias):\n            sim = sim + pos_bias\n\n        if exists(focus_present_mask) and not (~focus_present_mask).all():\n            attend_all_mask = torch.ones((n, n), device=device, dtype=torch.bool)\n            attend_self_mask = torch.eye(n, device=device, dtype=torch.bool)\n\n            mask = torch.where(\n                rearrange(focus_present_mask, 'b -> b 1 1 1 1'),\n                rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),\n                rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),\n            )\n\n            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)\n\n        # numerical stability\n\n        sim = sim - sim.amax(dim=-1, keepdim=True).detach()\n        attn = sim.softmax(dim=-1)\n\n        # aggregate values\n\n        out = einsum('... h i j, ... h j d -> ... h i d', attn, v)\n        out = rearrange(out, '... h n d -> ... n (h d)')\n        return self.to_out(out)\n\nclass PositionalEncoding(nn.Module):\n    def __init__(self, d_model, dropout=0.1, max_len=20000):\n        super(PositionalEncoding, self).__init__()\n        self.dropout = nn.Dropout(p=dropout)\n\n        pe = torch.zeros(max_len, d_model)\n        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))\n        pe[:, 0::2] = torch.sin(position * div_term)\n        pe[:, 1::2] = torch.cos(position * div_term)\n        pe = pe.unsqueeze(0).transpose(0, 1)\n        \n        self.register_buffer('pe', pe)\n\n    def forward(self, x):\n        # not used in the final model\n        x = x + self.pe[:x.shape[0], :]\n        return self.dropout(x)\n\n\nclass RelativePositionBias(nn.Module):\n    def __init__(\n            self,\n            heads=8,\n            num_buckets=32,\n            max_distance=128\n    ):\n        super().__init__()\n        self.num_buckets = num_buckets\n        self.max_distance = max_distance\n        self.relative_attention_bias = nn.Embedding(num_buckets, heads)\n\n    @staticmethod\n    def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):\n        ret = 0\n        n = -relative_position\n\n        num_buckets //= 2\n        ret += (n < 0).long() * num_buckets\n        n = torch.abs(n)\n\n        max_exact = num_buckets // 2\n        is_small = n < max_exact\n\n        val_if_large = max_exact + (\n                torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)\n        ).long()\n        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))\n\n        ret += torch.where(is_small, n, val_if_large)\n        return ret\n\n    def forward(self, n, device, eval = False):\n        q_pos = torch.arange(n, dtype=torch.long, device=device)\n        k_pos = torch.arange(n, dtype=torch.long, device=device)\n        rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')\n        rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets,\n                                                   max_distance=self.max_distance)\n        if True:\n            mask = - (((rel_pos > 32) + (rel_pos  < -32)) * (1e8))\n            values = self.relative_attention_bias(rp_bucket)\n            return rearrange(values, 'i j h -> h i j') + mask\n        else:\n            values = self.relative_attention_bias(rp_bucket)\n            return rearrange(values, 'i j h -> h i j')\n        \n# only for ablation / not used in the final model\nclass TimeEncoding(nn.Module):\n    def __init__(self, d_model, dropout=0.1, max_len=5000):\n        super(TimeEncoding, self).__init__()\n        self.dropout = nn.Dropout(p=dropout)\n\n    def forward(self, x, mask, lengths):\n        time = mask * 1/(lengths[..., None]-1)\n        time = time[:, None] * torch.arange(time.shape[1], device=x.device)[None, :]\n        time = time[:, 0].T\n        # add the time encoding\n        x = x + time[..., None]\n        return self.dropout(x)\n    \n\nclass Encoder_TRANSFORMERREEMB(nn.Module):\n    def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=7, pose_latent_dim=64,\n                 audio_latent_dim=256, ff_size=1024, num_layers=4, num_heads=4, dropout=0.1,\n                 ablation=None, activation=\"gelu\", **kargs):\n        super().__init__()\n        \n        self.modeltype = modeltype\n        self.pos_dim = pos_dim\n        self.num_frames = num_frames\n        self.audio_dim = audio_dim\n        \n        self.pose_latent_dim = pose_latent_dim\n        self.audio_latent_dim = audio_latent_dim\n        self.latent_dim = self.audio_latent_dim + self.pose_latent_dim*2\n        \n        self.ff_size = ff_size\n        self.num_layers = num_layers\n        self.num_heads = num_heads\n        self.dropout = dropout\n\n        self.ablation = ablation\n        self.activation = activation\n\n        \n        # if self.ablation == \"average_encoder\":\n        #     self.mu_layer = nn.Linear(self.latent_dim, self.latent_dim)\n        #     self.sigma_layer = nn.Linear(self.latent_dim, self.latent_dim)\n        # else:\n        #     self.muQuery = nn.Parameter(torch.randn(self.num_classes, self.latent_dim))\n        #     self.sigmaQuery = nn.Parameter(torch.randn(self.num_classes, self.latent_dim))\n        \n        # # there's no class of our dataset CREMA/HDTF, so noly  dont need to use nn.parameter\n        self.mu_layer = nn.Linear(self.latent_dim, self.audio_latent_dim)\n        self.sigma_layer = nn.Linear(self.latent_dim, self.audio_latent_dim)\n        \n        self.poseEmbedding = nn.Linear(self.pos_dim, self.pose_latent_dim) #6,64\n        self.firstposeEmbedding = nn.Linear(self.pos_dim, self.pose_latent_dim) #6,64\n        self.audioEmbedding = nn.Linear(self.audio_dim, self.audio_latent_dim) #1024, 256\n        \n        self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)\n        \n        # self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))\n        \n        seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,\n                                                          nhead=self.num_heads,\n                                                          dim_feedforward=self.ff_size,\n                                                          dropout=self.dropout,\n                                                          activation=self.activation)\n        self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,\n                                                     num_layers=self.num_layers)\n\n    def forward(self, batch):\n        '''\n            x: 6-dim pos, (bs, max_num_frames, 6)\n            y: 1024-dim audio embbeding, (bs, max_num_frames, 1024)\n        '''\n\n        x, y, mask = batch[\"x\"], batch[\"y\"], batch[\"mask\"]\n        # bs, njoints, nfeats, nframes = x.shape\n        # x = x.permute((3, 0, 1, 2)).reshape(nframes, bs, njoints*nfeats) \n        x_ref = x[:,0,:].unsqueeze(dim=1) # The pose information of the first frame(refrence img)\n        x = x-x_ref.repeat(1,x.size(1),1) # bs, nf, 6  Obtain the difference from the first frame\n        batch['x_delta'] = x\n        x_ref = x_ref.permute((1,0,2)) #1, bs, 6\n        x = x.permute((1, 0, 2)) #nf, bs, 6\n        y = y.permute((1, 0, 2)) #nf, bs, 1024\n        # embedding of the pose/audio\n        x_ref = self.firstposeEmbedding(x_ref).repeat(x.size(0),1,1) #nf, bs, 64\n        x = self.poseEmbedding(x) #nf, bs, 64\n        y = self.audioEmbedding(y) #nf, bs, 256\n        x = torch.cat([x_ref, x, y],dim=-1) # nf, bs, 64+64+256\n\n        # only use the \"average_encoder\" mode\n        # add positional encoding\n        x = self.sequence_pos_encoder(x)\n        # transformer layers\n        final = self.seqTransEncoder(x, src_key_padding_mask=~mask) #nu_frames, bs, 64+64+256\n        # get the average of the output\n        z = final# final.mean(axis=0) # nf, bs, 64+64+256\n        # extract mu and logvar\n        mu = self.mu_layer(z) # nf, bs, 256\n        logvar = self.sigma_layer(z) # nf, bs, 256\n        # logvar = - torch.ones_like(logvar) * 1e10\n\n        return {\"mu\": mu, \"logvar\": logvar}\n\n\nclass Decoder_TRANSFORMERREEMB(nn.Module):\n    def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=7, pose_latent_dim=64,\n                 audio_latent_dim=256, ff_size=1024, num_layers=4, num_heads=4, dropout=0.1, activation=\"gelu\",\n                 ablation=None, num_buckets = 32, max_distance = 32,**kargs):\n        super().__init__()\n\n        self.modeltype = modeltype\n\n        self.pos_dim = pos_dim\n        self.num_frames = num_frames\n        self.audio_dim = audio_dim\n        \n        self.pose_latent_dim = pose_latent_dim\n        self.audio_latent_dim = audio_latent_dim\n        self.latent_dim = self.audio_latent_dim + self.pose_latent_dim*2\n        \n        self.ff_size = ff_size\n        self.num_layers = num_layers\n        self.num_heads = num_heads\n        self.dropout = dropout\n\n        self.ablation = ablation\n\n        self.activation = activation\n\n        self.firstposeEmbedding = nn.Linear(self.pos_dim, self.pose_latent_dim) #6,64\n        self.audioEmbedding = nn.Linear(self.audio_dim, self.audio_latent_dim) #1024, 256\n        self.ztimelinear = nn.Linear(self.audio_latent_dim*2+self.pose_latent_dim, self.pose_latent_dim) #256*2+64,64\n        \n        self.init_proj = nn.Linear(self.pose_latent_dim, self.pose_latent_dim)\n        # self.input_feats = self.njoints*self.nfeats\n\n        # # only for ablation / not used in the final model\n        # if self.ablation == \"zandtime\":\n        #     self.ztimelinear = nn.Linear(self.latent_dim + self.num_classes, self.latent_dim)\n        # else:\n        #     self.actionBiases = nn.Parameter(torch.randn(1024, self.latent_dim))\n            # self.actionBiases = nn.Parameter(torch.randn(self.num_classes, self.latent_dim))\n\n        # # only for ablation / not used in the final model\n        # if self.ablation == \"time_encoding\":\n        #     self.sequence_pos_encoder = TimeEncoding(self.dropout)\n        # else:\n        #     self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)\n\n        self.sequence_pos_encoder = PositionalEncoding(self.pose_latent_dim, self.dropout)\n        rotary_emb = RotaryEmbedding(min(32, num_heads))\n\n        self.time_rel_pos_bias = RelativePositionBias(heads=num_heads,\n                                                      num_buckets=num_buckets,\n                                                      max_distance=max_distance)  \n\n        temporal_attn = lambda dim: EinopsToAndFrom('l b c', 'b l c',  # len, b, c\n                                                    Attention(dim, heads=num_heads, dim_head=32,\n                                                              rotary_emb=rotary_emb))\n\n        self.init_temporal_attn = Residual(PreNorm(self.pose_latent_dim, temporal_attn(self.pose_latent_dim)))\n        # self.sequence_pos_encoder = TimeEncoding(self.dropout) #time_encoding\n        \n        seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=self.pose_latent_dim,\n                                                          nhead=self.num_heads,\n                                                          dim_feedforward=self.ff_size,\n                                                          dropout=self.dropout,\n                                                          activation=activation)\n        self.seqTransDecoder = nn.TransformerDecoder(seqTransDecoderLayer,\n                                                     num_layers=self.num_layers)\n        \n        self.finallayer = nn.Linear(self.pose_latent_dim, self.pos_dim)\n        \n    def forward(self, batch):\n        '''\n            z: bs, audio_latent_dim(256)\n            y: bs, num_frames, 1024\n            mask: bs, num_frames\n            lengths: [num_frames,...]\n        '''\n        x, z, y, mask, lengths = batch[\"x\"], batch[\"z\"], batch[\"y\"], batch[\"mask\"], batch[\"lengths\"]\n        bs, nframes = mask.shape\n        # first img\n        x_ref = x[:,0,:].unsqueeze(dim=1) #bs, 1, 64\n        x_ref = self.firstposeEmbedding(x_ref.repeat(1, nframes, 1)) #bs, nf, 64\n        y = self.audioEmbedding(y) #bs, num_frames, 256\n        z = z.permute(1, 0, 2)\n        #z = z.unsqueeze(dim=1).repeat(1, nframes, 1) #bs, num_frames, 256\n        z = torch.cat([x_ref, z, y], dim=-1) # bs, num_frames, 256*2+64\n        z = self.ztimelinear(z)\n        z = z.permute((1, 0, 2)) # nf, bs, 64\n        pose_latent_dim = z.shape[2]\n        # z = z[None]  # sequence of size 1\n\n        # # only for ablation / not used in the final model\n        # if self.ablation == \"zandtime\":\n        #     yoh = F.one_hot(y, self.num_classes)\n        #     z = torch.cat((z, yoh), axis=1)\n        #     z = self.ztimelinear(z)\n        #     z = z[None]  # sequence of size 1\n        # else:\n        #     # only for ablation / not used in the final model\n        #     if self.ablation == \"concat_bias\":\n        #         # sequence of size 2\n        #         z = torch.stack((z, self.actionBiases[y]), axis=0)\n        #     else:\n        #         # shift the latent noise vector to be the action noise\n        #         z = z + self.actionBiases[y.long()] # NEED CHECK\n        #         z = z[None]  # sequence of size 1\n            \n        timequeries = torch.zeros(nframes, bs, pose_latent_dim, device=z.device) # len, b, c\n        timequeries = self.sequence_pos_encoder(timequeries)\n\n        time_rel_pos_bias = self.time_rel_pos_bias(timequeries.shape[0], device=x.device)\n\n        timequeries = self.init_proj(timequeries)\n\n        timequeries = self.init_temporal_attn(timequeries, pos_bias=time_rel_pos_bias)\n\n        # timequeries = self.sequence_pos_encoder(timequeries, mask, lengths) #time_encoding\n        \n        # # only for ablation / not used in the final model\n        # if self.ablation == \"time_encoding\":\n        #     timequeries = self.sequence_pos_encoder(timequeries, mask, lengths)\n        # else:\n        #     timequeries = self.sequence_pos_encoder(timequeries)\n        \n        # num_frames, bs, 64\n        output = self.seqTransDecoder(tgt=timequeries, memory=z, tgt_mask=time_rel_pos_bias.repeat(bs, 1, 1),\n                                      tgt_key_padding_mask=~mask)\n        \n        output = self.finallayer(output).reshape(nframes, bs, self.pos_dim) # num_frames, bs, 6\n        # output = self.finallayer(output).reshape(nframes, bs, njoints, nfeats)\n        \n        # zero for padded area\n        output[~mask.T] = 0 #nf, bs, 6\n        output = output.permute(1,0,2)#bs, nf, 6\n        \n        batch[\"output\"] = output\n        return batch\n"
  },
  {
    "path": "PBnet/src/models/architectures/transformerreemb5.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat, reduce, pack, unpack\nfrom einops_exts import rearrange_many\nfrom torch import einsum\nfrom rotary_embedding_torch import RotaryEmbedding\nimport math \nfrom src.models.architectures.transformerdecoder4 import *\n\n\ndef exists(x):\n    return x is not None\n\nclass LayerNorm(nn.Module):\n    def __init__(self, dim, eps=1e-5):\n        super().__init__()\n        self.eps = eps\n        self.gamma = nn.Parameter(torch.ones(1, 1,  dim))\n\n    def forward(self, x):\n        var = torch.var(x, dim=-1, unbiased=False, keepdim=True)\n        mean = torch.mean(x, dim=-1, keepdim=True)\n        return (x - mean) / (var + self.eps).sqrt() * self.gamma\n\nclass PreNorm(nn.Module):\n    def __init__(self, dim, fn):\n        super().__init__()\n        self.fn = fn\n        self.norm = LayerNorm(dim)\n\n    def forward(self, x, **kwargs):\n        x = self.norm(x)\n        return self.fn(x, **kwargs)\n\nclass Residual(nn.Module):\n    def __init__(self, fn):\n        super().__init__()\n        self.fn = fn\n\n    def forward(self, x, *args, **kwargs):\n        return self.fn(x, *args, **kwargs) + x\n\nclass EinopsToAndFrom(nn.Module):\n    def __init__(self, from_einops, to_einops, fn):\n        super().__init__()\n        self.from_einops = from_einops\n        self.to_einops = to_einops\n        self.fn = fn\n\n    def forward(self, x, **kwargs):\n        shape = x.shape\n        reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape)))\n        x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')\n        x = self.fn(x, **kwargs)\n        x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)\n        return x\n\n\nclass PositionalEncoding(nn.Module):\n    def __init__(self, d_model, dropout=0.1, max_len=20000):\n        super(PositionalEncoding, self).__init__()\n        self.dropout = nn.Dropout(p=dropout)\n\n        pe = torch.zeros(max_len, d_model)\n        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))\n        pe[:, 0::2] = torch.sin(position * div_term)\n        pe[:, 1::2] = torch.cos(position * div_term)\n        pe = pe.unsqueeze(0).transpose(0, 1)\n        \n        self.register_buffer('pe', pe)\n\n    def forward(self, x):\n        # not used in the final model\n        x = x + self.pe[:x.shape[0], :]\n        return self.dropout(x)\n\n\nclass RelativePositionBias(nn.Module):\n    def __init__(\n            self,\n            heads=8,\n            num_buckets=32,\n            max_distance=128\n    ):\n        super().__init__()\n        self.num_buckets = num_buckets\n        self.max_distance = max_distance\n        self.relative_attention_bias = nn.Embedding(num_buckets, heads)\n\n    @staticmethod\n    def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):\n        ret = 0\n        n = -relative_position\n\n        num_buckets //= 2\n        ret += (n < 0).long() * num_buckets\n        n = torch.abs(n)\n\n        max_exact = num_buckets // 2\n        is_small = n < max_exact\n\n        val_if_large = max_exact + (\n                torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)\n        ).long()\n        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))\n\n        ret += torch.where(is_small, n, val_if_large)\n        return ret\n\n    def forward(self, n, device, eval = False):\n        q_pos = torch.arange(n, dtype=torch.long, device=device)\n        k_pos = torch.arange(n, dtype=torch.long, device=device)\n        rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')\n        rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets,\n                                                   max_distance=self.max_distance)\n        if not self.relative_attention_bias.training:\n            print('eval!')\n            mask = - (((rel_pos > 200) + (rel_pos  < -200)) * (1e8))\n            values = self.relative_attention_bias(rp_bucket)\n            return rearrange(values, 'i j h -> h i j') + mask\n        else:\n            # values = self.relative_attention_bias(rp_bucket)\n            # return rearrange(values, 'i j h -> h i j')\n            # mask = - (((rel_pos > 100) + (rel_pos  < -100)) * (1e8))\n            values = self.relative_attention_bias(rp_bucket)\n            return rearrange(values, 'i j h -> h i j') # + mask\n        \n# only for ablation / not used in the final model\nclass TimeEncoding(nn.Module):\n    def __init__(self, d_model, dropout=0.1, max_len=5000):\n        super(TimeEncoding, self).__init__()\n        self.dropout = nn.Dropout(p=dropout)\n\n    def forward(self, x, mask, lengths):\n        time = mask * 1/(lengths[..., None]-1)\n        time = time[:, None] * torch.arange(time.shape[1], device=x.device)[None, :]\n        time = time[:, 0].T\n        # add the time encoding\n        x = x + time[..., None]\n        return self.dropout(x)\n    \n\nclass Encoder_TRANSFORMERREEMB5(nn.Module):\n    def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=6, eye_dim=2, pose_latent_dim=64,\n                 audio_latent_dim=256, ff_size=1024, num_layers=4, num_heads=4, dropout=0.1,\n                 ablation=None, activation=\"gelu\", **kargs):\n        super().__init__()\n        \n        self.modeltype = modeltype\n        self.pos_dim = pos_dim\n        self.eye_dim = eye_dim\n        self.num_frames = num_frames\n        self.audio_dim = audio_dim\n        \n        self.pose_latent_dim = pose_latent_dim\n        self.audio_latent_dim = audio_latent_dim\n        self.latent_dim = self.audio_latent_dim + self.pose_latent_dim*2\n        \n        self.ff_size = ff_size\n        self.num_layers = num_layers\n        self.num_heads = num_heads\n        self.dropout = dropout\n\n        self.ablation = ablation\n        self.activation = activation\n\n        \n        # if self.ablation == \"average_encoder\":\n        #     self.mu_layer = nn.Linear(self.latent_dim, self.latent_dim)\n        #     self.sigma_layer = nn.Linear(self.latent_dim, self.latent_dim)\n        # else:\n        #     self.muQuery = nn.Parameter(torch.randn(self.num_classes, self.latent_dim))\n        #     self.sigmaQuery = nn.Parameter(torch.randn(self.num_classes, self.latent_dim))\n        \n        # # there's no class of our dataset CREMA/HDTF, so noly  dont need to use nn.parameter\n        self.mu_layer = nn.Linear(self.latent_dim, self.audio_latent_dim)\n        self.sigma_layer = nn.Linear(self.latent_dim, self.audio_latent_dim)\n        \n        self.poseEmbedding = nn.Linear(self.pos_dim+self.eye_dim, self.pose_latent_dim) #6,64\n        self.firstposeEmbedding = nn.Linear(self.pos_dim+self.eye_dim, self.pose_latent_dim) #6,64\n        self.audioEmbedding = nn.Linear(self.audio_dim, self.audio_latent_dim) #1024, 256\n        \n        self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)\n        \n        # self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))\n        \n        seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,\n                                                          nhead=self.num_heads,\n                                                          dim_feedforward=self.ff_size,\n                                                          dropout=self.dropout,\n                                                          activation=self.activation)\n        self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,\n                                                     num_layers=self.num_layers)\n\n    def forward(self, batch):\n        '''\n            x: 6-dim pos, (bs, max_num_frames, 6)\n            y: 1024-dim audio embbeding, (bs, max_num_frames, 1024)\n        '''\n\n        x, y, mask = batch[\"x\"], batch[\"y\"], batch[\"mask\"]\n        # bs, njoints, nfeats, nframes = x.shape\n        # x = x.permute((3, 0, 1, 2)).reshape(nframes, bs, njoints*nfeats) \n        x_ref = x[:,0,:].unsqueeze(dim=1) # The pose information of the first frame(refrence img)\n        x = x-x_ref.repeat(1,x.size(1),1) # bs, nf, 6  Obtain the difference from the first frame\n        batch['x_delta'] = x\n        x_ref = x_ref.permute((1,0,2)) #1, bs, 6\n        x = x.permute((1, 0, 2)) #nf, bs, 6\n        y = y.permute((1, 0, 2)) #nf, bs, 1024\n        # embedding of the pose/audio\n        x_ref = self.firstposeEmbedding(x_ref).repeat(x.size(0),1,1) #nf, bs, 64\n        x = self.poseEmbedding(x) #nf, bs, 64\n        y = self.audioEmbedding(y) #nf, bs, 256\n        x = torch.cat([x_ref, x, y],dim=-1) # nf, bs, 64+64+256\n\n        # only use the \"average_encoder\" mode\n        # add positional encoding\n        x = self.sequence_pos_encoder(x)\n        # transformer layers\n        final = self.seqTransEncoder(x, src_key_padding_mask=~mask) #nu_frames, bs, 64+64+256\n        # get the average of the output\n        z = final# final.mean(axis=0) # nf, bs, 64+64+256\n        # extract mu and logvar\n        mu = self.mu_layer(z) # nf, bs, 256\n        logvar = self.sigma_layer(z) # nf, bs, 256\n        # logvar = - torch.ones_like(logvar) * 1e10\n\n        return {\"mu\": mu, \"logvar\": logvar}\n\n\nclass Decoder_TRANSFORMERREEMB5(nn.Module):\n    def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=6, eye_dim=2, pose_latent_dim=64,\n                 audio_latent_dim=256, ff_size=1024, num_layers=4, num_heads=4, dropout=0.1, activation=\"gelu\",\n                 ablation=None, num_buckets = 32, max_distance = 32,**kargs):\n        super().__init__()\n\n        self.modeltype = modeltype\n\n        self.pos_dim = pos_dim\n        self.eye_dim = eye_dim\n        self.num_frames = num_frames\n        self.audio_dim = audio_dim\n        \n        self.pose_latent_dim = pose_latent_dim\n        self.audio_latent_dim = audio_latent_dim\n        self.latent_dim = self.audio_latent_dim + self.pose_latent_dim*2\n        \n        self.ff_size = ff_size\n        self.num_layers = num_layers\n        self.num_heads = num_heads\n        self.dropout = dropout\n\n        self.ablation = ablation\n\n        self.activation = activation\n\n        self.firstposeEmbedding = nn.Linear(self.pos_dim+self.eye_dim, self.pose_latent_dim) #6,64\n        self.audioEmbedding = nn.Linear(self.audio_dim, self.audio_latent_dim) #1024, 256\n        self.ztimelinear = nn.Linear(self.audio_latent_dim*2+self.pose_latent_dim, self.pose_latent_dim) #256*2+64,64\n        \n        self.init_proj = nn.Linear(self.pose_latent_dim, self.pose_latent_dim)\n        # self.input_feats = self.njoints*self.nfeats\n\n        # # only for ablation / not used in the final model\n        # if self.ablation == \"zandtime\":\n        #     self.ztimelinear = nn.Linear(self.latent_dim + self.num_classes, self.latent_dim)\n        # else:\n        #     self.actionBiases = nn.Parameter(torch.randn(1024, self.latent_dim))\n            # self.actionBiases = nn.Parameter(torch.randn(self.num_classes, self.latent_dim))\n\n        # # only for ablation / not used in the final model\n        # if self.ablation == \"time_encoding\":\n        #     self.sequence_pos_encoder = TimeEncoding(self.dropout)\n        # else:\n        #     self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)\n\n        self.sequence_pos_encoder = PositionalEncoding(self.pose_latent_dim, self.dropout)\n        rotary_emb = RotaryEmbedding(min(32, num_heads))\n\n        self.time_rel_pos_bias_tgt = RelativePositionBias(heads=num_heads,\n                                                      num_buckets=num_buckets,\n                                                      max_distance=max_distance)  \n\n        self.time_rel_pos_bias_mem = RelativePositionBias(heads=num_heads,\n                                                      num_buckets=num_buckets,\n                                                      max_distance=max_distance)  \n\n        temporal_attn = lambda dim: Attention(dim, heads=num_heads, dim_head=32,\n                                                              rotary_emb=rotary_emb)\n\n        self.init_temporal_attn = Residual(PreNorm(self.pose_latent_dim, temporal_attn(self.pose_latent_dim)))\n        # self.sequence_pos_encoder = TimeEncoding(self.dropout) #time_encoding\n        \n        # seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=self.pose_latent_dim,\n        #                                                   nhead=self.num_heads,\n        #                                                   dim_feedforward=self.ff_size,\n        #                                                   dropout=self.dropout,\n        #                                                   activation=activation)\n        self.seqTransDecoder = TransformerDecoder(d_model = self.pose_latent_dim,\n                                                  num_heads = self.num_heads,\n                                                  dim_feedforward=self.ff_size,\n                                                  dropout=self.dropout,\n                                                  num_layers=self.num_layers)\n        \n        self.finallayer = nn.Linear(self.pose_latent_dim, self.pos_dim+self.eye_dim)\n        \n        self.q_dropout = nn.Dropout(self.dropout)\n    def forward(self, batch):\n        '''\n            z: bs, audio_latent_dim(256)\n            y: bs, num_frames, 1024\n            mask: bs, num_frames\n            lengths: [num_frames,...]\n        '''\n        x, z, y, mask, lengths = batch[\"x\"], batch[\"z\"], batch[\"y\"], batch[\"mask\"], batch[\"lengths\"]\n        bs, nframes = mask.shape\n        # first img\n        x_ref = x[:,0,:].unsqueeze(dim=1) #bs, 1, 64\n        x_ref = self.firstposeEmbedding(x_ref.repeat(1, nframes, 1)) #bs, nf, 64\n        y = self.audioEmbedding(y) #bs, num_frames, 256\n        z = z.permute(1, 0, 2)\n        #z = z.unsqueeze(dim=1).repeat(1, nframes, 1) #bs, num_frames, 256\n        z = torch.cat([x_ref, z, y], dim=-1) # bs, num_frames, 256*2+64\n        z = self.ztimelinear(z)\n        # z = z.permute((1, 0, 2)) # nf, bs, 64\n        pose_latent_dim = z.shape[2]\n        # z = z[None]  # sequence of size 1\n\n        # # only for ablation / not used in the final model\n        # if self.ablation == \"zandtime\":\n        #     yoh = F.one_hot(y, self.num_classes)\n        #     z = torch.cat((z, yoh), axis=1)\n        #     z = self.ztimelinear(z)\n        #     z = z[None]  # sequence of size 1\n        # else:\n        #     # only for ablation / not used in the final model\n        #     if self.ablation == \"concat_bias\":\n        #         # sequence of size 2\n        #         z = torch.stack((z, self.actionBiases[y]), axis=0)\n        #     else:\n        #         # shift the latent noise vector to be the action noise\n        #         z = z + self.actionBiases[y.long()] # NEED CHECK\n        #         z = z[None]  # sequence of size 1\n            \n        timequeries = torch.zeros(bs, nframes, pose_latent_dim, device=z.device) # len, b, c\n        # timequeries = self.sequence_pos_encoder(timequeries) #time_encoding\n\n        time_rel_pos_bias_tgt = self.time_rel_pos_bias_tgt(nframes, device=x.device)\n        time_rel_pos_bias_mem = self.time_rel_pos_bias_mem(nframes, device=x.device)\n\n        timequeries = self.init_proj(timequeries)      \n\n        \n\n        timequeries = self.init_temporal_attn(timequeries, pos_bias=time_rel_pos_bias_tgt.repeat(bs, 1, 1, 1))\n        \n        # # only for ablation / not used in the final model\n        # if self.ablation == \"time_encoding\":\n        #     timequeries = self.sequence_pos_encoder(timequeries, mask, lengths)\n        # else:\n        #     timequeries = self.sequence_pos_encoder(timequeries)\n        \n        # num_frames, bs, 64\n        output = self.seqTransDecoder(tgt=timequeries, memory=z, tgt_mask=time_rel_pos_bias_tgt.repeat(bs, 1, 1, 1), memory_mask = time_rel_pos_bias_mem.repeat(bs, 1, 1, 1),\n                                      )\n        \n        output = self.finallayer(output) # .reshape(nframes, bs, self.pos_dim) # num_frames, bs, 6\n        # output = self.finallayer(output).reshape(nframes, bs, njoints, nfeats)\n        \n        # zero for padded area\n        output[~mask] = 0 #nf, bs, 6\n        batch[\"out_pose\"] = output[:,:,:6] # .permute(1,0,2)#bs, nf, 6\n        batch[\"out_eye\"] = output[:,:,6:] # .permute(1,0,2)#bs, nf, 6\n        batch[\"output\"] = output\n        return batch\n"
  },
  {
    "path": "PBnet/src/models/architectures/transformerreemb6.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat, reduce, pack, unpack\nfrom einops_exts import rearrange_many\nfrom torch import einsum\nfrom rotary_embedding_torch import RotaryEmbedding\nimport math \nfrom src.models.architectures.transformerdecoder5 import *\n\n\ndef exists(x):\n    return x is not None\n\nclass LayerNorm(nn.Module):\n    def __init__(self, dim, eps=1e-5):\n        super().__init__()\n        self.eps = eps\n        self.gamma = nn.Parameter(torch.ones(1, 1,  dim))\n\n    def forward(self, x):\n        var = torch.var(x, dim=-1, unbiased=False, keepdim=True)\n        mean = torch.mean(x, dim=-1, keepdim=True)\n        return (x - mean) / (var + self.eps).sqrt() * self.gamma\n\nclass PreNorm(nn.Module):\n    def __init__(self, dim, fn):\n        super().__init__()\n        self.fn = fn\n        self.norm = LayerNorm(dim)\n\n    def forward(self, x, **kwargs):\n        x = self.norm(x)\n        return self.fn(x, **kwargs)\n\nclass Residual(nn.Module):\n    def __init__(self, fn):\n        super().__init__()\n        self.fn = fn\n\n    def forward(self, x, *args, **kwargs):\n        return self.fn(x, *args, **kwargs) + x\n\nclass EinopsToAndFrom(nn.Module):\n    def __init__(self, from_einops, to_einops, fn):\n        super().__init__()\n        self.from_einops = from_einops\n        self.to_einops = to_einops\n        self.fn = fn\n\n    def forward(self, x, **kwargs):\n        shape = x.shape\n        reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape)))\n        x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')\n        x = self.fn(x, **kwargs)\n        x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)\n        return x\n\n\nclass PositionalEncoding(nn.Module):\n    def __init__(self, d_model, dropout=0.1, max_len=20000):\n        super(PositionalEncoding, self).__init__()\n        self.dropout = nn.Dropout(p=dropout)\n\n        pe = torch.zeros(max_len, d_model)\n        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))\n        pe[:, 0::2] = torch.sin(position * div_term)\n        pe[:, 1::2] = torch.cos(position * div_term)\n        pe = pe.unsqueeze(0).transpose(0, 1)\n        \n        self.register_buffer('pe', pe)\n\n    def forward(self, x):\n        # not used in the final model\n        x = x + self.pe[:x.shape[0], :]\n        return self.dropout(x)\n\n\nclass RelativePositionBias(nn.Module):\n    def __init__(\n            self,\n            heads=8,\n            num_buckets=32,\n            max_distance=128\n    ):\n        super().__init__()\n        self.num_buckets = num_buckets\n        self.max_distance = max_distance\n        self.relative_attention_bias = nn.Embedding(num_buckets, heads)\n\n    @staticmethod\n    def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):\n        ret = 0\n        n = -relative_position\n\n        num_buckets //= 2\n        ret += (n < 0).long() * num_buckets\n        n = torch.abs(n)\n\n        max_exact = num_buckets // 2\n        is_small = n < max_exact\n\n        val_if_large = max_exact + (\n                torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)\n        ).long()\n        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))\n\n        ret += torch.where(is_small, n, val_if_large)\n        return ret\n\n    def forward(self, n, device, eval = False):\n        q_pos = torch.arange(n, dtype=torch.long, device=device)\n        k_pos = torch.arange(n, dtype=torch.long, device=device)\n        rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')\n        rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets,\n                                                   max_distance=self.max_distance)\n        if not self.relative_attention_bias.training:\n            print('eval!')\n            mask = - (((rel_pos > 100) + (rel_pos  < -100)) * (1e8))\n            values = self.relative_attention_bias(rp_bucket)\n            return rearrange(values, 'i j h -> h i j') + mask\n        else:\n            # values = self.relative_attention_bias(rp_bucket)\n            # return rearrange(values, 'i j h -> h i j')\n            # mask = - (((rel_pos > 100) + (rel_pos  < -100)) * (1e8))\n            values = self.relative_attention_bias(rp_bucket)\n            return rearrange(values, 'i j h -> h i j') # + mask\n        \n# only for ablation / not used in the final model\nclass TimeEncoding(nn.Module):\n    def __init__(self, d_model, dropout=0.1, max_len=5000):\n        super(TimeEncoding, self).__init__()\n        self.dropout = nn.Dropout(p=dropout)\n\n    def forward(self, x, mask, lengths):\n        time = mask * 1/(lengths[..., None]-1)\n        time = time[:, None] * torch.arange(time.shape[1], device=x.device)[None, :]\n        time = time[:, 0].T\n        # add the time encoding\n        x = x + time[..., None]\n        return self.dropout(x)\n    \n\nclass Encoder_TRANSFORMERREEMB6(nn.Module):\n    def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=6, eye_dim=2, pose_latent_dim=64,\n                 audio_latent_dim=256, ff_size=1024, num_layers=4, num_heads=4, dropout=0.1,\n                 ablation=None, activation=\"gelu\", **kargs):\n        super().__init__()\n        \n        self.modeltype = modeltype\n        self.pos_dim = pos_dim\n        self.eye_dim = 0\n        self.num_frames = num_frames\n        self.audio_dim = audio_dim\n        \n        self.pose_latent_dim = pose_latent_dim\n        self.audio_latent_dim = audio_latent_dim\n        self.latent_dim = self.audio_latent_dim + self.pose_latent_dim*2\n        \n        self.ff_size = ff_size\n        self.num_layers = num_layers\n        self.num_heads = num_heads\n        self.dropout = dropout\n\n        self.ablation = ablation\n        self.activation = activation\n\n        \n        # if self.ablation == \"average_encoder\":\n        #     self.mu_layer = nn.Linear(self.latent_dim, self.latent_dim)\n        #     self.sigma_layer = nn.Linear(self.latent_dim, self.latent_dim)\n        # else:\n        #     self.muQuery = nn.Parameter(torch.randn(self.num_classes, self.latent_dim))\n        #     self.sigmaQuery = nn.Parameter(torch.randn(self.num_classes, self.latent_dim))\n        \n        # # there's no class of our dataset CREMA/HDTF, so noly  dont need to use nn.parameter\n        self.mu_layer = nn.Linear(self.latent_dim, self.audio_latent_dim)\n        self.sigma_layer = nn.Linear(self.latent_dim, self.audio_latent_dim)\n        \n        self.poseEmbedding = nn.Linear(self.pos_dim+self.eye_dim, self.pose_latent_dim) #6,64\n        self.firstposeEmbedding = nn.Linear(self.pos_dim+self.eye_dim, self.pose_latent_dim) #6,64\n        self.audioEmbedding = nn.Linear(self.audio_dim, self.audio_latent_dim) #1024, 256\n        \n        self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)\n        \n        # self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))\n        \n        seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,\n                                                          nhead=self.num_heads,\n                                                          dim_feedforward=self.ff_size,\n                                                          dropout=self.dropout,\n                                                          activation=self.activation)\n        self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,\n                                                     num_layers=self.num_layers)\n\n    def forward(self, batch):\n        '''\n            x: 6-dim pos, (bs, max_num_frames, 6)\n            y: 1024-dim audio embbeding, (bs, max_num_frames, 1024)\n        '''\n\n        x, y, mask = batch[\"x\"], batch[\"y\"], batch[\"mask\"]\n        # bs, njoints, nfeats, nframes = x.shape\n        # x = x.permute((3, 0, 1, 2)).reshape(nframes, bs, njoints*nfeats) \n        x_ref = x[:,0,:].unsqueeze(dim=1) # The pose information of the first frame(refrence img)\n        x = x-x_ref.repeat(1,x.size(1),1) # bs, nf, 6  Obtain the difference from the first frame\n        batch['x_delta'] = x\n        x_ref = x_ref.permute((1,0,2)) #1, bs, 6\n        x = x.permute((1, 0, 2)) #nf, bs, 6\n        y = y.permute((1, 0, 2)) #nf, bs, 1024\n        # embedding of the pose/audio\n        x_ref = self.firstposeEmbedding(x_ref).repeat(x.size(0),1,1) #nf, bs, 64\n        x = self.poseEmbedding(x) #nf, bs, 64\n        y = self.audioEmbedding(y) #nf, bs, 256\n        x = torch.cat([x_ref, x, y],dim=-1) # nf, bs, 64+64+256\n\n        # only use the \"average_encoder\" mode\n        # add positional encoding\n        x = self.sequence_pos_encoder(x)\n        # transformer layers\n        final = self.seqTransEncoder(x, src_key_padding_mask=~mask) #nu_frames, bs, 64+64+256\n        # get the average of the output\n        z = final# final.mean(axis=0) # nf, bs, 64+64+256\n        # extract mu and logvar\n        mu = self.mu_layer(z) # nf, bs, 256\n        logvar = self.sigma_layer(z) # nf, bs, 256\n        # logvar = - torch.ones_like(logvar) * 1e10\n\n        return {\"mu\": mu, \"logvar\": logvar}\n\n\nclass Decoder_TRANSFORMERREEMB6(nn.Module):\n    def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=6, eye_dim=2, pose_latent_dim=64,\n                 audio_latent_dim=256, ff_size=1024, num_layers=4, num_heads=4, dropout=0.1, activation=\"gelu\",\n                 ablation=None, num_buckets = 32, max_distance = 32,**kargs):\n        super().__init__()\n\n        self.modeltype = modeltype\n\n        self.pos_dim = pos_dim\n        self.eye_dim = 0\n        self.num_frames = num_frames\n        self.audio_dim = audio_dim\n        \n        self.pose_latent_dim = pose_latent_dim\n        self.audio_latent_dim = audio_latent_dim\n        self.latent_dim = self.audio_latent_dim + self.pose_latent_dim*2\n        \n        self.ff_size = ff_size\n        self.num_layers = num_layers\n        self.num_heads = num_heads\n        self.dropout = dropout\n\n        self.ablation = ablation\n\n        self.activation = activation\n\n        self.firstposeEmbedding = nn.Linear(self.pos_dim+self.eye_dim, self.pose_latent_dim) #6,64\n        self.audioEmbedding = nn.Linear(self.audio_dim, self.audio_latent_dim) #1024, 256\n        self.ztimelinear = nn.Linear(self.audio_latent_dim*2+self.pose_latent_dim, self.pose_latent_dim) #256*2+64,64\n        \n        self.init_proj = nn.Linear(self.pose_latent_dim, self.pose_latent_dim)\n        # self.input_feats = self.njoints*self.nfeats\n\n        # # only for ablation / not used in the final model\n        # if self.ablation == \"zandtime\":\n        #     self.ztimelinear = nn.Linear(self.latent_dim + self.num_classes, self.latent_dim)\n        # else:\n        #     self.actionBiases = nn.Parameter(torch.randn(1024, self.latent_dim))\n            # self.actionBiases = nn.Parameter(torch.randn(self.num_classes, self.latent_dim))\n\n        # # only for ablation / not used in the final model\n        # if self.ablation == \"time_encoding\":\n        #     self.sequence_pos_encoder = TimeEncoding(self.dropout)\n        # else:\n        #     self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)\n\n        self.sequence_pos_encoder = PositionalEncoding(self.pose_latent_dim, self.dropout)\n        rotary_emb = RotaryEmbedding(min(32, num_heads))\n\n        self.time_rel_pos_bias_tgt = RelativePositionBias(heads=num_heads,\n                                                      num_buckets=num_buckets,\n                                                      max_distance=max_distance)  \n\n        self.time_rel_pos_bias_mem = RelativePositionBias(heads=num_heads,\n                                                      num_buckets=num_buckets,\n                                                      max_distance=max_distance)  \n\n        temporal_attn = lambda dim: Attention(dim, heads=num_heads, dim_head=32,\n                                                              rotary_emb=rotary_emb)\n\n        self.init_temporal_attn = Residual(PreNorm(self.pose_latent_dim, temporal_attn(self.pose_latent_dim)))\n        # self.sequence_pos_encoder = TimeEncoding(self.dropout) #time_encoding\n        \n        # seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=self.pose_latent_dim,\n        #                                                   nhead=self.num_heads,\n        #                                                   dim_feedforward=self.ff_size,\n        #                                                   dropout=self.dropout,\n        #                                                   activation=activation)\n        self.seqTransDecoder = TransformerDecoder(d_model = self.pose_latent_dim,\n                                                  num_heads = self.num_heads,\n                                                  dim_feedforward=self.ff_size,\n                                                  dropout=self.dropout,\n                                                  num_layers=self.num_layers)\n        \n        self.finallayer = nn.Linear(self.pose_latent_dim, self.pos_dim+self.eye_dim)\n        \n    def forward(self, batch):\n        '''\n            z: bs, audio_latent_dim(256)\n            y: bs, num_frames, 1024\n            mask: bs, num_frames\n            lengths: [num_frames,...]\n        '''\n        x, z, y, mask, lengths = batch[\"x\"], batch[\"z\"], batch[\"y\"], batch[\"mask\"], batch[\"lengths\"]\n        bs, nframes = mask.shape\n        # first img\n        x_ref = x[:,0,:].unsqueeze(dim=1) #bs, 1, 64\n        x_ref = self.firstposeEmbedding(x_ref.repeat(1, nframes, 1)) #bs, nf, 64\n        y = self.audioEmbedding(y) #bs, num_frames, 256\n        z = z.permute(1, 0, 2)\n        #z = z.unsqueeze(dim=1).repeat(1, nframes, 1) #bs, num_frames, 256\n        z = torch.cat([x_ref, z, y], dim=-1) # bs, num_frames, 256*2+64\n        z = self.ztimelinear(z)\n        # z = z.permute((1, 0, 2)) # nf, bs, 64\n        pose_latent_dim = z.shape[2]\n        # z = z[None]  # sequence of size 1\n\n        # # only for ablation / not used in the final model\n        # if self.ablation == \"zandtime\":\n        #     yoh = F.one_hot(y, self.num_classes)\n        #     z = torch.cat((z, yoh), axis=1)\n        #     z = self.ztimelinear(z)\n        #     z = z[None]  # sequence of size 1\n        # else:\n        #     # only for ablation / not used in the final model\n        #     if self.ablation == \"concat_bias\":\n        #         # sequence of size 2\n        #         z = torch.stack((z, self.actionBiases[y]), axis=0)\n        #     else:\n        #         # shift the latent noise vector to be the action noise\n        #         z = z + self.actionBiases[y.long()] # NEED CHECK\n        #         z = z[None]  # sequence of size 1\n            \n        timequeries = torch.zeros(bs, nframes, pose_latent_dim, device=z.device) # len, b, c\n        # timequeries = self.sequence_pos_encoder(timequeries) #time_encoding\n\n        time_rel_pos_bias_tgt = self.time_rel_pos_bias_tgt(nframes, device=x.device)\n        time_rel_pos_bias_mem = self.time_rel_pos_bias_mem(nframes, device=x.device)\n\n        timequeries = self.init_proj(timequeries)      \n\n        \n\n        timequeries = self.init_temporal_attn(timequeries, pos_bias=time_rel_pos_bias_tgt.repeat(bs, 1, 1, 1))\n        \n        # # only for ablation / not used in the final model\n        # if self.ablation == \"time_encoding\":\n        #     timequeries = self.sequence_pos_encoder(timequeries, mask, lengths)\n        # else:\n        #     timequeries = self.sequence_pos_encoder(timequeries)\n        \n        # num_frames, bs, 64\n        output = self.seqTransDecoder(tgt=timequeries, memory=z, tgt_mask=time_rel_pos_bias_tgt.repeat(bs, 1, 1, 1), memory_mask = time_rel_pos_bias_mem.repeat(bs, 1, 1, 1),\n                                      )\n        \n        output = self.finallayer(output) # .reshape(nframes, bs, self.pos_dim) # num_frames, bs, 6\n        # output = self.finallayer(output).reshape(nframes, bs, njoints, nfeats)\n        \n        # zero for padded area\n        output[~mask] = 0 #nf, bs, 6\n        # batch[\"out_pose\"] = output[:,:,:6] # .permute(1,0,2)#bs, nf, 6\n        # batch[\"out_eye\"] = output[:,:,6:] # .permute(1,0,2)#bs, nf, 6\n        batch[\"output\"] = output\n        return batch\n"
  },
  {
    "path": "PBnet/src/models/architectures/transgru.py",
    "content": "from .transformer import Encoder_TRANSFORMER as Encoder_TRANSGRU  # noqa\nfrom .gru import Decoder_GRU as Decoder_TRANSGRU  # noqa\n\n\n"
  },
  {
    "path": "PBnet/src/models/get_model.py",
    "content": "import importlib\n\nimport sys\nimport os\ncurrent_dir = os.path.dirname(os.path.abspath(__file__))\nparent_dir = os.path.dirname(os.path.dirname(current_dir))\nif parent_dir not in sys.path:\n    sys.path.append(parent_dir)\n    print(parent_dir)\n    \n# JOINTSTYPES = [\"a2m\", \"a2mpl\", \"smpl\", \"vibe\", \"vertices\"]\n\nLOSSES = [\"rc\", \"kl\", \"rcw\", \"ssim\", \"var\", 'reg']  # not used: \"hp\", \"mmd\", \"vel\", \"velxyz\"\n\nMODELTYPES = [\"cvae\"]  # not used: \"cae\"\nARCHINAMES = [\"fc\", \"gru\", \"transformer\",\"transformerreemb5\", \"transformerreemb6\", \"transformerreemb7\", \"transformerreemb8\",\"transformermel\", \"transgru\", \"grutrans\", \"autotrans\"]\n\n\ndef get_model(parameters):\n    modeltype = parameters[\"modeltype\"]\n    archiname = parameters[\"archiname\"]\n\n    archi_module = importlib.import_module(f'.architectures.{archiname}', package=\"src.models\")\n    Encoder = archi_module.__getattribute__(f\"Encoder_{archiname.upper()}\")\n    Decoder = archi_module.__getattribute__(f\"Decoder_{archiname.upper()}\")\n\n    model_module = importlib.import_module(f'.modeltype.{modeltype}', package=\"src.models\")\n    Model = model_module.__getattribute__(f\"{modeltype.upper()}\")\n\n    encoder = Encoder(**parameters)\n    decoder = Decoder(**parameters)\n    \n    # parameters[\"outputxyz\"] = \"rcxyz\" in parameters[\"lambdas\"]\n    return Model(encoder, decoder, **parameters).to(parameters[\"device\"])\n"
  },
  {
    "path": "PBnet/src/models/modeltype/__init__.py",
    "content": ""
  },
  {
    "path": "PBnet/src/models/modeltype/cae.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom ..tools.losses import get_loss_function\nimport torch.nn.functional as F\n# from ..rotation2xyz import Rotation2xyz\n\n\n\nclass CAE(nn.Module):\n    def __init__(self, encoder, decoder, device, lambdas, latent_dim, **kwargs):\n        super().__init__()\n\n        self.encoder = encoder\n        self.decoder = decoder\n\n        # self.outputxyz = outputxyz\n        \n        self.lambdas = lambdas\n        \n        self.latent_dim = latent_dim\n        # self.pose_rep = pose_rep\n        # self.glob = glob\n        # self.glob_rot = glob_rot\n        self.device = device\n        # self.translation = translation\n        # self.jointstype = jointstype\n        # self.vertstrans = vertstrans\n        \n        self.losses = list(self.lambdas) + [\"mixed\"]\n\n\n        # self.rotation2xyz = Rotation2xyz(device=self.device)\n        # self.param2xyz = {\"pose_rep\": self.pose_rep,\n        #                   \"glob_rot\": self.glob_rot,\n        #                   \"glob\": self.glob,\n        #                   \"jointstype\": self.jointstype,\n        #                   \"translation\": self.translation,\n        #                   \"vertstrans\": self.vertstrans}\n        \n    # def rot2xyz(self, x, mask, **kwargs):\n    #     kargs = self.param2xyz.copy()\n    #     kargs.update(kwargs)\n    #     return self.rotation2xyz(x, mask, **kargs)\n    \n    def forward(self, batch):\n        # if self.outputxyz:\n        #     batch[\"x_xyz\"] = self.rot2xyz(batch[\"x\"], batch[\"mask\"])\n        # elif self.pose_rep == \"xyz\":\n        #     batch[\"x_xyz\"] = batch[\"x\"]\n        \n        # encode\n        batch.update(self.encoder(batch))\n        # decode\n        batch.update(self.decoder(batch))\n\n        # # if we want to output xyz\n        # if self.outputxyz:\n        #     batch[\"output_xyz\"] = self.rot2xyz(batch[\"output\"], batch[\"mask\"])\n        # elif self.pose_rep == \"xyz\":\n        #     batch[\"output_xyz\"] = batch[\"output\"]\n        return batch\n\n    \n\n    def compute_loss(self, batch, epoch = 0):\n        mixed_loss = 0\n        losses = {}\n        for ltype, lam in self.lambdas.items():\n            loss_function = get_loss_function(ltype)\n            loss = loss_function(self, batch)\n            if 'kl' in ltype:\n                if epoch < 1e4 and epoch != 0: \n                    lam = 0\n                elif epoch != 0:\n                    lam = lam * max(epoch - 1e4, 7e4) / 7e4\n            mixed_loss += loss*lam\n            losses[ltype] = loss.item()\n        \n        # D_loss, G_loss = self.calculate_GAN_loss(batch)\n        # mixed_loss += G_loss * 0.7\n        # losses['GAN_D'] = D_loss\n        # losses['GAN_G'] = D_loss\n        losses[\"mixed\"] = mixed_loss.item()\n        return mixed_loss, losses\n\n    @staticmethod\n    def lengths_to_mask(lengths):\n        max_len = max(lengths)\n        if isinstance(max_len, torch.Tensor):\n            max_len = max_len.item()\n        index = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len)\n        mask = index < lengths.unsqueeze(1)\n        return mask\n\n    def generate_one(self, cls, duration, fact=1, xyz=False):\n        y = torch.tensor([cls], dtype=int, device=self.device)[None]\n        lengths = torch.tensor([duration], dtype=int, device=self.device)\n        mask = self.lengths_to_mask(lengths)\n        z = torch.randn(self.latent_dim, device=self.device)[None]\n        \n        batch = {\"z\": fact*z, \"y\": y, \"mask\": mask, \"lengths\": lengths}\n        batch = self.decoder(batch)\n\n        if not xyz:\n            return batch[\"output\"][0]\n        \n        output_xyz = self.rot2xyz(batch[\"output\"], batch[\"mask\"])\n\n        return output_xyz[0]\n            \n    def generate(self, pose, audio, durations,\n                 noise_same_action=\"random\", noise_diff_action=\"random\",\n                 fact=1):\n        '''\n            audio: hubert embeddbing, (bs, fn, 1024)\n            durations: different num_frames, (bs, )\n        '''\n        # if nspa is None:\n        #     nspa = 1\n        bs = len(audio)\n            \n        # y = audio.to(self.device).repeat(nspa)  # (view(nspa, nats))\n        x = pose.to(self.device)\n        y = audio.to(self.device) \n\n        if len(durations.shape) == 1:\n            lengths = durations.to(self.device)\n        else:\n            lengths = durations.to(self.device).reshape(y.shape)\n        \n        mask = self.lengths_to_mask(lengths)\n        z = torch.randn(audio[0].shape[0], bs, self.latent_dim, device=self.device)\n        # z = torch.randn(1, bs, self.latent_dim, device=self.device).repeat(audio[0].shape[0], 1, 1)\n        \n        # if noise_same_action == \"random\":\n        #     if noise_diff_action == \"random\":\n        #         z = torch.randn(nspa*bs, self.latent_dim, device=self.device)\n        #     elif noise_diff_action == \"same\":\n        #         z_same_action = torch.randn(nspa, self.latent_dim, device=self.device)\n        #         z = z_same_action.repeat_interleave(bs, axis=0)\n        #     else:\n        #         raise NotImplementedError(\"Noise diff action must be random or same.\")\n        # elif noise_same_action == \"interpolate\":\n        #     if noise_diff_action == \"random\":\n        #         z_diff_action = torch.randn(bs, self.latent_dim, device=self.device)\n        #     elif noise_diff_action == \"same\":\n        #         z_diff_action = torch.randn(1, self.latent_dim, device=self.device).repeat(bs, 1)\n        #     else:\n        #         raise NotImplementedError(\"Noise diff action must be random or same.\")\n        #     interpolation_factors = torch.linspace(-1, 1, nspa, device=self.device)\n        #     z = torch.einsum(\"ij,k->kij\", z_diff_action, interpolation_factors).view(nspa*bs, -1)\n        # elif noise_same_action == \"same\":\n        #     if noise_diff_action == \"random\":\n        #         z_diff_action = torch.randn(bs, self.latent_dim, device=self.device)\n        #     elif noise_diff_action == \"same\":\n        #         z_diff_action = torch.randn(1, self.latent_dim, device=self.device).repeat(bs, 1)\n        #     else:\n        #         raise NotImplementedError(\"Noise diff action must be random or same.\")\n        #     z = z_diff_action.repeat((nspa, 1))\n        # else:\n        #     raise NotImplementedError(\"Noise same action must be random, same or interpolate.\")\n\n        batch = {\"x\": x,\"z\": fact*z, \"y\": y, \"mask\": mask, \"lengths\": lengths}\n        batch = self.decoder(batch)\n        \n        # if self.outputxyz:\n        #     batch[\"output_xyz\"] = self.rot2xyz(batch[\"output\"], batch[\"mask\"])\n        # elif self.pose_rep == \"xyz\":\n        #     batch[\"output_xyz\"] = batch[\"output\"]\n        \n        return batch\n    \n    def return_latent(self, batch, seed=None):\n        return self.encoder(batch)[\"z\"]\n"
  },
  {
    "path": "PBnet/src/models/modeltype/cae_0.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom ..tools.losses import get_loss_function\n# from ..rotation2xyz import Rotation2xyz\n\n\nclass CAE(nn.Module):\n    def __init__(self, encoder, decoder, device, lambdas, latent_dim, **kwargs):\n        super().__init__()\n\n        self.encoder = encoder\n        self.decoder = decoder\n\n        # self.outputxyz = outputxyz\n        \n        self.lambdas = lambdas\n        \n        self.latent_dim = latent_dim\n        # self.pose_rep = pose_rep\n        # self.glob = glob\n        # self.glob_rot = glob_rot\n        self.device = device\n        # self.translation = translation\n        # self.jointstype = jointstype\n        # self.vertstrans = vertstrans\n        \n        self.losses = list(self.lambdas) + [\"mixed\"]\n\n        # self.rotation2xyz = Rotation2xyz(device=self.device)\n        # self.param2xyz = {\"pose_rep\": self.pose_rep,\n        #                   \"glob_rot\": self.glob_rot,\n        #                   \"glob\": self.glob,\n        #                   \"jointstype\": self.jointstype,\n        #                   \"translation\": self.translation,\n        #                   \"vertstrans\": self.vertstrans}\n        \n    # def rot2xyz(self, x, mask, **kwargs):\n    #     kargs = self.param2xyz.copy()\n    #     kargs.update(kwargs)\n    #     return self.rotation2xyz(x, mask, **kargs)\n    \n    def forward(self, batch):\n        # if self.outputxyz:\n        #     batch[\"x_xyz\"] = self.rot2xyz(batch[\"x\"], batch[\"mask\"])\n        # elif self.pose_rep == \"xyz\":\n        #     batch[\"x_xyz\"] = batch[\"x\"]\n        \n        # encode\n        batch.update(self.encoder(batch))\n        # decode\n        batch.update(self.decoder(batch))\n\n        # # if we want to output xyz\n        # if self.outputxyz:\n        #     batch[\"output_xyz\"] = self.rot2xyz(batch[\"output\"], batch[\"mask\"])\n        # elif self.pose_rep == \"xyz\":\n        #     batch[\"output_xyz\"] = batch[\"output\"]\n        return batch\n\n    def compute_loss(self, batch):\n        mixed_loss = 0\n        losses = {}\n        for ltype, lam in self.lambdas.items():\n            loss_function = get_loss_function(ltype)\n            loss = loss_function(self, batch)\n            mixed_loss += loss*lam\n            losses[ltype] = loss.item()\n        losses[\"mixed\"] = mixed_loss.item()\n        return mixed_loss, losses\n\n    @staticmethod\n    def lengths_to_mask(lengths):\n        max_len = max(lengths)\n        if isinstance(max_len, torch.Tensor):\n            max_len = max_len.item()\n        index = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len)\n        mask = index < lengths.unsqueeze(1)\n        return mask\n\n    def generate_one(self, cls, duration, fact=1, xyz=False):\n        y = torch.tensor([cls], dtype=int, device=self.device)[None]\n        lengths = torch.tensor([duration], dtype=int, device=self.device)\n        mask = self.lengths_to_mask(lengths)\n        z = torch.randn(self.latent_dim, device=self.device)[None]\n        \n        batch = {\"z\": fact*z, \"y\": y, \"mask\": mask, \"lengths\": lengths}\n        batch = self.decoder(batch)\n\n        if not xyz:\n            return batch[\"output\"][0]\n        \n        output_xyz = self.rot2xyz(batch[\"output\"], batch[\"mask\"])\n\n        return output_xyz[0]\n            \n    def generate(self, pose, audio, durations,\n                 noise_same_action=\"random\", noise_diff_action=\"random\",\n                 fact=1):\n        '''\n            audio: hubert embeddbing, (bs, fn, 1024)\n            durations: different num_frames, (bs, )\n        '''\n        # if nspa is None:\n        #     nspa = 1\n        bs = len(audio)\n            \n        # y = audio.to(self.device).repeat(nspa)  # (view(nspa, nats))\n        x = pose.to(self.device)\n        y = audio.to(self.device) \n\n        if len(durations.shape) == 1:\n            lengths = durations.to(self.device)\n        else:\n            lengths = durations.to(self.device).reshape(y.shape)\n        \n        mask = self.lengths_to_mask(lengths)\n        # z = torch.randn(bs, self.latent_dim, device=self.device)\n        z = torch.randn(audio[0].shape[0], bs, self.latent_dim, device=self.device)\n        \n        # if noise_same_action == \"random\":\n        #     if noise_diff_action == \"random\":\n        #         z = torch.randn(nspa*bs, self.latent_dim, device=self.device)\n        #     elif noise_diff_action == \"same\":\n        #         z_same_action = torch.randn(nspa, self.latent_dim, device=self.device)\n        #         z = z_same_action.repeat_interleave(bs, axis=0)\n        #     else:\n        #         raise NotImplementedError(\"Noise diff action must be random or same.\")\n        # elif noise_same_action == \"interpolate\":\n        #     if noise_diff_action == \"random\":\n        #         z_diff_action = torch.randn(bs, self.latent_dim, device=self.device)\n        #     elif noise_diff_action == \"same\":\n        #         z_diff_action = torch.randn(1, self.latent_dim, device=self.device).repeat(bs, 1)\n        #     else:\n        #         raise NotImplementedError(\"Noise diff action must be random or same.\")\n        #     interpolation_factors = torch.linspace(-1, 1, nspa, device=self.device)\n        #     z = torch.einsum(\"ij,k->kij\", z_diff_action, interpolation_factors).view(nspa*bs, -1)\n        # elif noise_same_action == \"same\":\n        #     if noise_diff_action == \"random\":\n        #         z_diff_action = torch.randn(bs, self.latent_dim, device=self.device)\n        #     elif noise_diff_action == \"same\":\n        #         z_diff_action = torch.randn(1, self.latent_dim, device=self.device).repeat(bs, 1)\n        #     else:\n        #         raise NotImplementedError(\"Noise diff action must be random or same.\")\n        #     z = z_diff_action.repeat((nspa, 1))\n        # else:\n        #     raise NotImplementedError(\"Noise same action must be random, same or interpolate.\")\n\n        batch = {\"x\": x,\"z\": fact*z, \"y\": y, \"mask\": mask, \"lengths\": lengths}\n        batch = self.decoder(batch)\n        \n        # if self.outputxyz:\n        #     batch[\"output_xyz\"] = self.rot2xyz(batch[\"output\"], batch[\"mask\"])\n        # elif self.pose_rep == \"xyz\":\n        #     batch[\"output_xyz\"] = batch[\"output\"]\n        \n        return batch\n    \n    def return_latent(self, batch, seed=None):\n        return self.encoder(batch)[\"z\"]\n"
  },
  {
    "path": "PBnet/src/models/modeltype/cvae.py",
    "content": "import torch\nfrom .cae import CAE\n\n\nclass CVAE(CAE):\n    def reparameterize(self, batch, seed=None):\n        mu, logvar = batch[\"mu\"], batch[\"logvar\"]\n        std = torch.exp(logvar / 2)\n\n        if seed is None:\n            eps = std.data.new(std.size()).normal_()\n        else:\n            generator = torch.Generator(device=self.device)\n            generator.manual_seed(seed)\n            eps = std.data.new(std.size()).normal_(generator=generator)\n\n        z = eps.mul(std).add_(mu)\n        return z\n\n    def forward(self, batch):\n        \n        # if self.outputxyz:\n        #     batch[\"x_xyz\"] = self.rot2xyz(batch[\"x\"], batch[\"mask\"])\n        # elif self.pose_rep == \"xyz\":\n        #     batch[\"x_xyz\"] = batch[\"x\"]\n        # encode\n        batch.update(self.encoder(batch))\n        batch[\"z\"] = self.reparameterize(batch)\n        \n        # decode\n        batch.update(self.decoder(batch))\n        \n        # if we want to output xyz\n        # if self.outputxyz:\n        #     batch[\"output_xyz\"] = self.rot2xyz(batch[\"output\"], batch[\"mask\"])\n        # elif self.pose_rep == \"xyz\":\n        #     batch[\"output_xyz\"] = batch[\"output\"]\n        return batch\n\n    def return_latent(self, batch, seed=None):\n        distrib_param = self.encoder(batch)\n        batch.update(distrib_param)\n        return self.reparameterize(batch, seed=seed)\n"
  },
  {
    "path": "PBnet/src/models/modeltype/lstm.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom ..tools.losses import get_loss_function\nimport torch.nn.functional as F\n# from ..rotation2xyz import Rotation2xyz\nfrom sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d\nfrom sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d\nfrom src.models.architectures.tools.resnet import resnet34\n\nclass MyResNet34(nn.Module):\n    def __init__(self,embedding_dim,input_channel = 3):\n        super(MyResNet34, self).__init__()\n        self.resnet = resnet34(norm_layer = BatchNorm2d,num_classes=embedding_dim,input_channel = input_channel)\n    def forward(self, x):\n        return self.resnet(x)\n\n\nclass LSTM(nn.Module):\n    def __init__(self, encoder, decoder, device, lambdas, latent_dim, **kwargs):\n        super(LSTM,self).__init__()\n\n        self.em_audio = MyResNet34(256, 1)\n        self.em_init_pose = nn.Linear(3,256)\n\n        self.lstm = nn.LSTM(512,256,num_layers=2,bias=True,batch_first=True)\n        self.output = nn.Linear(256,3)\n\n        self.lambdas = lambdas\n        self.losses = list(self.lambdas) + [\"mixed\"]\n        self.device = device\n\n    def compute_loss(self, batch):\n        mixed_loss = 0\n        losses = {}\n        for ltype, lam in self.lambdas.items():\n            loss_function = get_loss_function(ltype)\n            loss = loss_function(self, batch)\n            mixed_loss += loss*lam\n            losses[ltype] = loss.item()\n        \n        # D_loss, G_loss = self.calculate_GAN_loss(batch)\n        # mixed_loss += G_loss * 0.7\n        # losses['GAN_D'] = D_loss\n        # losses['GAN_G'] = D_loss\n        losses[\"mixed\"] = mixed_loss.item()\n        return mixed_loss, losses\n        \n    def forward(self,batch):\n        x, y, mask = batch[\"x\"], batch[\"y\"], batch[\"mask\"]\n        bs = x.shape[0]\n        x_ref = x[:,0,:].unsqueeze(dim=1) # The pose information of the first frame(refrence img)\n        x = x-x_ref.repeat(1,x.size(1),1) # bs, nf, 6  Obtain the difference from the first frame\n        batch['x_delta'] = x\n        ref_pose = self.em_init_pose(batch[\"x\"][:,0,:])\n        result = []\n        bs,seqlen,_,_ = batch[\"y\"].shape\n        zero_state = torch.zeros((2,bs,256),requires_grad=True).to(ref_pose.device)\n        cur_state = (zero_state,zero_state)\n        audio = batch[\"y\"].reshape(-1, 1, 4, 41)\n        audio_em = self.em_audio(audio).reshape(bs, seqlen, 256)\n        for i in range(seqlen):\n\n            ref_pose,cur_state = self.lstm(torch.cat((audio_em[:,i:i+1],ref_pose.unsqueeze(1)),dim=2),cur_state)\n            ref_pose = ref_pose.reshape(-1, 256)\n            result.append(self.output(ref_pose).unsqueeze(1))\n        res = torch.cat(result,dim=1)\n        batch['output'] = res\n        return batch\n\n    @staticmethod\n    def lengths_to_mask(lengths):\n        max_len = max(lengths)\n        if isinstance(max_len, torch.Tensor):\n            max_len = max_len.item()\n        index = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len)\n        mask = index < lengths.unsqueeze(1)\n        return mask\n\n    def generate(self, pose, audio, durations,\n                 noise_same_action=\"random\", noise_diff_action=\"random\",\n                 fact=1):\n        \n        x = pose.to(self.device)\n        y = audio.to(self.device) \n\n        if len(durations.shape) == 1:\n            lengths = durations.to(self.device)\n        else:\n            lengths = durations.to(self.device).reshape(y.shape)\n        \n        mask = self.lengths_to_mask(lengths)\n        batch = {\"x\": x, \"y\": y, \"mask\": mask, \"lengths\": lengths}\n        batch = self.forward(batch)\n        \n        return batch\n"
  },
  {
    "path": "PBnet/src/models/rotation2xyz.py",
    "content": "import torch\nimport src.utils.rotation_conversions as geometry\n\nfrom .smpl import SMPL, JOINTSTYPE_ROOT\nfrom .get_model import JOINTSTYPES\n\n\nclass Rotation2xyz:\n    def __init__(self, device):\n        self.device = device\n        self.smpl_model = SMPL().eval().to(device)\n\n    def __call__(self, x, mask, pose_rep, translation, glob,\n                 jointstype, vertstrans, betas=None, beta=0,\n                 glob_rot=None, **kwargs):\n        if pose_rep == \"xyz\":\n            return x\n\n        if mask is None:\n            mask = torch.ones((x.shape[0], x.shape[-1]), dtype=bool, device=x.device)\n\n        if not glob and glob_rot is None:\n            raise TypeError(\"You must specify global rotation if glob is False\")\n\n        if jointstype not in JOINTSTYPES:\n            raise NotImplementedError(\"This jointstype is not implemented.\")\n\n        if translation:\n            x_translations = x[:, -1, :3]\n            x_rotations = x[:, :-1]\n        else:\n            x_rotations = x\n\n        x_rotations = x_rotations.permute(0, 3, 1, 2)\n        nsamples, time, njoints, feats = x_rotations.shape\n\n        # Compute rotations (convert only masked sequences output)\n        if pose_rep == \"rotvec\":\n            rotations = geometry.axis_angle_to_matrix(x_rotations[mask])\n        elif pose_rep == \"rotmat\":\n            rotations = x_rotations[mask].view(-1, njoints, 3, 3)\n        elif pose_rep == \"rotquat\":\n            rotations = geometry.quaternion_to_matrix(x_rotations[mask])\n        elif pose_rep == \"rot6d\":\n            rotations = geometry.rotation_6d_to_matrix(x_rotations[mask])\n        else:\n            raise NotImplementedError(\"No geometry for this one.\")\n\n        if not glob:\n            global_orient = torch.tensor(glob_rot, device=x.device)\n            global_orient = geometry.axis_angle_to_matrix(global_orient).view(1, 1, 3, 3)\n            global_orient = global_orient.repeat(len(rotations), 1, 1, 1)\n        else:\n            global_orient = rotations[:, 0]\n            rotations = rotations[:, 1:]\n\n        if betas is None:\n            betas = torch.zeros([rotations.shape[0], self.smpl_model.num_betas],\n                                dtype=rotations.dtype, device=rotations.device)\n            betas[:, 1] = beta\n            # import ipdb; ipdb.set_trace()\n        out = self.smpl_model(body_pose=rotations, global_orient=global_orient, betas=betas)\n\n        # get the desirable joints\n        joints = out[jointstype]\n\n        x_xyz = torch.empty(nsamples, time, joints.shape[1], 3, device=x.device, dtype=x.dtype)\n        x_xyz[~mask] = 0\n        x_xyz[mask] = joints\n\n        x_xyz = x_xyz.permute(0, 2, 3, 1).contiguous()\n\n        # the first translation root at the origin on the prediction\n        if jointstype != \"vertices\":\n            rootindex = JOINTSTYPE_ROOT[jointstype]\n            x_xyz = x_xyz - x_xyz[:, [rootindex], :, :]\n\n        if translation and vertstrans:\n            # the first translation root at the origin\n            x_translations = x_translations - x_translations[:, :, [0]]\n\n            # add the translation to all the joints\n            x_xyz = x_xyz + x_translations[:, None, :, :]\n\n        return x_xyz\n"
  },
  {
    "path": "PBnet/src/models/smpl.py",
    "content": "import numpy as np\nimport torch\n\nimport contextlib\n\nfrom smplx import SMPLLayer as _SMPLLayer\nfrom smplx.lbs import vertices2joints\n\nfrom src.datasets.ntu13 import action2motion_joints\n\nfrom src.config import SMPL_MODEL_PATH, JOINT_REGRESSOR_TRAIN_EXTRA\n\nJOINTSTYPE_ROOT = {\"a2m\": 0, # action2motion\n                   \"smpl\": 0,\n                   \"a2mpl\": 0, # set(smpl, a2m)\n                   \"vibe\": 8}  # 0 is the 8 position: OP MidHip below\n\nJOINT_MAP = {\n    'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17,\n    'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16,\n    'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0,\n    'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8,\n    'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7,\n    'OP REye': 25, 'OP LEye': 26, 'OP REar': 27,\n    'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30,\n    'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34,\n    'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45,\n    'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7,\n    'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17,\n    'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20,\n    'Neck (LSP)': 47, 'Top of Head (LSP)': 48,\n    'Pelvis (MPII)': 49, 'Thorax (MPII)': 50,\n    'Spine (H36M)': 51, 'Jaw (H36M)': 52,\n    'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26,\n    'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27\n}\n\nJOINT_NAMES = [\n    'OP Nose', 'OP Neck', 'OP RShoulder',\n    'OP RElbow', 'OP RWrist', 'OP LShoulder',\n    'OP LElbow', 'OP LWrist', 'OP MidHip',\n    'OP RHip', 'OP RKnee', 'OP RAnkle',\n    'OP LHip', 'OP LKnee', 'OP LAnkle',\n    'OP REye', 'OP LEye', 'OP REar',\n    'OP LEar', 'OP LBigToe', 'OP LSmallToe',\n    'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel',\n    'Right Ankle', 'Right Knee', 'Right Hip',\n    'Left Hip', 'Left Knee', 'Left Ankle',\n    'Right Wrist', 'Right Elbow', 'Right Shoulder',\n    'Left Shoulder', 'Left Elbow', 'Left Wrist',\n    'Neck (LSP)', 'Top of Head (LSP)',\n    'Pelvis (MPII)', 'Thorax (MPII)',\n    'Spine (H36M)', 'Jaw (H36M)',\n    'Head (H36M)', 'Nose', 'Left Eye',\n    'Right Eye', 'Left Ear', 'Right Ear'\n]\n\n\n# adapted from VIBE/SPIN to output smpl_joints, vibe joints and action2motion joints\nclass SMPL(_SMPLLayer):\n    \"\"\" Extension of the official SMPL implementation to support more joints \"\"\"\n\n    def __init__(self, model_path=SMPL_MODEL_PATH, **kwargs):\n        kwargs[\"model_path\"] = model_path\n\n        # remove the verbosity for the 10-shapes beta parameters\n        with contextlib.redirect_stdout(None):\n            super(SMPL, self).__init__(**kwargs)\n            \n        J_regressor_extra = np.load(JOINT_REGRESSOR_TRAIN_EXTRA)\n        self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32))\n        vibe_indexes = np.array([JOINT_MAP[i] for i in JOINT_NAMES])\n        a2m_indexes = vibe_indexes[action2motion_joints]\n        smpl_indexes = np.arange(24)\n        a2mpl_indexes = np.unique(np.r_[smpl_indexes, a2m_indexes])\n\n        self.maps = {\"vibe\": vibe_indexes,\n                     \"a2m\": a2m_indexes,\n                     \"smpl\": smpl_indexes,\n                     \"a2mpl\": a2mpl_indexes}\n        \n    def forward(self, *args, **kwargs):\n        smpl_output = super(SMPL, self).forward(*args, **kwargs)\n        \n        extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices)\n        all_joints = torch.cat([smpl_output.joints, extra_joints], dim=1)\n\n        output = {\"vertices\": smpl_output.vertices}\n\n        for joinstype, indexes in self.maps.items():\n            output[joinstype] = all_joints[:, indexes]\n            \n        return output\n"
  },
  {
    "path": "PBnet/src/models/tools/__init__.py",
    "content": ""
  },
  {
    "path": "PBnet/src/models/tools/graphconv.py",
    "content": "import math\n\nimport torch\n\nfrom torch.nn.parameter import Parameter\nfrom torch.nn.modules.module import Module\n\n\nclass GraphConvolution(Module):\n    \"\"\"\n    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907\n    \"\"\"\n\n    def __init__(self, in_features, out_features, bias=True):\n        super(GraphConvolution, self).__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.weight = Parameter(torch.FloatTensor(in_features, out_features))\n        if bias:\n            self.bias = Parameter(torch.FloatTensor(out_features))\n        else:\n            self.register_parameter('bias', None)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        stdv = 1. / math.sqrt(self.weight.size(1))\n        self.weight.data.uniform_(-stdv, stdv)\n        if self.bias is not None:\n            self.bias.data.uniform_(-stdv, stdv)\n\n    def forward(self, input, adj):\n        support = torch.mm(input, self.weight)\n        output = torch.spmm(adj, support)\n        if self.bias is not None:\n            return output + self.bias\n        else:\n            return output\n\n    def __repr__(self):\n        return self.__class__.__name__ + ' (' \\\n               + str(self.in_features) + ' -> ' \\\n               + str(self.out_features) + ')'\n"
  },
  {
    "path": "PBnet/src/models/tools/hessian_penalty.py",
    "content": "\"\"\"\n## Adapted to work with our \"batches\"\nOfficial PyTorch implementation of the Hessian Penalty regularization term from https://arxiv.org/pdf/2008.10599.pdf\nAuthor: Bill Peebles\nTensorFlow Implementation (GPU + Multi-Layer): hessian_penalty_tf.py\nSimple Pure NumPy Implementation: hessian_penalty_np.py\n\nSimple use case where you want to apply the Hessian Penalty to the output of net w.r.t. net_input:\n>>> from hessian_penalty_pytorch import hessian_penalty\n>>> net = MyNeuralNet()\n>>> net_input = sample_input()\n>>> loss = hessian_penalty(net, z=net_input)  # Compute hessian penalty of net's output w.r.t. net_input\n>>> loss.backward()  # Compute gradients w.r.t. net's parameters\n\nIf your network takes multiple inputs, simply supply them to hessian_penalty as you do in the net's forward pass. In the\nfollowing example, we assume BigGAN.forward takes a second input argument \"y\". Note that we always take the Hessian\nPenalty w.r.t. the z argument supplied to hessian_penalty:\n>>> from hessian_penalty_pytorch import hessian_penalty\n>>> net = BigGAN()\n>>> z_input = sample_z_vector()\n>>> class_label = sample_class_label()\n>>> loss = hessian_penalty(net, z=net_input, y=class_label)\n>>> loss.backward()\n\"\"\"\n\nimport torch\n\n\ndef hessian_penalty(G, batch, k=2, epsilon=0.1, reduction=torch.max, return_separately=False, G_z=None, **G_kwargs):\n    \"\"\"\n    Official PyTorch Hessian Penalty implementation.\n\n    Note: If you want to regularize multiple network activations simultaneously, you need to\n    make sure the function G you pass to hessian_penalty returns a list of those activations when it's called with\n    G(z, **G_kwargs). Otherwise, if G returns a tensor the Hessian Penalty will only be computed for the final\n    output of G.\n\n    :param G: Function that maps input z to either a tensor or a list of tensors (activations)\n    :param z: Input to G that the Hessian Penalty will be computed with respect to\n    :param k: Number of Hessian directions to sample (must be >= 2)\n    :param epsilon: Amount to blur G before estimating Hessian (must be > 0)\n    :param reduction: Many-to-one function to reduce each pixel/neuron's individual hessian penalty into a final loss\n    :param return_separately: If False, hessian penalties for each activation output by G are automatically summed into\n                              a final loss. If True, the hessian penalties for each layer will be returned in a list\n                              instead. If G outputs a single tensor, setting this to True will produce a length-1\n                              list.\n    :param G_z: [Optional small speed-up] If you have already computed G(z, **G_kwargs) for the current training\n                iteration, then you can provide it here to reduce the number of forward passes of this method by 1\n    :param G_kwargs: Additional inputs to G besides the z vector. For example, in BigGAN you\n                     would pass the class label into this function via y=<class_label_tensor>\n\n    :return: A differentiable scalar (the hessian penalty), or a list of hessian penalties if return_separately is True\n    \"\"\"\n    if G_z is None:\n        G_z = G(batch, **G_kwargs)\n    z = batch[\"x\"]\n    rademacher_size = torch.Size((k, *z.size()))  # (k, N, z.size())\n    dzs = epsilon * rademacher(rademacher_size, device=z.device)\n    second_orders = []\n    for dz in dzs:  # Iterate over each (N, z.size()) tensor in xs\n        central_second_order = multi_layer_second_directional_derivative(G, batch, dz, G_z, epsilon, **G_kwargs)\n        second_orders.append(central_second_order)  # Appends a tensor with shape equal to G(z).size()\n    loss = multi_stack_var_and_reduce(second_orders, reduction, return_separately)  # (k, G(z).size()) --> scalar\n    return loss\n\n\ndef rademacher(shape, device='cpu'):\n    \"\"\"Creates a random tensor of size [shape] under the Rademacher distribution (P(x=1) == P(x=-1) == 0.5)\"\"\"\n    x = torch.empty(shape, device=device)\n    x.random_(0, 2)  # Creates random tensor of 0s and 1s\n    x[x == 0] = -1  # Turn the 0s into -1s\n    return x\n\n\ndef multi_layer_second_directional_derivative(G, batch, dz, G_z, epsilon, **G_kwargs):\n    \"\"\"Estimates the second directional derivative of G w.r.t. its input at z in the direction x\"\"\"\n    batch_plus = {**batch, \"x\": batch[\"x\"] + dz}\n    batch_moins = {**batch, \"x\": batch[\"x\"] - dz}\n    G_to_x = G(batch_plus, **G_kwargs)\n    G_from_x = G(batch_moins, **G_kwargs)\n\n    G_to_x = listify(G_to_x)\n    G_from_x = listify(G_from_x)\n    G_z = listify(G_z)\n\n    eps_sqr = epsilon ** 2\n    sdd = [(G2x - 2 * G_z_base + Gfx) / eps_sqr for G2x, G_z_base, Gfx in zip(G_to_x, G_z, G_from_x)]\n    return sdd\n\n\ndef stack_var_and_reduce(list_of_activations, reduction=torch.max):\n    \"\"\"Equation (5) from the paper.\"\"\"\n    second_orders = torch.stack(list_of_activations)  # (k, N, C, H, W)\n    var_tensor = torch.var(second_orders, dim=0, unbiased=True)  # (N, C, H, W)\n    penalty = reduction(var_tensor)  # (1,) (scalar)\n    return penalty\n\n\ndef multi_stack_var_and_reduce(sdds, reduction=torch.max, return_separately=False):\n    \"\"\"Iterate over all activations to be regularized, then apply Equation (5) to each.\"\"\"\n    sum_of_penalties = 0 if not return_separately else []\n    for activ_n in zip(*sdds):\n        penalty = stack_var_and_reduce(activ_n, reduction)\n        sum_of_penalties += penalty if not return_separately else [penalty]\n    return sum_of_penalties\n\n\ndef listify(x):\n    \"\"\"If x is already a list, do nothing. Otherwise, wrap x in a list.\"\"\"\n    if isinstance(x, list):\n        return x\n    else:\n        return [x]\n\n\ndef _test_hessian_penalty():\n    \"\"\"\n    A simple multi-layer test to verify the implementation.\n    Function: G(z) = [z_0 * z_1, z_0**2 * z_1]\n    Ground Truth Hessian Penalty: [4, 16 * z_0**2]\n    \"\"\"\n    batch_size = 10\n    nz = 2\n    z = torch.randn(batch_size, nz)\n    def reduction(x): return x.abs().mean()\n    def G(z): return [z[:, 0] * z[:, 1], (z[:, 0] ** 2) * z[:, 1]]\n    ground_truth = [4, reduction(16 * z[:, 0] ** 2).item()]\n    # In this simple example, we use k=100 to reduce variance, but when applied to neural networks\n    # you will probably want to use a small k (e.g., k=2) due to memory considerations.\n    predicted = hessian_penalty(G, z, G_z=None, k=100, reduction=reduction, return_separately=True)\n    predicted = [p.item() for p in predicted]\n    print('Ground Truth: %s' % ground_truth)\n    print('Approximation: %s' % predicted)  # This should be close to ground_truth, but not exactly correct\n    print('Difference: %s' % [str(100 * abs(p - gt) / gt) + '%' for p, gt in zip(predicted, ground_truth)])\n\n\nif __name__ == '__main__':\n    _test_hessian_penalty()\n"
  },
  {
    "path": "PBnet/src/models/tools/losses.py",
    "content": "import torch\nfrom einops import rearrange\nimport torch.nn.functional as F\nfrom .hessian_penalty import hessian_penalty\nfrom .mmd import compute_mmd\nfrom .ssim_loss import ssim\nfrom .normalize_data import normalize_data\n\ndef compute_rc_loss(model, batch):\n    # x = batch[\"x\"] #bs, nf, 6\n    x_delta = batch[\"x_delta\"]\n    output = batch[\"output\"] #bs, nf, 6\n    mask = batch[\"mask\"] #bs, nf\n\n    # gtmasked = x[mask]\n    gtmasked = x_delta[mask]\n    outmasked = output[mask]\n    \n    # loss is large in the beginning\n    loss = F.mse_loss(gtmasked, outmasked, reduction='mean')\n    return loss\n\ndef compute_reg_loss(model, batch):\n    # x = batch[\"x\"] #bs, nf, 6\n    x_delta = batch[\"x_delta\"]\n    mask = batch[\"mask\"] #bs, nf\n    x_1 = x_delta[:,:-1]    \n    x_2 = x_delta[:,1:]\n\n    # gtmasked = x[mask]\n    \n    \n    # loss is large in the beginning\n    loss = F.mse_loss(x_1, x_2, reduction='mean')\n    return loss\n\ndef compute_rc_weight_loss(model, batch):\n    x = batch[\"x\"] #bs, nf, 6\n    x_delta = batch[\"x_delta\"]\n    output = batch[\"output\"] #bs, nf, 6\n    mask = batch[\"mask\"] #bs, nf\n\n    # gtmasked = x[mask] #bs*nf, 6\n    gtmasked = x_delta[mask] #bs*nf, 6\n    outmasked = output[mask] #bs*nf, 6\n    if x.size(2) == 6:\n        weights = torch.tensor([3, 3, 3, 1, 1, 1], dtype=torch.float32).cuda()\n    elif x.size(2) == 7:\n        weights = torch.tensor([3, 3, 3, 1, 1, 1, 0.5], dtype=torch.float32).cuda()\n    elif x.size(2) == 8:\n        weights = torch.tensor([3, 3, 3, 0, 0, 0, 3, 3], dtype=torch.float32).cuda()\n    else:\n        weights = torch.ones(x.size(2), dtype=torch.float32).cuda()\n    weights = weights.unsqueeze(0)\n   \n    # loss is large in the beginning\n    loss = F.mse_loss(gtmasked*weights, outmasked*weights, reduction='mean')\n\n    return loss\n\n\ndef compute_hp_loss(model, batch):\n    loss = hessian_penalty(model.return_latent, batch, seed=torch.random.seed())\n    return loss\n\n\ndef compute_kl_loss(model, batch):\n    # mu, logvar: bs, 256\n    mu, logvar = batch[\"mu\"], batch[\"logvar\"]\n    loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())\n    return loss\n\ndef compute_ssim_loss(model, batch):\n    x = batch[\"x\"] #bs, nf, 6\n    x_ref = x[:,0,:].unsqueeze(dim=1) #bs, 1, 64\n    bs = x_ref.shape[0]\n    mask = batch[\"mask\"] #bs, nf\n\n    x_delta = batch[\"x_delta\"]\n    output = batch[\"output\"]\n    loss = ssimnorm_loss(x_delta, output, mask, bs)\n\n\n    return loss\n\ndef ssimnorm_loss(x, output, mask, bs):\n    min_vals = min(x.min(),output.min())\n    max_vals = max(x.max(),output.max())\n    x_norm = normalize_data(x, min_vals, max_vals)\n    out_norm = normalize_data(output, min_vals, max_vals)\n    gtmasked = x_norm[mask] #bs*nf, 6\n    outmasked = out_norm[mask] #bs*nf, 6\n    gtmasked = rearrange(gtmasked, '(b f) c -> b f c', b=bs)\n    outmasked = rearrange(outmasked, '(b f) c -> b f c', b=bs)\n    gtmasked = gtmasked.unsqueeze(dim=1) # b 1 f c\n    outmasked = outmasked.unsqueeze(dim=1) # b 1 f c\n    loss = 1-ssim(gtmasked, outmasked, val_range=1, window_size=3)\n    return loss\n\ndef ssimnorm_self_loss(x, output, mask, bs):\n    x_norm = normalize_data(x, x.min(), x.max())\n    out_norm = normalize_data(output, output.min(),output.max())\n    gtmasked = x_norm[mask] #bs*nf, 6\n    outmasked = out_norm[mask] #bs*nf, 6\n    gtmasked = rearrange(gtmasked, '(b f) c -> b f c', b=bs)\n    outmasked = rearrange(outmasked, '(b f) c -> b f c', b=bs)\n    gtmasked = gtmasked.unsqueeze(dim=1) # b 1 f c\n    outmasked = outmasked.unsqueeze(dim=1) # b 1 f c\n    loss = 1-ssim(gtmasked, outmasked, val_range=1, window_size=5)\n    return loss\n\ndef ssim255_loss(x, output, mask, bs):\n    gtmasked = x[mask] #bs*nf, 6\n    outmasked = output[mask] #bs*nf, 6\n\n    # add 128 to ensue input range is 0-255\n    gtmasked = rearrange(gtmasked, '(b f) c -> b f c', b=bs)+128\n    outmasked = rearrange(outmasked, '(b f) c -> b f c', b=bs)+128\n\n    gtmasked = gtmasked.unsqueeze(dim=1) # b 1 f c\n    outmasked = outmasked.unsqueeze(dim=1) # b 1 f c\n\n    loss = 1-ssim(gtmasked, outmasked, val_range=255, window_size=5)\n    return loss\n\ndef comput_var_loss(model, batch):\n    output = batch[\"output\"] #bs, nf, 6\n    mask = batch[\"mask\"] #bs, nf\n    outmasked = output[mask] #bs*nf, 6\n\n    batch_size, num_frames, dim = output.size()\n    outmasked = rearrange(outmasked, '(b f) c -> b f c', b=batch_size)\n    variance_loss = 0\n    zero_loss = torch.tensor(0)\n\n    for b in range(batch_size):\n        for d in range(dim):\n            dimension_output = outmasked[b, :, d]  # shape: (bs, nf)\n            frame_variance = torch.var(dimension_output)\n            variance_loss += frame_variance\n    variance_loss /= (batch_size * dim)\n    if 3>variance_loss>0:\n        return variance_loss\n    else:\n        return zero_loss\n\ndef compute_mmd_loss(model, batch):\n    z = batch[\"z\"]\n    true_samples = torch.randn(z.shape, requires_grad=False, device=model.device)\n    loss = compute_mmd(true_samples, z)\n    return loss\n\n\n_matching_ = {\"rc\": compute_rc_loss, \"rcw\": compute_rc_weight_loss,\n              \"kl\": compute_kl_loss, \"hp\": compute_hp_loss,\n              \"mmd\": compute_mmd_loss, \"ssim\": compute_ssim_loss,\n              \"var\": comput_var_loss, 'reg': compute_reg_loss}\n\n# _matching_ = {\"rc\": compute_rc_loss, \"kl\": compute_kl_loss, \"hp\": compute_hp_loss,\n#               \"mmd\": compute_mmd_loss, \"rcxyz\": compute_rcxyz_loss,\n#               \"vel\": compute_vel_loss, \"velxyz\": compute_velxyz_loss}\n\n\ndef get_loss_function(ltype):\n    return _matching_[ltype]\n\n\ndef get_loss_names():\n    return list(_matching_.keys())\n"
  },
  {
    "path": "PBnet/src/models/tools/mmd.py",
    "content": "import torch\n\n\n# from https://github.com/napsternxg/pytorch-practice/blob/master/Pytorch%20-%20MMD%20VAE.ipynb\ndef compute_kernel(x, y):\n    x_size = x.size(0)\n    y_size = y.size(0)\n    dim = x.size(1)\n    x = x.unsqueeze(1)  # (x_size, 1, dim)\n    y = y.unsqueeze(0)  # (1, y_size, dim)\n    tiled_x = x.expand(x_size, y_size, dim)\n    tiled_y = y.expand(x_size, y_size, dim)\n    kernel_input = (tiled_x - tiled_y).pow(2).mean(2)/float(dim)\n    return torch.exp(-kernel_input)  # (x_size, y_size)\n\n\ndef compute_mmd(x, y):\n    x_kernel = compute_kernel(x, x)\n    y_kernel = compute_kernel(y, y)\n    xy_kernel = compute_kernel(x, y)\n    mmd = x_kernel.mean() + y_kernel.mean() - 2*xy_kernel.mean()\n    return mmd\n"
  },
  {
    "path": "PBnet/src/models/tools/msssim_loss.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom math import exp\nimport numpy as np\n\n\ndef gaussian(window_size, sigma):\n    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])\n    return gauss/gauss.sum()\n\n\ndef create_window(window_size, channel=1):\n    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)\n    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)\n    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()\n    return window\n\n\ndef ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):\n    # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).\n    if val_range is None:\n        if torch.max(img1) > 128:\n            max_val = 255\n        else:\n            max_val = 1\n\n        if torch.min(img1) < -0.5:\n            min_val = -1\n        else:\n            min_val = 0\n        L = max_val - min_val\n    else:\n        L = val_range\n\n    padd = 0\n    (_, channel, height, width) = img1.size()\n    if window is None:\n        real_size = min(window_size, height, width)\n        window = create_window(real_size, channel=channel).to(img1.device)\n\n    mu1 = F.conv2d(img1, window, padding=padd, groups=channel)\n    mu2 = F.conv2d(img2, window, padding=padd, groups=channel)\n\n    mu1_sq = mu1.pow(2)\n    mu2_sq = mu2.pow(2)\n    mu1_mu2 = mu1 * mu2\n\n    sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq\n    sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq\n    sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2\n\n    C1 = (0.01 * L) ** 2\n    C2 = (0.03 * L) ** 2\n\n    v1 = 2.0 * sigma12 + C2\n    v2 = sigma1_sq + sigma2_sq + C2\n    cs = v1 / v2  # contrast sensitivity\n\n    ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)\n\n    if size_average:\n        cs = cs.mean()\n        ret = ssim_map.mean()\n    else:\n        cs = cs.mean(1).mean(1).mean(1)\n        ret = ssim_map.mean(1).mean(1).mean(1)\n\n    if full:\n        return ret, cs\n    return ret\n\n\ndef msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=None):\n    device = img1.device\n    weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)\n    levels = weights.size()[0]\n    ssims = []\n    mcs = []\n    for _ in range(levels):\n        sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)\n\n        # Relu normalize (not compliant with original definition)\n        if normalize == \"relu\":\n            ssims.append(torch.relu(sim))\n            mcs.append(torch.relu(cs))\n        else:\n            ssims.append(sim)\n            mcs.append(cs)\n\n        img1 = F.avg_pool2d(img1, (2, 2))\n        img2 = F.avg_pool2d(img2, (2, 2))\n\n    ssims = torch.stack(ssims)\n    mcs = torch.stack(mcs)\n\n    # Simple normalize (not compliant with original definition)\n    # TODO: remove support for normalize == True (kept for backward support)\n    if normalize == \"simple\" or normalize == True:\n        ssims = (ssims + 1) / 2\n        mcs = (mcs + 1) / 2\n\n    pow1 = mcs ** weights\n    pow2 = ssims ** weights\n\n    # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/\n    output = torch.prod(pow1[:-1]) * pow2[-1]\n    return output\n\n\n# Classes to re-use window\nclass SSIM(torch.nn.Module):\n    def __init__(self, window_size=11, size_average=True, val_range=None):\n        super(SSIM, self).__init__()\n        self.window_size = window_size\n        self.size_average = size_average\n        self.val_range = val_range\n\n        # Assume 1 channel for SSIM\n        self.channel = 1\n        self.window = create_window(window_size)\n\n    def forward(self, img1, img2):\n        (_, channel, _, _) = img1.size()\n\n        if channel == self.channel and self.window.dtype == img1.dtype:\n            window = self.window\n        else:\n            window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)\n            self.window = window\n            self.channel = channel\n\n        return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)\n\nclass MSSSIM(torch.nn.Module):\n    def __init__(self, window_size=11, size_average=True, channel=3):\n        super(MSSSIM, self).__init__()\n        self.window_size = window_size\n        self.size_average = size_average\n        self.channel = channel\n\n    def forward(self, img1, img2):\n        # TODO: store window between calls if possible\n        return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)\n\n\nif __name__ == \"__main__\":\n\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    m = MSSSIM()\n\n    img1 = torch.rand(1, 1, 256, 256)\n    img2 = torch.rand(1, 1, 256, 256)\n\n    print(msssim(img1, img2))\n    print(m(img1, img2))\n\n"
  },
  {
    "path": "PBnet/src/models/tools/normalize_data.py",
    "content": "import torch\n\ndef normalize_data(data, min_vals, max_vals):\n    min_vals = min_vals.unsqueeze(0).unsqueeze(0)  \n    max_vals = max_vals.unsqueeze(0).unsqueeze(0)  \n\n    normalized_data = (data - min_vals) / (max_vals - min_vals)\n\n    return normalized_data\n\nif __name__ == \"__main__\":\n    bs = 32\n    nf = 10\n    data = torch.randn((bs, nf, 6))  \n\n    # means = torch.tensor([2.17239228e-02 -8.76334959e-01 1.83403242e-01 4.68812609e-04 6.09114990e+01 6.82846017e+01])\n    # stds = torch.tensor([3.95977561e+00 2.74141379e+00 2.70259097e+00 8.42982963e-06 1.71036724e+00 1.94872744e+00])\n    min_vals = torch.tensor([-1.03461033e+01, -8.08477430e+00, -7.56659334e+00, 4.33026857e-04, 5.68175623e+01, 6.36141304e+01])\n    max_vals = torch.tensor([1.75214498e+01, 8.44862517e+00, 7.98321722e+00, 6.12732050e-04, 6.88481830e+01, 8.21925801e+01])\n\n    normalized_data = normalize_data(data, min_vals, max_vals)\n    print(normalized_data)\n"
  },
  {
    "path": "PBnet/src/models/tools/ssim_loss.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\nimport numpy as np\nfrom math import exp\n\ndef gaussian(window_size, sigma):\n    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])\n    return gauss/gauss.sum()\n\ndef create_window(window_size, channel):\n    _1D_window = gaussian(window_size, 0.5).unsqueeze(1)\n    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)\n    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())\n    return window\n\ndef _ssim(img1, img2, window, window_size, channel, val_range = 1, size_average = True):\n    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)\n    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)\n\n    mu1_sq = mu1.pow(2)\n    mu2_sq = mu2.pow(2)\n    mu1_mu2 = mu1*mu2\n\n    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq\n    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq\n    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2\n\n    C1 = (0.01*val_range)**2\n    C2 = (0.03*val_range)**2\n\n    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))\n\n    if size_average:\n        return ssim_map.mean()\n    else:\n        return ssim_map.mean(1).mean(1).mean(1)\n\nclass SSIM(torch.nn.Module):\n    def __init__(self, window_size = 11, size_average = True):\n        super(SSIM, self).__init__()\n        self.window_size = window_size\n        self.size_average = size_average\n        self.channel = 1\n        self.window = create_window(window_size, self.channel)\n\n    def forward(self, img1, img2):\n        (_, channel, _, _) = img1.size()\n\n        if channel == self.channel and self.window.data.type() == img1.data.type():\n            window = self.window\n        else:\n            window = create_window(self.window_size, channel)\n            \n            if img1.is_cuda:\n                window = window.cuda(img1.get_device())\n            window = window.type_as(img1)\n            \n            self.window = window\n            self.channel = channel\n\n\n        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)\n\ndef ssim(img1, img2, window_size = 11, val_range=1, size_average = True):\n    (_, channel, _, _) = img1.size()\n    window = create_window(window_size, channel)\n    \n    if img1.is_cuda:\n        window = window.cuda(img1.get_device())\n    window = window.type_as(img1)\n    \n    return _ssim(img1, img2, window, window_size, channel, val_range, size_average)\n\ndef read_pose_from_txt(file_path):\n    data = np.loadtxt(file_path)\n    return data\n\nif __name__ == \"__main__\":\n    pose1 = read_pose_from_txt('/train20/intern/permanent/lmlin2/ReferenceCode/ACTOR-master/exps_delta_pose/HDTF_nf40_kl1_ssim1_128_w5_1w/nofinetune/3500/eval_gt/0/RD_Radio14_000_522_gt')\n    pose2 = read_pose_from_txt('/train20/intern/permanent/lmlin2/ReferenceCode/ACTOR-master/exps_delta_pose/HDTF_nf40_kl1_ssim1_128_w5_1w/nofinetune/3500/eval_pred/0/RD_Radio14_000_522')\n\n    pose1_tensor = torch.tensor(pose1).unsqueeze(0).unsqueeze(0).float()[:,:,:,:-1]\n    pose2_tensor = torch.tensor(pose2).unsqueeze(0).unsqueeze(0).float()[:,:,:,:-1]\n    # pose1_tensor = torch.tensor(pose1).unsqueeze(0).unsqueeze(0).float()[:,:,:,:-1]+128\n    # pose2_tensor = torch.tensor(pose2).unsqueeze(0).unsqueeze(0).float()[:,:,:,:-1]+128\n\n    pose1_tensor = (pose1_tensor - pose1_tensor.min()) / (pose1_tensor.max() - pose1_tensor.min())\n    pose2_tensor = (pose2_tensor - pose2_tensor.min()) / (pose2_tensor.max() - pose2_tensor.min())\n\n    ssim_loss = 1-ssim(pose1_tensor, pose2_tensor, window_size=3, val_range=1)\n\n    print(ssim_loss)\n"
  },
  {
    "path": "PBnet/src/models/tools/tools.py",
    "content": "import torch.nn as nn\nfrom torch.nn.modules.module import ModuleAttributeError\n\n\nclass AutoParams(nn.Module):\n    def __init__(self, **kargs):\n        try:\n            for param in self.needed_params:\n                if param in kargs:\n                    setattr(self, param, kargs[param])\n                else:\n                    raise ValueError(f\"{param} is needed.\")\n        except ModuleAttributeError:\n            pass\n            \n        try:\n            for param, default in self.optional_params.items():\n                if param in kargs and kargs[param] is not None:\n                    setattr(self, param, kargs[param])\n                else:\n                    setattr(self, param, default)\n        except ModuleAttributeError:\n            pass\n        super().__init__()\n\n\n# taken from joeynmt repo\ndef freeze_params(module: nn.Module) -> None:\n    \"\"\"\n    Freeze the parameters of this module,\n    i.e. do not update them during training\n\n    :param module: freeze parameters of this module\n    \"\"\"\n    for _, p in module.named_parameters():\n        p.requires_grad = False\n"
  },
  {
    "path": "PBnet/src/parser/base.py",
    "content": "from argparse import ArgumentParser  # noqa\n\n\ndef add_misc_options(parser):\n    group = parser.add_argument_group('Miscellaneous options')\n    group.add_argument(\"--expname\", default=\"exps\", help=\"general directory to this experiments, use it if you don't provide folder name\")\n    group.add_argument(\"--folder\", default=\"exps/default_path\", help=\"directory name to save models\")\n    \n\ndef add_cuda_options(parser):\n    group = parser.add_argument_group('Cuda options')\n    group.add_argument(\"--cuda\", dest='cuda', action='store_true', help=\"if we want to try to use gpu\")\n    group.add_argument('--cpu', dest='cuda', action='store_false', help=\"if we want to use cpu\")\n    group.set_defaults(cuda=True)\n    \n    group.add_argument(\"--gpu\", default='0', help=\"choose gpu device.\")\n\n    \ndef adding_cuda(parameters):\n    import torch\n    if (parameters[\"cuda\"] or parameters[\"gpu\"])  and torch.cuda.is_available():\n        parameters[\"device\"] = torch.device(\"cuda\")\n    else:\n        parameters[\"device\"] = torch.device(\"cpu\")\n        \n    \n"
  },
  {
    "path": "PBnet/src/parser/checkpoint.py",
    "content": "import os\nfrom .base import ArgumentParser, adding_cuda\nfrom .tools import load_args\n\n\ndef parser():\n    parser = ArgumentParser()\n    parser.add_argument(\"checkpointname\")\n    parser.add_argument(\"--num_epochs\", type=int, default=5000, help=\"new number of epochs of training\")\n\n    opt = parser.parse_args()\n    \n    folder, checkpoint = os.path.split(opt.checkpointname)\n    parameters = load_args(os.path.join(folder, \"opt.yaml\"))\n\n    parameters[\"num_epochs\"] = opt.num_epochs\n\n    adding_cuda(parameters)\n    epoch = int(checkpoint.split(\"_\")[-1].split('.')[0])\n    return parameters, folder, checkpoint, epoch\n\n\ndef construct_checkpointname(parameters, folder):\n    implist = [parameters[\"modelname\"],\n               parameters[\"dataset\"],\n               parameters[\"extraction_method\"],\n               parameters[\"pose_rep\"]]\n    if parameters[\"pose_rep\"] != \"xyz\":\n        # [True, \"\"] to be compatible with generate job\n        if \"glob\" in parameters:\n            implist.append(\"glob\" if parameters[\"glob\"] in [True, \"\"] else \"noglob\")\n        else:\n            implist.append(\"noglob\")\n        if \"translation\" in parameters:\n            implist.append(\"translation\" if parameters[\"translation\"] in [True, \"\"] else \"notranslation\")\n        else:\n            implist.append(\"notranslation\")\n            \n        if \"rcxyz\" in parameters[\"modelname\"]:\n            implist.append(\"joinstype_{}\".format(parameters[\"jointstype\"]))\n\n    if \"num_layers\" in parameters:\n        implist.append(\"numlayers_{}\".format(parameters[\"num_layers\"]))\n            \n    for name in [\"num_frames\", \"min_len\", \"max_len\", \"num_seq_max\"]:\n        pvalue = parameters[name]\n        pname = name.replace(\"_\", \"\")\n        if pvalue != -1:\n            implist.append(f\"{pname}_{pvalue}\")\n    \n    if \"view\" in parameters:\n        if parameters[\"view\"] == \"frontview\":\n            implist.append(\"frontview\")\n\n    if \"use_z\" in parameters:\n        if parameters[\"use_z\"] != 0:\n            implist.append(\"usez\")\n        else:\n            implist.append(\"noz\")\n\n    if \"vertstrans\" in parameters:\n        implist.append(\"vetr\" if parameters[\"vertstrans\"] else \"novetr\")\n        \n    if \"ablation\" in parameters:\n        abl = parameters[\"ablation\"]\n        if abl not in [\"\", None]:\n            implist.append(f\"abl_{abl}\")\n            \n    if parameters[\"num_frames\"] != -1:\n        implist.append(\"sampling_{}\".format(parameters[\"sampling\"]))\n        if parameters[\"sampling\"] == \"conseq\":\n            implist.append(\"samplingstep_{}\".format(parameters[\"sampling_step\"]))\n    if \"lambda_kl\" in parameters:\n        implist.append(\"kl_{:.0e}\".format(float(parameters[\"lambda_kl\"])))\n\n    if \"activation\" in parameters:\n        act = parameters[\"activation\"]\n        implist.append(act)\n\n    implist.append(\"bs_{}\".format(parameters[\"batch_size\"]))\n    implist.append(\"ldim_{}\".format(parameters[\"latent_dim\"]))\n    \n    checkpoint = \"_\".join(implist)\n    return os.path.join(folder, checkpoint)\n\n\n"
  },
  {
    "path": "PBnet/src/parser/dataset.py",
    "content": "from src.datasets.dataset import POSE_REPS\n\n\ndef add_dataset_options(parser):\n    group = parser.add_argument_group('Dataset options')\n\n    group.add_argument(\"--dataset\", default='crema', help=\"Name of the dataset\")\n    group.add_argument(\"--num_frames\", default=60, type=int, help=\"number of frames or -1 => whole, -2 => random between min_len and total\")\n    # group.add_argument(\"--sampling\", default=\"conseq\", choices=[\"conseq\", \"random_conseq\", \"random\"], help=\"sampling choices\")\n    # group.add_argument(\"--sampling_step\", default=1, type=int, help=\"sampling step\")\n    # group.add_argument(\"--pose_rep\", required=True, choices=POSE_REPS, help=\"xyz or rotvec etc\")\n\n    group.add_argument(\"--max_len\", default=-1, type=int, help=\"number of frames maximum per sequence or -1\")\n    group.add_argument(\"--min_len\", default=-1, type=int, help=\"number of frames minimum per sequence or -1\")\n    group.add_argument(\"--num_seq_max\", default=-1, type=int, help=\"number of sequences maximum to load or -1\")\n\n    # group.add_argument(\"--glob\", dest='glob', action='store_true', help=\"if we want global rotation\")\n    # group.add_argument('--no-glob', dest='glob', action='store_false', help=\"if we don't want global rotation\")\n    # group.set_defaults(glob=True)\n    # group.add_argument(\"--glob_rot\", type=int, nargs=\"+\", default=[3.141592653589793, 0, 0],\n    #                    help=\"Default rotation, usefull if glob is False\")\n    # group.add_argument(\"--translation\", dest='translation', action='store_true',\n    #                    help=\"if we want to output translation\")\n    # group.add_argument('--no-translation', dest='translation', action='store_false',\n    #                    help=\"if we don't want to output translation\")\n    # group.set_defaults(translation=True)\n\n    # group.add_argument(\"--debug\", dest='debug', action='store_true', help=\"if we are in debug mode\")\n    # group.set_defaults(debug=False)\n"
  },
  {
    "path": "PBnet/src/parser/evaluation.py",
    "content": "import argparse\nimport os\nimport sys\nsys.path.append('/train20/intern/permanent/lmlin2/ReferenceCode/ACTOR-master')\n\nfrom src.parser.tools import load_args\nfrom src.parser.base import add_cuda_options, adding_cuda\n\n\ndef parser():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"checkpointname\")\n    parser.add_argument(\"--dataset\", default='crema', help=\"name of dataset\")\n    parser.add_argument(\"--batch_size\", type=int, default=32, help=\"size of the batches\")\n    parser.add_argument(\"--num_frames\", default=60, type=int, help=\"number of frames, if value is bigger than gt nf, load all nf. \")\n    parser.add_argument(\"--niter\", default=20, type=int, help=\"number of iterations\")\n    parser.add_argument(\"--num_seq_max\", default=3000, type=int, help=\"number of sequences maximum to load or -1\")\n\n    # cuda options\n    add_cuda_options(parser)\n    \n    opt = parser.parse_args()\n    newparameters = {key: val for key, val in vars(opt).items() if val is not None}\n    \n    folder, checkpoint = os.path.split(newparameters[\"checkpointname\"])\n    parameters = load_args(os.path.join(folder, \"opt.yaml\"))\n    parameters.update(newparameters)\n    adding_cuda(parameters)\n    \n    if checkpoint.split(\"_\")[0] == 'retraincheckpoint':\n        epoch = int(checkpoint.split(\"_\")[2])+int(checkpoint.split(\"_\")[4].split('.')[0])\n    else:\n        epoch = int(checkpoint.split(\"_\")[1].split('.')[0])\n \n    return parameters, folder, checkpoint, epoch, opt.niter\n\n\n"
  },
  {
    "path": "PBnet/src/parser/finetunning.py",
    "content": "import os\nfrom .base import argparse, adding_cuda, load_args\n\n    \ndef parser():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"checkpointname\")\n\n    group = parser.add_argument_group('Finetunning options (what should change)')\n    group.add_argument(\"--num_epochs\", type=int, help=\"new number of epochs of training\")\n    group.add_argument(\"--batch_size\", type=int, help=\"size of the batches\")\n    group.add_argument(\"--lr\", type=float, help=\"AdamW: learning rate\")\n    group.add_argument(\"--snapshot\", type=int, help=\"frequency of saving model/viz\")\n    group.add_argument(\"--num_frames\", default=-2, type=int, help=\"number of frames or -1 => whole, -2 => random between min_len and total\")\n    group.add_argument(\"--min_len\", default=60, type=int, help=\"number of frames minimum per sequence or -1\")\n    group.add_argument(\"--max_len\", default=100, type=int, help=\"number of frames maximum per sequence or -1\")\n    \n    opt = parser.parse_args()\n    \n    folder, checkpoint = os.path.split(opt.checkpointname)\n    parameters = load_args(os.path.join(folder, \"opt.yaml\"))\n    parameters[\"folder\"] = folder\n    \n    adding_cuda(parameters)\n    epoch = int(checkpoint.split(\"_\")[-1].split('.')[0])\n    return parameters, folder, checkpoint, epoch\n"
  },
  {
    "path": "PBnet/src/parser/generate.py",
    "content": "import os\n\nfrom src.models.get_model import JOINTSTYPES\nfrom .base import ArgumentParser, add_cuda_options, adding_cuda\nfrom .tools import load_args\n\n\ndef add_generation_options(parser):\n    group = parser.add_argument_group('Generation options')\n    group.add_argument(\"--num_samples_per_action\", default=5, type=int, help=\"num samples per action\")\n    group.add_argument(\"--num_frames\", default=60, type=int, help=\"The number of frames considered (overrided if duration mode is chosen)\")\n    group.add_argument(\"--fact_latent\", default=1, type=int, help=\"Fact latent\")\n\n    group.add_argument(\"--jointstype\", default=\"smpl\", choices=JOINTSTYPES,\n                       help=\"Jointstype for training with xyz\")\n\n    group.add_argument('--vertstrans', dest='vertstrans', action='store_true', help=\"Add the vertex translations\")\n    group.add_argument('--no-vertstrans', dest='vertstrans', action='store_false', help=\"Do not add the vertex translations\")\n    group.set_defaults(vertstrans=False)\n\n    group.add_argument(\"--mode\", default=\"gen\", choices=[\"interpolate\", \"gen\", \"duration\", \"reconstruction\"],\n                       help=\"The kind of generation considered.\")\n\n\ndef parser():\n    parser = ArgumentParser()\n    parser.add_argument(\"checkpointname\")\n\n    # add visualize options back\n    add_generation_options(parser)\n\n    # cuda options\n    add_cuda_options(parser)\n\n    opt = parser.parse_args()\n    newparameters = {key: val for key, val in vars(opt).items() if val is not None}\n    folder, checkpoint = os.path.split(newparameters[\"checkpointname\"])\n    parameters = load_args(os.path.join(folder, \"opt.yaml\"))\n    parameters.update(newparameters)\n\n    adding_cuda(parameters)\n\n    epoch = int(checkpoint.split(\"_\")[-1].split('.')[0])\n    return parameters, folder, checkpoint, epoch\n"
  },
  {
    "path": "PBnet/src/parser/model.py",
    "content": "from src.models.get_model import LOSSES, MODELTYPES, ARCHINAMES\n\n\ndef add_model_options(parser):\n    group = parser.add_argument_group('Model options')\n    group.add_argument(\"--modelname\", default='cvae_transformer_rc_kl', help=\"Choice of the model, should be like cvae_transformer_rc_rcxyz_kl\")\n    group.add_argument(\"--latent_dim\", default=256, type=int, help=\"dimensionality of the latent space\")\n    group.add_argument(\"--lambda_kl\", default=1.0, type=float, help=\"weight of the kl divergence loss\")\n    group.add_argument(\"--lambda_rcw\", default=1.0, type=float, help=\"weight of the rc divergence loss with weight\")\n    group.add_argument(\"--lambda_rc\", default=1.0, type=float, help=\"weight of the rc divergence loss\")\n    group.add_argument(\"--lambda_ssim\", default=1.0, type=float, help=\"weight of the ssim divergence loss\")\n    group.add_argument(\"--lambda_reg\", default=0.1, type=float, help=\"weight of the reg loss\")\n    # group.add_argument(\"--lambda_var\", default=-0.1, type=float, help=\"weight of the var divergence loss\")\n\n    group.add_argument(\"--num_layers\", default=2, type=int, help=\"Number of layers for GRU and transformer\")\n    group.add_argument(\"--ff_size\", default=128, type=int, help=\"Size of feedforward for transformer\")\n    group.add_argument(\"--max_distance\", default=128, type=int, help=\"\")\n    group.add_argument(\"--num_buckets\", default=128, type=int, help=\"\")\n    group.add_argument(\"--audio_latent_dim\", default=256, type=int, help=\"Size of audio latent for transformer\")\n    group.add_argument(\"--first3\", default=False, help=\"Dim of pose, 3 or 6\")\n    group.add_argument(\"--eye\", default=False, help=\"eye information\")\n    group.add_argument(\"--activation\", default=\"gelu\", help=\"Activation for function for the transformer layers\")\n    group.add_argument(\"--dropout\", default=0.1, type=float, help=\"Activation for function for the transformer layers\")\n\n    # # Ablations\n    # group.add_argument(\"--ablation\", choices=[None, \"average_encoder\", \"zandtime\", \"time_encoding\", \"concat_bias\"],\n    #                    help=\"Ablations for the transformer architechture\")\n\n\ndef parse_modelname(modelname):\n    modeltype, archiname, *losses = modelname.split(\"_\")\n\n    if modeltype not in MODELTYPES:\n        raise NotImplementedError(\"This type of model is not implemented.\")\n    if archiname not in ARCHINAMES:\n        raise NotImplementedError(\"This architechture is not implemented.\")\n\n    if len(losses) == 0:\n        raise NotImplementedError(\"You have to specify at least one loss function.\")\n\n    for loss in losses:\n        if loss not in LOSSES:\n            raise NotImplementedError(\"This loss is not implemented.\")\n\n    return modeltype, archiname, losses\n"
  },
  {
    "path": "PBnet/src/parser/recognition.py",
    "content": "import os\n\nfrom .base import argparse, add_misc_options, add_cuda_options, adding_cuda\nfrom .tools import save_args\nfrom .dataset import add_dataset_options\nfrom .training import add_training_options\nfrom .checkpoint import construct_checkpointname\n\n\ndef training_parser():\n    parser = argparse.ArgumentParser()\n    \n    # misc options\n    add_misc_options(parser)\n    \n    # training options\n    add_training_options(parser)\n\n    # dataset options\n    add_dataset_options(parser)\n\n    # model options\n    add_cuda_options(parser)\n    \n    opt = parser.parse_args()\n    \n    # remove None params, and create a dictionnary\n    parameters = {key: val for key, val in vars(opt).items() if val is not None}\n\n    parameters[\"modelname\"] = \"recognition\"\n    \n    if \"folder\" not in parameters:\n        parameters[\"folder\"] = construct_checkpointname(parameters,\n                                                        parameters[\"expname\"])\n\n    os.makedirs(parameters[\"folder\"], exist_ok=True)\n    save_args(parameters, folder=parameters[\"folder\"])\n\n    adding_cuda(parameters)\n    \n    return parameters\n"
  },
  {
    "path": "PBnet/src/parser/tools.py",
    "content": "import os\nimport yaml\n\n\ndef save_args(opt, folder):\n    os.makedirs(folder, exist_ok=True)\n    \n    # Save as yaml\n    optpath = os.path.join(folder, \"opt.yaml\")\n    with open(optpath, 'w') as opt_file:\n        yaml.dump(opt, opt_file)\n\n\ndef load_args(filename):\n    with open(filename, \"rb\") as optfile:\n        opt = yaml.load(optfile, Loader=yaml.Loader)\n    return opt\n\n\n"
  },
  {
    "path": "PBnet/src/parser/training.py",
    "content": "import os\n\nfrom .base import add_misc_options, add_cuda_options, adding_cuda, ArgumentParser\nfrom .tools import save_args\nfrom .dataset import add_dataset_options\nfrom .model import add_model_options, parse_modelname\nfrom .checkpoint import construct_checkpointname\n\n\ndef add_training_options(parser):\n    group = parser.add_argument_group('Training options')\n    group.add_argument(\"--ckpt\", default='')\n    group.add_argument(\"--batch_size\", default=100, type=int, help=\"size of the batches\")\n    group.add_argument(\"--num_epochs\", default=5000, type=int, help=\"number of epochs of training\")\n    group.add_argument(\"--lr\", default=0.0004, type=float, help=\"AdamW: learning rate\")\n    group.add_argument(\"--snapshot\", default=2000, type=int, help=\"frequency of saving model/viz\")\n    # ff_size\n\ndef parser():\n    parser = ArgumentParser()\n\n    # misc options\n    add_misc_options(parser)\n\n    # cuda options\n    add_cuda_options(parser)\n    \n    # training options\n    add_training_options(parser)\n\n    # dataset options\n    add_dataset_options(parser)\n\n    # model options\n    add_model_options(parser)\n\n    opt = parser.parse_args()\n    \n    # remove None params, and create a dictionnary\n    parameters = {key: val for key, val in vars(opt).items() if val is not None}\n\n    # parse modelname\n    ret = parse_modelname(parameters[\"modelname\"])\n    parameters[\"modeltype\"], parameters[\"archiname\"], parameters[\"losses\"] = ret\n    \n    # update lambdas params\n    lambdas = {}\n    for loss in parameters[\"losses\"]:\n        lambdas[loss] = opt.__getattribute__(f\"lambda_{loss}\")\n    parameters[\"lambdas\"] = lambdas\n    \n    if \"folder\" not in parameters:\n        parameters[\"folder\"] = construct_checkpointname(parameters, parameters[\"expname\"])\n\n    os.makedirs(parameters[\"folder\"], exist_ok=True)\n    save_args(parameters, folder=parameters[\"folder\"])\n\n    adding_cuda(parameters)\n    \n    return parameters\n"
  },
  {
    "path": "PBnet/src/parser/visualize.py",
    "content": "import os\n\nfrom src.models.get_model import JOINTSTYPES\nfrom .base import ArgumentParser, add_cuda_options, adding_cuda\nfrom .tools import load_args\nfrom .dataset import add_dataset_options\n\n\ndef construct_figname(parameters):\n    figname = \"fig_{:03d}\"\n    return figname\n\n\ndef add_visualize_options(parser):\n    group = parser.add_argument_group('Visualization options')\n    group.add_argument(\"--num_actions_to_sample\", default=5, type=int, help=\"num actions to sample\")\n    group.add_argument(\"--num_samples_per_action\", default=5, type=int, help=\"num samples per action\")\n    group.add_argument(\"--fps\", default=20, type=int, help=\"FPS for the rendering\")\n\n    group.add_argument(\"--force_visu_joints\", dest='force_visu_joints', action='store_true',\n                       help=\"if we want to visualize joints even if it is rotation\")\n    group.add_argument('--no-force_visu_joints', dest='force_visu_joints', action='store_false',\n                       help=\"if we don't want to visualize joints even if it is rotation\")\n    group.set_defaults(force_visu_joints=True)\n\n    group.add_argument(\"--jointstype\", default=\"smpl\", choices=JOINTSTYPES,\n                       help=\"Jointstype for training with xyz\")\n    group.add_argument('--vertstrans', dest='vertstrans', action='store_true', help=\"Training with vertex translations\")\n    group.add_argument('--no-vertstrans', dest='vertstrans', action='store_false', help=\"Training without vertex translations\")\n    group.set_defaults(vertstrans=False)\n\n    group.add_argument(\"--noise_same_action\", default=\"random\",\n                       choices=[\"interpolate\", \"random\", \"same\"],\n                       help=\"inside one action, sample several noise or interpolate it\")\n    \n    group.add_argument(\"--noise_diff_action\", default=\"random\",\n                       choices=[\"random\", \"same\"],\n                       help=\"use the same noise or different noise for every actions\")\n\n    group.add_argument(\"--duration_mode\", default=\"mean\",\n                       choices=[\"mean\", \"interpolate\"],\n                       help=\"use the same noise or different noise for every actions\")\n\n    group.add_argument(\"--reconstruction_mode\", default=\"ntf\",\n                       choices=[\"tf\", \"ntf\", \"both\"],\n                       help=\"reconstruction: teacher forcing or not or both\")\n\n    group.add_argument(\"--decoder_test\", default=\"new\",\n                       choices=[\"new\", \"diffaction\", \"diffduration\", \"interpolate_action\"],\n                       help=\"what is the test we want to do\")\n    \n    group.add_argument(\"--fact_latent\", type=int, default=1,\n                       help=\"factor for max latent space\")\n\n\ndef parser(checkpoint=True):\n    parser = ArgumentParser()\n    if checkpoint:\n        parser.add_argument(\"checkpointname\")\n    else:\n        add_dataset_options(parser)\n    \n    # add visualize options back\n    add_visualize_options(parser)\n\n    # cuda options\n    add_cuda_options(parser)\n    \n    opt = parser.parse_args()\n    if checkpoint:\n        newparameters = {key: val for key, val in vars(opt).items() if val is not None}\n        folder, checkpoint = os.path.split(newparameters[\"checkpointname\"])\n        parameters = load_args(os.path.join(folder, \"opt.yaml\"))\n        parameters.update(newparameters)\n    else:\n        parameters = {key: val for key, val in vars(opt).items() if val is not None}\n        \n    adding_cuda(parameters)\n\n    if checkpoint:\n        parameters[\"figname\"] = construct_figname(parameters)\n        epoch = int(checkpoint.split(\"_\")[-1].split('.')[0])\n        return parameters, folder, checkpoint, epoch\n    else:\n        return parameters\n"
  },
  {
    "path": "PBnet/src/preprocess/humanact12_process.py",
    "content": "import os\nimport numpy as np\nimport pickle as pkl\nfrom phspdtools import CameraParams\n\n\ndef splitname(name):\n    subject = name[1:3]\n    group = name[4:6]\n    time = name[7:9]\n    frame1 = name[10:14]\n    frame2 = name[15:19]\n    action = name[20:24]\n    return subject, group, time, frame1, frame2, action\n\n\ndef create_phpsd_name(name):\n    subject, group, time, frame1, frame2, action = splitname(name)\n    phpsdname = f\"subject{subject}_group{int(group)}_time{int(time)}\"\n    return phpsdname\n\n\ndef get_frames(name):\n    subject, group, time, frame1, frame2, action = splitname(name)\n    return int(frame1), int(frame2)\n\n\ndef get_action(name, coarse=True):\n    subject, group, time, frame1, frame2, action = splitname(name)\n    if coarse:\n        return action[:2]\n    else:\n        return action\n\n\nhumanact12_coarse_action_enumerator = {\n    1: \"warm_up\",\n    2: \"walk\",\n    3: \"run\",\n    4: \"jump\",\n    5: \"drink\",\n    6: \"lift_dumbbell\",\n    7: \"sit\",\n    8: \"eat\",\n    9: \"turn steering wheel\",\n    10: \"phone\",\n    11: \"boxing\",\n    12: \"throw\",\n}\n\n\nhumanact12_coarse_action_to_label = {x: x-1 for x in range(1, 13)}\n\n\ndef process_datata(savepath, posesfolder=\"data/PHPSDposes\", datapath=\"data/HumanAct12\", campath=\"data/phspdCameras\"):\n    data_list = os.listdir(datapath)\n    data_list.sort()\n\n    camera_params = CameraParams(campath)\n\n    vibestyle = {\"poses\": [], \"oldposes\": [], \"joints3D\": [], \"y\": []}\n    for index, name in enumerate(data_list):\n        foldername = create_phpsd_name(name)\n        subject = foldername.split(\"_\")[0]\n        T = camera_params.get_extrinsic(\"c2\", subject)\n\n        frame1, frame2 = get_frames(name)\n        # subjecta, groupa, timea, frame1a, frame2a, actiona = splitname(name)\n\n        posepath = os.path.join(posesfolder, foldername, \"pose.txt\")\n        smplposepath = os.path.join(posesfolder, foldername, \"shape_smpl.txt\")\n        npypath = os.path.join(datapath, name)\n        joints3D = np.load(npypath)\n\n        # take this one to get same number of frames that HumanAct12 joints .npy file\n        # Otherwise we have to much frames (the registration is not perfect)\n        poses = []\n        goodframes = []\n        with open(posepath) as f:\n            for line in f.readlines():\n                tmp = line.split(' ')\n                frame_idx = int(tmp[0])\n                if frame_idx >= frame1 and frame_idx <= frame2:\n                    goodframes.append(frame_idx)\n                    pose = np.asarray([float(i) for i in tmp[1:]]).reshape([-1, 3])\n                    poses.append(pose)\n        poses = np.array(poses)\n\n        # if joints3D.shape[0] == (frame2 - frame1 + 1):\n        #     continue\n\n        smplposes = []\n        with open(smplposepath) as f:\n            for line in f.readlines():\n                tmp = line.split(' ')\n                frame_idx = int(tmp[0])\n                if frame_idx in goodframes:\n                    # pose = np.asarray([float(i) for i in tmp[1:]]).reshape([-1, 3])\n                    # poses.append(pose)\n                    smplparam = np.asarray([float(i) for i in tmp[1:]])\n                    smplpose = smplparam[13:85]\n                    smplposes.append(smplpose)\n        smplposes = np.array(smplposes)\n\n        oldposes = poses.copy()\n        # rotate to the good camera\n        poses = T.transform(poses)\n        poses = poses - poses[0][0] + joints3D[0][0]\n\n        # and verify that the pose correspond to the humanact12 data\n        if np.linalg.norm(poses - joints3D) >= 1e-10:\n            print(\"bad\")\n            continue\n\n        assert np.linalg.norm(poses - joints3D) < 1e-10\n\n        rotation = T.getmat4()[:3, :3]\n\n        import pytorch3d.transforms.rotation_conversions as p3d\n        import torch\n\n        # rotate the global rotation\n        global_matrix = p3d.axis_angle_to_matrix(torch.from_numpy(smplposes[:, :3]))\n        smplposes[:, :3] = p3d.matrix_to_axis_angle(torch.from_numpy(rotation) @ global_matrix).numpy()\n\n        assert poses.shape[0] == joints3D.shape[0]\n        assert smplposes.shape[0] == joints3D.shape[0]\n\n        vibestyle[\"poses\"].append(smplposes)\n        vibestyle[\"joints3D\"].append(joints3D)\n\n        action = get_action(name, coarse=True)\n        label = humanact12_coarse_action_to_label[int(action)]\n        vibestyle[\"y\"].append(label)\n\n    pkl.dump(vibestyle, open(savepath, \"wb\"))\n\n\nif __name__ == \"__main__\":\n    folder = \"data/HumanAct12Poses/\"\n    os.makedirs(folder, exist_ok=True)\n    savepath = os.path.join(folder, \"humanact12poses.pkl\")\n    process_datata(savepath)\n"
  },
  {
    "path": "PBnet/src/preprocess/phspdtools.py",
    "content": "# taken and adapted from https://github.com/JimmyZou/PolarHumanPoseShape/\nimport pickle\nimport numpy as np\nimport os\n\n\nclass Transform:\n    def __init__(self, R=np.eye(3, dtype='float'), t=np.zeros(3, 'float'), s=np.ones(3, 'float')):\n        self.R = R.copy()  # rotation\n        self.t = t.reshape(-1).copy()  # translation\n        self.s = s.copy()  # scale\n\n    def __mul__(self, other):\n        # combine two transformation together\n        R = np.dot(self.R, other.R)\n        t = np.dot(self.R, other.t * self.s) + self.t\n        if not hasattr(other, 's'):\n            other.s = np.ones(3, 'float').copy()\n        s = other.s.copy()\n        return Transform(R, t, s)\n\n    def inv(self):\n        # inverse the rigid tansformation\n        R = self.R.T\n        t = -np.dot(self.R.T, self.t)\n        return Transform(R, t)\n\n    def transform(self, xyz):\n        # transform 3D point\n        if not hasattr(self, 's'):\n            self.s = np.ones(3, 'float').copy()\n        assert xyz.shape[-1] == 3\n        assert len(self.s) == 3\n        return np.dot(xyz * self.s, self.R.T) + self.t\n\n    def getmat4(self):\n        # homogeneous transformation matrix\n        M = np.eye(4)\n        M[:3, :3] = self.R * self.s\n        M[:3, 3] = self.t\n        return M\n\n\ndef quat2R(quat):\n    \"\"\"\n    Description\n    ===========\n    convert vector q to matrix R\n\n    Parameters\n    ==========\n    :param quat: (4,) array\n\n    Returns\n    =======\n    :return: (3,3) array\n    \"\"\"\n    w = quat[0]\n    x = quat[1]\n    y = quat[2]\n    z = quat[3]\n\n    n = w * w + x * x + y * y + z * z\n    s = 2. / np.clip(n, 1e-7, 1e7)\n\n    wx = s * w * x\n    wy = s * w * y\n    wz = s * w * z\n    xx = s * x * x\n    xy = s * x * y\n    xz = s * x * z\n    yy = s * y * y\n    yz = s * y * z\n    zz = s * z * z\n\n    R = np.stack([1 - (yy + zz), xy - wz, xz + wy,\n                  xy + wz, 1 - (xx + zz), yz - wx,\n                  xz - wy, yz + wx, 1 - (xx + yy)])\n\n    return R.reshape((3, 3))\n\n\ndef convert_param2tranform(param, scale=1):\n    R = quat2R(param[0:4])\n    t = param[4:7]\n    s = scale * np.ones(3, 'float')\n    return Transform(R, t, s)\n\n\nclass CameraParams:\n    def __init__(self, cam_folder=\"data/phspdCameras\"):\n        \n        # load camera params, save intrinsic and extrinsic camera parameters as a dictionary\n        # intrinsic ['param_p', 'param_c1', 'param_d1', 'param_c2', 'param_d2', 'param_c3', 'param_d3']\n        # extrinsic ['d1p', 'd2p', 'd3p', 'cd1', 'cd2', 'cd3']\n        self.cam_params = []\n        with open(os.path.join(cam_folder, \"CamParams0906.pkl\"), 'rb') as f:\n            self.cam_params.append(pickle.load(f))\n        with open(os.path.join(cam_folder, \"CamParams0909.pkl\"), 'rb') as f:\n            self.cam_params.append(pickle.load(f))\n\n        # corresponding cam params to each subject\n        self.name_cam_params = {}  # {\"name\": 0 or 1}\n        for name in ['subject06', 'subject09', 'subject11', 'subject05', 'subject12', 'subject04']:\n            self.name_cam_params[name] = 0\n        for name in ['subject03', 'subject01', 'subject02', 'subject10', 'subject07', 'subject08']:\n            self.name_cam_params[name] = 1\n\n        # corresponding cam params to each subject\n        self.name_gender = {}  # {\"name\": 0 or 1}\n        for name in ['subject02', 'subject03', 'subject04', 'subject05', 'subject06',\n                     'subject08', 'subject09', 'subject11', 'subject12']:\n            self.name_gender[name] = 0  # male\n        for name in ['subject01', 'subject07', 'subject10']:\n            self.name_gender[name] = 1  # female\n\n    def get_intrinsic(self, cam_name, subject_no):\n        \"\"\"\n        'p': polarization camera, color\n        'c1': color camera for the 1st Kinect\n        'd1': depth (ToF) camera for the 1st Kinect\n        ...\n        return\n            (fx, fy, cx, cy)\n        \"\"\"\n        assert cam_name in ['p', 'c1', 'd1', 'c2', 'd2', 'c3', 'd3']\n        assert subject_no in ['subject06', 'subject09', 'subject11', 'subject05', 'subject12', 'subject04',\n                              'subject03', 'subject01', 'subject02', 'subject10', 'subject07', 'subject08']\n        fx, fy, cx, cy, _, _, _ = self.cam_params[self.name_cam_params[subject_no]]['param_%s' % cam_name]\n        intrinsic = (fx, fy, cx, cy)\n        return intrinsic\n\n    def get_extrinsic(self, cams_name, subject_no):\n        \"\"\"\n        The annotated poses and shapes are saved in polarization camera coordinate.\n        'd1p': transform from polarization camera to 1st Kinect depth image\n        'c1p': transform from polarization camera to 1st Kinect color image\n        ...\n        return\n            transform class\n        \"\"\"\n        assert cams_name in ['d1', 'd2', 'd3', 'c1', 'c2', 'c3']\n        assert subject_no in ['subject06', 'subject09', 'subject11', 'subject05', 'subject12', 'subject04',\n                              'subject03', 'subject01', 'subject02', 'subject10', 'subject07', 'subject08']\n\n        if cams_name in ['d1p', 'd2p', 'd3p']:\n            T = convert_param2tranform(self.cam_params[self.name_cam_params[subject_no]][cams_name])\n        else:\n            i = cams_name[1]\n            T_dp = convert_param2tranform(self.cam_params[self.name_cam_params[subject_no]]['d%sp' % i])\n            T_cd = convert_param2tranform(self.cam_params[self.name_cam_params[subject_no]]['cd%s' % i])\n            T = T_cd * T_dp\n        return T\n\n    def get_gender(self, subject_no):\n        return self.name_gender[subject_no]\n\n\nif __name__ == '__main__':\n    # test\n    camera_params = CameraParams(data_dir='../..//data')\n    T = camera_params.get_extrinsic('c2', 'subject01')\n    print(T.getmat4())\n\n\n"
  },
  {
    "path": "PBnet/src/preprocess/uestc_vibe_postprocessing.py",
    "content": "import numpy as np\nimport pickle as pkl\nimport tarfile\nimport os\nimport scipy.io as sio\nfrom tqdm import tqdm\nimport src.utils.rotation_conversions as geometry\nimport torch\n\nW = 960\nH = 540\n\n\ndef get_kinect_motion(tar, videos, index):\n    # skeleton loading\n    video = videos[index]\n    skeleton_name = video.replace(\"color.avi\", \"skeleton.mat\")\n    skeleton_path = os.path.join(\"mat_from_skeleton\", skeleton_name)\n    ffile = tar.extractfile(skeleton_path)\n    skeleton = sio.loadmat(ffile, variable_names=[\"v\"])[\"v\"]\n    skeleton = skeleton.reshape(-1, 25, 3)\n    return skeleton\n\n\ndef motionto2d(motion, W=960, H=540):\n    K = np.array(((540, 0, W / 2),\n                  (0, 540, H / 2),\n                  (0, 0, 1)))\n    motion[..., 1] = -motion[..., 1]\n    motion2d = np.einsum(\"tjk,lk->tjl\", motion, K)\n    nonzeroix = np.where(motion2d[..., 2] != 0)\n    motion2d[nonzeroix] = motion2d[nonzeroix] / motion2d[(*nonzeroix, 2)][..., None]\n    return motion2d[..., :2]\n\n\ndef motionto2dvibe(motion, cam):\n    sx, sy, tx, ty = cam\n    return (motion[..., :2] + [tx, ty]) * [W/2*sx, H/2*sy] + [W/2, H/2]\n\n\ndef get_kcenter(tar, videos, index):\n    kmotion2d = motionto2d(get_kinect_motion(tar, videos, index))\n    kboxes = np.hstack((kmotion2d.min(1), kmotion2d.max(1)))\n    x1, y1, x2, y2 = kboxes.T\n    kcenter = np.stack(((x1 + x2)/2, (y1 + y2)/2)).T\n    return kcenter\n\n\ndef get_concat_goodtracks(allvibe, tar, videos, index):\n    idxall = allvibe[index]\n    kcenter = get_kcenter(tar, videos, index)\n    tracks = np.array(list(idxall.keys()))\n\n    if len(tracks) == 1:\n        return idxall[tracks[0]], tracks\n\n    remainingmask = np.ones(len(tracks), dtype=bool)\n\n    currenttrack = None\n    vibetracks = []\n    while remainingmask.any():\n        # find new track\n        # first look at the closest new track in time\n        candidate = np.argmin([idxall[track][\"frame_ids\"][0] for track in tracks[remainingmask]])\n        candidate_max = idxall[tracks[remainingmask][candidate]][\"frame_ids\"][-1]\n\n        # look for other candidate which intersect with the candidate (conflict)\n        candidates = np.where(np.array([idxall[track][\"frame_ids\"][0] <= candidate_max for track in tracks[remainingmask]]))[0]\n\n        # if the candidate is alone, take it\n        if len(candidates) == 1:\n            idx = np.where(remainingmask)[0][candidate]\n        # if there are conflit, find the closest match\n        else:\n            # take the closest one in distance to the last center observed\n            if currenttrack is None:  # take the kinect output\n                lastbox = kcenter[0]\n            else:  # take the last boxe output\n                lastbox = idxall[currenttrack][\"bboxes\"][-1, :2]\n            dists = np.linalg.norm([idxall[tracks[remainingmask][candidate]][\"bboxes\"][0, :2] - lastbox\n                                    for candidate in candidates], axis=1)\n            idx = np.where(remainingmask)[0][candidates[np.argmin(dists)]]\n\n        # compute informations\n        currenttrack = tracks[idx]\n        vibetracks.append(currenttrack)\n        lastframe = idxall[currenttrack][\"frame_ids\"][-1]\n\n        # filter overlapping frames\n        remainingmask = np.array([idxall[track][\"frame_ids\"][0] > lastframe for track in tracks]) & remainingmask\n\n    goodvibe = {key: [] for key in ['pred_cam', 'orig_cam', 'pose',\n                                    'betas', 'joints3d', 'bboxes', 'frame_ids']}\n\n    for key in goodvibe:\n        goodvibe[key] = np.concatenate([idxall[track][key] for track in vibetracks])\n\n    return goodvibe, vibetracks\n\n\ndef interpolate_track(gvibe):\n    # interpolation\n    starting = np.where((gvibe[\"frame_ids\"][1:] - gvibe[\"frame_ids\"][:-1]) != 1)[0] + 1\n\n    lastend = 0\n    saveall = {key: [] for key in gvibe.keys() if key != \"joints2d\"}\n\n    for start in starting:\n        begin = start - 1\n        end = start\n        lastgoodidx = gvibe[\"frame_ids\"][begin]\n        firstnewgoodidx = gvibe[\"frame_ids\"][end]\n\n        for key in saveall.keys():\n            # save the segment before the cut\n            saveall[key].append(gvibe[key][lastend:begin+1])\n\n            # extract the last good info\n            lastgoodinfo = gvibe[key][begin]\n\n            # extract the first regood info\n            newfirstgoodinfo = gvibe[key][end]\n\n            if key == \"pose\":  # interpolate in quaternions\n                q0 = geometry.axis_angle_to_quaternion(torch.from_numpy(lastgoodinfo.reshape(24, 3)))\n                q1 = geometry.axis_angle_to_quaternion(torch.from_numpy(newfirstgoodinfo.reshape(24, 3)))\n                q2 = geometry.axis_angle_to_quaternion(-torch.from_numpy(newfirstgoodinfo.reshape(24, 3)))\n                # Help when the interpolation is between pi and -pi\n                # It avoid the problem of inverting people with global rotation\n                # It is not optimal but it is better than nothing\n                # newfirstgoodinfo = torch.where((torch.argmin(torch.stack((torch.linalg.norm(q0-q1, axis=1),\n                # torch.linalg.norm(q0-q2, axis=1))), axis=0) == 0)[:, None], q1, q2)\n                first = [q1[0], q2[0]][np.argmin((torch.linalg.norm(q0[0]-q1[0]),\n                                                  torch.linalg.norm(q0[0]-q2[0])))]\n                newfirstgoodinfo = q1\n                newfirstgoodinfo[0] = first\n                lastgoodinfo = q0\n\n            # interpolate in between\n            interinfo = []\n            for x in range(lastgoodidx+1, firstnewgoodidx):\n                # linear coeficient\n                w2 = x - lastgoodidx\n                w1 = firstnewgoodidx - x\n                w1, w2 = w1/(w1+w2), w2/(w1+w2)\n\n                inter = lastgoodinfo * w1 + newfirstgoodinfo * w2\n                if key == \"pose\":  # interpolate in quaternions\n                    # normalize the quaternion\n                    inter = inter/torch.linalg.norm(inter, axis=1)[:, None]\n                    inter = geometry.quaternion_to_axis_angle(inter).numpy().reshape(-1)\n\n                interinfo.append(inter)\n\n            saveall[key].append(interinfo)\n        lastend = end\n\n    for key in saveall.keys():\n        saveall[key].append(gvibe[key][lastend:])\n        saveall[key] = np.concatenate(saveall[key])\n\n    saveall[\"frame_ids\"] = np.round(saveall[\"frame_ids\"]).astype(int)\n\n    # make sure the interpolation was fine => looking at a whole frame_ids\n\n    assert (saveall[\"frame_ids\"] == np.arange(gvibe[\"frame_ids\"].min(), gvibe[\"frame_ids\"].max()+1)).all()\n\n    return saveall\n\n\nif __name__ == \"__main__\":\n    datapath = \"datasets/uestc/\"\n    allpath = os.path.join(datapath, \"vibe_cache_all_tracks.pkl\")\n    oldpath = os.path.join(datapath, \"vibe_cache.pkl\")\n    videopath = os.path.join(datapath, 'info', 'names.txt')\n\n    kinectpath = os.path.join(datapath, \"mat_from_skeleton.tar\")\n\n    allvibe = pkl.load(open(allpath, \"rb\"))\n    oldvibe = pkl.load(open(oldpath, \"rb\"))\n\n    videos = open(videopath, 'r').read().splitlines()\n\n    tar = tarfile.open(kinectpath, \"r\")\n\n    newvibelst = []\n    allvtracks = []\n    for index in tqdm(range(len(videos))):\n        gvibe, vtracks = get_concat_goodtracks(allvibe, tar, videos, index)\n        allvtracks.append(vtracks)\n        newvibelst.append(interpolate_track(gvibe))\n\n    newvibe = {key: [] for key in newvibelst[0].keys()}\n\n    for nvibe in newvibelst:\n        for key in newvibe:\n            newvibe[key].append(nvibe[key])\n\n    pkl.dump(newvibe, open(\"newvibe.pkl\", \"wb\"))\n"
  },
  {
    "path": "PBnet/src/recognition/compute_accuracy.py",
    "content": "import os\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom tqdm import tqdm\n\nfrom src.utils.get_model_and_data import get_model_and_data\nfrom src.utils.tensors import collate\n\nfrom src.evaluate.tools import save_metrics\nfrom src.parser.checkpoint import parser\n\nimport src.utils.fixseed  # noqa\n\n\ndef compute_accuracy(model, datasets, parameters):\n    device = parameters[\"device\"]\n    iterators = {key: DataLoader(datasets[key], batch_size=parameters[\"batch_size\"],\n                                 shuffle=False, num_workers=8, collate_fn=collate)\n                 for key in datasets.keys()}\n\n    model.eval()\n    num_labels = parameters[\"num_classes\"]\n\n    accuracies = {}\n    with torch.no_grad():\n        for key, iterator in iterators.items():\n            confusion = torch.zeros(num_labels, num_labels, dtype=torch.long)\n            for batch in tqdm(iterator, desc=f\"Computing {key} batch\"):\n                # Put everything in device\n                batch = {key: val.to(device) for key, val in batch.items()}\n                # forward pass\n                batch = model(batch)\n                yhat = batch[\"yhat\"].max(dim=1).indices\n                ygt = batch[\"y\"]\n                for label, pred in zip(ygt, yhat):\n                    confusion[label][pred] += 1\n            accuracy = (torch.trace(confusion)/torch.sum(confusion)).item()\n            accuracies[key] = accuracy\n    return accuracies\n        \n\ndef main():\n    # parse options\n    parameters, folder, checkpointname, epoch = parser()\n    model, datasets = get_model_and_data(parameters)\n    \n    print(\"Restore weights..\")\n    checkpointpath = os.path.join(folder, checkpointname)\n    state_dict = torch.load(checkpointpath, map_location=parameters[\"device\"])\n    model.load_state_dict(state_dict)\n\n    accuracies = compute_accuracy(model, datasets, parameters)\n\n    metricname = \"recognition_accuracies_on_samedata_{}.yaml\".format(epoch)\n    \n    evalpath = os.path.join(folder, metricname)\n    print(f\"Saving score: {evalpath}\")\n    save_metrics(evalpath, accuracies)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "PBnet/src/recognition/get_model.py",
    "content": "from .models.stgcn import STGCN\n\n\ndef get_model(parameters):\n    layout = \"smpl\" if parameters[\"glob\"] else \"smpl_noglobal\"\n    \n    model = STGCN(in_channels=parameters[\"nfeats\"],\n                  num_class=parameters[\"num_classes\"],\n                  graph_args={\"layout\": layout, \"strategy\": \"spatial\"},\n                  edge_importance_weighting=True,\n                  device=parameters[\"device\"])\n    \n    model = model.to(parameters[\"device\"])\n    return model\n    \n"
  },
  {
    "path": "PBnet/src/recognition/models/stgcn.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .stgcnutils.tgcn import ConvTemporalGraphical\nfrom .stgcnutils.graph import Graph\n\n__all__ = [\"STGCN\"]\n\n\nclass STGCN(nn.Module):\n    r\"\"\"Spatial temporal graph convolutional networks.\n    Args:\n        in_channels (int): Number of channels in the input data\n        num_class (int): Number of classes for the classification task\n        graph_args (dict): The arguments for building the graph\n        edge_importance_weighting (bool): If ``True``, adds a learnable\n            importance weighting to the edges of the graph\n        **kwargs (optional): Other parameters for graph convolution units\n    Shape:\n        - Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})`\n        - Output: :math:`(N, num_class)` where\n            :math:`N` is a batch size,\n            :math:`T_{in}` is a length of input sequence,\n            :math:`V_{in}` is the number of graph nodes,\n            :math:`M_{in}` is the number of instance in a frame.\n    \"\"\"\n\n    def __init__(self, in_channels, num_class, graph_args,\n                 edge_importance_weighting, device, **kwargs):\n        super().__init__()\n\n        self.device = device\n        self.num_class = num_class\n        \n        self.losses = [\"accuracy\", \"cross_entropy\", \"mixed\"]\n        self.criterion = torch.nn.CrossEntropyLoss(reduction='mean')\n\n        # load graph\n        self.graph = Graph(**graph_args)\n        A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)\n        self.register_buffer('A', A)\n\n        # build networks\n        spatial_kernel_size = A.size(0)\n        temporal_kernel_size = 9\n        kernel_size = (temporal_kernel_size, spatial_kernel_size)\n        self.data_bn = nn.BatchNorm1d(in_channels * A.size(1))\n        kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'}\n        self.st_gcn_networks = nn.ModuleList((\n            st_gcn(in_channels, 64, kernel_size, 1, residual=False, **kwargs0),\n            st_gcn(64, 64, kernel_size, 1, **kwargs),\n            st_gcn(64, 64, kernel_size, 1, **kwargs),\n            st_gcn(64, 64, kernel_size, 1, **kwargs),\n            st_gcn(64, 128, kernel_size, 2, **kwargs),\n            st_gcn(128, 128, kernel_size, 1, **kwargs),\n            st_gcn(128, 128, kernel_size, 1, **kwargs),\n            st_gcn(128, 256, kernel_size, 2, **kwargs),\n            st_gcn(256, 256, kernel_size, 1, **kwargs),\n            st_gcn(256, 256, kernel_size, 1, **kwargs),\n        ))\n\n        # initialize parameters for edge importance weighting\n        if edge_importance_weighting:\n            self.edge_importance = nn.ParameterList([\n                nn.Parameter(torch.ones(self.A.size()))\n                for i in self.st_gcn_networks\n            ])\n        else:\n            self.edge_importance = [1] * len(self.st_gcn_networks)\n\n        # fcn for prediction\n        self.fcn = nn.Conv2d(256, num_class, kernel_size=1)\n\n    def forward(self, batch):\n        # TODO: use mask\n        # Received batch[\"x\"] as\n        #   Batch(48), Joints(23), Quat(4), Time(157\n        # Expecting:\n        #   Batch, Quat:4, Time, Joints, 1\n        x = batch[\"x\"].permute(0, 2, 3, 1).unsqueeze(4).contiguous()\n\n        # data normalization\n        N, C, T, V, M = x.size()\n        x = x.permute(0, 4, 3, 1, 2).contiguous()\n        x = x.view(N * M, V * C, T)\n        x = self.data_bn(x)\n        x = x.view(N, M, V, C, T)\n        x = x.permute(0, 1, 3, 4, 2).contiguous()\n        x = x.view(N * M, C, T, V)\n\n        # forward\n        for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):\n            x, _ = gcn(x, self.A * importance)\n\n        # compute feature\n        # _, c, t, v = x.size()\n        # features = x.view(N, M, c, t, v).permute(0, 2, 3, 4, 1)\n        # batch[\"features\"] = features\n        \n        # global pooling\n        x = F.avg_pool2d(x, x.size()[2:])\n        x = x.view(N, M, -1, 1, 1).mean(dim=1)\n\n        # features\n        batch[\"features\"] = x.squeeze()\n        \n        # prediction\n        x = self.fcn(x)\n        x = x.view(x.size(0), -1)\n        batch[\"yhat\"] = x\n        return batch\n\n    def compute_accuracy(self, batch):\n        confusion = torch.zeros(self.num_class, self.num_class, dtype=int)\n        yhat = batch[\"yhat\"].max(dim=1).indices\n        ygt = batch[\"y\"]\n        for label, pred in zip(ygt, yhat):\n            confusion[label][pred] += 1\n        accuracy = torch.trace(confusion)/torch.sum(confusion)\n        return accuracy\n    \n    def compute_loss(self, batch):\n        cross_entropy = self.criterion(batch[\"yhat\"], batch[\"y\"])\n        mixed_loss = cross_entropy\n        \n        acc = self.compute_accuracy(batch)\n        losses = {\"cross_entropy\": cross_entropy.item(),\n                  \"mixed\": mixed_loss.item(),\n                  \"accuracy\": acc.item()}\n        return mixed_loss, losses\n\n\nclass st_gcn(nn.Module):\n    r\"\"\"Applies a spatial temporal graph convolution over an input graph sequence.\n    Args:\n        in_channels (int): Number of channels in the input sequence data\n        out_channels (int): Number of channels produced by the convolution\n        kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel\n        stride (int, optional): Stride of the temporal convolution. Default: 1\n        dropout (int, optional): Dropout rate of the final output. Default: 0\n        residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True``\n    Shape:\n        - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format\n        - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format\n        - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format\n        - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format\n        where\n            :math:`N` is a batch size,\n            :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,\n            :math:`T_{in}/T_{out}` is a length of input/output sequence,\n            :math:`V` is the number of graph nodes.\n    \"\"\"\n\n    def __init__(self,\n                 in_channels,\n                 out_channels,\n                 kernel_size,\n                 stride=1,\n                 dropout=0,\n                 residual=True):\n        super().__init__()\n\n        assert len(kernel_size) == 2\n        assert kernel_size[0] % 2 == 1\n        padding = ((kernel_size[0] - 1) // 2, 0)\n\n        self.gcn = ConvTemporalGraphical(in_channels, out_channels,\n                                         kernel_size[1])\n\n        self.tcn = nn.Sequential(\n            nn.BatchNorm2d(out_channels),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(\n                out_channels,\n                out_channels,\n                (kernel_size[0], 1),\n                (stride, 1),\n                padding,\n            ),\n            nn.BatchNorm2d(out_channels),\n            nn.Dropout(dropout, inplace=True),\n        )\n\n        if not residual:\n            self.residual = lambda x: 0\n\n        elif (in_channels == out_channels) and (stride == 1):\n            self.residual = lambda x: x\n\n        else:\n            self.residual = nn.Sequential(\n                nn.Conv2d(\n                    in_channels,\n                    out_channels,\n                    kernel_size=1,\n                    stride=(stride, 1)),\n                nn.BatchNorm2d(out_channels),\n            )\n\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, x, A):\n\n        res = self.residual(x)\n        x, A = self.gcn(x, A)\n        x = self.tcn(x) + res\n\n        return self.relu(x), A\n\n\nif __name__ == \"__main__\":\n    model = STGCN(in_channels=3, num_class=60, edge_importance_weighting=True, graph_args={\"layout\": \"smpl_noglobal\", \"strategy\": \"spatial\"})\n    # Batch, in_channels, time, vertices, M\n    inp = torch.rand(10, 3, 16, 23, 1)\n    out = model(inp)\n    print(out.shape)\n    import pdb\n    pdb.set_trace()\n"
  },
  {
    "path": "PBnet/src/recognition/models/stgcnutils/graph.py",
    "content": "import numpy as np\nimport pickle as pkl\n\nfrom src.config import SMPL_KINTREE_PATH\n\n\nclass Graph:\n    \"\"\" The Graph to model the skeletons extracted by the openpose\n    Args:\n        strategy (string): must be one of the follow candidates\n        - uniform: Uniform Labeling\n        - distance: Distance Partitioning\n        - spatial: Spatial Configuration\n        For more information, please refer to the section 'Partition Strategies'\n            in our paper (https://arxiv.org/abs/1801.07455).\n        layout (string): must be one of the follow candidates\n        - openpose: Is consists of 18 joints. For more information, please\n            refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose#output\n        - ntu-rgb+d: Is consists of 25 joints. For more information, please\n            refer to https://github.com/shahroudy/NTURGB-D\n        - smpl: Consists of 24/23 joints with without global rotation.\n        max_hop (int): the maximal distance between two connected nodes\n        dilation (int): controls the spacing between the kernel points\n    \"\"\"\n\n    def __init__(self,\n                 layout='openpose',\n                 strategy='uniform',\n                 kintree_path=SMPL_KINTREE_PATH,\n                 max_hop=1,\n                 dilation=1):\n        self.max_hop = max_hop\n        self.dilation = dilation\n\n        self.kintree_path = kintree_path\n        \n        self.get_edge(layout)\n        self.hop_dis = get_hop_distance(\n            self.num_node, self.edge, max_hop=max_hop)\n        self.get_adjacency(strategy)\n\n    def __str__(self):\n        return self.A\n\n    def get_edge(self, layout):\n        if layout == 'openpose':\n            self.num_node = 18\n            self_link = [(i, i) for i in range(self.num_node)]\n            neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12,\n                                                                        11),\n                             (10, 9), (9, 8), (11, 5), (8, 2), (5, 1), (2, 1),\n                             (0, 1), (15, 0), (14, 0), (17, 15), (16, 14)]\n            self.edge = self_link + neighbor_link\n            self.center = 1\n        elif layout == 'smpl':\n            self.num_node = 24\n            self_link = [(i, i) for i in range(self.num_node)]\n            kt = pkl.load(open(self.kintree_path, \"rb\"))\n            neighbor_link = [(k, kt[1][i + 1]) for i, k in enumerate(kt[0][1:])]\n            self.edge = self_link + neighbor_link\n            self.center = 0\n        elif layout == 'smpl_noglobal':\n            self.num_node = 23\n            self_link = [(i, i) for i in range(self.num_node)]\n            kt = pkl.load(open(self.kintree_path, \"rb\"))\n            neighbor_link = [(k, kt[1][i + 1]) for i, k in enumerate(kt[0][1:])]\n            # remove the root joint\n            neighbor_1base = [n for n in neighbor_link if n[0] != 0 and n[1] != 0]\n            neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]\n            self.edge = self_link + neighbor_link\n            self.center = 0\n        elif layout == 'ntu-rgb+d':\n            self.num_node = 25\n            self_link = [(i, i) for i in range(self.num_node)]\n            neighbor_1base = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21),\n                              (6, 5), (7, 6), (8, 7), (9, 21), (10, 9),\n                              (11, 10), (12, 11), (13, 1), (14, 13), (15, 14),\n                              (16, 15), (17, 1), (18, 17), (19, 18), (20, 19),\n                              (22, 23), (23, 8), (24, 25), (25, 12)]\n            neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]\n            self.edge = self_link + neighbor_link\n            self.center = 21 - 1\n        elif layout == 'ntu_edge':\n            self.num_node = 24\n            self_link = [(i, i) for i in range(self.num_node)]\n            neighbor_1base = [(1, 2), (3, 2), (4, 3), (5, 2), (6, 5), (7, 6),\n                              (8, 7), (9, 2), (10, 9), (11, 10), (12, 11),\n                              (13, 1), (14, 13), (15, 14), (16, 15), (17, 1),\n                              (18, 17), (19, 18), (20, 19), (21, 22), (22, 8),\n                              (23, 24), (24, 12)]\n            neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]\n            self.edge = self_link + neighbor_link\n            self.center = 2\n        # elif layout=='customer settings'\n        #     pass\n        else:\n            raise NotImplementedError(\"This Layout is not supported\")\n\n    def get_adjacency(self, strategy):\n        valid_hop = range(0, self.max_hop + 1, self.dilation)\n        adjacency = np.zeros((self.num_node, self.num_node))\n        for hop in valid_hop:\n            adjacency[self.hop_dis == hop] = 1\n        normalize_adjacency = normalize_digraph(adjacency)\n\n        if strategy == 'uniform':\n            A = np.zeros((1, self.num_node, self.num_node))\n            A[0] = normalize_adjacency\n            self.A = A\n        elif strategy == 'distance':\n            A = np.zeros((len(valid_hop), self.num_node, self.num_node))\n            for i, hop in enumerate(valid_hop):\n                A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis == hop]\n            self.A = A\n        elif strategy == 'spatial':\n            A = []\n            for hop in valid_hop:\n                a_root = np.zeros((self.num_node, self.num_node))\n                a_close = np.zeros((self.num_node, self.num_node))\n                a_further = np.zeros((self.num_node, self.num_node))\n                for i in range(self.num_node):\n                    for j in range(self.num_node):\n                        if self.hop_dis[j, i] == hop:\n                            if self.hop_dis[j, self.center] == self.hop_dis[\n                                    i, self.center]:\n                                a_root[j, i] = normalize_adjacency[j, i]\n                            elif self.hop_dis[j, self.\n                                              center] > self.hop_dis[i, self.\n                                                                     center]:\n                                a_close[j, i] = normalize_adjacency[j, i]\n                            else:\n                                a_further[j, i] = normalize_adjacency[j, i]\n                if hop == 0:\n                    A.append(a_root)\n                else:\n                    A.append(a_root + a_close)\n                    A.append(a_further)\n            A = np.stack(A)\n            self.A = A\n        else:\n            raise NotImplementedError(\"This Strategy is not supported\")\n\n\ndef get_hop_distance(num_node, edge, max_hop=1):\n    A = np.zeros((num_node, num_node))\n    for i, j in edge:\n        A[j, i] = 1\n        A[i, j] = 1\n\n    # compute hop steps\n    hop_dis = np.zeros((num_node, num_node)) + np.inf\n    transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]\n    arrive_mat = (np.stack(transfer_mat) > 0)\n    for d in range(max_hop, -1, -1):\n        hop_dis[arrive_mat[d]] = d\n    return hop_dis\n\n\ndef normalize_digraph(A):\n    Dl = np.sum(A, 0)\n    num_node = A.shape[0]\n    Dn = np.zeros((num_node, num_node))\n    for i in range(num_node):\n        if Dl[i] > 0:\n            Dn[i, i] = Dl[i]**(-1)\n    AD = np.dot(A, Dn)\n    return AD\n\n\ndef normalize_undigraph(A):\n    Dl = np.sum(A, 0)\n    num_node = A.shape[0]\n    Dn = np.zeros((num_node, num_node))\n    for i in range(num_node):\n        if Dl[i] > 0:\n            Dn[i, i] = Dl[i]**(-0.5)\n    DAD = np.dot(np.dot(Dn, A), Dn)\n    return DAD\n"
  },
  {
    "path": "PBnet/src/recognition/models/stgcnutils/tgcn.py",
    "content": "# The based unit of graph convolutional networks.\n\nimport torch\nimport torch.nn as nn\n\n\nclass ConvTemporalGraphical(nn.Module):\n\n    r\"\"\"The basic module for applying a graph convolution.\n    Args:\n        in_channels (int): Number of channels in the input sequence data\n        out_channels (int): Number of channels produced by the convolution\n        kernel_size (int): Size of the graph convolving kernel\n        t_kernel_size (int): Size of the temporal convolving kernel\n        t_stride (int, optional): Stride of the temporal convolution. Default: 1\n        t_padding (int, optional): Temporal zero-padding added to both sides of\n            the input. Default: 0\n        t_dilation (int, optional): Spacing between temporal kernel elements.\n            Default: 1\n        bias (bool, optional): If ``True``, adds a learnable bias to the output.\n            Default: ``True``\n    Shape:\n        - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format\n        - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format\n        - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format\n        - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format\n        where\n            :math:`N` is a batch size,\n            :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,\n            :math:`T_{in}/T_{out}` is a length of input/output sequence,\n            :math:`V` is the number of graph nodes.\n    \"\"\"\n\n    def __init__(self,\n                 in_channels,\n                 out_channels,\n                 kernel_size,\n                 t_kernel_size=1,\n                 t_stride=1,\n                 t_padding=0,\n                 t_dilation=1,\n                 bias=True):\n        super().__init__()\n\n        self.kernel_size = kernel_size\n        self.conv = nn.Conv2d(\n            in_channels,\n            out_channels * kernel_size,\n            kernel_size=(t_kernel_size, 1),\n            padding=(t_padding, 0),\n            stride=(t_stride, 1),\n            dilation=(t_dilation, 1),\n            bias=bias)\n\n    def forward(self, x, A):\n        assert A.size(0) == self.kernel_size\n\n        x = self.conv(x)\n\n        n, kc, t, v = x.size()\n        x = x.view(n, self.kernel_size, kc//self.kernel_size, t, v)\n        x = torch.einsum('nkctv,kvw->nctw', (x, A))\n\n        return x.contiguous(), A\n"
  },
  {
    "path": "PBnet/src/render/renderer.py",
    "content": "\"\"\"\nThis script is borrowed from https://github.com/mkocabas/VIBE\n Adhere to their licence to use this script\n It has been modified\n\"\"\"\n\nimport math\nimport trimesh\nimport pyrender\nimport numpy as np\nfrom pyrender.constants import RenderFlags\nimport os\n\n\nos.environ['PYOPENGL_PLATFORM'] = 'egl'\nSMPL_MODEL_DIR = \"models/smpl/\"\n\n\ndef get_smpl_faces():\n    return np.load(os.path.join(SMPL_MODEL_DIR, \"smplfaces.npy\"))\n\n\nclass WeakPerspectiveCamera(pyrender.Camera):\n    def __init__(self,\n                 scale,\n                 translation,\n                 znear=pyrender.camera.DEFAULT_Z_NEAR,\n                 zfar=None,\n                 name=None):\n        super(WeakPerspectiveCamera, self).__init__(\n            znear=znear,\n            zfar=zfar,\n            name=name,\n        )\n        self.scale = scale\n        self.translation = translation\n\n    def get_projection_matrix(self, width=None, height=None):\n        P = np.eye(4)\n        P[0, 0] = self.scale[0]\n        P[1, 1] = self.scale[1]\n        P[0, 3] = self.translation[0] * self.scale[0]\n        P[1, 3] = -self.translation[1] * self.scale[1]\n        P[2, 2] = -1\n        return P\n\n\nclass Renderer:\n    def __init__(self, background=None, resolution=(224, 224), bg_color=[0, 0, 0, 0.5], orig_img=False, wireframe=False):\n        width, height = resolution\n        self.background = np.zeros((height, width, 3))\n        self.resolution = resolution\n\n        self.faces = get_smpl_faces()\n        self.orig_img = orig_img\n        self.wireframe = wireframe\n        self.renderer = pyrender.OffscreenRenderer(\n            viewport_width=self.resolution[0],\n            viewport_height=self.resolution[1],\n            point_size=0.5\n        )\n\n        # set the scene\n        self.scene = pyrender.Scene(bg_color=bg_color, ambient_light=(0.4, 0.4, 0.4))\n\n        light = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=4)\n\n        light_pose = np.eye(4)\n        light_pose[:3, 3] = [0, -1, 1]\n        self.scene.add(light, pose=light_pose.copy())\n\n        light_pose[:3, 3] = [0, 1, 1]\n        self.scene.add(light, pose=light_pose.copy())\n\n        light_pose[:3, 3] = [1, 1, 2]\n        self.scene.add(light, pose=light_pose.copy())\n\n        \"\"\"ok\n        light_pose = np.eye(4)\n        light_pose[:3, 3] = [0, -1, 1]\n        self.scene.add(light, pose=light_pose)\n\n        light_pose[:3, 3] = [0, 1, 1]\n        self.scene.add(light, pose=light_pose)\n\n        light_pose[:3, 3] = [1, 1, 2]\n        self.scene.add(light, pose=light_pose)\n        \"\"\"\n\n        # light_pose[:3, 3] = [0, -2, 2]\n        # [droite, hauteur, profondeur camera]\n        \"\"\"\n        light_pose = np.eye(4)\n        light_pose[:3, 3] = [0, -1, 1]\n        self.scene.add(light, pose=light_pose)\n\n        light_pose[:3, 3] = [0, 1, 1]\n        self.scene.add(light, pose=light_pose)\n\n        light_pose[:3, 3] = [1, 1, 2]\n        self.scene.add(light, pose=light_pose)\n        \"\"\"\n\n    def render(self, img, verts, cam, angle=None, axis=None, mesh_filename=None, color=[1.0, 1.0, 0.9]):\n        mesh = trimesh.Trimesh(vertices=verts, faces=self.faces, process=False)\n\n        Rx = trimesh.transformations.rotation_matrix(math.radians(180), [1, 0, 0])\n        mesh.apply_transform(Rx)\n\n        if mesh_filename is not None:\n            mesh.export(mesh_filename)\n\n        if angle and axis:\n            R = trimesh.transformations.rotation_matrix(math.radians(angle), axis)\n            mesh.apply_transform(R)\n\n        sx, sy, tx, ty = cam\n\n        camera = WeakPerspectiveCamera(\n            scale=[sx, sy],\n            translation=[tx, ty],\n            zfar=1000.\n        )\n\n        material = pyrender.MetallicRoughnessMaterial(\n            metallicFactor=0.7,\n            alphaMode='OPAQUE',\n            baseColorFactor=(color[0], color[1], color[2], 1.0)\n        )\n\n        mesh = pyrender.Mesh.from_trimesh(mesh, material=material)\n\n        mesh_node = self.scene.add(mesh, 'mesh')\n\n        camera_pose = np.eye(4)\n        cam_node = self.scene.add(camera, pose=camera_pose)\n\n        if self.wireframe:\n            render_flags = RenderFlags.RGBA | RenderFlags.ALL_WIREFRAME\n        else:\n            render_flags = RenderFlags.RGBA\n\n        rgb, _ = self.renderer.render(self.scene, flags=render_flags)\n        valid_mask = (rgb[:, :, -1] > 0)[:, :, np.newaxis]\n        output_img = rgb[:, :, :-1] * valid_mask + (1 - valid_mask) * img\n        image = output_img.astype(np.uint8)\n\n        self.scene.remove_node(mesh_node)\n        self.scene.remove_node(cam_node)\n\n        return image\n\n\ndef get_renderer(width, height):\n    renderer = Renderer(resolution=(width, height),\n                        bg_color=[1, 1, 1, 0.5],\n                        orig_img=False,\n                        wireframe=False)\n    return renderer\n"
  },
  {
    "path": "PBnet/src/render/rendermotion.py",
    "content": "import numpy as np\nimport imageio\nimport os\nimport argparse\nfrom tqdm import tqdm\nfrom .renderer import get_renderer\n\n\ndef get_rotation(theta=np.pi/3):\n    import src.utils.rotation_conversions as geometry\n    import torch\n    axis = torch.tensor([0, 1, 0], dtype=torch.float)\n    axisangle = theta*axis\n    matrix = geometry.axis_angle_to_matrix(axisangle)\n    return matrix.numpy()\n\n\ndef render_video(meshes, key, action, renderer, savepath, background, cam=(0.75, 0.75, 0, 0.10), color=[0.11, 0.53, 0.8]):\n    writer = imageio.get_writer(savepath, fps=30)\n    # center the first frame\n    meshes = meshes - meshes[0].mean(axis=0)\n    # matrix = get_rotation(theta=np.pi/4)\n    # meshes = meshes[45:]\n    # meshes = np.einsum(\"ij,lki->lkj\", matrix, meshes)\n    imgs = []\n    for mesh in tqdm(meshes, desc=f\"Visualize {key}, action {action}\"):\n        img = renderer.render(background, mesh, cam, color=color)\n        imgs.append(img)\n        # show(img)\n\n    imgs = np.array(imgs)\n    masks = ~(imgs/255. > 0.96).all(-1)\n\n    coords = np.argwhere(masks.sum(axis=0))\n    y1, x1 = coords.min(axis=0)\n    y2, x2 = coords.max(axis=0)\n\n    for cimg in imgs[:, y1:y2, x1:x2]:\n        writer.append_data(cimg)\n    writer.close()\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"filename\")\n    opt = parser.parse_args()\n    filename = opt.filename\n    savefolder = os.path.splitext(filename)[0]\n    os.makedirs(savefolder, exist_ok=True)\n\n    output = np.load(filename)\n\n    if output.shape[0] == 3:\n        visualization, generation, reconstruction = output\n        output = {\"visualization\": visualization,\n                  \"generation\": generation,\n                  \"reconstruction\": reconstruction}\n    else:\n        # output = {f\"generation_{key}\": output[key] for key in range(2)} #  len(output))}\n        # output = {f\"generation_{key}\": output[key] for key in range(len(output))}\n        output = {f\"generation_{key}\": output[key] for key in range(len(output))}\n\n    width = 1024\n    height = 1024\n\n    background = np.zeros((height, width, 3))\n    renderer = get_renderer(width, height)\n\n    # if duration mode, put back durations\n    if output[\"generation_3\"].shape[-1] == 100:\n        output[\"generation_0\"] = output[\"generation_0\"][:, :, :, :40]\n        output[\"generation_1\"] = output[\"generation_1\"][:, :, :, :60]\n        output[\"generation_2\"] = output[\"generation_2\"][:, :, :, :80]\n        output[\"generation_3\"] = output[\"generation_3\"][:, :, :, :100]\n    elif output[\"generation_3\"].shape[-1] == 160:\n        print(\"160 mode\")\n        output[\"generation_0\"] = output[\"generation_0\"][:, :, :, :100]\n        output[\"generation_1\"] = output[\"generation_1\"][:, :, :, :120]\n        output[\"generation_2\"] = output[\"generation_2\"][:, :, :, :140]\n        output[\"generation_3\"] = output[\"generation_3\"][:, :, :, :160]\n\n    # if str(action) == str(1) and str(key) == \"generation_4\":\n    for key in output:\n        vidmeshes = output[key]\n        for action in range(len(vidmeshes)):\n            meshes = vidmeshes[action].transpose(2, 0, 1)\n            path = os.path.join(savefolder, \"action{}_{}.mp4\".format(action, key))\n            render_video(meshes, key, action, renderer, path, background)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "PBnet/src/train/__init__.py",
    "content": ""
  },
  {
    "path": "PBnet/src/train/train_cvae_ganloss_ann_eye.py",
    "content": "import sys\nsys.path.append('your_path/PBnet')\n\nimport os\nimport torch\nfrom torch.utils.tensorboard import SummaryWriter\n\nfrom src.utils.utils import MultiEpochsDataLoader as DataLoader\nfrom src.utils.utils import CudaDataLoader\n# import torch.utils.data.dataloader as DataLoader\nfrom src.train.trainer_gan_ann import train\n\nimport src.utils.fixseed  # noqa\n\nfrom src.parser.training import parser\n\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport importlib\nimport random\nimport numpy as np\n\nJOINTSTYPES = [\"a2m\", \"a2mpl\", \"smpl\", \"vibe\", \"vertices\"]\n\nLOSSES = [\"rc\", \"kl\", \"rcw\", \"ssim\"]  # not used: \"hp\", \"mmd\", \"vel\", \"velxyz\"\n\nMODELTYPES = [\"cvae\"]  # not used: \"cae\"\nARCHINAMES = [\"fc\", \"gru\", \"transformer\", \"transformerreemb5\", \"transformermel\", \"transgru\", \"grutrans\", \"autotrans\"]\n\nclass ConvNormRelu(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, norm='batch', leaky=True):\n        super(ConvNormRelu, self).__init__()\n        layers = []\n        layers.append(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding))\n        if norm == 'batch':\n            layers.append(nn.BatchNorm1d(out_channels))\n        if leaky:\n            layers.append(nn.LeakyReLU(0.2))\n        else:\n            layers.append(nn.ReLU())\n        self.model = nn.Sequential(*layers)\n\n    def forward(self, x):\n        return self.model(x)\n\n    \n\nclass D_patchgan(nn.Module):\n    def __init__(self, n_downsampling=2, pos_dim=6, eye_dim=0, norm='batch'):\n        super(D_patchgan, self).__init__()\n        ndf = 64\n        self.eye_dim = eye_dim\n        self.dim = pos_dim + self.eye_dim\n        self.conv1 = nn.Conv1d(self.dim, ndf, kernel_size=4, stride=2, padding=1)\n        self.leaky_relu = nn.LeakyReLU(0.2)\n\n        layers = []\n        for n in range(0, n_downsampling):\n            nf_mult = min(2**n, 8)\n            layers.append(ConvNormRelu(ndf * nf_mult, ndf * nf_mult * 2, kernel_size=4, stride=2, padding=1, norm=norm))\n\n        nf_mult = min(2**n_downsampling, 8)\n        layers.append(ConvNormRelu(ndf * nf_mult, ndf * nf_mult, kernel_size=4, stride=1, padding=1, norm=norm))\n\n        layers.append(nn.Conv1d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1))\n        self.model = nn.Sequential(*layers)\n\n    def forward(self, x):\n        out = self.conv1(x)\n        out = self.leaky_relu(out)\n        out = self.model(out)\n        return out\n\n    def calculate_GAN_loss(self, batch):\n            x = batch[\"x\"] #bs, nf, 6\n            x_ref = x[:,0,:].unsqueeze(dim=1) #bs, 1, 64\n            output = batch[\"output\"]+x_ref #bs, nf, 6\n\n            real_pose_score = self.forward(x.permute(0,2,1))\n            fake_pose_score = self.forward(output.permute(0,2,1))\n\n            D_loss = F.binary_cross_entropy_with_logits(real_pose_score, torch.ones_like(real_pose_score)) + F.binary_cross_entropy_with_logits(fake_pose_score, torch.zeros_like(fake_pose_score))\n            G_loss = F.binary_cross_entropy_with_logits(fake_pose_score, torch.ones_like(fake_pose_score))\n\n            return D_loss.mean(), G_loss.mean()\n\n\ndef get_model(parameters):\n    modeltype = parameters[\"modeltype\"]\n    archiname = parameters[\"archiname\"]\n\n    archi_module = importlib.import_module(f'.architectures.{archiname}', package=\"src.models\")\n    Encoder = archi_module.__getattribute__(f\"Encoder_{archiname.upper()}\")\n    Decoder = archi_module.__getattribute__(f\"Decoder_{archiname.upper()}\")\n\n    model_module = importlib.import_module(f'.modeltype.{modeltype}', package=\"src.models\")\n    Model = model_module.__getattribute__(f\"{modeltype.upper()}\")\n\n    encoder = Encoder(**parameters)\n    decoder = Decoder(**parameters)\n    \n    # parameters[\"outputxyz\"] = \"rcxyz\" in parameters[\"lambdas\"]\n    return Model(encoder, decoder, **parameters).to(parameters[\"device\"])\n\ndef do_epochs(model, model_d, dataset, parameters, optimizer_g, optimizer_d, scheduler_g, scheduler_d, writer):\n    # train_iterator = DataLoader(dataset, batch_size=parameters[\"batch_size\"],\n    #                             shuffle=True, num_workers=8, pin_memory=True)\n    train_iterator = DataLoader(dataset, batch_size=parameters[\"batch_size\"],\n                                shuffle=True, num_workers=16, collate_fn=collate, pin_memory=True)\n    train_iterator = CudaDataLoader(train_iterator, device = 'cuda:0')\n\n    logpath = os.path.join(parameters[\"folder\"], \"training.log\")\n    with open(logpath, \"w\") as logfile:\n        for epoch in range(1, parameters[\"num_epochs\"]+1):\n            dict_loss = train(model, model_d, optimizer_g, optimizer_d, train_iterator, model.device, epoch)\n\n            for key in dict_loss.keys():\n                dict_loss[key] /= len(train_iterator)\n                writer.add_scalar(f\"Loss/{key}\", dict_loss[key], epoch)\n\n            epochlog = f\"Epoch {epoch}, train losses: {dict_loss}\"\n            print(epochlog)\n            print(epochlog, file=logfile)\n            scheduler_g.step()\n            scheduler_d.step()\n            if ((epoch % parameters[\"snapshot\"]) == 0) or (epoch == parameters[\"num_epochs\"]):\n                checkpoint_path = os.path.join(parameters[\"folder\"],\n                                               'checkpoint_{:04d}.pth.tar'.format(epoch))\n                print('Saving checkpoint {}'.format(checkpoint_path))\n                torch.save(model.state_dict(), checkpoint_path)\n\n            writer.flush()\n\n\nif __name__ == '__main__':\n    # setup_seed(1234)\n    # parse options\n    parameters = parser()\n    \n    # logging tensorboard\n    writer = SummaryWriter(log_dir=parameters[\"folder\"])\n\n    os.environ[\"CUDA_VISIBLE_DEVICES\"] = parameters[\"gpu\"]\n    dataset_name = parameters[\"dataset\"]\n    if dataset_name == 'crema':\n        from src.datasets.datasets_crema_pos_eye_fast import CREMA\n        from src.utils.tensors_eye import collate\n        # data path\n        data_dir = \"/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images\"\n        # model and dataset\n        dataset = CREMA(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'train')\n        dataset.update_parameters(parameters)\n    elif dataset_name == 'hdtf':\n        data_dir = \"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\"\n        if parameters[\"first3\"]=='True':\n            if parameters[\"eye\"]=='True':\n                from src.utils.tensors_eye import collate\n                from src.datasets.datasets_hdtf_pos_chunk_norm_eye_first3 import HDTF\n            else:\n                from src.utils.tensors import collate\n                from src.datasets.datasets_hdtf_pos_chunk_norm_2_first3 import HDTF\n        else:\n            if parameters[\"eye\"]=='True':\n                from src.utils.tensors_eye import collate\n                from src.datasets.datasets_hdtf_pos_chunk_norm_eye_fast import HDTF\n            else:\n                from src.utils.tensors_eye import collate\n                from src.datasets.datasets_hdtf_pos_chunk_norm_2 import HDTF\n        dataset = HDTF(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'train')\n        dataset.update_parameters(parameters)\n    else:\n        dataset = None\n        print('Dataset can not be found!!')\n\n    \n\n    model = get_model(parameters)\n    if parameters['eye']=='True':\n        model_d = D_patchgan(pos_dim=parameters[\"pos_dim\"], eye_dim=parameters[\"eye_dim\"]).to(parameters[\"device\"])\n    else:\n        model_d = D_patchgan(pos_dim=parameters[\"pos_dim\"]).to(parameters[\"device\"])\n    # optimizer\n    optimizer_g = torch.optim.AdamW(model.parameters(), lr=parameters[\"lr\"])\n    optimizer_d = torch.optim.AdamW(model_d.parameters(), lr=parameters[\"lr\"])\n    scheduler_g = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_g, T_max=parameters[\"num_epochs\"], eta_min=2e-5)\n    scheduler_d = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_d, T_max=parameters[\"num_epochs\"], eta_min=2e-5)\n    print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))\n    print(\"Training model..\")\n    do_epochs(model, model_d, dataset, parameters, optimizer_g, optimizer_d, scheduler_g, scheduler_d, writer)\n\n    writer.close()\n"
  },
  {
    "path": "PBnet/src/train/train_cvae_ganloss_ann_fast.py",
    "content": "import sys\nsys.path.append('your_path/PBnet')\n\nimport os\nimport torch\nfrom torch.utils.tensorboard import SummaryWriter\n\nfrom src.utils.utils import MultiEpochsDataLoader as DataLoader\nfrom src.train.trainer_gan_ann import train\nfrom src.utils.tensors import collate\nimport src.utils.fixseed  # noqa\n\nfrom src.parser.training import parser\nfrom src.datasets.datasets_crema_pos import CREMA\nfrom src.datasets.datasets_hdtf_pos_chunk_norm_2 import HDTF\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport importlib\n\nJOINTSTYPES = [\"a2m\", \"a2mpl\", \"smpl\", \"vibe\", \"vertices\"]\n\nLOSSES = [\"rc\", \"kl\", \"rcw\", \"ssim\"]  # not used: \"hp\", \"mmd\", \"vel\", \"velxyz\"\n\nMODELTYPES = [\"cvae\"]  # not used: \"cae\"\nARCHINAMES = [\"fc\", \"gru\", \"transformer\", \"transgru\", \"grutrans\", \"autotrans\"]\n\nclass ConvNormRelu(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, norm='batch', leaky=True):\n        super(ConvNormRelu, self).__init__()\n        layers = []\n        layers.append(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding))\n        if norm == 'batch':\n            layers.append(nn.BatchNorm1d(out_channels))\n        if leaky:\n            layers.append(nn.LeakyReLU(0.2))\n        else:\n            layers.append(nn.ReLU())\n        self.model = nn.Sequential(*layers)\n\n    def forward(self, x):\n        return self.model(x)\n\n    \n\nclass D_patchgan(nn.Module):\n    def __init__(self, n_downsampling=2, norm='batch'):\n        super(D_patchgan, self).__init__()\n        ndf = 64\n        self.conv1 = nn.Conv1d(6, ndf, kernel_size=4, stride=2, padding=1)\n        self.leaky_relu = nn.LeakyReLU(0.2)\n\n        layers = []\n        for n in range(0, n_downsampling):\n            nf_mult = min(2**n, 8)\n            layers.append(ConvNormRelu(ndf * nf_mult, ndf * nf_mult * 2, kernel_size=4, stride=2, padding=1, norm=norm))\n\n        nf_mult = min(2**n_downsampling, 8)\n        layers.append(ConvNormRelu(ndf * nf_mult, ndf * nf_mult, kernel_size=4, stride=1, padding=1, norm=norm))\n\n        layers.append(nn.Conv1d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1))\n        self.model = nn.Sequential(*layers)\n\n    def forward(self, x):\n        out = self.conv1(x)\n        out = self.leaky_relu(out)\n        out = self.model(out)\n        return out\n\n    def calculate_GAN_loss(self, batch):\n            x = batch[\"x\"] #bs, nf, 6\n            x_ref = x[:,0,:].unsqueeze(dim=1) #bs, 1, 64\n            output = batch[\"output\"]+x_ref #bs, nf, 6\n\n            real_pose_score = self.forward(x.permute(0,2,1))\n            fake_pose_score = self.forward(output.permute(0,2,1))\n\n            D_loss = F.binary_cross_entropy_with_logits(real_pose_score, torch.ones_like(real_pose_score)) + F.binary_cross_entropy_with_logits(fake_pose_score, torch.zeros_like(fake_pose_score))\n            G_loss = F.binary_cross_entropy_with_logits(fake_pose_score, torch.ones_like(fake_pose_score))\n\n            return D_loss.mean(), G_loss.mean()\n\n\ndef get_model(parameters):\n    modeltype = parameters[\"modeltype\"]\n    archiname = parameters[\"archiname\"]\n\n    archi_module = importlib.import_module(f'.architectures.{archiname}', package=\"src.models\")\n    Encoder = archi_module.__getattribute__(f\"Encoder_{archiname.upper()}\")\n    Decoder = archi_module.__getattribute__(f\"Decoder_{archiname.upper()}\")\n\n    model_module = importlib.import_module(f'.modeltype.{modeltype}', package=\"src.models\")\n    Model = model_module.__getattribute__(f\"{modeltype.upper()}\")\n\n    encoder = Encoder(**parameters)\n    decoder = Decoder(**parameters)\n    \n    # parameters[\"outputxyz\"] = \"rcxyz\" in parameters[\"lambdas\"]\n    return Model(encoder, decoder, **parameters).to(parameters[\"device\"])\n\ndef do_epochs(model, model_d, dataset, parameters, optimizer_g, optimizer_d, scheduler_g, scheduler_d, writer):\n    # train_iterator = DataLoader(dataset, batch_size=parameters[\"batch_size\"],\n    #                             shuffle=True, num_workers=8, pin_memory=True)\n    train_iterator = DataLoader(dataset, batch_size=parameters[\"batch_size\"],\n                                shuffle=True, num_workers=8, collate_fn=collate, pin_memory = True)\n\n    logpath = os.path.join(parameters[\"folder\"], \"training.log\")\n    with open(logpath, \"w\") as logfile:\n        for epoch in range(1, parameters[\"num_epochs\"]+1):\n            dict_loss = train(model, model_d, optimizer_g, optimizer_d, train_iterator, model.device, epoch)\n\n            for key in dict_loss.keys():\n                dict_loss[key] /= len(train_iterator)\n                writer.add_scalar(f\"Loss/{key}\", dict_loss[key], epoch)\n\n            epochlog = f\"Epoch {epoch}, train losses: {dict_loss}\"\n            print(epochlog)\n            print(epochlog, file=logfile)\n            scheduler_g.step()\n            scheduler_d.step()\n            if ((epoch % parameters[\"snapshot\"]) == 0) or (epoch == parameters[\"num_epochs\"]):\n                checkpoint_path = os.path.join(parameters[\"folder\"],\n                                               'checkpoint_{:04d}.pth.tar'.format(epoch))\n                print('Saving checkpoint {}'.format(checkpoint_path))\n                torch.save(model.state_dict(), checkpoint_path)\n\n            writer.flush()\n\n\nif __name__ == '__main__':\n\n    # parse options\n    parameters = parser()\n    \n    # logging tensorboard\n    writer = SummaryWriter(log_dir=parameters[\"folder\"])\n\n    os.environ[\"CUDA_VISIBLE_DEVICES\"] = parameters[\"gpu\"]\n    dataset_name = parameters[\"dataset\"]\n    if dataset_name == 'crema':\n        # data path\n        data_dir = \"/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images\"\n        # model and dataset\n        dataset = CREMA(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'train')\n        dataset.update_parameters(parameters)\n    elif dataset_name == 'hdtf':\n        data_dir = \"/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz\"\n        dataset = HDTF(data_dir=data_dir,\n                        max_num_frames=parameters[\"num_frames\"],\n                        mode = 'train')\n        dataset.update_parameters(parameters)\n    else:\n        dataset = None\n        print('Dataset can not be found!!')\n\n    \n\n    model = get_model(parameters)\n    model_d = D_patchgan().to(parameters[\"device\"])\n    # optimizer\n    optimizer_g = torch.optim.AdamW(model.parameters(), lr=parameters[\"lr\"])\n    optimizer_d = torch.optim.AdamW(model_d.parameters(), lr=parameters[\"lr\"])\n    scheduler_g = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_g, T_max=parameters[\"num_epochs\"], eta_min=2e-5)\n    scheduler_d = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_d, T_max=parameters[\"num_epochs\"], eta_min=2e-5)\n    print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))\n    print(\"Training model..\")\n    do_epochs(model, model_d, dataset, parameters, optimizer_g, optimizer_d, scheduler_g, scheduler_d, writer)\n\n    writer.close()\n"
  },
  {
    "path": "PBnet/src/train/trainer.py",
    "content": "import torch\nfrom tqdm import tqdm\n\n\ndef train_or_test(model, optimizer, iterator, device, mode=\"train\"):\n    if mode == \"train\":\n        model.train()\n        grad_env = torch.enable_grad\n    elif mode == \"test\":\n        model.eval()\n        grad_env = torch.no_grad\n    else:\n        raise ValueError(\"This mode is not recognized.\")\n\n    # loss of the epoch\n    dict_loss = {loss: 0 for loss in model.losses}\n\n    with grad_env():\n        for i, batch in tqdm(enumerate(iterator), desc=\"Computing batch\"):\n            # Put everything in device\n            batch = {key: val.to(device) for key, val in batch.items() if key!='videoname'}\n\n            if mode == \"train\":\n                # update the gradients to zero\n                optimizer.zero_grad()\n\n            # forward pass\n            batch = model(batch)\n            mixed_loss, losses = model.compute_loss(batch)\n            \n            for key in dict_loss.keys():\n                dict_loss[key] += losses[key]\n\n            if mode == \"train\":\n                # backward pass\n                mixed_loss.backward()\n                # update the weights\n                optimizer.step()\n\n            if i % 10 == 0:\n                print(losses)\n    return dict_loss\n\n\ndef train(model, optimizer, iterator, device):\n    return train_or_test(model, optimizer, iterator, device, mode=\"train\")\n\n\ndef test(model, optimizer, iterator, device):\n    return train_or_test(model, optimizer, iterator, device, mode=\"test\")\n"
  },
  {
    "path": "PBnet/src/train/trainer_gan.py",
    "content": "import torch\nfrom tqdm import tqdm\nimport time\n\n\ndef train_or_test(model, model_d, optimizer_g, optimizer_d, iterator, device, mode=\"train\", epoch = 0):\n    if mode == \"train\":\n        model.train()\n        model_d.train()\n        grad_env = torch.enable_grad\n    elif mode == \"test\":\n        model.eval()\n        model_d.eval()\n        grad_env = torch.no_grad\n    else:\n        raise ValueError(\"This mode is not recognized.\")\n\n    # loss of the epoch\n    dict_loss = {loss: 0 for loss in (model.losses)}\n    dict_loss['Dloss'] = 0\n    dict_loss['Gloss'] = 0\n\n    with grad_env():\n        start_time = time.time()  # end\n        # print(f'load time {end_time- start_time}')\n\n        for i, batch in tqdm(enumerate(iterator), desc=\"Computing batch\"):\n            # Put everything in device\n            # end_time = time.time()\n            # print(\"load_cost: \", - start_time + end_time)\n            # start_time = time.time()\n\n            batch = {key: val.to(device) for key, val in batch.items() if key!='videoname'}\n\n            if mode == \"train\":\n                # update the gradients to zero\n                optimizer_g.zero_grad()\n                optimizer_d.zero_grad()\n\n            # forward pass\n            batch = model(batch)\n            mixed_loss, losses = model.compute_loss(batch, epoch)\n            D_loss, G_loss = model_d.calculate_GAN_loss(batch)\n\n            end_time = time.time()\n            print(\"forward: \", - start_time + end_time)\n            start_time = time.time()\n            \n            for key in dict_loss.keys():\n                if key != 'Gloss' and key != 'Dloss':\n                    dict_loss[key] += losses[key]\n            \n            dict_loss['Dloss'] += D_loss.item()\n            dict_loss['Gloss'] += G_loss.item()\n\n            if mode == \"train\":\n                # backward pass\n                (mixed_loss + (G_loss + D_loss * 0.5) ).backward()\n                torch.nn.utils.clip_grad_norm_(model.parameters(), 2.)\n                # update the weights\n                optimizer_g.step()\n                optimizer_d.step()\n            \n            end_time = time.time()\n            print(\"back: \", - start_time + end_time)\n            start_time = time.time()\n\n            # if i % 10 == 0:\n            #     print(dict_loss)\n    return dict_loss\n\n\ndef train(model, model_d, optimizer_g, optimizer_d, iterator, device):\n    return train_or_test(model, model_d, optimizer_g, optimizer_d, iterator, device, mode=\"train\")\n\n\ndef test(model, model_d, optimizer_g, optimizer_d, iterator, device):\n    return train_or_test(model, model_d, optimizer_g, optimizer_d, iterator, device, mode=\"test\")\n"
  },
  {
    "path": "PBnet/src/train/trainer_gan_ann.py",
    "content": "import torch\nfrom tqdm import tqdm\nimport time\n\n\ndef train_or_test(model, model_d, optimizer_g, optimizer_d, iterator, device, mode=\"train\", epoch = 0):\n    if mode == \"train\":\n        model.train()\n        model_d.train()\n        grad_env = torch.enable_grad\n    elif mode == \"test\":\n        model.eval()\n        model_d.eval()\n        grad_env = torch.no_grad\n    else:\n        raise ValueError(\"This mode is not recognized.\")\n\n    # loss of the epoch\n    dict_loss = {loss: 0 for loss in (model.losses)}\n    dict_loss['Dloss'] = 0\n    dict_loss['Gloss'] = 0\n\n    with grad_env():\n        # start_time = time.time()  # end\n        # print(f'load time {end_time- start_time}')\n\n        for i, batch in tqdm(enumerate(iterator), desc=\"Computing batch\"):\n            # Put everything in device\n\n            # end_time = time.time()\n            # print(\"load_cost: \", - start_time + end_time)\n            # start_time = time.time() \n\n            batch = {key: val.to(device) for key, val in batch.items() if key!='videoname'}\n\n            # end_time = time.time()\n            # print(\"tocuda_cost: \", - start_time + end_time)\n            # start_time = time.time()\n\n            if mode == \"train\":\n                # update the gradients to zero\n                optimizer_g.zero_grad()\n                optimizer_d.zero_grad()\n\n            # forward pass\n            batch = model(batch)\n            mixed_loss, losses = model.compute_loss(batch, epoch)\n            D_loss, G_loss = model_d.calculate_GAN_loss(batch)\n\n            # end_time = time.time()\n            # print(\"forward: \", - start_time + end_time)\n            # start_time = time.time()\n            \n            for key in dict_loss.keys():\n                if key != 'Gloss' and key != 'Dloss':\n                    dict_loss[key] += losses[key]\n            \n            dict_loss['Dloss'] += D_loss.item()\n            dict_loss['Gloss'] += G_loss.item()\n\n            if mode == \"train\":\n                # backward pass\n                ((mixed_loss + (G_loss + D_loss) )).backward()\n\n                torch.nn.utils.clip_grad_norm_(model.parameters(), 2.)\n                # update the weights\n                optimizer_g.step()\n                optimizer_d.step()\n            \n            # end_time = time.time()\n            # print(\"back: \", - start_time + end_time)\n            # start_time = time.time()\n\n            # if i % 10 == 0:\n            #     print(dict_loss)\n    return dict_loss\n\n\ndef train(model, model_d, optimizer_g, optimizer_d, iterator, device, epoch):\n    return train_or_test(model, model_d, optimizer_g, optimizer_d, iterator, device, mode=\"train\", epoch = epoch)\n\n\ndef test(model, model_d, optimizer_g, optimizer_d, iterator, device):\n    return train_or_test(model, model_d, optimizer_g, optimizer_d, iterator, device, mode=\"test\")\n"
  },
  {
    "path": "PBnet/src/utils/PYTORCH3D_LICENSE",
    "content": "BSD License\n\nFor PyTorch3D software\n\nCopyright (c) Facebook, Inc. and its affiliates. All rights reserved.\n\nRedistribution and use in source and binary forms, with or without modification,\nare permitted provided that the following conditions are met:\n\n * Redistributions of source code must retain the above copyright notice, this\n    list of conditions and the following disclaimer.\n\n * Redistributions in binary form must reproduce the above copyright notice,\n    this list of conditions and the following disclaimer in the documentation\n       and/or other materials provided with the distribution.\n\n * Neither the name Facebook nor the names of its contributors may be used to\n    endorse or promote products derived from this software without specific\n       prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\nANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\nWARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR\nANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\nLOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON\nANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\nSOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."
  },
  {
    "path": "PBnet/src/utils/__init__.py",
    "content": ""
  },
  {
    "path": "PBnet/src/utils/fixseed.py",
    "content": "import numpy as np\nimport torch\nimport random\n\n\ndef fixseed(seed):\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n\n\nSEED = 10\nEVALSEED = 0\n# Provoc warning: not fully functionnal yet\n# torch.set_deterministic(True)\ntorch.backends.cudnn.benchmark = False\n\nfixseed(SEED)\n"
  },
  {
    "path": "PBnet/src/utils/get_model_and_data.py",
    "content": "from ..datasets.get_dataset import get_datasets\nfrom ..recognition.get_model import get_model as get_rec_model\nfrom ..models.get_model import get_model as get_gen_model\n\n\ndef get_model_and_data(parameters):\n    datasets = get_datasets(parameters)\n\n    if parameters[\"modelname\"] == \"recognition\":\n        model = get_rec_model(parameters)\n    else:\n        model = get_gen_model(parameters)\n    return model, datasets\n"
  },
  {
    "path": "PBnet/src/utils/misc.py",
    "content": "import torch\n\n\ndef to_numpy(tensor):\n    if torch.is_tensor(tensor):\n        return tensor.cpu().numpy()\n    elif type(tensor).__module__ != 'numpy':\n        raise ValueError(\"Cannot convert {} to numpy array\".format(\n            type(tensor)))\n    return tensor\n\n\ndef to_torch(ndarray):\n    if type(ndarray).__module__ == 'numpy':\n        return torch.from_numpy(ndarray)\n    elif not torch.is_tensor(ndarray):\n        raise ValueError(\"Cannot convert {} to torch tensor\".format(\n            type(ndarray)))\n    return ndarray\n\n\ndef cleanexit():\n    import sys\n    import os\n    try:\n        sys.exit(0)\n    except SystemExit:\n        os._exit(0)\n\n"
  },
  {
    "path": "PBnet/src/utils/rotation_conversions.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.\n# Check PYTORCH3D_LICENCE before use\n\nimport functools\nfrom typing import Optional\n\nimport torch\nimport torch.nn.functional as F\n\n\n\"\"\"\nThe transformation matrices returned from the functions in this file assume\nthe points on which the transformation will be applied are column vectors.\ni.e. the R matrix is structured as\n\n    R = [\n            [Rxx, Rxy, Rxz],\n            [Ryx, Ryy, Ryz],\n            [Rzx, Rzy, Rzz],\n        ]  # (3, 3)\n\nThis matrix can be applied to column vectors by post multiplication\nby the points e.g.\n\n    points = [[0], [1], [2]]  # (3 x 1) xyz coordinates of a point\n    transformed_points = R * points\n\nTo apply the same matrix to points which are row vectors, the R matrix\ncan be transposed and pre multiplied by the points:\n\ne.g.\n    points = [[0, 1, 2]]  # (1 x 3) xyz coordinates of a point\n    transformed_points = points * R.transpose(1, 0)\n\"\"\"\n\n\ndef quaternion_to_matrix(quaternions):\n    \"\"\"\n    Convert rotations given as quaternions to rotation matrices.\n\n    Args:\n        quaternions: quaternions with real part first,\n            as tensor of shape (..., 4).\n\n    Returns:\n        Rotation matrices as tensor of shape (..., 3, 3).\n    \"\"\"\n    r, i, j, k = torch.unbind(quaternions, -1)\n    two_s = 2.0 / (quaternions * quaternions).sum(-1)\n\n    o = torch.stack(\n        (\n            1 - two_s * (j * j + k * k),\n            two_s * (i * j - k * r),\n            two_s * (i * k + j * r),\n            two_s * (i * j + k * r),\n            1 - two_s * (i * i + k * k),\n            two_s * (j * k - i * r),\n            two_s * (i * k - j * r),\n            two_s * (j * k + i * r),\n            1 - two_s * (i * i + j * j),\n        ),\n        -1,\n    )\n    return o.reshape(quaternions.shape[:-1] + (3, 3))\n\n\ndef _copysign(a, b):\n    \"\"\"\n    Return a tensor where each element has the absolute value taken from the,\n    corresponding element of a, with sign taken from the corresponding\n    element of b. This is like the standard copysign floating-point operation,\n    but is not careful about negative 0 and NaN.\n\n    Args:\n        a: source tensor.\n        b: tensor whose signs will be used, of the same shape as a.\n\n    Returns:\n        Tensor of the same shape as a with the signs of b.\n    \"\"\"\n    signs_differ = (a < 0) != (b < 0)\n    return torch.where(signs_differ, -a, a)\n\n\ndef _sqrt_positive_part(x):\n    \"\"\"\n    Returns torch.sqrt(torch.max(0, x))\n    but with a zero subgradient where x is 0.\n    \"\"\"\n    ret = torch.zeros_like(x)\n    positive_mask = x > 0\n    ret[positive_mask] = torch.sqrt(x[positive_mask])\n    return ret\n\n\ndef matrix_to_quaternion(matrix):\n    \"\"\"\n    Convert rotations given as rotation matrices to quaternions.\n\n    Args:\n        matrix: Rotation matrices as tensor of shape (..., 3, 3).\n\n    Returns:\n        quaternions with real part first, as tensor of shape (..., 4).\n    \"\"\"\n    if matrix.size(-1) != 3 or matrix.size(-2) != 3:\n        raise ValueError(f\"Invalid rotation matrix  shape f{matrix.shape}.\")\n    m00 = matrix[..., 0, 0]\n    m11 = matrix[..., 1, 1]\n    m22 = matrix[..., 2, 2]\n    o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)\n    x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)\n    y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)\n    z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)\n    o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])\n    o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])\n    o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])\n    return torch.stack((o0, o1, o2, o3), -1)\n\n\ndef _axis_angle_rotation(axis: str, angle):\n    \"\"\"\n    Return the rotation matrices for one of the rotations about an axis\n    of which Euler angles describe, for each value of the angle given.\n\n    Args:\n        axis: Axis label \"X\" or \"Y or \"Z\".\n        angle: any shape tensor of Euler angles in radians\n\n    Returns:\n        Rotation matrices as tensor of shape (..., 3, 3).\n    \"\"\"\n\n    cos = torch.cos(angle)\n    sin = torch.sin(angle)\n    one = torch.ones_like(angle)\n    zero = torch.zeros_like(angle)\n\n    if axis == \"X\":\n        R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)\n    if axis == \"Y\":\n        R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)\n    if axis == \"Z\":\n        R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)\n\n    return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))\n\n\ndef euler_angles_to_matrix(euler_angles, convention: str):\n    \"\"\"\n    Convert rotations given as Euler angles in radians to rotation matrices.\n\n    Args:\n        euler_angles: Euler angles in radians as tensor of shape (..., 3).\n        convention: Convention string of three uppercase letters from\n            {\"X\", \"Y\", and \"Z\"}.\n\n    Returns:\n        Rotation matrices as tensor of shape (..., 3, 3).\n    \"\"\"\n    if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:\n        raise ValueError(\"Invalid input euler angles.\")\n    if len(convention) != 3:\n        raise ValueError(\"Convention must have 3 letters.\")\n    if convention[1] in (convention[0], convention[2]):\n        raise ValueError(f\"Invalid convention {convention}.\")\n    for letter in convention:\n        if letter not in (\"X\", \"Y\", \"Z\"):\n            raise ValueError(f\"Invalid letter {letter} in convention string.\")\n    matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))\n    return functools.reduce(torch.matmul, matrices)\n\n\ndef _angle_from_tan(\n    axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool\n):\n    \"\"\"\n    Extract the first or third Euler angle from the two members of\n    the matrix which are positive constant times its sine and cosine.\n\n    Args:\n        axis: Axis label \"X\" or \"Y or \"Z\" for the angle we are finding.\n        other_axis: Axis label \"X\" or \"Y or \"Z\" for the middle axis in the\n            convention.\n        data: Rotation matrices as tensor of shape (..., 3, 3).\n        horizontal: Whether we are looking for the angle for the third axis,\n            which means the relevant entries are in the same row of the\n            rotation matrix. If not, they are in the same column.\n        tait_bryan: Whether the first and third axes in the convention differ.\n\n    Returns:\n        Euler Angles in radians for each matrix in data as a tensor\n        of shape (...).\n    \"\"\"\n\n    i1, i2 = {\"X\": (2, 1), \"Y\": (0, 2), \"Z\": (1, 0)}[axis]\n    if horizontal:\n        i2, i1 = i1, i2\n    even = (axis + other_axis) in [\"XY\", \"YZ\", \"ZX\"]\n    if horizontal == even:\n        return torch.atan2(data[..., i1], data[..., i2])\n    if tait_bryan:\n        return torch.atan2(-data[..., i2], data[..., i1])\n    return torch.atan2(data[..., i2], -data[..., i1])\n\n\ndef _index_from_letter(letter: str):\n    if letter == \"X\":\n        return 0\n    if letter == \"Y\":\n        return 1\n    if letter == \"Z\":\n        return 2\n\n\ndef matrix_to_euler_angles(matrix, convention: str):\n    \"\"\"\n    Convert rotations given as rotation matrices to Euler angles in radians.\n\n    Args:\n        matrix: Rotation matrices as tensor of shape (..., 3, 3).\n        convention: Convention string of three uppercase letters.\n\n    Returns:\n        Euler angles in radians as tensor of shape (..., 3).\n    \"\"\"\n    if len(convention) != 3:\n        raise ValueError(\"Convention must have 3 letters.\")\n    if convention[1] in (convention[0], convention[2]):\n        raise ValueError(f\"Invalid convention {convention}.\")\n    for letter in convention:\n        if letter not in (\"X\", \"Y\", \"Z\"):\n            raise ValueError(f\"Invalid letter {letter} in convention string.\")\n    if matrix.size(-1) != 3 or matrix.size(-2) != 3:\n        raise ValueError(f\"Invalid rotation matrix  shape f{matrix.shape}.\")\n    i0 = _index_from_letter(convention[0])\n    i2 = _index_from_letter(convention[2])\n    tait_bryan = i0 != i2\n    if tait_bryan:\n        central_angle = torch.asin(\n            matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)\n        )\n    else:\n        central_angle = torch.acos(matrix[..., i0, i0])\n\n    o = (\n        _angle_from_tan(\n            convention[0], convention[1], matrix[..., i2], False, tait_bryan\n        ),\n        central_angle,\n        _angle_from_tan(\n            convention[2], convention[1], matrix[..., i0, :], True, tait_bryan\n        ),\n    )\n    return torch.stack(o, -1)\n\n\ndef random_quaternions(\n    n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False\n):\n    \"\"\"\n    Generate random quaternions representing rotations,\n    i.e. versors with nonnegative real part.\n\n    Args:\n        n: Number of quaternions in a batch to return.\n        dtype: Type to return.\n        device: Desired device of returned tensor. Default:\n            uses the current device for the default tensor type.\n        requires_grad: Whether the resulting tensor should have the gradient\n            flag set.\n\n    Returns:\n        Quaternions as tensor of shape (N, 4).\n    \"\"\"\n    o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)\n    s = (o * o).sum(1)\n    o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]\n    return o\n\n\ndef random_rotations(\n    n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False\n):\n    \"\"\"\n    Generate random rotations as 3x3 rotation matrices.\n\n    Args:\n        n: Number of rotation matrices in a batch to return.\n        dtype: Type to return.\n        device: Device of returned tensor. Default: if None,\n            uses the current device for the default tensor type.\n        requires_grad: Whether the resulting tensor should have the gradient\n            flag set.\n\n    Returns:\n        Rotation matrices as tensor of shape (n, 3, 3).\n    \"\"\"\n    quaternions = random_quaternions(\n        n, dtype=dtype, device=device, requires_grad=requires_grad\n    )\n    return quaternion_to_matrix(quaternions)\n\n\ndef random_rotation(\n    dtype: Optional[torch.dtype] = None, device=None, requires_grad=False\n):\n    \"\"\"\n    Generate a single random 3x3 rotation matrix.\n\n    Args:\n        dtype: Type to return\n        device: Device of returned tensor. Default: if None,\n            uses the current device for the default tensor type\n        requires_grad: Whether the resulting tensor should have the gradient\n            flag set\n\n    Returns:\n        Rotation matrix as tensor of shape (3, 3).\n    \"\"\"\n    return random_rotations(1, dtype, device, requires_grad)[0]\n\n\ndef standardize_quaternion(quaternions):\n    \"\"\"\n    Convert a unit quaternion to a standard form: one in which the real\n    part is non negative.\n\n    Args:\n        quaternions: Quaternions with real part first,\n            as tensor of shape (..., 4).\n\n    Returns:\n        Standardized quaternions as tensor of shape (..., 4).\n    \"\"\"\n    return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)\n\n\ndef quaternion_raw_multiply(a, b):\n    \"\"\"\n    Multiply two quaternions.\n    Usual torch rules for broadcasting apply.\n\n    Args:\n        a: Quaternions as tensor of shape (..., 4), real part first.\n        b: Quaternions as tensor of shape (..., 4), real part first.\n\n    Returns:\n        The product of a and b, a tensor of quaternions shape (..., 4).\n    \"\"\"\n    aw, ax, ay, az = torch.unbind(a, -1)\n    bw, bx, by, bz = torch.unbind(b, -1)\n    ow = aw * bw - ax * bx - ay * by - az * bz\n    ox = aw * bx + ax * bw + ay * bz - az * by\n    oy = aw * by - ax * bz + ay * bw + az * bx\n    oz = aw * bz + ax * by - ay * bx + az * bw\n    return torch.stack((ow, ox, oy, oz), -1)\n\n\ndef quaternion_multiply(a, b):\n    \"\"\"\n    Multiply two quaternions representing rotations, returning the quaternion\n    representing their composition, i.e. the versor with nonnegative real part.\n    Usual torch rules for broadcasting apply.\n\n    Args:\n        a: Quaternions as tensor of shape (..., 4), real part first.\n        b: Quaternions as tensor of shape (..., 4), real part first.\n\n    Returns:\n        The product of a and b, a tensor of quaternions of shape (..., 4).\n    \"\"\"\n    ab = quaternion_raw_multiply(a, b)\n    return standardize_quaternion(ab)\n\n\ndef quaternion_invert(quaternion):\n    \"\"\"\n    Given a quaternion representing rotation, get the quaternion representing\n    its inverse.\n\n    Args:\n        quaternion: Quaternions as tensor of shape (..., 4), with real part\n            first, which must be versors (unit quaternions).\n\n    Returns:\n        The inverse, a tensor of quaternions of shape (..., 4).\n    \"\"\"\n\n    return quaternion * quaternion.new_tensor([1, -1, -1, -1])\n\n\ndef quaternion_apply(quaternion, point):\n    \"\"\"\n    Apply the rotation given by a quaternion to a 3D point.\n    Usual torch rules for broadcasting apply.\n\n    Args:\n        quaternion: Tensor of quaternions, real part first, of shape (..., 4).\n        point: Tensor of 3D points of shape (..., 3).\n\n    Returns:\n        Tensor of rotated points of shape (..., 3).\n    \"\"\"\n    if point.size(-1) != 3:\n        raise ValueError(f\"Points are not in 3D, f{point.shape}.\")\n    real_parts = point.new_zeros(point.shape[:-1] + (1,))\n    point_as_quaternion = torch.cat((real_parts, point), -1)\n    out = quaternion_raw_multiply(\n        quaternion_raw_multiply(quaternion, point_as_quaternion),\n        quaternion_invert(quaternion),\n    )\n    return out[..., 1:]\n\n\ndef axis_angle_to_matrix(axis_angle):\n    \"\"\"\n    Convert rotations given as axis/angle to rotation matrices.\n\n    Args:\n        axis_angle: Rotations given as a vector in axis angle form,\n            as a tensor of shape (..., 3), where the magnitude is\n            the angle turned anticlockwise in radians around the\n            vector's direction.\n\n    Returns:\n        Rotation matrices as tensor of shape (..., 3, 3).\n    \"\"\"\n    return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))\n\n\ndef matrix_to_axis_angle(matrix):\n    \"\"\"\n    Convert rotations given as rotation matrices to axis/angle.\n\n    Args:\n        matrix: Rotation matrices as tensor of shape (..., 3, 3).\n\n    Returns:\n        Rotations given as a vector in axis angle form, as a tensor\n            of shape (..., 3), where the magnitude is the angle\n            turned anticlockwise in radians around the vector's\n            direction.\n    \"\"\"\n    return quaternion_to_axis_angle(matrix_to_quaternion(matrix))\n\n\ndef axis_angle_to_quaternion(axis_angle):\n    \"\"\"\n    Convert rotations given as axis/angle to quaternions.\n\n    Args:\n        axis_angle: Rotations given as a vector in axis angle form,\n            as a tensor of shape (..., 3), where the magnitude is\n            the angle turned anticlockwise in radians around the\n            vector's direction.\n\n    Returns:\n        quaternions with real part first, as tensor of shape (..., 4).\n    \"\"\"\n    angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)\n    half_angles = 0.5 * angles\n    eps = 1e-6\n    small_angles = angles.abs() < eps\n    sin_half_angles_over_angles = torch.empty_like(angles)\n    sin_half_angles_over_angles[~small_angles] = (\n        torch.sin(half_angles[~small_angles]) / angles[~small_angles]\n    )\n    # for x small, sin(x/2) is about x/2 - (x/2)^3/6\n    # so sin(x/2)/x is about 1/2 - (x*x)/48\n    sin_half_angles_over_angles[small_angles] = (\n        0.5 - (angles[small_angles] * angles[small_angles]) / 48\n    )\n    quaternions = torch.cat(\n        [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1\n    )\n    return quaternions\n\n\ndef quaternion_to_axis_angle(quaternions):\n    \"\"\"\n    Convert rotations given as quaternions to axis/angle.\n\n    Args:\n        quaternions: quaternions with real part first,\n            as tensor of shape (..., 4).\n\n    Returns:\n        Rotations given as a vector in axis angle form, as a tensor\n            of shape (..., 3), where the magnitude is the angle\n            turned anticlockwise in radians around the vector's\n            direction.\n    \"\"\"\n    norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)\n    half_angles = torch.atan2(norms, quaternions[..., :1])\n    angles = 2 * half_angles\n    eps = 1e-6\n    small_angles = angles.abs() < eps\n    sin_half_angles_over_angles = torch.empty_like(angles)\n    sin_half_angles_over_angles[~small_angles] = (\n        torch.sin(half_angles[~small_angles]) / angles[~small_angles]\n    )\n    # for x small, sin(x/2) is about x/2 - (x/2)^3/6\n    # so sin(x/2)/x is about 1/2 - (x*x)/48\n    sin_half_angles_over_angles[small_angles] = (\n        0.5 - (angles[small_angles] * angles[small_angles]) / 48\n    )\n    return quaternions[..., 1:] / sin_half_angles_over_angles\n\n\ndef rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Converts 6D rotation representation by Zhou et al. [1] to rotation matrix\n    using Gram--Schmidt orthogonalisation per Section B of [1].\n    Args:\n        d6: 6D rotation representation, of size (*, 6)\n\n    Returns:\n        batch of rotation matrices of size (*, 3, 3)\n\n    [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.\n    On the Continuity of Rotation Representations in Neural Networks.\n    IEEE Conference on Computer Vision and Pattern Recognition, 2019.\n    Retrieved from http://arxiv.org/abs/1812.07035\n    \"\"\"\n\n    a1, a2 = d6[..., :3], d6[..., 3:]\n    b1 = F.normalize(a1, dim=-1)\n    b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1\n    b2 = F.normalize(b2, dim=-1)\n    b3 = torch.cross(b1, b2, dim=-1)\n    return torch.stack((b1, b2, b3), dim=-2)\n\n\ndef matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Converts rotation matrices to 6D rotation representation by Zhou et al. [1]\n    by dropping the last row. Note that 6D representation is not unique.\n    Args:\n        matrix: batch of rotation matrices of size (*, 3, 3)\n\n    Returns:\n        6D rotation representation, of size (*, 6)\n\n    [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.\n    On the Continuity of Rotation Representations in Neural Networks.\n    IEEE Conference on Computer Vision and Pattern Recognition, 2019.\n    Retrieved from http://arxiv.org/abs/1812.07035\n    \"\"\"\n    return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)\n"
  },
  {
    "path": "PBnet/src/utils/tensors.py",
    "content": "import torch\n\n\ndef lengths_to_mask(lengths):\n    max_len = max(lengths)\n    mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)\n    return mask\n    \n\ndef collate_tensors(batch):\n    dims = batch[0].dim()\n    max_size = [max([b.size(i) for b in batch]) for i in range(dims)]\n    size = (len(batch),) + tuple(max_size)\n    canvas = batch[0].new_zeros(size=size)\n    for i, b in enumerate(batch):\n        sub_tensor = canvas[i]\n        for d in range(dims):\n            sub_tensor = sub_tensor.narrow(d, 0, b.size(d))\n        sub_tensor.add_(b)\n    return canvas\n\n\ndef collate(batch):\n\n    posbatch = [b[1] for b in batch]\n    audiobatch = [b[0] for b in batch]\n    lenbatch = [len(b[0]) for b in batch]\n    videonamebatch=[b[2] for b in batch]\n\n    posbatchTensor = collate_tensors(posbatch)\n    audiobatchTensor = collate_tensors(audiobatch)\n    lenbatchTensor = torch.as_tensor(lenbatch)\n\n    maskbatchTensor = lengths_to_mask(lenbatchTensor)\n    batch = {\"x\": posbatchTensor, \"y\": audiobatchTensor,\n             \"mask\": maskbatchTensor, \"lengths\": lenbatchTensor,\n             \"videoname\": videonamebatch}\n    return batch\n"
  },
  {
    "path": "PBnet/src/utils/tensors_eye.py",
    "content": "import torch\n\n\ndef lengths_to_mask(lengths):\n    max_len = max(lengths)\n    mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)\n    return mask\n    \n\ndef collate_tensors(batch):\n    dims = batch[0].dim()\n    max_size = [max([b.size(i) for b in batch]) for i in range(dims)]\n    size = (len(batch),) + tuple(max_size)\n    canvas = batch[0].new_zeros(size=size)\n    for i, b in enumerate(batch):\n        sub_tensor = canvas[i]\n        for d in range(dims):\n            sub_tensor = sub_tensor.narrow(d, 0, b.size(d))\n        sub_tensor.add_(b)\n    return canvas\n\n\ndef collate(batch):\n\n    posbatch = [b[1] for b in batch]\n    audiobatch = [b[0] for b in batch]\n    # eyebatch = [b[2] for b in batch]\n    lenbatch = [len(b[0]) for b in batch]\n    # startbatch = [b[4] for b in batch]\n    videonamebatch=[b[3] for b in batch]\n    poseyebatch=[b[5] for b in batch]\n\n    posbatchTensor = collate_tensors(posbatch)\n    audiobatchTensor = collate_tensors(audiobatch)\n    # eyebatchTensor = collate_tensors(eyebatch)\n    poseyebatchTensor = collate_tensors(poseyebatch)\n    # startbatchTensor = collate_tensors(startbatch)\n    lenbatchTensor = torch.as_tensor(lenbatch)\n\n    maskbatchTensor = lengths_to_mask(lenbatchTensor)\n    batch = {\"x\":poseyebatchTensor,\"p\": posbatchTensor, \"y\": audiobatchTensor,\n             \"mask\": maskbatchTensor, \"lengths\": lenbatchTensor,\n             \"videoname\": videonamebatch} # \n    return batch \n"
  },
  {
    "path": "PBnet/src/utils/tensors_eye_eval.py",
    "content": "import torch\n\n\ndef lengths_to_mask(lengths):\n    max_len = max(lengths)\n    mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)\n    return mask\n    \n\ndef collate_tensors(batch):\n    dims = batch[0].dim()\n    max_size = [max([b.size(i) for b in batch]) for i in range(dims)]\n    size = (len(batch),) + tuple(max_size)\n    canvas = batch[0].new_zeros(size=size)\n    for i, b in enumerate(batch):\n        sub_tensor = canvas[i]\n        for d in range(dims):\n            sub_tensor = sub_tensor.narrow(d, 0, b.size(d))\n        sub_tensor.add_(b)\n    return canvas\n\n\ndef collate(batch):\n\n    posbatch = [b[1] for b in batch]\n    audiobatch = [b[0] for b in batch]\n    eyebatch = [b[2] for b in batch]\n    lenbatch = [len(b[0]) for b in batch]\n    startbatch = [b[4] for b in batch]\n    videonamebatch=[b[3] for b in batch]\n    poseyebatch=[b[5] for b in batch]\n\n    posbatchTensor = collate_tensors(posbatch)\n    audiobatchTensor = collate_tensors(audiobatch)\n    eyebatchTensor = collate_tensors(eyebatch)\n    poseyebatchTensor = collate_tensors(poseyebatch)\n    # startbatchTensor = collate_tensors(startbatch)\n    lenbatchTensor = torch.as_tensor(lenbatch)\n\n    maskbatchTensor = lengths_to_mask(lenbatchTensor)\n    batch = {\"x\":poseyebatchTensor,\"p\": posbatchTensor, \"y\": audiobatchTensor,\n             \"e\": eyebatchTensor, \"mask\": maskbatchTensor, \"lengths\": lenbatchTensor,\n             \"videoname\": videonamebatch, \"start\": startbatch}\n    return batch\n"
  },
  {
    "path": "PBnet/src/utils/tensors_hdtf.py",
    "content": "import torch\n\n\ndef lengths_to_mask(lengths):\n    max_len = max(lengths)\n    mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)\n    return mask\n    \n\ndef collate_tensors(batch):\n    dims = batch[0].dim()\n    max_size = [max([b.size(i) for b in batch]) for i in range(dims)]\n    size = (len(batch),) + tuple(max_size)\n    canvas = batch[0].new_zeros(size=size)\n    for i, b in enumerate(batch):\n        sub_tensor = canvas[i]\n        for d in range(dims):\n            sub_tensor = sub_tensor.narrow(d, 0, b.size(d))\n        sub_tensor.add_(b)\n    return canvas\n\n\ndef collate_old(batch):\n\n    posbatch = [b[1] for b in batch]\n    audiobatch = [b[0] for b in batch]\n    lenbatch = [len(b[0]) for b in batch]\n    startbatch = [b[3] for b in batch]\n    videonamebatch=[b[2] for b in batch]\n\n    posbatchTensor = collate_tensors(posbatch)\n    audiobatchTensor = collate_tensors(audiobatch)\n    # startbatchTensor = collate_tensors(startbatch)\n    lenbatchTensor = torch.as_tensor(lenbatch)\n\n    maskbatchTensor = lengths_to_mask(lenbatchTensor)\n    batch = {\"x\": posbatchTensor, \"y\": audiobatchTensor,\n             \"mask\": maskbatchTensor, \"lengths\": lenbatchTensor,\n             \"videoname\": videonamebatch, \"start\": startbatch}\n    return batch\n\n\ndef collate(batch):\n\n    posbatch = [b[1] for b in batch]\n    audiobatch = [b[0] for b in batch]\n    eyebatch = [b[2] for b in batch]\n    lenbatch = [len(b[0]) for b in batch]\n    # startbatch = [b[4] for b in batch]\n    videonamebatch=[b[3] for b in batch]\n    poseyebatch=[b[5] for b in batch]\n\n    posbatchTensor = collate_tensors(posbatch)\n    audiobatchTensor = collate_tensors(audiobatch)\n    eyebatchTensor = collate_tensors(eyebatch)\n    poseyebatchTensor = collate_tensors(poseyebatch)\n    # startbatchTensor = collate_tensors(startbatch)\n    lenbatchTensor = torch.as_tensor(lenbatch)\n\n    maskbatchTensor = lengths_to_mask(lenbatchTensor)\n    batch = {\"x\":poseyebatchTensor,\"p\": posbatchTensor, \"y\": audiobatchTensor,\n             \"e\": eyebatchTensor, \"mask\": maskbatchTensor, \"lengths\": lenbatchTensor,\n             \"videoname\": videonamebatch}\n    return batch"
  },
  {
    "path": "PBnet/src/utils/tensors_onlyeye.py",
    "content": "import torch\n\n\ndef lengths_to_mask(lengths):\n    max_len = max(lengths)\n    mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)\n    return mask\n    \n\ndef collate_tensors(batch):\n    dims = batch[0].dim()\n    max_size = [max([b.size(i) for b in batch]) for i in range(dims)]\n    size = (len(batch),) + tuple(max_size)\n    canvas = batch[0].new_zeros(size=size)\n    for i, b in enumerate(batch):\n        sub_tensor = canvas[i]\n        for d in range(dims):\n            sub_tensor = sub_tensor.narrow(d, 0, b.size(d))\n        sub_tensor.add_(b)\n    return canvas\n\n\ndef collate(batch):\n\n    # posbatch = [b[1] for b in batch]\n    audiobatch = [b[0] for b in batch]\n    eyebatch = [b[1] for b in batch]\n    lenbatch = [len(b[0]) for b in batch]\n    # startbatch = [b[4] for b in batch]\n    videonamebatch=[b[2] for b in batch]\n    # poseyebatch=[b[5] for b in batch]\n\n    # posbatchTensor = collate_tensors(posbatch)\n    audiobatchTensor = collate_tensors(audiobatch)\n    eyebatchTensor = collate_tensors(eyebatch)\n    # poseyebatchTensor = collate_tensors(poseyebatch)\n    # startbatchTensor = collate_tensors(startbatch)\n    lenbatchTensor = torch.as_tensor(lenbatch)\n\n    maskbatchTensor = lengths_to_mask(lenbatchTensor)\n    batch = {\"x\":eyebatchTensor, \"y\": audiobatchTensor,\n             \"mask\": maskbatchTensor, \"lengths\": lenbatchTensor,\n             \"videoname\": videonamebatch}\n    return batch\n\ndef collate_eval(batch):\n\n    # posbatch = [b[1] for b in batch]\n    audiobatch = [b[0] for b in batch]\n    eyebatch = [b[1] for b in batch]\n    lenbatch = [len(b[0]) for b in batch]\n    startbatch = [b[3] for b in batch]\n    videonamebatch=[b[2] for b in batch]\n    # poseyebatch=[b[5] for b in batch]\n\n    # posbatchTensor = collate_tensors(posbatch)\n    audiobatchTensor = collate_tensors(audiobatch)\n    eyebatchTensor = collate_tensors(eyebatch)\n    # poseyebatchTensor = collate_tensors(poseyebatch)\n    # startbatchTensor = collate_tensors(startbatch)\n    lenbatchTensor = torch.as_tensor(lenbatch)\n\n    maskbatchTensor = lengths_to_mask(lenbatchTensor)\n    batch = {\"x\":eyebatchTensor, \"y\": audiobatchTensor,\n             \"mask\": maskbatchTensor, \"lengths\": lenbatchTensor,\n             \"videoname\": videonamebatch,\"start\": startbatch}\n    return batch\n"
  },
  {
    "path": "PBnet/src/utils/utils.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom threading import Thread\nfrom queue import Queue\n\nclass _RepeatSampler(object):\n\n    def __init__(self, sampler):\n        self.sampler = sampler\n\n    def __iter__(self):\n        while True:\n            yield from iter(self.sampler)\n\nclass MultiEpochsDataLoader(torch.utils.data.DataLoader):\n    \"\"\" During multi-epoch training, the DataLoader object does not need to recreate the thread and batch_sampler objects, \n    in order to save the initialization time for each epoch. \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))\n        self.iterator = super().__iter__()\n\n    def __len__(self):\n        return len(self.batch_sampler.sampler)\n\n    def __iter__(self):\n        for i in range(len(self)):\n            yield next(self.iterator)\n\nclass CudaDataLoader:\n\n    def __init__(self, loader, device, queue_size=2):\n        self.device = device\n        self.queue_size = queue_size\n        self.loader = loader\n\n        self.load_stream = torch.cuda.Stream(device=device)\n        self.queue = Queue(maxsize=self.queue_size)\n\n        self.idx = 0\n        self.worker = Thread(target=self.load_loop)\n        self.worker.setDaemon(True)\n        self.worker.start()\n\n    def load_loop(self):\n        # The loop that will load into the queue in the background\n        while True:\n            for i, sample in enumerate(self.loader):\n                self.queue.put(self.load_instance(sample))\n\n    def load_instance(self, sample):\n        if torch.is_tensor(sample):\n            with torch.cuda.stream(self.load_stream):\n                return sample.to(self.device, non_blocking=True)\n        elif sample is None or type(sample) in (list, str):\n            return sample\n        elif isinstance(sample, dict):\n            return {k: self.load_instance(v) for k, v in sample.items()}\n        else:\n            return [self.load_instance(s) for s in sample]\n\n    def __iter__(self):\n        self.idx = 0\n        return self\n\n    def __next__(self):\n        if not self.worker.is_alive() and self.queue.empty():\n            self.idx = 0\n            self.queue.join()\n            self.worker.join()\n            raise StopIteration\n        elif self.idx >= len(self.loader):\n            self.idx = 0\n            raise StopIteration\n        else:\n            out = self.queue.get()\n            self.queue.task_done()\n            self.idx += 1\n        return out\n\n    def next(self):\n        return self.__next__()\n\n    def __len__(self):\n        return len(self.loader)\n\n    @property\n    def sampler(self):\n        return self.loader.sampler\n\n    @property\n    def dataset(self):\n        return self.loader.dataset"
  },
  {
    "path": "PBnet/src/utils/video.py",
    "content": "import numpy as np\nimport imageio\n\n\ndef load_video(filename):\n    vid = imageio.get_reader(filename, 'ffmpeg')\n    fps = vid.get_meta_data()['fps']\n    nframes = vid.count_frames()\n    return vid, fps, nframes\n\n\nclass SaveVideo:\n    def __init__(self, outname, fps):\n        self.outname = outname\n        self.fps = fps\n\n    def __enter__(self):\n        self.writter = imageio.get_writer(self.outname,\n                                          format='FFMPEG',\n                                          fps=self.fps)\n        return self\n\n    def __exit__(self, exc_type, exc_value, exc_traceback):\n        self.writter.close()\n\n    def __iadd__(self, data):\n        if np.max(data) <= 1:\n            data = np.array(255*data, dtype=np.uint8)\n        else:\n            data = np.array(data, dtype=np.uint8)\n        self.writter.append_data(data)\n        return self\n"
  },
  {
    "path": "PBnet/src/visualize/__init__.py",
    "content": ""
  },
  {
    "path": "PBnet/src/visualize/anim.py",
    "content": "import numpy as np\nimport torch\nimport imageio\n\n# from action2motion\n# Define a kinematic tree for the skeletal struture\nhumanact12_kinematic_chain = [[0, 1, 4, 7, 10],\n                              [0, 2, 5, 8, 11],\n                              [0, 3, 6, 9, 12, 15],\n                              [9, 13, 16, 18, 20, 22],\n                              [9, 14, 17, 19, 21, 23]]  # same as smpl\n\nsmpl_kinematic_chain = humanact12_kinematic_chain\n\nmocap_kinematic_chain = [[0, 1, 2, 3],\n                         [0, 12, 13, 14, 15],\n                         [0, 16, 17, 18, 19],\n                         [1, 4, 5, 6, 7],\n                         [1, 8, 9, 10, 11]]\n\nvibe_kinematic_chain = [[0, 12, 13, 14, 15],\n                        [0, 9, 10, 11, 16],\n                        [0, 1, 8, 17],\n                        [1, 5, 6, 7],\n                        [1, 2, 3, 4]]\n\naction2motion_kinematic_chain = vibe_kinematic_chain\n\n\ndef add_shadow(img, shadow=15):\n    img = np.copy(img)\n    mask = img > shadow\n    img[mask] = img[mask] - shadow\n    img[~mask] = 0\n    return img\n\n\ndef load_anim(path, timesize=None):\n    data = np.array(imageio.mimread(path, memtest=False))[..., :3]\n    if timesize is None:\n        return data\n    # take the last frame and put shadow repeat the last frame but with a little shadow\n    lastframe = add_shadow(data[-1])\n    alldata = np.tile(lastframe, (timesize, 1, 1, 1))\n\n    # copy the first frames\n    lenanim = data.shape[0]\n    alldata[:lenanim] = data[:lenanim]\n    return alldata\n\n\ndef plot_3d_motion(motion, length, save_path, params, title=\"\", interval=50):\n    import matplotlib\n    import matplotlib.pyplot as plt\n    from mpl_toolkits.mplot3d import Axes3D  # noqa: F401\n    from mpl_toolkits.mplot3d.art3d import Poly3DCollection  # noqa: F401\n    from matplotlib.animation import FuncAnimation, writers  # noqa: F401\n    # import mpl_toolkits.mplot3d.axes3d as p3\n    matplotlib.use('Agg')\n    pose_rep = params[\"pose_rep\"]\n\n    fig = plt.figure(figsize=[2.6, 2.8])\n    ax = fig.add_subplot(111, projection='3d')\n    # ax = p3.Axes3D(fig)\n    # ax = fig.gca(projection='3d')\n\n    def init():\n        ax.set_xticklabels([])\n        ax.set_yticklabels([])\n        ax.set_zticklabels([])\n\n        ax.set_xlim(-0.7, 0.7)\n        ax.set_ylim(-0.7, 0.7)\n        ax.set_zlim(-0.7, 0.7)\n\n        ax.view_init(azim=-90, elev=110)\n        # ax.set_axis_off()\n        ax.xaxis._axinfo[\"grid\"]['color'] = (0.5, 0.5, 0.5, 0.25)\n        ax.yaxis._axinfo[\"grid\"]['color'] = (0.5, 0.5, 0.5, 0.25)\n        ax.zaxis._axinfo[\"grid\"]['color'] = (0.5, 0.5, 0.5, 0.25)\n\n    colors = ['red', 'magenta', 'black', 'green', 'blue']\n\n    if pose_rep != \"xyz\":\n        raise ValueError(\"It should already be xyz.\")\n\n    if torch.is_tensor(motion):\n        motion = motion.numpy()\n\n    # invert axis\n    motion[:, 1, :] = -motion[:, 1, :]\n    motion[:, 2, :] = -motion[:, 2, :]\n\n    \"\"\"\n    Debug: to rotate the bodies\n    import src.utils.rotation_conversions as geometry\n    glob_rot = [0, 1.5707963267948966, 0]\n    global_orient = torch.tensor(glob_rot)\n    rotmat = geometry.axis_angle_to_matrix(global_orient)\n    motion = np.einsum(\"ikj,ko->ioj\", motion, rotmat)\n    \"\"\"\n\n    if motion.shape[0] == 18:\n        kinematic_tree = action2motion_kinematic_chain\n    elif motion.shape[0] == 24:\n        kinematic_tree = smpl_kinematic_chain\n    else:\n        kinematic_tree = None\n\n    def update(index):\n        ax.lines = []\n        ax.collections = []\n        if kinematic_tree is not None:\n            for chain, color in zip(kinematic_tree, colors):\n                ax.plot(motion[chain, 0, index],\n                        motion[chain, 1, index],\n                        motion[chain, 2, index], linewidth=4.0, color=color)\n        else:\n            ax.scatter(motion[1:, 0, index], motion[1:, 1, index],\n                       motion[1:, 2, index], c=\"red\")\n            ax.scatter(motion[:1, 0, index], motion[:1, 1, index],\n                       motion[:1, 2, index], c=\"blue\")\n\n    ax.set_title(title)\n\n    ani = FuncAnimation(fig, update, frames=length, interval=interval, repeat=False, init_func=init)\n\n    plt.tight_layout()\n    # pillow have problem droping frames\n    ani.save(save_path, writer='ffmpeg', fps=1000/interval)\n    plt.close()\n\n\ndef plot_3d_motion_dico(x):\n    motion, length, save_path, params, kargs = x\n    plot_3d_motion(motion, length, save_path, params, **kargs)\n"
  },
  {
    "path": "PBnet/src/visualize/visualize.py",
    "content": "import os\nimport imageio\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom tqdm import tqdm\nfrom .anim import plot_3d_motion_dico, load_anim\n\n\ndef stack_images(real, real_gens, gen):\n    nleft_cols = len(real_gens) + 1\n    print(\"Stacking frames..\")\n    allframes = np.concatenate((real[:, None, ...], *[x[:, None, ...] for x in real_gens], gen), 1)\n    nframes, nspa, nats, h, w, pix = allframes.shape\n    blackborder = np.zeros((w//30, h*nats, pix), dtype=allframes.dtype)\n    frames = []\n    for frame_idx in tqdm(range(nframes)):\n        columns = np.vstack(allframes[frame_idx].transpose(1, 2, 3, 4, 0)).transpose(3, 1, 0, 2)\n        frame = np.concatenate((*columns[0:nleft_cols], blackborder, *columns[nleft_cols:]), 0).transpose(1, 0, 2)\n        frames.append(frame)\n    return np.stack(frames)\n\n\ndef generate_by_video(visualization, reconstructions, generation,\n                      label_to_action_name, params, nats, nspa, tmp_path):\n    # shape : (17, 3, 4, 480, 640, 3)\n    # (nframes, row, column, h, w, 3)\n    fps = params[\"fps\"]\n\n    params = params.copy()\n\n    if \"output_xyz\" in visualization:\n        outputkey = \"output_xyz\"\n        params[\"pose_rep\"] = \"xyz\"\n    else:\n        outputkey = \"poses\"\n\n    keep = [outputkey, \"lengths\", \"y\"]\n\n    visu = {key: visualization[key].data.cpu().numpy() for key in keep}\n    recons = {mode: {key: reconstruction[key].data.cpu().numpy() for key in keep}\n              for mode, reconstruction in reconstructions.items()}\n    gener = {key: generation[key].data.cpu().numpy() for key in keep}\n\n    lenmax = max(gener[\"lengths\"].max(),\n                 visu[\"lengths\"].max())\n\n    timesize = lenmax + 5\n    import multiprocessing\n\n    def pool_job_with_desc(pool, iterator, desc, max_, save_path_format, isij):\n        with tqdm(total=max_, desc=desc.format(\"Render\")) as pbar:\n            for _ in pool.imap_unordered(plot_3d_motion_dico, iterator):\n                pbar.update()\n        if isij:\n            array = np.stack([[load_anim(save_path_format.format(i, j), timesize)\n                               for j in range(nats)]\n                              for i in tqdm(range(nspa), desc=desc.format(\"Load\"))])\n            return array.transpose(2, 0, 1, 3, 4, 5)\n        else:\n            array = np.stack([load_anim(save_path_format.format(i), timesize)\n                              for i in tqdm(range(nats), desc=desc.format(\"Load\"))])\n            return array.transpose(1, 0, 2, 3, 4)\n\n    with multiprocessing.Pool() as pool:\n        # Generated samples\n        save_path_format = os.path.join(tmp_path, \"gen_{}_{}.gif\")\n        iterator = ((gener[outputkey][i, j],\n                     gener[\"lengths\"][i, j],\n                     save_path_format.format(i, j),\n                     params, {\"title\": f\"gen: {label_to_action_name(gener['y'][i, j])}\", \"interval\": 1000/fps})\n                    for j in range(nats) for i in range(nspa))\n        gener[\"frames\"] = pool_job_with_desc(pool, iterator,\n                                             \"{} the generated samples\",\n                                             nats*nspa,\n                                             save_path_format,\n                                             True)\n        # Real samples\n        save_path_format = os.path.join(tmp_path, \"real_{}.gif\")\n        iterator = ((visu[outputkey][i],\n                     visu[\"lengths\"][i],\n                     save_path_format.format(i),\n                     params, {\"title\": f\"real: {label_to_action_name(visu['y'][i])}\", \"interval\": 1000/fps})\n                    for i in range(nats))\n        visu[\"frames\"] = pool_job_with_desc(pool, iterator,\n                                            \"{} the real samples\",\n                                            nats,\n                                            save_path_format,\n                                            False)\n        for mode, recon in recons.items():\n            # Reconstructed samples\n            save_path_format = os.path.join(tmp_path, f\"reconstructed_{mode}_\" + \"{}.gif\")\n            iterator = ((recon[outputkey][i],\n                         recon[\"lengths\"][i],\n                         save_path_format.format(i),\n                         params, {\"title\": f\"recons: {label_to_action_name(recon['y'][i])}\",\n                                  \"interval\": 1000/fps})\n                        for i in range(nats))\n            recon[\"frames\"] = pool_job_with_desc(pool, iterator,\n                                                 \"{} the reconstructed samples\",\n                                                 nats,\n                                                 save_path_format,\n                                                 False)\n\n    frames = stack_images(visu[\"frames\"], [recon[\"frames\"] for recon in recons.values()], gener[\"frames\"])\n    return frames\n\n\ndef viz_epoch(model, dataset, epoch, params, folder, writer=None):\n    \"\"\" Generate & viz samples \"\"\"\n\n    # visualize with joints3D\n    model.outputxyz = True\n\n    print(f\"Visualization of the epoch {epoch}\")\n\n    noise_same_action = params[\"noise_same_action\"]\n    noise_diff_action = params[\"noise_diff_action\"]\n    duration_mode = params[\"duration_mode\"]\n    reconstruction_mode = params[\"reconstruction_mode\"]\n    decoder_test = params[\"decoder_test\"]\n\n    fact = params[\"fact_latent\"]\n    figname = params[\"figname\"].format(epoch)\n\n    nspa = params[\"num_samples_per_action\"]\n    nats = params[\"num_actions_to_sample\"]\n\n    num_classes = params[\"num_classes\"]\n\n    # define some classes\n    classes = torch.randperm(num_classes)[:nats]\n\n    meandurations = torch.from_numpy(np.array([round(dataset.get_mean_length_label(cl.item()))\n                                               for cl in classes]))\n\n    if duration_mode == \"interpolate\" or decoder_test == \"diffduration\":\n        points, step = np.linspace(-nspa, nspa, nspa, retstep=True)\n        points = np.round(10*points/step).astype(int)\n        gendurations = meandurations.repeat((nspa, 1)) + points[:, None]\n    else:\n        gendurations = meandurations.repeat((nspa, 1))\n\n    # extract the real samples\n    real_samples, mask_real, real_lengths = dataset.get_label_sample_batch(classes.numpy())\n    # to visualize directly\n\n    # Visualizaion of real samples\n    visualization = {\"x\": real_samples.to(model.device),\n                     \"y\": classes.to(model.device),\n                     \"mask\": mask_real.to(model.device),\n                     \"lengths\": real_lengths.to(model.device),\n                     \"output\": real_samples.to(model.device)}\n\n    # Visualizaion of real samples\n    if reconstruction_mode == \"both\":\n        reconstructions = {\"tf\": {\"x\": real_samples.to(model.device),\n                                  \"y\": classes.to(model.device),\n                                  \"lengths\": real_lengths.to(model.device),\n                                  \"mask\": mask_real.to(model.device),\n                                  \"teacher_force\": True},\n                           \"ntf\": {\"x\": real_samples.to(model.device),\n                                   \"y\": classes.to(model.device),\n                                   \"lengths\": real_lengths.to(model.device),\n                                   \"mask\": mask_real.to(model.device)}}\n    else:\n        reconstructions = {reconstruction_mode: {\"x\": real_samples.to(model.device),\n                                                 \"y\": classes.to(model.device),\n                                                 \"lengths\": real_lengths.to(model.device),\n                                                 \"mask\": mask_real.to(model.device),\n                                                 \"teacher_force\": reconstruction_mode == \"tf\"}}\n    print(\"Computing the samples poses..\")\n\n    # generate the repr (joints3D/pose etc)\n    model.eval()\n    with torch.no_grad():\n        # Reconstruction of the real data\n        for mode in reconstructions:\n            model(reconstructions[mode])  # update reconstruction dicts\n        reconstruction = reconstructions[list(reconstructions.keys())[0]]\n\n        if decoder_test == \"new\":\n            # Generate the new data\n            generation = model.generate(classes, gendurations, nspa=nspa,\n                                        noise_same_action=noise_same_action,\n                                        noise_diff_action=noise_diff_action,\n                                        fact=fact)\n        elif decoder_test == \"diffaction\":\n            assert nats == nspa\n            # keep the same noise for each \"sample\"\n            z = reconstruction[\"z\"].repeat((nspa, 1))\n            mask = reconstruction[\"mask\"].repeat((nspa, 1))\n            lengths = reconstruction[\"lengths\"].repeat(nspa)\n            # but use other labels\n            y = classes.repeat_interleave(nspa).to(model.device)\n            generation = {\"z\": z, \"y\": y, \"mask\": mask, \"lengths\": lengths}\n            model.decoder(generation)\n\n        elif decoder_test == \"diffduration\":\n            z = reconstruction[\"z\"].repeat((nspa, 1))\n            lengths = gendurations.reshape(-1).to(model.device)\n            mask = model.lengths_to_mask(lengths)\n            y = classes.repeat(nats).to(model.device)\n            generation = {\"z\": z, \"y\": y, \"mask\": mask, \"lengths\": lengths}\n            model.decoder(generation)\n\n        elif decoder_test == \"interpolate_action\":\n            assert nats == nspa\n            # same noise for each sample\n            z_diff_action = torch.randn(1, model.latent_dim, device=model.device).repeat(nats, 1)\n            z = z_diff_action.repeat((nspa, 1))\n\n            # but use combination of labels and labels below\n            y = F.one_hot(classes.to(model.device), model.num_classes).to(model.device)\n            y_below = F.one_hot(torch.cat((classes[1:], classes[0:1])), model.num_classes).to(model.device)\n            convex_factors = torch.linspace(0, 1, nspa, device=model.device)\n            y_mixed = torch.einsum(\"nk,m->mnk\", y, 1-convex_factors) + torch.einsum(\"nk,m->mnk\", y_below, convex_factors)\n            y_mixed = y_mixed.reshape(nspa*nats, y_mixed.shape[-1])\n\n            durations = gendurations[0].to(model.device)\n            durations_below = torch.cat((durations[1:], durations[0:1]))\n\n            gendurations = torch.einsum(\"l,k->kl\", durations, 1-convex_factors) + torch.einsum(\"l,k->kl\", durations_below, convex_factors)\n            gendurations = gendurations.to(dtype=durations.dtype)\n\n            lengths = gendurations.to(model.device).reshape(z.shape[0])\n            mask = model.lengths_to_mask(lengths)\n\n            generation = {\"z\": z, \"y\": y_mixed, \"mask\": mask, \"lengths\": lengths}\n            model.decoder(generation)\n\n        # Get xyz for the real ones\n        visualization[\"output_xyz\"] = model.rot2xyz(visualization[\"output\"], visualization[\"mask\"])\n\n    for key, val in generation.items():\n        if len(generation[key].shape) == 1:\n            generation[key] = val.reshape(nspa, nats)\n        else:\n            generation[key] = val.reshape(nspa, nats, *val.shape[1:])\n\n    finalpath = os.path.join(folder, figname + \".gif\")\n    tmp_path = os.path.join(folder, f\"subfigures_{figname}\")\n    os.makedirs(tmp_path, exist_ok=True)\n\n    print(\"Generate the videos..\")\n    frames = generate_by_video(visualization, reconstructions, generation,\n                               dataset.label_to_action_name, params, nats, nspa, tmp_path)\n\n    print(f\"Writing video {finalpath}..\")\n    imageio.mimsave(finalpath, frames, fps=params[\"fps\"])\n\n    if writer is not None:\n        writer.add_video(f\"Video/Epoch {epoch}\", frames.transpose(0, 3, 1, 2)[None], epoch, fps=params[\"fps\"])\n\n\ndef viz_dataset(dataset, params, folder):\n    \"\"\" Generate & viz samples \"\"\"\n    print(\"Visualization of the dataset\")\n\n    nspa = params[\"num_samples_per_action\"]\n    nats = params[\"num_actions_to_sample\"]\n\n    num_classes = params[\"num_classes\"]\n\n    figname = \"{}_{}_numframes_{}_sampling_{}_step_{}\".format(params[\"dataset\"],\n                                                              params[\"pose_rep\"],\n                                                              params[\"num_frames\"],\n                                                              params[\"sampling\"],\n                                                              params[\"sampling_step\"])\n\n    # define some classes\n    classes = torch.randperm(num_classes)[:nats]\n\n    allclasses = classes.repeat(nspa, 1).reshape(nspa*nats)\n    # extract the real samples\n    real_samples, mask_real, real_lengths = dataset.get_label_sample_batch(allclasses.numpy())\n    # to visualize directly\n\n    # Visualizaion of real samples\n    visualization = {\"x\": real_samples,\n                     \"y\": allclasses,\n                     \"mask\": mask_real,\n                     \"lengths\": real_lengths,\n                     \"output\": real_samples}\n\n    from src.models.rotation2xyz import Rotation2xyz\n\n    device = params[\"device\"]\n    rot2xyz = Rotation2xyz(device=device)\n\n    rot2xyz_params = {\"pose_rep\": params[\"pose_rep\"],\n                      \"glob_rot\": params[\"glob_rot\"],\n                      \"glob\": params[\"glob\"],\n                      \"jointstype\": params[\"jointstype\"],\n                      \"translation\": params[\"translation\"]}\n\n    output = visualization[\"output\"]\n    visualization[\"output_xyz\"] = rot2xyz(output.to(device),\n                                          visualization[\"mask\"].to(device), **rot2xyz_params)\n\n    for key, val in visualization.items():\n        if len(visualization[key].shape) == 1:\n            visualization[key] = val.reshape(nspa, nats)\n        else:\n            visualization[key] = val.reshape(nspa, nats, *val.shape[1:])\n\n    finalpath = os.path.join(folder, figname + \".gif\")\n    tmp_path = os.path.join(folder, f\"subfigures_{figname}\")\n    os.makedirs(tmp_path, exist_ok=True)\n\n    print(\"Generate the videos..\")\n    frames = generate_by_video_sequences(visualization, dataset.label_to_action_name, params, nats, nspa, tmp_path)\n\n    print(f\"Writing video {finalpath}..\")\n    imageio.mimsave(finalpath, frames, fps=params[\"fps\"])\n\n\ndef generate_by_video_sequences(visualization, label_to_action_name, params, nats, nspa, tmp_path):\n    # shape : (17, 3, 4, 480, 640, 3)\n    # (nframes, row, column, h, w, 3)\n    fps = params[\"fps\"]\n\n    if \"output_xyz\" in visualization:\n        outputkey = \"output_xyz\"\n        params[\"pose_rep\"] = \"xyz\"\n    else:\n        outputkey = \"poses\"\n\n    keep = [outputkey, \"lengths\", \"y\"]\n    visu = {key: visualization[key].data.cpu().numpy() for key in keep}\n    lenmax = visu[\"lengths\"].max()\n\n    timesize = lenmax + 5\n    import multiprocessing\n\n    def pool_job_with_desc(pool, iterator, desc, max_, save_path_format):\n        with tqdm(total=max_, desc=desc.format(\"Render\")) as pbar:\n            for _ in pool.imap_unordered(plot_3d_motion_dico, iterator):\n                pbar.update()\n        array = np.stack([[load_anim(save_path_format.format(i, j), timesize)\n                           for j in range(nats)]\n                          for i in tqdm(range(nspa), desc=desc.format(\"Load\"))])\n        return array.transpose(2, 0, 1, 3, 4, 5)\n\n    with multiprocessing.Pool() as pool:\n        # Real samples\n        save_path_format = os.path.join(tmp_path, \"real_{}_{}.gif\")\n        iterator = ((visu[outputkey][i, j],\n                     visu[\"lengths\"][i, j],\n                     save_path_format.format(i, j),\n                     params, {\"title\": f\"real: {label_to_action_name(visu['y'][i, j])}\", \"interval\": 1000/fps})\n                    for j in range(nats) for i in range(nspa))\n        visu[\"frames\"] = pool_job_with_desc(pool, iterator,\n                                            \"{} the real samples\",\n                                            nats,\n                                            save_path_format)\n    frames = stack_images_sequence(visu[\"frames\"])\n    return frames\n\n\ndef stack_images_sequence(visu):\n    print(\"Stacking frames..\")\n    allframes = visu\n    nframes, nspa, nats, h, w, pix = allframes.shape\n    frames = []\n    for frame_idx in tqdm(range(nframes)):\n        columns = np.vstack(allframes[frame_idx].transpose(1, 2, 3, 4, 0)).transpose(3, 1, 0, 2)\n        frame = np.concatenate(columns).transpose(1, 0, 2)\n        frames.append(frame)\n    return np.stack(frames)\n"
  },
  {
    "path": "PBnet/src/visualize/visualize_checkpoint.py",
    "content": "import os\n\nimport matplotlib.pyplot as plt\nimport torch\nfrom src.utils.get_model_and_data import get_model_and_data\nfrom src.parser.visualize import parser\nfrom .visualize import viz_epoch\n\nimport src.utils.fixseed  # noqa\n\nplt.switch_backend('agg')\n\n\ndef main():\n    # parse options\n    parameters, folder, checkpointname, epoch = parser()\n\n    model, datasets = get_model_and_data(parameters)\n    dataset = datasets[\"train\"]\n\n    print(\"Restore weights..\")\n    checkpointpath = os.path.join(folder, checkpointname)\n    state_dict = torch.load(checkpointpath, map_location=parameters[\"device\"])\n    model.load_state_dict(state_dict)\n    \n    # visualize_params\n    viz_epoch(model, dataset, epoch, parameters, folder=folder, writer=None)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "PBnet/src/visualize/visualize_dataset.py",
    "content": "import matplotlib.pyplot as plt\n# import torch\nimport os\n\nfrom src.datasets.get_dataset import get_dataset\nfrom src.utils import optutils\nfrom src.utils.visualize import viz_dataset\n\nimport src.utils.fixseed  # noqa\n\nplt.switch_backend('agg')\n\n\nif __name__ == '__main__':\n    # parse options\n    parameters = optutils.visualize_dataset_parser()\n\n    # get device\n    device = parameters[\"device\"]\n\n    # get data\n    DATA = get_dataset(name=parameters[\"dataset\"])\n    dataset = DATA(split=\"train\", **parameters)\n\n    # add specific parameters from the dataset loading\n    dataset.update_parameters(parameters)\n\n    name = f\"{parameters['dataset']}_{parameters['extraction_method']}\"\n    folder = os.path.join(\"datavisualize\", name)\n    viz_dataset(dataset, parameters, folder)\n"
  },
  {
    "path": "PBnet/src/visualize/visualize_latent_space.py",
    "content": "import matplotlib.pyplot as plt\nimport numpy as np\nimport os\nimport scipy\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch.utils.tensorboard import SummaryWriter\nfrom tqdm import tqdm\n\nfrom ..utils import optutils\nfrom ..utils.visualize import viz_epoch, viz_fake, viz_real\n\nfrom ..models.get_model import get_model\nfrom ..datasets.get_dataset import get_dataset\nfrom ..utils.trainer import train, test\n\n# import ..utils.fixseed  # noqa\n\n\nplt.switch_backend('agg')\n\n\nif __name__ == '__main__':\n    # parse options\n    opt, folder, checkpointname, epoch = optutils.parse_load_args()\n    \n    # get device \n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    \n    # get data\n    DATA = get_dataset(name=opt.dataname)\n    dataset = DATA(split=\"train\", **opt.data)\n    test_dataset = train_dataset = dataset\n    \n    # update model parameters\n    opt.model.update({\"num_classes\": dataset.num_classes, \"nfeats\": dataset.nfeats, \"device\": device})\n\n    # update visualize params\n    opt.visualize.update({\"num_classes\": dataset.num_classes,\n                          \"num_actions_to_sample\": min(opt.visualize[\"num_actions_to_sample\"],\n                                                       dataset.num_classes)})\n    \n    # get model\n    MODEL = get_model(opt.modelname)\n    model = MODEL(**opt.model)\n    model = model.to(device)\n\n    print(\"Restore weights..\")\n    checkpointpath = os.path.join(folder, checkpointname)\n    state_dict = torch.load(checkpointpath, map_location=device)\n    model.load_state_dict(state_dict)\n\n    nexemple = 20\n    latents = []\n    labels = []\n    generats = []\n    \n    print(\"Evaluating model..\")\n    keep = {\"x\": [], \"y\": [], \"di\": []}\n\n    num_classes = dataset.num_classes\n    # num_classes = 1\n    \n    for label in tqdm(range(num_classes)):\n        xcp, ycp, di = dataset.get_label_sample(label, n=nexemple, return_labels=True, return_index=True)\n        keep[\"x\"].append(xcp)\n        keep[\"y\"].append(ycp)\n        keep[\"di\"].append(di)\n        \n        x = torch.from_numpy(xcp).to(device)\n        y = torch.from_numpy(ycp).to(device)\n        h = model.return_latent(x, y)\n        \n        # mu, var = model.encoder(x, y)\n        # h = mu\n\n        hy = torch.randn(nexemple, model.latent_dim, device=device)\n        \n        hcp = h.data.cpu().numpy()\n        hycp = hy.data.cpu().numpy()\n        \n        latents.append(hcp)\n        generats.append(hycp)\n        \n        labels.append(ycp)\n        \n    latents = np.array(latents)\n    generats = np.array(generats)\n    \n    nclasses, nexemple, latent_dim = latents.shape\n    labels = np.array(labels)\n    all_latents = np.concatenate(latents)\n    all_generats = np.concatenate(generats)\n\n    nall_latents = len(all_latents)\n\n    # import ipdb; ipdb.set_trace()\n    print(\"Computing tsne..\")\n    from sklearn.manifold import TSNE\n\n    all_input = np.concatenate((all_latents, all_generats))\n    # tsne = TSNE(n_components=2)\n    # all_vizu_concat = tsne.fit_transform(all_input)\n    # import ipdb; ipdb.set_trace()\n    # feats = tuple(np.argsort(all_latents.var(0))[::-1][:2])\n    feats = tuple(np.argsort(all_latents.min(0)-all_latents.max(0))[::-1][:2] )\n    all_vizu_concat = all_input[:, feats]\n    \n    all_vizu_vectors = all_vizu_concat[:nall_latents]\n    all_gen_vizu_vectors = all_vizu_concat[nall_latents:]\n\n    gen_vizu_vectors = all_gen_vizu_vectors.reshape(nclasses, nexemple, 2)\n    vizu_vectors = all_vizu_vectors.reshape(nclasses, nexemple, 2)\n    \n    print(\"Plotting..\")\n    import matplotlib.pyplot as plt\n    import matplotlib.colors as mcolors\n\n    colors = list(mcolors.TABLEAU_COLORS.values()) + list(mcolors.BASE_COLORS.values()) + list(mcolors.CSS4_COLORS.values())    \n    for label in tqdm(range(num_classes)):\n        color = colors[label]\n        plt.scatter(*gen_vizu_vectors[label].T, color=color, marker=\"X\")\n        \n    for label in tqdm(range(num_classes)):\n        color = colors[label]\n        plt.scatter(*vizu_vectors[label].T, color=color)\n        \n    plt.savefig(\"tsne_all.png\")\n    plt.close()\n\n    import ipdb; ipdb.set_trace()\n    \"\"\"\n    mean = all_vizu_vectors.mean()\n    farthest = np.argsort(np.linalg.norm(mean - all_vizu_vectors, axis=1))[::-1][0]\n    cl_number, exnumber = np.argwhere(np.arange(all_vizu_vectors.shape[0]).reshape(nclasses, nexemple) == farthest)[0]\n\n    outlier_vid = keep[\"x\"][cl_number][exnumber]\n    nframe = outlier_vid.shape[-1]\n    \n    from ..utils.video import SaveVideo\n    save_path = \"outlier.mp4\"\n\n    cl_name = dataset.label_to_action_name(cl_number)\n    \n    with SaveVideo(save_path, opt.visualize[\"fps\"]) as outvideo:\n        for frame in range(nframe):\n            outvideo += repr_to_frame(outlier_vid[..., frame], f\"{cl_name} outlier\", {\"pose_rep\": \"xyz\"})\n    \n\"\"\"\n"
  },
  {
    "path": "PBnet/src/visualize/visualize_nturefined.py",
    "content": "import matplotlib.pyplot as plt\nimport torch\n\nfrom src.datasets.get_dataset import get_dataset\nfrom src.utils.anim import plot_3d_motion\nimport src.utils.fixseed  # noqa\n\nplt.switch_backend('agg')\n\n\ndef viz_ntu13(dataset, device):\n    \"\"\" Generate & viz samples \"\"\"\n    print(\"Visualization of the ntu13\")\n    \n    from src.models.rotation2xyz import Rotation2xyz\n    rot2xyz = Rotation2xyz(device)\n    \n    realsamples = []\n    pose18samples = []\n    pose24samples = []\n\n    translation = True\n    dataset.glob = True\n    dataset.translation = translation\n    \n    for i in range(1, 2):\n        dataset.pose_rep = \"xyz\"\n        x_xyz = dataset[i][0]\n        realsamples.append(x_xyz)\n        \n        dataset.pose_rep = \"rotvec\"\n        pose = dataset[i][0]\n        mask = torch.ones(pose.shape[2], dtype=bool)\n\n        # from src.models.smpl import SMPL\n        # smplmodel = SMPL().eval().to(device)\n        # import ipdb; ipdb.set_trace()\n        pose24 = rot2xyz(pose[None], mask[None], pose_rep=\"rotvec\", jointstype=\"smpl\", glob=True, translation=translation)[0]\n        pose18 = rot2xyz(pose[None], mask[None], pose_rep=\"rotvec\", jointstype=\"a2m\", glob=True, translation=translation)[0]\n        \n        translation = True\n        dataset.glob = True\n        dataset.translation = translation\n        \n        # poseT = dataset[i][0]\n        # pose18T = rot2xyz(poseT[None], mask[None], pose_rep=\"rotvec\", jointstype=\"action2motion\", glob=True, translation=translation)[0]\n        \n        # import ipdb; ipdb.set_trace()\n        pose18samples.append(pose18)\n        pose24samples.append(pose24)\n\n    params = {\"pose_rep\": \"xyz\"}\n    for i in [0]:\n        for x_xyz, title in zip([pose24samples[i], pose18samples[i], realsamples[i]], [\"pose_to_24\", \"pose_to_18\", \"action2motion_18\"]):\n            save_path = title + \".gif\"\n            plot_3d_motion(x_xyz, x_xyz.shape[-1], save_path, params, title=title)\n            print(f\"saving {save_path}\")\n    \n\nif __name__ == '__main__':\n    # get device\n    device = torch.device('cpu')\n\n    # get data\n    DATA = get_dataset(name=\"ntu13\")\n    dataset = DATA(split=\"train\")\n    \n    viz_ntu13(dataset, device)\n"
  },
  {
    "path": "PBnet/src/visualize/visualize_sequence.py",
    "content": "import os\n\nimport matplotlib.pyplot as plt\nimport torch\nimport numpy as np\n\nfrom src.datasets.get_dataset import get_dataset\nfrom src.models.get_model import get_model\nfrom src.utils import optutils\n\nfrom src.utils.anim import plot_3d_motion_on_oneframe\nfrom src.utils.visualize import process_to_visualize\n\nimport src.utils.fixseed  # noqa\n\n\nplt.switch_backend('agg')\n\n\nif __name__ == '__main__':\n    # parse options\n    opt, folder, checkpointname, epoch = optutils.visualize_parser()\n\n    # get device\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\n    # get data\n    DATA = get_dataset(name=opt.dataset)\n    dataset = DATA(split=\"train\", **opt.data)\n    test_dataset = train_dataset = dataset\n\n    # update model parameters\n    opt.model.update({\"num_classes\": dataset.num_classes, \"nfeats\": dataset.nfeats, \"device\": device})\n\n    # update visualize params\n    opt.visualize.update({\"num_classes\": dataset.num_classes,\n                          \"num_actions_to_sample\": min(opt.visualize[\"num_actions_to_sample\"],\n                                                       dataset.num_classes)})\n\n    # get model\n    MODEL = get_model(opt.modelname)\n    model = MODEL(**opt.model)\n    model = model.to(device)\n\n    print(\"Restore weights..\")\n    checkpointpath = os.path.join(folder, checkpointname)\n    state_dict = torch.load(checkpointpath, map_location=device)\n    model.load_state_dict(state_dict)\n\n    save_path = os.path.join(folder, f\"fig_{epoch}\")\n\n    action_number = 0\n    actioname = dataset.action_to_action_name(action_number)\n    label = dataset.action_to_label(action_number)\n    print(f\"Generate {actioname}..\")\n    \n    y = torch.from_numpy(np.array([label], dtype=int)).to(device)\n    motion = model.generate(y, fact=1)\n    motion = process_to_visualize(motion.data.cpu().numpy(), opt.visualize)[0]\n    \n    print(\"Plot motion..\")\n    plot_3d_motion_on_oneframe(motion, \"motion.png\", opt.visualize, title=actioname)\n"
  },
  {
    "path": "README.md",
    "content": "# 🌅 DAWN: Dynamic Frame Avatar with Non-autoregressive Diffusion Framework for Talking Head Video Generation\n\n[![arXiv](https://img.shields.io/badge/Arxiv-2410.13726-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.13726)\n[![Demo Page](https://img.shields.io/badge/Demo_Page-blue)](https://hanbo-cheng.github.io/DAWN/)\n[![zhihu](https://img.shields.io/badge/知乎-0079FF.svg?logo=zhihu&logoColor=white)](https://zhuanlan.zhihu.com/p/2253009511)\n <a href='https://huggingface.co/Hanbo-Cheng/DAWN'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Model-yellow'></a>\n\n\n[中文文档](README_CN.md)\n<p align=\"center\">\n<img src=\"structure_img\\ifferent-styles-at-higher-resolution.gif\" width=600>\n</p>\n\n\n<h5 align=\"center\"> 😊 Please give us a star ⭐ to support us for continous update 😊  </h5>\n\n## News\n* ```2024.10.14``` 🔥 We release the [Demo page](https://hanbo-cheng.github.io/DAWN/).\n* ```2024.10.18``` 🔥 We release the paper [DAWN](https://arxiv.org/abs/2410.13726).\n* ```2024.10.21``` 🔥 We update the Chinese introduction [](https://zhuanlan.zhihu.com/p/2253009511).\n* ```2024.11.7``` 🔥🔥 We realse the pretrained model on [hugging face](https://huggingface.co/Hanbo-Cheng/DAWN).\n* ```2024.11.9``` 🔥🔥🔥 We realse the inference code. We sincerely invite you to experience our model. 😊\n*  ```2025.2.16``` 🔥🔥🔥 We optimize the unified inference code. Now you can run the test pipeline with only one script. 🚀\n## TODO list:\n- [x]  release the inference code\n- [x]  release the pretrained model of **128*128**\n- [x]  release the pretrained model of **256*256** \n- [x] release the unified test code\n- [ ] in progress ...\n\n\n## Equipment Requirements\n\nWith our VRAM-oriented optimized [code](DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test_local_opt.py), the maximum length of video that can be generated is **linearly related** to the size of the GPU VRAM. Larger VRAM produce longer videos.\n- To generate **128*128** video, we recommend using a GPU with **12GB** or more VRAM. This can at least generate video of approximately **400 frames**.\n- To generate **256*256** video, we recommend using a GPU with **24GB** or more VRAM. This can at least generate video of approximately **200 frames**.\n\nPS: Although optimized [code](DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test_local_opt.py) can improve VRAM utilization, it currently sacrifices inference speed due to incomplete optimization of local attention. We are actively working on this issue, and if you have a better solution, we welcome your PR. If you wish to achieve faster inference speeds, you can use [unoptimized code](DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test.py), but this will increase VRAM usage (O(n²) spatial complexity).\n## Methodology\n### The overall structure of DAWN:\n<p align=\"center\">\n<img src=\"structure_img\\pipeline.png\" width=600 alt=\"framework\"/>\n</p>\n\n\n## Environment\nWe highly recommend to try DAWN on linux platform. Runing on windows may produce some rubbish files need to be deleted manually and requires additional effort for the deployment of the 3DDFA repository (our `extract_init_states` folder) [comment](https://github.com/cleardusk/3DDFA_V2/issues/12#issuecomment-697479173).\n\n1. set up the conda environment\n```\nconda create -n DAWN python=3.8\nconda activate DAWN\npip install -r requirements.txt\n```\n\n2. Follow the [readme](extract_init_states/readme.md) and [3DDFA_V2](https://github.com/cleardusk/3DDFA_V2) to set up the 3DDFA environment.\n \n\n## Inference\n\nSince our model **is trained only on the HDTF dataset** and has few parameters, in order to ensure the best driving effect, please provide examples of :\n- standard human photos as much as possible, try not to wear hats or large headgear\n- ensure a clear boundary between the background and the subject\n- have the face occupying the main position in the image.\n\nThe preparation for inference:\n1. Download the pretrain checkpoints from [hugging face](https://huggingface.co/Hanbo-Cheng/DAWN). Create the `./pretrain_models` directory and put the checkpoint files into it. Please down load the Hubert model from [facebook/hubert-large-ls960-ft](https://huggingface.co/facebook/hubert-large-ls960-ft/tree/main). \n      ```\n      directory structure:\n\n      pretrain_models/\n            ├── LFG_256_400ep.pth\n            ├── LFG_128_1000ep.pth\n            ├── DAWN_256.pth\n            ├── DAWN_128.pth\n            └── hubert-large-ls960-ft/\n                  ├── .....\n      ```\n\n2. Run the inference script: \n   ```\n   python unified_video_generator.py  \\\n      --audio_path your/audio/path  \\\n      --image_path your/image/path  \\\n      --output_path output/path \\\n      --cache_path cache/path \\\n      --resolution 128 \\   # optional: 128 or 256\n   ```\n\n***Inference on other dataset:***\nBy specifying the `audio_path`, `image_path`, and `output_path` of the `VideoGenerator` class during each inference, and modifying the contents of `directory_name` and `output_video_path` in `unified_video_generator.py` Lines 310-312 and 393-394, you can control the naming logic for saving images and videos, enabling testing on any dataset.\n\n\n## Citing DAWN\nIf you wish to refer to the baseline results published here, please use the following BibTeX entries:\n\n```BibTeX\n@misc{dawn2024,\n      title={DAWN: Dynamic Frame Avatar with Non-autoregressive Diffusion Framework for Talking Head Video Generation}, \n      author={Hanbo Cheng and Limin Lin and Chenyu Liu and Pengcheng Xia and Pengfei Hu and Jiefeng Ma and Jun Du and Jia Pan},\n      year={2024},\n      eprint={2410.13726},\n      archivePrefix={arXiv},\n      primaryClass={cs.CV},\n      url={https://arxiv.org/abs/2410.13726}, \n}\n```\n## Acknowledgement\n\n[Limin Lin](https://github.com/LiminLin0) and [Hanbo Cheng](https://github.com/Hanbo-Cheng) contributed equally to the project.\n\nThank you to the authors of [Diffused Heads](https://github.com/MStypulkowski/diffused-heads) for assisting us in reproducing their work! We also extend our gratitude to the authors of [MRAA](https://github.com/snap-research/articulated-animation), [LFDM](https://github.com/snap-research/articulated-animation), [3DDFA_V2](https://github.com/cleardusk/3DDFA_V2) and [ACTOR](https://github.com/Mathux/ACTOR) for their contributions to the open-source community. Lastly, we thank our mentors and co-authors for their continuous support in our research work!\n\n"
  },
  {
    "path": "README_CN.md",
    "content": "# 🌅 DAWN：Dynamic Frame Avatar with Non-autoregressive Diffusion Framework for Talking Head Video Generation\n\n[![arXiv](https://img.shields.io/badge/Arxiv-2410.13726-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.13726)\n[![Demo Page](https://img.shields.io/badge/Demo_Page-blue)](https://hanbo-cheng.github.io/DAWN/)\n[![zhihu](https://img.shields.io/badge/知乎-0079FF.svg?logo=zhihu&logoColor=white)](https://zhuanlan.zhihu.com/p/2253009511)\n <a href='https://huggingface.co/Hanbo-Cheng/DAWN'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Model-yellow'></a>\n<p align=\"center\">\n\n<img src=\"structure_img\\ifferent-styles-at-higher-resolution.gif\" width=600>\n</p>\n\n\n😊 请给我们一个star⭐支持我们的持续更新 😊\n## 新闻\n* ```2024.10.14``` 🔥 我们发布了 [DEMO](https://hanbo-cheng.github.io/DAWN/)。\n* ```2024.10.18``` 🔥 我们发布了论文 [DAWN](https://arxiv.org/abs/2410.13726)。\n* ```2024.10.21``` 🔥 我们更新了中文介绍 [知乎](https://zhuanlan.zhihu.com/p/2253009511)。\n* ```2024.11.7``` 🔥🔥 我们在 [hugging face](https://huggingface.co/Hanbo-Cheng/DAWN) 上发布了预训练模型。\n* ```2024.11.9``` 🔥🔥🔥 我们发布了推理代码。我们诚挚邀请您体验我们的模型。😊\n*  ```2025.2.16``` 🔥🔥🔥 我们优化了统一推理代码。现在您可以仅用一个脚本运行测试流程。🚀\n\n## 待办事项列表：\n- [x]  发布推理代码\n- [x]  发布 **128*128** 的预训练模型\n- [x]  发布 **256*256** 的预训练模型 \n- [x] 发布统一测试代码\n- [ ]  进行中 ...\n\n## 设备要求\n\n使用我们针对VRAM优化的 [代码](DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test_local_opt.py)，生成的视频最大长度与GPU VRAM的大小 **成线性关系**。更大的VRAM可以生成更长的视频。\n- 要生成 **128*128** 视频，我们建议使用 **12GB** 或更多VRAM的GPU。这至少可以生成大约 **400帧** 的视频。\n- 要生成 **256*256** 视频，我们建议使用 **24GB** 或更多VRAM的GPU。这至少可以生成大约 **200帧** 的视频。\n\nPS：尽管优化的 [代码](DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test_local_opt.py) 可以提高VRAM利用率，但由于局部注意力的优化尚不完整，目前牺牲了推理速度。我们正在积极解决这个问题，如果您有更好的解决方案，欢迎您提交PR。如果您希望实现更快的推理速度，可以使用 [未优化的代码](DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test.py)，但这将增加VRAM使用（O(n²) 空间复杂度）。\n\n## 方法论\n### DAWN的整体结构：\n<p align=\"center\">\n<img src=\"structure_img\\pipeline.png\" width=600 alt=\"framework\"/>\n</p>\n\n## 环境\n我们强烈建议在Linux平台上尝试DAWN。在Windows上运行可能会产生一些需要手动删除的垃圾文件，并且需要额外的努力来部署3DDFA库（我们的 `extract_init_states` 文件夹） [评论](https://github.com/cleardusk/3DDFA_V2/issues/12#issuecomment-697479173)。\n\n1. 设置conda环境\n```\nconda create -n DAWN python=3.8\nconda activate DAWN\npip install -r requirements.txt\n```\n\n2. 按照 [readme](extract_init_states/readme.md) 和 [3DDFA_V2](https://github.com/cleardusk/3DDFA_V2) 设置3DDFA环境。\n\n## 推理\n\n由于我们的模型 **仅在HDTF数据集上训练**，并且参数较少，为了确保最佳的驱动效果，请尽量提供以下示例：\n- 尽量使用标准人像照片，避免佩戴帽子或大型头饰\n- 确保背景与主体之间有清晰的边界\n- 确保面部在图像中占据主要位置。\n\n推理准备：\n1. 从 [hugging face](https://huggingface.co/Hanbo-Cheng/DAWN) 下载预训练检查点。创建 `./pretrain_models` 目录并将检查点文件放入其中。请从 [facebook/hubert-large-ls960-ft](https://huggingface.co/facebook/hubert-large-ls960-ft/tree/main) 下载Hubert模型。\n   \n2. 运行推理脚本： \n   ```\n   python unified_video_generator.py  \\\n      --audio_path your/audio/path  \\\n      --image_path your/image/path  \\\n      --output_path output/path \\\n      --cache_path cache/path \n   ```\n\n***在其他数据集上的推理：***\n通过在每次推理时指定 `VideoGenerator` 类的 `audio_path`、`image_path` 和 `output_path`，并修改 `unified_video_generator.py` 中第310-312行和393-394行的 `directory_name` 和 `output_video_path` 的内容，您可以控制保存图像和视频的命名逻辑，从而在任何数据集上进行测试。\n\n\n"
  },
  {
    "path": "config/DAWN_128.yaml",
    "content": "input_size: 128\nmax_n_frames: 200\nrandom_seed: 1234\nmean: [0.0, 0.0, 0.0]\nwin_width: 40\nsampling_step: 20\nddim_sampling_eta: 1.0\ncond_scale: 1.0\n\nmodel_config:\n  is_train: true\n  pose_dim: 6\n  config_pth: './config/hdtf128.yaml'\n  ae_pretrained_pth: './pretrain_models/LFG_128_1000ep.pth'\n  diffusion_pretrained_pth: './pretrain_models/DAWN_128.pth'"
  },
  {
    "path": "config/DAWN_256.yaml",
    "content": "input_size: 256\nmax_n_frames: 200\nrandom_seed: 1234\nmean: [0.0, 0.0, 0.0]\nwin_width: 40\nsampling_step: 20\nddim_sampling_eta: 1.0\ncond_scale: 1.0\n\nmodel_config:\n  is_train: true\n  pose_dim: 6\n  config_pth: './config/hdtf256.yaml'\n  ae_pretrained_pth: './pretrain_models/LFG_256_400ep.pth'\n  diffusion_pretrained_pth: './pretrain_models/DAWN_256.pth'"
  },
  {
    "path": "config/hdtf128.yaml",
    "content": "#Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.\n#No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify,\n#publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights.\n#Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability,\n#title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses.\n#In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof.\n\n\n# Dataset parameters\n# Each dataset should contain 2 folders train and test\n# Each video can be represented as:\n#   - an image of concatenated frames\n#   - '.mp4' or '.gif'\n#   - folder with all frames from a specific video\ndataset_params:\n  # Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames.\n  # Folder with frames is preferred format for training, since it is the fastest.\n  root_dir: /work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images\n  # Shape to resize all frames to, specify null if resizing is not needed\n  frame_shape: 128\n  # In case of Vox or Taichi single video can be splitted in many chunks, or the maybe several videos for single person.\n  # In this case epoch can be a pass over different identities (if id_sampling=True) or over different chunks (if id_sampling=False)\n  # If the name the video '12335#adsbf.mp4' the id is assumed to be 12335\n  id_sampling: False\n  # List with pairs for animation, null for random pairs\n  pairs_list: null\n  # Augmentation parameters see augmentation.py for all possible augmentations\n  augmentation_params:\n    flip_param:\n      horizontal_flip: True\n      time_flip: True\n    jitter_param:\n      brightness: 0.1\n      contrast: 0.1\n      saturation: 0.1\n      hue: 0.1\n\n# Defines architecture of the model\nmodel_params:\n  # Number of regions\n  num_regions: 10\n  # Number of channels, for RGB image it is always 3\n  num_channels: 3\n  # Enable estimation of affine parameters for each region,\n  # set to False if only region centers (keypoints) need to be estimated\n  estimate_affine: True\n  # Svd can perform random axis swap between source and driving if singular values are close to each other\n  # Set to True to avoid axis swap between source and driving\n  revert_axis_swap: True\n\n  # Parameters of background prediction network based on simple Unet-like encoder.\n  bg_predictor_params:\n    # Number of features multiplier\n    block_expansion: 32\n    # Maximum allowed number of features\n    max_features: 1024\n    # Number of block in the Encoder.\n    num_blocks: 5\n    # Type of background movement model, select one from ['zero', 'shift', 'affine', 'perspective']\n    bg_type: 'affine'\n\n  # Parameters of the region prediction network based on Unet\n  region_predictor_params:\n    # Softmax temperature for heatmaps\n    temperature: 0.1\n    # Number of features multiplier\n    block_expansion: 32\n    # Maximum allowed number of features\n    max_features: 1024\n    # Regions is predicted on smaller images for better performance,\n    # scale_factor=0.25 means that 256x256 image will be resized to 64x64\n    scale_factor: 0.25\n    # Number of block in Unet. Can be increased or decreased depending or resolution.\n    num_blocks: 5\n    # Either to use pca_based estimation of affine parameters of regression based\n    pca_based: True\n    # Either to use fast_svd (https://github.com/KinglittleQ/torch-batch-svd) or standard pytorch svd\n    # Fast svd may produce not meaningful regions if used along with revert_axis_swap\n    fast_svd: False\n\n  # Parameters of Generator, based on Jonson architecture\n  generator_params:\n    # Number of features multiplier\n    block_expansion: 64\n    # Maximum allowed number of features\n    max_features: 512\n    # Number of down-sampling blocks in Jonson architecture.\n    # Can be increased or decreased depending or resolution.\n    num_down_blocks: 2\n    # Number of ResBlocks  in Jonson architecture.\n    num_bottleneck_blocks: 6\n    # To use skip connections or no.\n    skips: True\n    # Parameters of pixelwise flow predictor based on Unet\n    pixelwise_flow_predictor_params:\n      # Number of features multiplier\n      block_expansion: 64\n      # Maximum allowed number of features\n      max_features: 1024\n      # Number of block in Unet. Can be increased or decreased depending on resolution.\n      num_blocks: 5\n      # Flow predictor operates on the smaller images for better performance,\n      # scale_factor=0.25 means that 256x256 image will be resized to 64x64\n      scale_factor: 0.25\n      # Set to True in order to use deformed source images using sparse flow\n      use_deformed_source: True\n      # Set to False in order to render region heatmaps with fixed covariance\n      # True for covariance estimate using region_predictor\n      use_covar_heatmap: True\n      # Set to False to disable occlusion mask estimation\n      estimate_occlusion_map: True\n\n  # Parameter for animation-via-disentanglement (avd) network\n  avd_network_params:\n    # Bottleneck for identity branch\n    id_bottle_size: 64\n    # Bottleneck for pose branch\n    pose_bottle_size: 64\n\n# Parameters of training (reconstruction)\ntrain_params:\n  max_epochs: 100\n  # For better i/o performance when number of videos is small number of epochs can be multiplied by this number.\n  # Thus effectively with num_repeats=100 each epoch is 100 times larger.\n  num_repeats: 100\n  # Drop learning rate 10 times after this epochs\n  epoch_milestones: [60, 90]\n  # Initial learning rate\n  lr: 2.0e-4\n  # Batch size. (14 is batch size for one V100 gpu).\n  batch_size: 100\n  # Either to use sync_bn or not, enabling sync_bn will significantly slow the training time\n  use_sync_bn: False\n  # Dataset preprocessing cpu workers\n  dataloader_workers: 16\n  print_freq: 10\n  save_img_freq: 100\n  # update checkpoint in this frequent\n  update_ckpt_freq: 5000\n  # Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256,\n  # than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32.\n  scales: [1, 0.5, 0.25, 0.125]\n  # Parameters of transform for equivariance loss\n  transform_params:\n    sigma_affine: 0.05\n    sigma_tps: 0.005\n    points_tps: 5\n  loss_weights:\n    # Weights for perceptual pyramide loss. Note that here you can only specify weight across the layer, and\n    # weights across the resolution will be the same.\n    perceptual: [10, 10, 10, 10, 10]\n    rec_vgg: [0, 0, 0, 0, 0]\n    # Weights for equivariance loss.\n    equivariance_shift: 10\n    equivariance_affine: 10\n\n# Parameters of visualization\nvisualizer_params:\n  # Size of keypoints\n  kp_size: 2\n  # Draw border between images or not\n  draw_border: True\n  # Colormap for regions and keypoints visualization\n  colormap: 'gist_rainbow'\n  # Background color for region visualization\n  region_bg_color: [1, 1, 1]\n"
  },
  {
    "path": "config/hdtf128_1000ep.yaml",
    "content": "#Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.\n#No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify,\n#publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights.\n#Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability,\n#title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses.\n#In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof.\n\n\n# Dataset parameters\n# Each dataset should contain 2 folders train and test\n# Each video can be represented as:\n#   - an image of concatenated frames\n#   - '.mp4' or '.gif'\n#   - folder with all frames from a specific video\ndataset_params:\n  # Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames.\n  # Folder with frames is preferred format for training, since it is the fastest.\n  root_dir: /yrfs2/cv2/pcxia/audiovisual/hdtf/images\n  # Shape to resize all frames to, specify null if resizing is not needed\n  frame_shape: 128\n  # In case of Vox or Taichi single video can be splitted in many chunks, or the maybe several videos for single person.\n  # In this case epoch can be a pass over different identities (if id_sampling=True) or over different chunks (if id_sampling=False)\n  # If the name the video '12335#adsbf.mp4' the id is assumed to be 12335\n  id_sampling: False\n  # List with pairs for animation, null for random pairs\n  pairs_list: null\n  # Augmentation parameters see augmentation.py for all possible augmentations\n  augmentation_params:\n    flip_param:\n      horizontal_flip: True\n      time_flip: True\n    jitter_param:\n      brightness: 0.1\n      contrast: 0.1\n      saturation: 0.1\n      hue: 0.1\n\n# Defines architecture of the model\nmodel_params:\n  # Number of regions\n  num_regions: 10\n  # Number of channels, for RGB image it is always 3\n  num_channels: 3\n  # Enable estimation of affine parameters for each region,\n  # set to False if only region centers (keypoints) need to be estimated\n  estimate_affine: True\n  # Svd can perform random axis swap between source and driving if singular values are close to each other\n  # Set to True to avoid axis swap between source and driving\n  revert_axis_swap: True\n\n  # Parameters of background prediction network based on simple Unet-like encoder.\n  bg_predictor_params:\n    # Number of features multiplier\n    block_expansion: 32\n    # Maximum allowed number of features\n    max_features: 1024\n    # Number of block in the Encoder.\n    num_blocks: 5\n    # Type of background movement model, select one from ['zero', 'shift', 'affine', 'perspective']\n    bg_type: 'affine'\n\n  # Parameters of the region prediction network based on Unet\n  region_predictor_params:\n    # Softmax temperature for heatmaps\n    temperature: 0.1\n    # Number of features multiplier\n    block_expansion: 32\n    # Maximum allowed number of features\n    max_features: 1024\n    # Regions is predicted on smaller images for better performance,\n    # scale_factor=0.25 means that 256x256 image will be resized to 64x64\n    scale_factor: 0.25\n    # Number of block in Unet. Can be increased or decreased depending or resolution.\n    num_blocks: 5\n    # Either to use pca_based estimation of affine parameters of regression based\n    pca_based: True\n    # Either to use fast_svd (https://github.com/KinglittleQ/torch-batch-svd) or standard pytorch svd\n    # Fast svd may produce not meaningful regions if used along with revert_axis_swap\n    fast_svd: False\n\n  # Parameters of Generator, based on Jonson architecture\n  generator_params:\n    # Number of features multiplier\n    block_expansion: 64\n    # Maximum allowed number of features\n    max_features: 512\n    # Number of down-sampling blocks in Jonson architecture.\n    # Can be increased or decreased depending or resolution.\n    num_down_blocks: 2\n    # Number of ResBlocks  in Jonson architecture.\n    num_bottleneck_blocks: 6\n    # To use skip connections or no.\n    skips: True\n    # Parameters of pixelwise flow predictor based on Unet\n    pixelwise_flow_predictor_params:\n      # Number of features multiplier\n      block_expansion: 64\n      # Maximum allowed number of features\n      max_features: 1024\n      # Number of block in Unet. Can be increased or decreased depending on resolution.\n      num_blocks: 5\n      # Flow predictor operates on the smaller images for better performance,\n      # scale_factor=0.25 means that 256x256 image will be resized to 64x64\n      scale_factor: 0.25\n      # Set to True in order to use deformed source images using sparse flow\n      use_deformed_source: True\n      # Set to False in order to render region heatmaps with fixed covariance\n      # True for covariance estimate using region_predictor\n      use_covar_heatmap: True\n      # Set to False to disable occlusion mask estimation\n      estimate_occlusion_map: True\n\n  # Parameter for animation-via-disentanglement (avd) network\n  avd_network_params:\n    # Bottleneck for identity branch\n    id_bottle_size: 64\n    # Bottleneck for pose branch\n    pose_bottle_size: 64\n\n# Parameters of training (reconstruction)\ntrain_params:\n  max_epochs: 1000\n  # For better i/o performance when number of videos is small number of epochs can be multiplied by this number.\n  # Thus effectively with num_repeats=100 each epoch is 100 times larger.\n  num_repeats: 100\n  # Drop learning rate 10 times after this epochs\n  epoch_milestones: [60, 90]\n  # Initial learning rate\n  lr: 4.0e-4\n  # Batch size. (14 is batch size for one V100 gpu).\n  batch_size: 82\n  # Either to use sync_bn or not, enabling sync_bn will significantly slow the training time\n  use_sync_bn: False\n  # Dataset preprocessing cpu workers\n  dataloader_workers: 8\n  print_freq: 10\n  save_img_freq: 100\n  # update checkpoint in this frequent\n  update_ckpt_freq: 5000\n  # Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256,\n  # than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32.\n  scales: [1, 0.5, 0.25, 0.125]\n  # Parameters of transform for equivariance loss\n  transform_params:\n    sigma_affine: 0.05\n    sigma_tps: 0.005\n    points_tps: 5\n  loss_weights:\n    # Weights for perceptual pyramide loss. Note that here you can only specify weight across the layer, and\n    # weights across the resolution will be the same.\n    perceptual: [10, 10, 10, 10, 10]\n    rec_vgg: [1, 1, 1, 1, 1]\n    # Weights for equivariance loss.\n    equivariance_shift: 10\n    equivariance_affine: 10\n\n# Parameters of visualization\nvisualizer_params:\n  # Size of keypoints\n  kp_size: 2\n  # Draw border between images or not\n  draw_border: True\n  # Colormap for regions and keypoints visualization\n  colormap: 'gist_rainbow'\n  # Background color for region visualization\n  region_bg_color: [1, 1, 1]\n"
  },
  {
    "path": "config/hdtf128_1000ep_crema.yaml",
    "content": "#Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.\n#No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify,\n#publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights.\n#Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability,\n#title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses.\n#In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof.\n\n\n# Dataset parameters\n# Each dataset should contain 2 folders train and test\n# Each video can be represented as:\n#   - an image of concatenated frames\n#   - '.mp4' or '.gif'\n#   - folder with all frames from a specific video\ndataset_params:\n  # Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames.\n  # Folder with frames is preferred format for training, since it is the fastest.\n  root_dir: /work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images\n  # Shape to resize all frames to, specify null if resizing is not needed\n  frame_shape: 128\n  # In case of Vox or Taichi single video can be splitted in many chunks, or the maybe several videos for single person.\n  # In this case epoch can be a pass over different identities (if id_sampling=True) or over different chunks (if id_sampling=False)\n  # If the name the video '12335#adsbf.mp4' the id is assumed to be 12335\n  id_sampling: False\n  # List with pairs for animation, null for random pairs\n  pairs_list: null\n  # Augmentation parameters see augmentation.py for all possible augmentations\n  augmentation_params:\n    flip_param:\n      horizontal_flip: True\n      time_flip: True\n    jitter_param:\n      brightness: 0.1\n      contrast: 0.1\n      saturation: 0.1\n      hue: 0.1\n\n# Defines architecture of the model\nmodel_params:\n  # Number of regions\n  num_regions: 10\n  # Number of channels, for RGB image it is always 3\n  num_channels: 3\n  # Enable estimation of affine parameters for each region,\n  # set to False if only region centers (keypoints) need to be estimated\n  estimate_affine: True\n  # Svd can perform random axis swap between source and driving if singular values are close to each other\n  # Set to True to avoid axis swap between source and driving\n  revert_axis_swap: True\n\n  # Parameters of background prediction network based on simple Unet-like encoder.\n  bg_predictor_params:\n    # Number of features multiplier\n    block_expansion: 32\n    # Maximum allowed number of features\n    max_features: 1024\n    # Number of block in the Encoder.\n    num_blocks: 5\n    # Type of background movement model, select one from ['zero', 'shift', 'affine', 'perspective']\n    bg_type: 'affine'\n\n  # Parameters of the region prediction network based on Unet\n  region_predictor_params:\n    # Softmax temperature for heatmaps\n    temperature: 0.1\n    # Number of features multiplier\n    block_expansion: 32\n    # Maximum allowed number of features\n    max_features: 1024\n    # Regions is predicted on smaller images for better performance,\n    # scale_factor=0.25 means that 256x256 image will be resized to 64x64\n    scale_factor: 0.25\n    # Number of block in Unet. Can be increased or decreased depending or resolution.\n    num_blocks: 5\n    # Either to use pca_based estimation of affine parameters of regression based\n    pca_based: True\n    # Either to use fast_svd (https://github.com/KinglittleQ/torch-batch-svd) or standard pytorch svd\n    # Fast svd may produce not meaningful regions if used along with revert_axis_swap\n    fast_svd: False\n\n  # Parameters of Generator, based on Jonson architecture\n  generator_params:\n    # Number of features multiplier\n    block_expansion: 64\n    # Maximum allowed number of features\n    max_features: 512\n    # Number of down-sampling blocks in Jonson architecture.\n    # Can be increased or decreased depending or resolution.\n    num_down_blocks: 2\n    # Number of ResBlocks  in Jonson architecture.\n    num_bottleneck_blocks: 6\n    # To use skip connections or no.\n    skips: True\n    # Parameters of pixelwise flow predictor based on Unet\n    pixelwise_flow_predictor_params:\n      # Number of features multiplier\n      block_expansion: 64\n      # Maximum allowed number of features\n      max_features: 1024\n      # Number of block in Unet. Can be increased or decreased depending on resolution.\n      num_blocks: 5\n      # Flow predictor operates on the smaller images for better performance,\n      # scale_factor=0.25 means that 256x256 image will be resized to 64x64\n      scale_factor: 0.25\n      # Set to True in order to use deformed source images using sparse flow\n      use_deformed_source: True\n      # Set to False in order to render region heatmaps with fixed covariance\n      # True for covariance estimate using region_predictor\n      use_covar_heatmap: True\n      # Set to False to disable occlusion mask estimation\n      estimate_occlusion_map: True\n\n  # Parameter for animation-via-disentanglement (avd) network\n  avd_network_params:\n    # Bottleneck for identity branch\n    id_bottle_size: 64\n    # Bottleneck for pose branch\n    pose_bottle_size: 64\n\n# Parameters of training (reconstruction)\ntrain_params:\n  max_epochs: 600\n  # For better i/o performance when number of videos is small number of epochs can be multiplied by this number.\n  # Thus effectively with num_repeats=100 each epoch is 100 times larger.\n  num_repeats: 100\n  # Drop learning rate 10 times after this epochs\n  epoch_milestones: [60, 90]\n  # Initial learning rate\n  lr: 4.0e-4\n  # Batch size. (14 is batch size for one V100 gpu).\n  batch_size: 100\n  # Either to use sync_bn or not, enabling sync_bn will significantly slow the training time\n  use_sync_bn: False\n  # Dataset preprocessing cpu workers\n  dataloader_workers: 8\n  print_freq: 10\n  save_img_freq: 100\n  # update checkpoint in this frequent\n  update_ckpt_freq: 5000\n  # Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256,\n  # than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32.\n  scales: [1, 0.5, 0.25, 0.125]\n  # Parameters of transform for equivariance loss\n  transform_params:\n    sigma_affine: 0.05\n    sigma_tps: 0.005\n    points_tps: 5\n  loss_weights:\n    # Weights for perceptual pyramide loss. Note that here you can only specify weight across the layer, and\n    # weights across the resolution will be the same.\n    perceptual: [10, 10, 10, 10, 10]\n    rec_vgg: [1, 1, 1, 1, 1]\n    # Weights for equivariance loss.\n    equivariance_shift: 10\n    equivariance_affine: 10\n\n# Parameters of visualization\nvisualizer_params:\n  # Size of keypoints\n  kp_size: 2\n  # Draw border between images or not\n  draw_border: True\n  # Colormap for regions and keypoints visualization\n  colormap: 'gist_rainbow'\n  # Background color for region visualization\n  region_bg_color: [1, 1, 1]\n"
  },
  {
    "path": "config/hdtf256.yaml",
    "content": "#Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.\n#No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify,\n#publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights.\n#Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability,\n#title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses.\n#In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof.\n\n\n# Dataset parameters\n# Each dataset should contain 2 folders train and test\n# Each video can be represented as:\n#   - an image of concatenated frames\n#   - '.mp4' or '.gif'\n#   - folder with all frames from a specific video\ndataset_params:\n  # Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames.\n  # Folder with frames is preferred format for training, since it is the fastest.\n  root_dir: /yrfs2/cv2/pcxia/audiovisual/hdtf/images\n  # Shape to resize all frames to, specify null if resizing is not needed\n  frame_shape: 256\n  # In case of Vox or Taichi single video can be splitted in many chunks, or the maybe several videos for single person.\n  # In this case epoch can be a pass over different identities (if id_sampling=True) or over different chunks (if id_sampling=False)\n  # If the name the video '12335#adsbf.mp4' the id is assumed to be 12335\n  id_sampling: False\n  # List with pairs for animation, null for random pairs\n  pairs_list: null\n  # Augmentation parameters see augmentation.py for all possible augmentations\n  augmentation_params:\n    flip_param:\n      horizontal_flip: True\n      time_flip: True\n    jitter_param:\n      brightness: 0.1\n      contrast: 0.1\n      saturation: 0.1\n      hue: 0.1\n\n# Defines architecture of the model\nmodel_params:\n  # Number of regions\n  num_regions: 10\n  # Number of channels, for RGB image it is always 3\n  num_channels: 3\n  # Enable estimation of affine parameters for each region,\n  # set to False if only region centers (keypoints) need to be estimated\n  estimate_affine: True\n  # Svd can perform random axis swap between source and driving if singular values are close to each other\n  # Set to True to avoid axis swap between source and driving\n  revert_axis_swap: True\n\n  # Parameters of background prediction network based on simple Unet-like encoder.\n  bg_predictor_params:\n    # Number of features multiplier\n    block_expansion: 32\n    # Maximum allowed number of features\n    max_features: 1024\n    # Number of block in the Encoder.\n    num_blocks: 5\n    # Type of background movement model, select one from ['zero', 'shift', 'affine', 'perspective']\n    bg_type: 'affine'\n\n  # Parameters of the region prediction network based on Unet\n  region_predictor_params:\n    # Softmax temperature for heatmaps\n    temperature: 0.1\n    # Number of features multiplier\n    block_expansion: 32\n    # Maximum allowed number of features\n    max_features: 1024\n    # Regions is predicted on smaller images for better performance,\n    # scale_factor=0.25 means that 256x256 image will be resized to 64x64\n    scale_factor: 0.25\n    # Number of block in Unet. Can be increased or decreased depending or resolution.\n    num_blocks: 5\n    # Either to use pca_based estimation of affine parameters of regression based\n    pca_based: True\n    # Either to use fast_svd (https://github.com/KinglittleQ/torch-batch-svd) or standard pytorch svd\n    # Fast svd may produce not meaningful regions if used along with revert_axis_swap\n    fast_svd: False\n\n  # Parameters of Generator, based on Jonson architecture\n  generator_params:\n    # Number of features multiplier\n    block_expansion: 64\n    # Maximum allowed number of features\n    max_features: 512\n    # Number of down-sampling blocks in Jonson architecture.\n    # Can be increased or decreased depending or resolution.\n    num_down_blocks: 2\n    # Number of ResBlocks  in Jonson architecture.\n    num_bottleneck_blocks: 6\n    # To use skip connections or no.\n    skips: True\n    # Parameters of pixelwise flow predictor based on Unet\n    pixelwise_flow_predictor_params:\n      # Number of features multiplier\n      block_expansion: 64\n      # Maximum allowed number of features\n      max_features: 1024\n      # Number of block in Unet. Can be increased or decreased depending or resolution.\n      num_blocks: 5\n      # Flow predictor operates on the smaller images for better performance,\n      # scale_factor=0.25 means that 256x256 image will be resized to 64x64\n      scale_factor: 0.25\n      # Set to True in order to use deformed source images using sparse flow\n      use_deformed_source: True\n      # Set to False in order to render region heatmaps with fixed covariance\n      # True for covariance estimate using region_predictor\n      use_covar_heatmap: True\n      # Set to False to disable occlusion mask estimation\n      estimate_occlusion_map: True\n\n  # Parameter for animation-via-disentanglement (avd) network\n  avd_network_params:\n    # Bottleneck for identity branch\n    id_bottle_size: 64\n    # Bottleneck for pose branch\n    pose_bottle_size: 64\n\n# Parameters of training (reconstruction)\ntrain_params:\n  max_epochs: 100\n  # For better i/o performance when number of videos is small number of epochs can be multiplied by this number.\n  # Thus effectively with num_repeats=100 each epoch is 100 times larger.\n  num_repeats: 100\n  # Drop learning rate 10 times after this epochs\n  epoch_milestones: [60, 90]\n  # Initial learning rate\n  lr: 2.0e-4\n  # Batch size. (14 is batch size for one V100 gpu).\n  batch_size: 42\n  # Either to use sync_bn or not, enabling sync_bn will significantly slow the training time\n  use_sync_bn: False\n  # Dataset preprocessing cpu workers\n  dataloader_workers: 12\n  print_freq: 10\n  save_img_freq: 100\n  # update checkpoint in this frequent\n  update_ckpt_freq: 5000\n  # Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256,\n  # than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32.\n  scales: [1, 0.5, 0.25, 0.125]\n  # Parameters of transform for equivariance loss\n  transform_params:\n    sigma_affine: 0.05\n    sigma_tps: 0.005\n    points_tps: 5\n  loss_weights:\n    # Weights for perceptual pyramide loss. Note that here you can only specify weight across the layer, and\n    # weights across the resolution will be the same.\n    perceptual: [10, 10, 10, 10, 10]\n    # Weights for equivariance loss.\n    equivariance_shift: 10\n    equivariance_affine: 10\n\n# Parameters of visualization\nvisualizer_params:\n  # Size of keypoints\n  kp_size: 2\n  # Draw border between images or not\n  draw_border: True\n  # Colormap for regions and keypoints visualization\n  colormap: 'gist_rainbow'\n  # Background color for region visualization\n  region_bg_color: [1, 1, 1]\n"
  },
  {
    "path": "config/hdtf256_400ep.yaml",
    "content": "#Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.\n#No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify,\n#publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights.\n#Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability,\n#title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses.\n#In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof.\n\n\n# Dataset parameters\n# Each dataset should contain 2 folders train and test\n# Each video can be represented as:\n#   - an image of concatenated frames\n#   - '.mp4' or '.gif'\n#   - folder with all frames from a specific video\ndataset_params:\n  # Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames.\n  # Folder with frames is preferred format for training, since it is the fastest.\n  root_dir: /yrfs2/cv2/pcxia/audiovisual/hdtf/images\n  # Shape to resize all frames to, specify null if resizing is not needed\n  frame_shape: 256\n  # In case of Vox or Taichi single video can be splitted in many chunks, or the maybe several videos for single person.\n  # In this case epoch can be a pass over different identities (if id_sampling=True) or over different chunks (if id_sampling=False)\n  # If the name the video '12335#adsbf.mp4' the id is assumed to be 12335\n  id_sampling: False\n  # List with pairs for animation, null for random pairs\n  pairs_list: null\n  # Augmentation parameters see augmentation.py for all possible augmentations\n  augmentation_params:\n    flip_param:\n      horizontal_flip: True\n      time_flip: True\n    jitter_param:\n      brightness: 0.1\n      contrast: 0.1\n      saturation: 0.1\n      hue: 0.1\n\n# Defines architecture of the model\nmodel_params:\n  # Number of regions\n  num_regions: 10\n  # Number of channels, for RGB image it is always 3\n  num_channels: 3\n  # Enable estimation of affine parameters for each region,\n  # set to False if only region centers (keypoints) need to be estimated\n  estimate_affine: True\n  # Svd can perform random axis swap between source and driving if singular values are close to each other\n  # Set to True to avoid axis swap between source and driving\n  revert_axis_swap: True\n\n  # Parameters of background prediction network based on simple Unet-like encoder.\n  bg_predictor_params:\n    # Number of features multiplier\n    block_expansion: 32\n    # Maximum allowed number of features\n    max_features: 1024\n    # Number of block in the Encoder.\n    num_blocks: 5\n    # Type of background movement model, select one from ['zero', 'shift', 'affine', 'perspective']\n    bg_type: 'affine'\n\n  # Parameters of the region prediction network based on Unet\n  region_predictor_params:\n    # Softmax temperature for heatmaps\n    temperature: 0.1\n    # Number of features multiplier\n    block_expansion: 32\n    # Maximum allowed number of features\n    max_features: 1024\n    # Regions is predicted on smaller images for better performance,\n    # scale_factor=0.25 means that 256x256 image will be resized to 64x64\n    scale_factor: 0.25\n    # Number of block in Unet. Can be increased or decreased depending or resolution.\n    num_blocks: 5\n    # Either to use pca_based estimation of affine parameters of regression based\n    pca_based: True\n    # Either to use fast_svd (https://github.com/KinglittleQ/torch-batch-svd) or standard pytorch svd\n    # Fast svd may produce not meaningful regions if used along with revert_axis_swap\n    fast_svd: False\n\n  # Parameters of Generator, based on Jonson architecture\n  generator_params:\n    # Number of features multiplier\n    block_expansion: 64\n    # Maximum allowed number of features\n    max_features: 512\n    # Number of down-sampling blocks in Jonson architecture.\n    # Can be increased or decreased depending or resolution.\n    num_down_blocks: 2\n    # Number of ResBlocks  in Jonson architecture.\n    num_bottleneck_blocks: 6\n    # To use skip connections or no.\n    skips: True\n    # Parameters of pixelwise flow predictor based on Unet\n    pixelwise_flow_predictor_params:\n      # Number of features multiplier\n      block_expansion: 64\n      # Maximum allowed number of features\n      max_features: 1024\n      # Number of block in Unet. Can be increased or decreased depending or resolution.\n      num_blocks: 5\n      # Flow predictor operates on the smaller images for better performance,\n      # scale_factor=0.25 means that 256x256 image will be resized to 64x64\n      scale_factor: 0.25\n      # Set to True in order to use deformed source images using sparse flow\n      use_deformed_source: True\n      # Set to False in order to render region heatmaps with fixed covariance\n      # True for covariance estimate using region_predictor\n      use_covar_heatmap: True\n      # Set to False to disable occlusion mask estimation\n      estimate_occlusion_map: True\n\n  # Parameter for animation-via-disentanglement (avd) network\n  avd_network_params:\n    # Bottleneck for identity branch\n    id_bottle_size: 64\n    # Bottleneck for pose branch\n    pose_bottle_size: 64\n\n# Parameters of training (reconstruction)\ntrain_params:\n  max_epochs: 400\n  # For better i/o performance when number of videos is small number of epochs can be multiplied by this number.\n  # Thus effectively with num_repeats=100 each epoch is 100 times larger.\n  num_repeats: 100\n  # Drop learning rate 10 times after this epochs\n  epoch_milestones: [60, 90]\n  # Initial learning rate\n  lr: 6.0e-4\n  # Batch size. (14 is batch size for one V100 gpu).\n  batch_size: 20\n  # Either to use sync_bn or not, enabling sync_bn will significantly slow the training time\n  use_sync_bn: False\n  # Dataset preprocessing cpu workers\n  dataloader_workers: 12\n  print_freq: 10\n  save_img_freq: 100\n  # update checkpoint in this frequent\n  update_ckpt_freq: 5000\n  # Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256,\n  # than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32.\n  scales: [1, 0.5, 0.25, 0.125]\n  # Parameters of transform for equivariance loss\n  transform_params:\n    sigma_affine: 0.05\n    sigma_tps: 0.005\n    points_tps: 5\n  loss_weights:\n    # Weights for perceptual pyramide loss. Note that here you can only specify weight across the layer, and\n    # weights across the resolution will be the same.\n    perceptual: [10, 10, 10, 10, 10]\n    # Weights for equivariance loss.\n    equivariance_shift: 10\n    equivariance_affine: 10\n\n# Parameters of visualization\nvisualizer_params:\n  # Size of keypoints\n  kp_size: 2\n  # Draw border between images or not\n  draw_border: True\n  # Colormap for regions and keypoints visualization\n  colormap: 'gist_rainbow'\n  # Background color for region visualization\n  region_bg_color: [1, 1, 1]\n"
  },
  {
    "path": "extract_init_states/FaceBoxes/FaceBoxes.py",
    "content": "# coding: utf-8\n\nimport os.path as osp\n\nimport torch\nimport numpy as np\nimport cv2\n\nfrom .utils.prior_box import PriorBox\nfrom .utils.nms_wrapper import nms\nfrom .utils.box_utils import decode\nfrom .utils.timer import Timer\nfrom .utils.functions import check_keys, remove_prefix, load_model\nfrom .utils.config import cfg\nfrom .models.faceboxes import FaceBoxesNet\nimport torch.backends.cudnn as cudnn\n\n# some global configs\nconfidence_threshold = 0.05\ntop_k = 5000\nkeep_top_k = 750\nnms_threshold = 0.3\nvis_thres = 0.5\nresize = 1\n\nscale_flag = True\nHEIGHT, WIDTH = 720, 1080\n\nmake_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn)\npretrained_path = make_abs_path('weights/FaceBoxesProd.pth')\n\n\ndef viz_bbox(img, dets, wfp='out.jpg'):\n    # show\n    for b in dets:\n        if b[4] < vis_thres:\n            continue\n        text = \"{:.4f}\".format(b[4])\n        b = list(map(int, b))\n        cv2.rectangle(img, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)\n        cx = b[0]\n        cy = b[1] + 12\n        cv2.putText(img, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255))\n    cv2.imwrite(wfp, img)\n    print(f'Viz bbox to {wfp}')\n\n\nclass FaceBoxes:\n    def __init__(self, timer_flag=False):\n        torch.set_grad_enabled(False)\n\n        net = FaceBoxesNet(phase='test', size=None, num_classes=2)  # initialize detector\n        self.net = load_model(net, pretrained_path=pretrained_path, load_to_cpu=True)\n        self.net.eval()\n        # print('Finished loading model!')\n        cudnn.benchmark = True\n        self.net = self.net.cuda()\n        self.timer_flag = timer_flag\n\n    @torch.no_grad()\n    def __call__(self, img_):\n        img_raw = img_.copy()\n\n        # scaling to speed up\n        scale = 1\n        if scale_flag:\n            h, w = img_raw.shape[:2]\n            if h > HEIGHT:\n                scale = HEIGHT / h\n            if w * scale > WIDTH:\n                scale *= WIDTH / (w * scale)\n            # print(scale)\n            if scale == 1:\n                img_raw_scale = img_raw\n            else:\n                h_s = int(scale * h)\n                w_s = int(scale * w)\n                # print(h_s, w_s)\n                img_raw_scale = cv2.resize(img_raw, dsize=(w_s, h_s))\n                # print(img_raw_scale.shape)\n\n            img = np.float32(img_raw_scale)\n        else:\n            img = np.float32(img_raw)\n\n        # forward\n        _t = {'forward_pass': Timer(), 'misc': Timer()}\n        im_height, im_width, _ = img.shape\n        scale_bbox = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]]).cuda()\n        img -= (104, 117, 123)\n        img = img.transpose(2, 0, 1)\n        img = torch.from_numpy(img).cuda().unsqueeze(0)\n\n        _t['forward_pass'].tic()\n        loc, conf = self.net(img)  # forward pass\n        _t['forward_pass'].toc()\n        _t['misc'].tic()\n        priorbox = PriorBox(image_size=(im_height, im_width))\n        priors = priorbox.forward()\n        prior_data = priors.data.cuda()\n        boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])\n        if scale_flag:\n            boxes = boxes * scale_bbox / scale / resize\n        else:\n            boxes = boxes * scale_bbox / resize\n\n        boxes = boxes.cpu().numpy()\n        scores = conf.squeeze(0).data.cpu().numpy()[:, 1]\n\n        # ignore low scores\n        inds = np.where(scores > confidence_threshold)[0]\n        boxes = boxes[inds]\n        scores = scores[inds]\n\n        # keep top-K before NMS\n        order = scores.argsort()[::-1][:top_k]\n        boxes = boxes[order]\n        scores = scores[order]\n\n        # do NMS\n        dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)\n        keep = nms(dets, nms_threshold)\n        dets = dets[keep, :]\n\n        # keep top-K faster NMS\n        dets = dets[:keep_top_k, :]\n        _t['misc'].toc()\n\n        if self.timer_flag:\n            print('Detection: {:d}/{:d} forward_pass_time: {:.4f}s misc: {:.4f}s'.format(1, 1, _t[\n                'forward_pass'].average_time, _t['misc'].average_time))\n\n        # filter using vis_thres\n        det_bboxes = []\n        for b in dets:\n            if b[4] > vis_thres:\n                xmin, ymin, xmax, ymax, score = b[0], b[1], b[2], b[3], b[4]\n                bbox = [xmin, ymin, xmax, ymax, score]\n                det_bboxes.append(bbox)\n\n        return det_bboxes\n\n\ndef main():\n    face_boxes = FaceBoxes(timer_flag=True)\n\n    fn = 'trump_hillary.jpg'\n    img_fp = f'../examples/inputs/{fn}'\n    img = cv2.imread(img_fp)\n    print(f'input shape: {img.shape}')\n    dets = face_boxes(img)  # xmin, ymin, w, h\n    # print(dets)\n\n    # repeating inference for `n` times\n    n = 10\n    for i in range(n):\n        dets = face_boxes(img)\n\n    wfn = fn.replace('.jpg', '_det.jpg')\n    wfp = osp.join('../examples/results', wfn)\n    viz_bbox(img, dets, wfp)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "extract_init_states/FaceBoxes/FaceBoxes_ONNX.py",
    "content": "# coding: utf-8\n\nimport os.path as osp\n\nimport torch\nimport numpy as np\nimport cv2\n\nfrom .utils.prior_box import PriorBox\nfrom .utils.nms_wrapper import nms\nfrom .utils.box_utils import decode\nfrom .utils.timer import Timer\nfrom .utils.config import cfg\nfrom .onnx import convert_to_onnx\n\nimport onnxruntime\n\n# some global configs\nconfidence_threshold = 0.05\ntop_k = 5000\nkeep_top_k = 750\nnms_threshold = 0.3\nvis_thres = 0.2\nresize = 1\n\nscale_flag = True\nHEIGHT, WIDTH = 720, 1080\n\nmake_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn)\nonnx_path = make_abs_path('weights/FaceBoxesProd.onnx')\n\n\ndef viz_bbox(img, dets, wfp='out.jpg'):\n    # show\n    for b in dets:\n        if b[4] < vis_thres:\n            continue\n        text = \"{:.4f}\".format(b[4])\n        b = list(map(int, b))\n        cv2.rectangle(img, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)\n        cx = b[0]\n        cy = b[1] + 12\n        cv2.putText(img, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255))\n    cv2.imwrite(wfp, img)\n    print(f'Viz bbox to {wfp}')\n\n\nclass FaceBoxes_ONNX(object):\n    def __init__(self, timer_flag=False):\n        if not osp.exists(onnx_path):\n            convert_to_onnx(onnx_path)\n        self.session = onnxruntime.InferenceSession(onnx_path, providers=['CUDAExecutionProvider'])\n\n        self.timer_flag = timer_flag\n\n    def __call__(self, img_):\n        img_raw = img_.copy()\n\n        # scaling to speed up\n        scale = 1\n        if scale_flag:\n            h, w = img_raw.shape[:2]\n            if h > HEIGHT:\n                scale = HEIGHT / h\n            if w * scale > WIDTH:\n                scale *= WIDTH / (w * scale)\n            # print(scale)\n            if scale == 1:\n                img_raw_scale = img_raw\n            else:\n                h_s = int(scale * h)\n                w_s = int(scale * w)\n                # print(h_s, w_s)\n                img_raw_scale = cv2.resize(img_raw, dsize=(w_s, h_s))\n                # print(img_raw_scale.shape)\n\n            img = np.float32(img_raw_scale)\n        else:\n            img = np.float32(img_raw)\n\n        # forward\n        _t = {'forward_pass': Timer(), 'misc': Timer()}\n        im_height, im_width, _ = img.shape\n        scale_bbox = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])\n\n        img -= (104, 117, 123)\n        img = img.transpose(2, 0, 1)\n        # img = torch.from_numpy(img).unsqueeze(0)\n        img = img[np.newaxis, ...]\n\n        _t['forward_pass'].tic()\n        # loc, conf = self.net(img)  # forward pass\n        out = self.session.run(None, {'input': img})\n        loc, conf = out[0], out[1]\n        # for compatibility, may need to optimize\n        loc = torch.from_numpy(loc)\n        _t['forward_pass'].toc()\n        _t['misc'].tic()\n\n        priorbox = PriorBox(image_size=(im_height, im_width))\n        priors = priorbox.forward()\n        prior_data = priors.data\n        boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])\n        if scale_flag:\n            boxes = boxes * scale_bbox / scale / resize\n        else:\n            boxes = boxes * scale_bbox / resize\n\n        boxes = boxes.cpu().numpy()\n        scores = conf[0][:, 1]\n        # scores = conf.squeeze(0).data.cpu().numpy()[:, 1]\n\n        # ignore low scores\n        inds = np.where(scores > confidence_threshold)[0]\n        boxes = boxes[inds]\n        scores = scores[inds]\n\n        # keep top-K before NMS\n        order = scores.argsort()[::-1][:top_k]\n        boxes = boxes[order]\n        scores = scores[order]\n\n        # do NMS\n        dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)\n        keep = nms(dets, nms_threshold)\n        dets = dets[keep, :]\n\n        # keep top-K faster NMS\n        dets = dets[:keep_top_k, :]\n        _t['misc'].toc()\n\n        if self.timer_flag:\n            print('Detection: {:d}/{:d} forward_pass_time: {:.4f}s misc: {:.4f}s'.format(1, 1, _t[\n                'forward_pass'].average_time, _t['misc'].average_time))\n\n        # filter using vis_thres\n        det_bboxes = []\n        for b in dets:\n            if b[4] > vis_thres:\n                xmin, ymin, xmax, ymax, score = b[0], b[1], b[2], b[3], b[4]\n                bbox = [xmin, ymin, xmax, ymax, score]\n                det_bboxes.append(bbox)\n\n        return det_bboxes\n\n\ndef main():\n    face_boxes = FaceBoxes_ONNX(timer_flag=True)\n\n    fn = 'trump_hillary.jpg'\n    img_fp = f'../examples/inputs/{fn}'\n    img = cv2.imread(img_fp)\n    print(f'input shape: {img.shape}')\n    dets = face_boxes(img)  # xmin, ymin, w, h\n    # print(dets)\n\n    # repeating inference for `n` times\n    n = 10\n    for i in range(n):\n        dets = face_boxes(img)\n\n    wfn = fn.replace('.jpg', '_det.jpg')\n    wfp = osp.join('../examples/results', wfn)\n    viz_bbox(img, dets, wfp)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "extract_init_states/FaceBoxes/__init__.py",
    "content": "from .FaceBoxes import FaceBoxes\n"
  },
  {
    "path": "extract_init_states/FaceBoxes/build_cpu_nms.sh",
    "content": "cd utils\npython3 build.py build_ext --inplace\ncd .."
  },
  {
    "path": "extract_init_states/FaceBoxes/models/__init__.py",
    "content": ""
  },
  {
    "path": "extract_init_states/FaceBoxes/models/faceboxes.py",
    "content": "# coding: utf-8\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass BasicConv2d(nn.Module):\n\n    def __init__(self, in_channels, out_channels, **kwargs):\n        super(BasicConv2d, self).__init__()\n        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)\n        self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)\n\n    def forward(self, x):\n        x = self.conv(x)\n        x = self.bn(x)\n        return F.relu(x, inplace=True)\n\n\nclass Inception(nn.Module):\n    def __init__(self):\n        super(Inception, self).__init__()\n        self.branch1x1 = BasicConv2d(128, 32, kernel_size=1, padding=0)\n        self.branch1x1_2 = BasicConv2d(128, 32, kernel_size=1, padding=0)\n        self.branch3x3_reduce = BasicConv2d(128, 24, kernel_size=1, padding=0)\n        self.branch3x3 = BasicConv2d(24, 32, kernel_size=3, padding=1)\n        self.branch3x3_reduce_2 = BasicConv2d(128, 24, kernel_size=1, padding=0)\n        self.branch3x3_2 = BasicConv2d(24, 32, kernel_size=3, padding=1)\n        self.branch3x3_3 = BasicConv2d(32, 32, kernel_size=3, padding=1)\n\n    def forward(self, x):\n        branch1x1 = self.branch1x1(x)\n\n        branch1x1_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)\n        branch1x1_2 = self.branch1x1_2(branch1x1_pool)\n\n        branch3x3_reduce = self.branch3x3_reduce(x)\n        branch3x3 = self.branch3x3(branch3x3_reduce)\n\n        branch3x3_reduce_2 = self.branch3x3_reduce_2(x)\n        branch3x3_2 = self.branch3x3_2(branch3x3_reduce_2)\n        branch3x3_3 = self.branch3x3_3(branch3x3_2)\n\n        outputs = [branch1x1, branch1x1_2, branch3x3, branch3x3_3]\n        return torch.cat(outputs, 1)\n\n\nclass CRelu(nn.Module):\n\n    def __init__(self, in_channels, out_channels, **kwargs):\n        super(CRelu, self).__init__()\n        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)\n        self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)\n\n    def forward(self, x):\n        x = self.conv(x)\n        x = self.bn(x)\n        x = torch.cat([x, -x], 1)\n        x = F.relu(x, inplace=True)\n        return x\n\n\nclass FaceBoxesNet(nn.Module):\n\n    def __init__(self, phase, size, num_classes):\n        super(FaceBoxesNet, self).__init__()\n        self.phase = phase\n        self.num_classes = num_classes\n        self.size = size\n\n        self.conv1 = CRelu(3, 24, kernel_size=7, stride=4, padding=3)\n        self.conv2 = CRelu(48, 64, kernel_size=5, stride=2, padding=2)\n\n        self.inception1 = Inception()\n        self.inception2 = Inception()\n        self.inception3 = Inception()\n\n        self.conv3_1 = BasicConv2d(128, 128, kernel_size=1, stride=1, padding=0)\n        self.conv3_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)\n\n        self.conv4_1 = BasicConv2d(256, 128, kernel_size=1, stride=1, padding=0)\n        self.conv4_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)\n\n        self.loc, self.conf = self.multibox(self.num_classes)\n\n        if self.phase == 'test':\n            self.softmax = nn.Softmax(dim=-1)\n\n        if self.phase == 'train':\n            for m in self.modules():\n                if isinstance(m, nn.Conv2d):\n                    if m.bias is not None:\n                        nn.init.xavier_normal_(m.weight.data)\n                        m.bias.data.fill_(0.02)\n                    else:\n                        m.weight.data.normal_(0, 0.01)\n                elif isinstance(m, nn.BatchNorm2d):\n                    m.weight.data.fill_(1)\n                    m.bias.data.zero_()\n\n    def multibox(self, num_classes):\n        loc_layers = []\n        conf_layers = []\n        loc_layers += [nn.Conv2d(128, 21 * 4, kernel_size=3, padding=1)]\n        conf_layers += [nn.Conv2d(128, 21 * num_classes, kernel_size=3, padding=1)]\n        loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)]\n        conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)]\n        loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)]\n        conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)]\n        return nn.Sequential(*loc_layers), nn.Sequential(*conf_layers)\n\n    def forward(self, x):\n\n        detection_sources = list()\n        loc = list()\n        conf = list()\n\n        x = self.conv1(x)\n        x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)\n        x = self.conv2(x)\n        x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)\n        x = self.inception1(x)\n        x = self.inception2(x)\n        x = self.inception3(x)\n        detection_sources.append(x)\n\n        x = self.conv3_1(x)\n        x = self.conv3_2(x)\n        detection_sources.append(x)\n\n        x = self.conv4_1(x)\n        x = self.conv4_2(x)\n        detection_sources.append(x)\n\n        for (x, l, c) in zip(detection_sources, self.loc, self.conf):\n            loc.append(l(x).permute(0, 2, 3, 1).contiguous())\n            conf.append(c(x).permute(0, 2, 3, 1).contiguous())\n\n        loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)\n        conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)\n\n        if self.phase == \"test\":\n            output = (loc.view(loc.size(0), -1, 4),\n                      self.softmax(conf.view(conf.size(0), -1, self.num_classes)))\n        else:\n            output = (loc.view(loc.size(0), -1, 4),\n                      conf.view(conf.size(0), -1, self.num_classes))\n\n        return output\n"
  },
  {
    "path": "extract_init_states/FaceBoxes/onnx.py",
    "content": "# coding: utf-8\n\n__author__ = 'cleardusk'\n\nimport torch\n\nfrom .models.faceboxes import FaceBoxesNet\nfrom .utils.functions import load_model\n\n\ndef convert_to_onnx(onnx_path):\n    pretrained_path = onnx_path.replace('.onnx', '.pth')\n    # 1. load model\n    torch.set_grad_enabled(False)\n    net = FaceBoxesNet(phase='test', size=None, num_classes=2)  # initialize detector\n    net = load_model(net, pretrained_path=pretrained_path, load_to_cpu=True)\n    net.eval()\n\n    # 2. convert\n    batch_size = 1\n    dummy_input = torch.randn(batch_size, 3, 720, 1080)\n    # export with dynamic axes for various input sizes\n    torch.onnx.export(\n        net,\n        (dummy_input,),\n        onnx_path,\n        input_names=['input'],\n        output_names=['output'],\n        dynamic_axes={\n            'input': [0, 2, 3],\n            'output': [0]\n        },\n        do_constant_folding=True\n    )\n    print(f'Convert {pretrained_path} to {onnx_path} done.')\n"
  },
  {
    "path": "extract_init_states/FaceBoxes/readme.md",
    "content": "## How to fun FaceBoxes\n\n### Build the cpu version of NMS\n```shell script\ncd utils\npython3 build.py build_ext --inplace\n```\n\nor just run\n\n```shell script\nsh ./build_cpu_nms.sh\n```\n\n### Run the demo of face detection\n```shell script\npython3 FaceBoxes.py\n```"
  },
  {
    "path": "extract_init_states/FaceBoxes/utils/.gitignore",
    "content": "utils/build\nutils/nms/*.so\nutils/*.c\nbuild/\n"
  },
  {
    "path": "extract_init_states/FaceBoxes/utils/__init__.py",
    "content": ""
  },
  {
    "path": "extract_init_states/FaceBoxes/utils/box_utils.py",
    "content": "# coding: utf-8\n\nimport torch\nimport numpy as np\n\n\ndef point_form(boxes):\n    \"\"\" Convert prior_boxes to (xmin, ymin, xmax, ymax)\n    representation for comparison to point form ground truth data.\n    Args:\n        boxes: (tensor) center-size default boxes from priorbox layers.\n    Return:\n        boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.\n    \"\"\"\n    return torch.cat((boxes[:, :2] - boxes[:, 2:] / 2,  # xmin, ymin\n                      boxes[:, :2] + boxes[:, 2:] / 2), 1)  # xmax, ymax\n\n\ndef center_size(boxes):\n    \"\"\" Convert prior_boxes to (cx, cy, w, h)\n    representation for comparison to center-size form ground truth data.\n    Args:\n        boxes: (tensor) point_form boxes\n    Return:\n        boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.\n    \"\"\"\n    return torch.cat((boxes[:, 2:] + boxes[:, :2]) / 2,  # cx, cy\n                     boxes[:, 2:] - boxes[:, :2], 1)  # w, h\n\n\ndef intersect(box_a, box_b):\n    \"\"\" We resize both tensors to [A,B,2] without new malloc:\n    [A,2] -> [A,1,2] -> [A,B,2]\n    [B,2] -> [1,B,2] -> [A,B,2]\n    Then we compute the area of intersect between box_a and box_b.\n    Args:\n      box_a: (tensor) bounding boxes, Shape: [A,4].\n      box_b: (tensor) bounding boxes, Shape: [B,4].\n    Return:\n      (tensor) intersection area, Shape: [A,B].\n    \"\"\"\n    A = box_a.size(0)\n    B = box_b.size(0)\n    max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),\n                       box_b[:, 2:].unsqueeze(0).expand(A, B, 2))\n    min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),\n                       box_b[:, :2].unsqueeze(0).expand(A, B, 2))\n    inter = torch.clamp((max_xy - min_xy), min=0)\n    return inter[:, :, 0] * inter[:, :, 1]\n\n\ndef jaccard(box_a, box_b):\n    \"\"\"Compute the jaccard overlap of two sets of boxes.  The jaccard overlap\n    is simply the intersection over union of two boxes.  Here we operate on\n    ground truth boxes and default boxes.\n    E.g.:\n        A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)\n    Args:\n        box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]\n        box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]\n    Return:\n        jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]\n    \"\"\"\n    inter = intersect(box_a, box_b)\n    area_a = ((box_a[:, 2] - box_a[:, 0]) *\n              (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter)  # [A,B]\n    area_b = ((box_b[:, 2] - box_b[:, 0]) *\n              (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter)  # [A,B]\n    union = area_a + area_b - inter\n    return inter / union  # [A,B]\n\n\ndef matrix_iou(a, b):\n    \"\"\"\n    return iou of a and b, numpy version for data augenmentation\n    \"\"\"\n    lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])\n    rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])\n\n    area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)\n    area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)\n    area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)\n    return area_i / (area_a[:, np.newaxis] + area_b - area_i)\n\n\ndef matrix_iof(a, b):\n    \"\"\"\n    return iof of a and b, numpy version for data augenmentation\n    \"\"\"\n    lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])\n    rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])\n\n    area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)\n    area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)\n    return area_i / np.maximum(area_a[:, np.newaxis], 1)\n\n\ndef match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx):\n    \"\"\"Match each prior box with the ground truth box of the highest jaccard\n    overlap, encode the bounding boxes, then return the matched indices\n    corresponding to both confidence and location preds.\n    Args:\n        threshold: (float) The overlap threshold used when mathing boxes.\n        truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors].\n        priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].\n        variances: (tensor) Variances corresponding to each prior coord,\n            Shape: [num_priors, 4].\n        labels: (tensor) All the class labels for the image, Shape: [num_obj].\n        loc_t: (tensor) Tensor to be filled w/ endcoded location targets.\n        conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.\n        idx: (int) current batch index\n    Return:\n        The matched indices corresponding to 1)location and 2)confidence preds.\n    \"\"\"\n    # jaccard index\n    overlaps = jaccard(\n        truths,\n        point_form(priors)\n    )\n    # (Bipartite Matching)\n    # [1,num_objects] best prior for each ground truth\n    best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)\n\n    # ignore hard gt\n    valid_gt_idx = best_prior_overlap[:, 0] >= 0.2\n    best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]\n    if best_prior_idx_filter.shape[0] <= 0:\n        loc_t[idx] = 0\n        conf_t[idx] = 0\n        return\n\n    # [1,num_priors] best ground truth for each prior\n    best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)\n    best_truth_idx.squeeze_(0)\n    best_truth_overlap.squeeze_(0)\n    best_prior_idx.squeeze_(1)\n    best_prior_idx_filter.squeeze_(1)\n    best_prior_overlap.squeeze_(1)\n    best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2)  # ensure best prior\n    # TODO refactor: index  best_prior_idx with long tensor\n    # ensure every gt matches with its prior of max overlap\n    for j in range(best_prior_idx.size(0)):\n        best_truth_idx[best_prior_idx[j]] = j\n    matches = truths[best_truth_idx]  # Shape: [num_priors,4]\n    conf = labels[best_truth_idx]  # Shape: [num_priors]\n    conf[best_truth_overlap < threshold] = 0  # label as background\n    loc = encode(matches, priors, variances)\n    loc_t[idx] = loc  # [num_priors,4] encoded offsets to learn\n    conf_t[idx] = conf  # [num_priors] top class label for each prior\n\n\ndef encode(matched, priors, variances):\n    \"\"\"Encode the variances from the priorbox layers into the ground truth boxes\n    we have matched (based on jaccard overlap) with the prior boxes.\n    Args:\n        matched: (tensor) Coords of ground truth for each prior in point-form\n            Shape: [num_priors, 4].\n        priors: (tensor) Prior boxes in center-offset form\n            Shape: [num_priors,4].\n        variances: (list[float]) Variances of priorboxes\n    Return:\n        encoded boxes (tensor), Shape: [num_priors, 4]\n    \"\"\"\n\n    # dist b/t match center and prior's center\n    g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]\n    # encode variance\n    g_cxcy /= (variances[0] * priors[:, 2:])\n    # match wh / prior wh\n    g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]\n    g_wh = torch.log(g_wh) / variances[1]\n    # return target for smooth_l1_loss\n    return torch.cat([g_cxcy, g_wh], 1)  # [num_priors,4]\n\n\n# Adapted from https://github.com/Hakuyume/chainer-ssd\ndef decode(loc, priors, variances):\n    \"\"\"Decode locations from predictions using priors to undo\n    the encoding we did for offset regression at train time.\n    Args:\n        loc (tensor): location predictions for loc layers,\n            Shape: [num_priors,4]\n        priors (tensor): Prior boxes in center-offset form.\n            Shape: [num_priors,4].\n        variances: (list[float]) Variances of priorboxes\n    Return:\n        decoded bounding box predictions\n    \"\"\"\n\n    boxes = torch.cat((\n        priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],\n        priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)\n    boxes[:, :2] -= boxes[:, 2:] / 2\n    boxes[:, 2:] += boxes[:, :2]\n    return boxes\n\n\ndef log_sum_exp(x):\n    \"\"\"Utility function for computing log_sum_exp while determining\n    This will be used to determine unaveraged confidence loss across\n    all examples in a batch.\n    Args:\n        x (Variable(tensor)): conf_preds from conf layers\n    \"\"\"\n    x_max = x.data.max()\n    return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max\n\n\n# Original author: Francisco Massa:\n# https://github.com/fmassa/object-detection.torch\n# Ported to PyTorch by Max deGroot (02/01/2017)\ndef nms(boxes, scores, overlap=0.5, top_k=200):\n    \"\"\"Apply non-maximum suppression at test time to avoid detecting too many\n    overlapping bounding boxes for a given object.\n    Args:\n        boxes: (tensor) The location preds for the img, Shape: [num_priors,4].\n        scores: (tensor) The class predscores for the img, Shape:[num_priors].\n        overlap: (float) The overlap thresh for suppressing unnecessary boxes.\n        top_k: (int) The Maximum number of box preds to consider.\n    Return:\n        The indices of the kept boxes with respect to num_priors.\n    \"\"\"\n\n    keep = torch.Tensor(scores.size(0)).fill_(0).long()\n    if boxes.numel() == 0:\n        return keep\n    x1 = boxes[:, 0]\n    y1 = boxes[:, 1]\n    x2 = boxes[:, 2]\n    y2 = boxes[:, 3]\n    area = torch.mul(x2 - x1, y2 - y1)\n    v, idx = scores.sort(0)  # sort in ascending order\n    # I = I[v >= 0.01]\n    idx = idx[-top_k:]  # indices of the top-k largest vals\n    xx1 = boxes.new()\n    yy1 = boxes.new()\n    xx2 = boxes.new()\n    yy2 = boxes.new()\n    w = boxes.new()\n    h = boxes.new()\n\n    # keep = torch.Tensor()\n    count = 0\n    while idx.numel() > 0:\n        i = idx[-1]  # index of current largest val\n        # keep.append(i)\n        keep[count] = i\n        count += 1\n        if idx.size(0) == 1:\n            break\n        idx = idx[:-1]  # remove kept element from view\n        # load bboxes of next highest vals\n        torch.index_select(x1, 0, idx, out=xx1)\n        torch.index_select(y1, 0, idx, out=yy1)\n        torch.index_select(x2, 0, idx, out=xx2)\n        torch.index_select(y2, 0, idx, out=yy2)\n        # store element-wise max with next highest score\n        xx1 = torch.clamp(xx1, min=x1[i])\n        yy1 = torch.clamp(yy1, min=y1[i])\n        xx2 = torch.clamp(xx2, max=x2[i])\n        yy2 = torch.clamp(yy2, max=y2[i])\n        w.resize_as_(xx2)\n        h.resize_as_(yy2)\n        w = xx2 - xx1\n        h = yy2 - yy1\n        # check sizes of xx1 and xx2.. after each iteration\n        w = torch.clamp(w, min=0.0)\n        h = torch.clamp(h, min=0.0)\n        inter = w * h\n        # IoU = i / (area(a) + area(b) - i)\n        rem_areas = torch.index_select(area, 0, idx)  # load remaining areas)\n        union = (rem_areas - inter) + area[i]\n        IoU = inter / union  # store result in iou\n        # keep only elements with an IoU <= overlap\n        idx = idx[IoU.le(overlap)]\n    return keep, count\n"
  },
  {
    "path": "extract_init_states/FaceBoxes/utils/build.py",
    "content": "# coding: utf-8\n\n# --------------------------------------------------------\n# Fast R-CNN\n# Copyright (c) 2015 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ross Girshick\n# --------------------------------------------------------\n\nimport os\nfrom os.path import join as pjoin\nimport numpy as np\nfrom distutils.core import setup\nfrom distutils.extension import Extension\nfrom Cython.Distutils import build_ext\n\n\ndef find_in_path(name, path):\n    \"Find a file in a search path\"\n    # adapted fom http://code.activestate.com/recipes/52224-find-a-file-given-a-search-path/\n    for dir in path.split(os.pathsep):\n        binpath = pjoin(dir, name)\n        if os.path.exists(binpath):\n            return os.path.abspath(binpath)\n    return None\n\n\n# Obtain the numpy include directory.  This logic works across numpy versions.\ntry:\n    numpy_include = np.get_include()\nexcept AttributeError:\n    numpy_include = np.get_numpy_include()\n\n\n# run the customize_compiler\nclass custom_build_ext(build_ext):\n    def build_extensions(self):\n        # customize_compiler_for_nvcc(self.compiler)\n        build_ext.build_extensions(self)\n\n\next_modules = [\n    Extension(\n        \"nms.cpu_nms\",\n        [\"nms/cpu_nms.pyx\"],\n        # extra_compile_args={'gcc': [\"-Wno-cpp\", \"-Wno-unused-function\"]},\n        # extra_compile_args=[\"-Wno-cpp\", \"-Wno-unused-function\"],  # !!! if you are on windows platform, you need to comment this line \n        include_dirs=[numpy_include]\n    )\n]\n\nsetup(\n    name='mot_utils',\n    ext_modules=ext_modules,\n    # inject our custom trigger\n    cmdclass={'build_ext': custom_build_ext},\n)\n"
  },
  {
    "path": "extract_init_states/FaceBoxes/utils/config.py",
    "content": "# coding: utf-8\n\ncfg = {\n    'name': 'FaceBoxes',\n    'min_sizes': [[32, 64, 128], [256], [512]],\n    'steps': [32, 64, 128],\n    'variance': [0.1, 0.2],\n    'clip': False\n}\n"
  },
  {
    "path": "extract_init_states/FaceBoxes/utils/functions.py",
    "content": "# coding: utf-8\n\nimport sys\nimport os.path as osp\nimport torch\n\ndef check_keys(model, pretrained_state_dict):\n    ckpt_keys = set(pretrained_state_dict.keys())\n    model_keys = set(model.state_dict().keys())\n    used_pretrained_keys = model_keys & ckpt_keys\n    unused_pretrained_keys = ckpt_keys - model_keys\n    missing_keys = model_keys - ckpt_keys\n    # print('Missing keys:{}'.format(len(missing_keys)))\n    # print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))\n    # print('Used keys:{}'.format(len(used_pretrained_keys)))\n    assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'\n    return True\n\n\ndef remove_prefix(state_dict, prefix):\n    ''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''\n    # print('remove prefix \\'{}\\''.format(prefix))\n    f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x\n    return {f(key): value for key, value in state_dict.items()}\n\n\ndef load_model(model, pretrained_path, load_to_cpu):\n    if not osp.isfile(pretrained_path):\n        print(f'The pre-trained FaceBoxes model {pretrained_path} does not exist')\n        sys.exit('-1')\n    # print('Loading pretrained model from {}'.format(pretrained_path))\n    if load_to_cpu:\n        pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)\n    else:\n        device = torch.cuda.current_device()\n        pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))\n    if \"state_dict\" in pretrained_dict.keys():\n        pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')\n    else:\n        pretrained_dict = remove_prefix(pretrained_dict, 'module.')\n    check_keys(model, pretrained_dict)\n    model.load_state_dict(pretrained_dict, strict=False)\n    return model\n"
  },
  {
    "path": "extract_init_states/FaceBoxes/utils/nms/.gitignore",
    "content": "*.c\n*.so\n"
  },
  {
    "path": "extract_init_states/FaceBoxes/utils/nms/__init__.py",
    "content": ""
  },
  {
    "path": "extract_init_states/FaceBoxes/utils/nms/cpu_nms.pyx",
    "content": "# --------------------------------------------------------\n# Fast R-CNN\n# Copyright (c) 2015 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ross Girshick\n# --------------------------------------------------------\n\nimport numpy as np\ncimport numpy as np\n\ncdef inline np.float32_t max(np.float32_t a, np.float32_t b):\n    return a if a >= b else b\n\ncdef inline np.float32_t min(np.float32_t a, np.float32_t b):\n    return a if a <= b else b\n\ndef cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh):\n    cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0]\n    cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]\n    cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2]\n    cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3]\n    cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4]\n\n    cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1)\n    cdef np.ndarray[np.int64_t, ndim=1] order = scores.argsort()[::-1]\n\n    cdef int ndets = dets.shape[0]\n    cdef np.ndarray[np.int64_t, ndim=1] suppressed = \\\n            np.zeros((ndets), dtype=np.int64)\n\n    # nominal indices\n    cdef int _i, _j\n    # sorted indices\n    cdef int i, j\n    # temp variables for box i's (the box currently under consideration)\n    cdef np.float32_t ix1, iy1, ix2, iy2, iarea\n    # variables for computing overlap with box j (lower scoring box)\n    cdef np.float32_t xx1, yy1, xx2, yy2\n    cdef np.float32_t w, h\n    cdef np.float32_t inter, ovr\n\n    keep = []\n    for _i in range(ndets):\n        i = order[_i]\n        if suppressed[i] == 1:\n            continue\n        keep.append(i)\n        ix1 = x1[i]\n        iy1 = y1[i]\n        ix2 = x2[i]\n        iy2 = y2[i]\n        iarea = areas[i]\n        for _j in range(_i + 1, ndets):\n            j = order[_j]\n            if suppressed[j] == 1:\n                continue\n            xx1 = max(ix1, x1[j])\n            yy1 = max(iy1, y1[j])\n            xx2 = min(ix2, x2[j])\n            yy2 = min(iy2, y2[j])\n            w = max(0.0, xx2 - xx1 + 1)\n            h = max(0.0, yy2 - yy1 + 1)\n            inter = w * h\n            ovr = inter / (iarea + areas[j] - inter)\n            if ovr >= thresh:\n                suppressed[j] = 1\n\n    return keep\n\ndef cpu_soft_nms(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0):\n    cdef unsigned int N = boxes.shape[0]\n    cdef float iw, ih, box_area\n    cdef float ua\n    cdef int pos = 0\n    cdef float maxscore = 0\n    cdef int maxpos = 0\n    cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov\n\n    for i in range(N):\n        maxscore = boxes[i, 4]\n        maxpos = i\n\n        tx1 = boxes[i,0]\n        ty1 = boxes[i,1]\n        tx2 = boxes[i,2]\n        ty2 = boxes[i,3]\n        ts = boxes[i,4]\n\n        pos = i + 1\n\t# get max box\n        while pos < N:\n            if maxscore < boxes[pos, 4]:\n                maxscore = boxes[pos, 4]\n                maxpos = pos\n            pos = pos + 1\n\n\t# add max box as a detection \n        boxes[i,0] = boxes[maxpos,0]\n        boxes[i,1] = boxes[maxpos,1]\n        boxes[i,2] = boxes[maxpos,2]\n        boxes[i,3] = boxes[maxpos,3]\n        boxes[i,4] = boxes[maxpos,4]\n\n\t# swap ith box with position of max box\n        boxes[maxpos,0] = tx1\n        boxes[maxpos,1] = ty1\n        boxes[maxpos,2] = tx2\n        boxes[maxpos,3] = ty2\n        boxes[maxpos,4] = ts\n\n        tx1 = boxes[i,0]\n        ty1 = boxes[i,1]\n        tx2 = boxes[i,2]\n        ty2 = boxes[i,3]\n        ts = boxes[i,4]\n\n        pos = i + 1\n\t# NMS iterations, note that N changes if detection boxes fall below threshold\n        while pos < N:\n            x1 = boxes[pos, 0]\n            y1 = boxes[pos, 1]\n            x2 = boxes[pos, 2]\n            y2 = boxes[pos, 3]\n            s = boxes[pos, 4]\n\n            area = (x2 - x1 + 1) * (y2 - y1 + 1)\n            iw = (min(tx2, x2) - max(tx1, x1) + 1)\n            if iw > 0:\n                ih = (min(ty2, y2) - max(ty1, y1) + 1)\n                if ih > 0:\n                    ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih)\n                    ov = iw * ih / ua #iou between max box and detection box\n\n                    if method == 1: # linear\n                        if ov > Nt: \n                            weight = 1 - ov\n                        else:\n                            weight = 1\n                    elif method == 2: # gaussian\n                        weight = np.exp(-(ov * ov)/sigma)\n                    else: # original NMS\n                        if ov > Nt: \n                            weight = 0\n                        else:\n                            weight = 1\n\n                    boxes[pos, 4] = weight*boxes[pos, 4]\n\t\t    \n\t\t    # if box score falls below threshold, discard the box by swapping with last box\n\t\t    # update N\n                    if boxes[pos, 4] < threshold:\n                        boxes[pos,0] = boxes[N-1, 0]\n                        boxes[pos,1] = boxes[N-1, 1]\n                        boxes[pos,2] = boxes[N-1, 2]\n                        boxes[pos,3] = boxes[N-1, 3]\n                        boxes[pos,4] = boxes[N-1, 4]\n                        N = N - 1\n                        pos = pos - 1\n\n            pos = pos + 1\n\n    keep = [i for i in range(N)]\n    return keep\n"
  },
  {
    "path": "extract_init_states/FaceBoxes/utils/nms/py_cpu_nms.py",
    "content": "# --------------------------------------------------------\n# Fast R-CNN\n# Copyright (c) 2015 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ross Girshick\n# --------------------------------------------------------\n\nimport numpy as np\n\ndef py_cpu_nms(dets, thresh):\n    \"\"\"Pure Python NMS baseline.\"\"\"\n    x1 = dets[:, 0]\n    y1 = dets[:, 1]\n    x2 = dets[:, 2]\n    y2 = dets[:, 3]\n    scores = dets[:, 4]\n\n    areas = (x2 - x1 + 1) * (y2 - y1 + 1)\n    order = scores.argsort()[::-1]\n\n    keep = []\n    while order.size > 0:\n        i = order[0]\n        keep.append(i)\n        xx1 = np.maximum(x1[i], x1[order[1:]])\n        yy1 = np.maximum(y1[i], y1[order[1:]])\n        xx2 = np.minimum(x2[i], x2[order[1:]])\n        yy2 = np.minimum(y2[i], y2[order[1:]])\n\n        w = np.maximum(0.0, xx2 - xx1 + 1)\n        h = np.maximum(0.0, yy2 - yy1 + 1)\n        inter = w * h\n        ovr = inter / (areas[i] + areas[order[1:]] - inter)\n\n        inds = np.where(ovr <= thresh)[0]\n        order = order[inds + 1]\n\n    return keep\n"
  },
  {
    "path": "extract_init_states/FaceBoxes/utils/nms_wrapper.py",
    "content": "# coding: utf-8\n\n# --------------------------------------------------------\n# Fast R-CNN\n# Copyright (c) 2015 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ross Girshick\n# --------------------------------------------------------\n\nfrom .nms.cpu_nms import cpu_nms, cpu_soft_nms\n\n\ndef nms(dets, thresh):\n    \"\"\"Dispatch to either CPU or GPU NMS implementations.\"\"\"\n\n    if dets.shape[0] == 0:\n        return []\n    return cpu_nms(dets, thresh)\n    # return gpu_nms(dets, thresh)\n"
  },
  {
    "path": "extract_init_states/FaceBoxes/utils/prior_box.py",
    "content": "# coding: utf-8\n\nfrom .config import cfg\n\nimport torch\nfrom itertools import product as product\nfrom math import ceil\n\n\nclass PriorBox(object):\n    def __init__(self, image_size=None):\n        super(PriorBox, self).__init__()\n        # self.aspect_ratios = cfg['aspect_ratios']\n        self.min_sizes = cfg['min_sizes']\n        self.steps = cfg['steps']\n        self.clip = cfg['clip']\n        self.image_size = image_size\n        self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps]\n\n    def forward(self):\n        anchors = []\n        for k, f in enumerate(self.feature_maps):\n            min_sizes = self.min_sizes[k]\n            for i, j in product(range(f[0]), range(f[1])):\n                for min_size in min_sizes:\n                    s_kx = min_size / self.image_size[1]\n                    s_ky = min_size / self.image_size[0]\n                    if min_size == 32:\n                        dense_cx = [x * self.steps[k] / self.image_size[1] for x in\n                                    [j + 0, j + 0.25, j + 0.5, j + 0.75]]\n                        dense_cy = [y * self.steps[k] / self.image_size[0] for y in\n                                    [i + 0, i + 0.25, i + 0.5, i + 0.75]]\n                        for cy, cx in product(dense_cy, dense_cx):\n                            anchors += [cx, cy, s_kx, s_ky]\n                    elif min_size == 64:\n                        dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0, j + 0.5]]\n                        dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0, i + 0.5]]\n                        for cy, cx in product(dense_cy, dense_cx):\n                            anchors += [cx, cy, s_kx, s_ky]\n                    else:\n                        cx = (j + 0.5) * self.steps[k] / self.image_size[1]\n                        cy = (i + 0.5) * self.steps[k] / self.image_size[0]\n                        anchors += [cx, cy, s_kx, s_ky]\n        # back to torch land\n        output = torch.Tensor(anchors).view(-1, 4)\n        if self.clip:\n            output.clamp_(max=1, min=0)\n        return output\n"
  },
  {
    "path": "extract_init_states/FaceBoxes/utils/timer.py",
    "content": "# coding: utf-8\n\n# --------------------------------------------------------\n# Fast R-CNN\n# Copyright (c) 2015 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ross Girshick\n# --------------------------------------------------------\n\nimport time\n\n\nclass Timer(object):\n    \"\"\"A simple timer.\"\"\"\n\n    def __init__(self):\n        self.total_time = 0.\n        self.calls = 0\n        self.start_time = 0.\n        self.diff = 0.\n        self.average_time = 0.\n\n    def tic(self):\n        # using time.time instead of time.clock because time time.clock\n        # does not normalize for multithreading\n        self.start_time = time.time()\n\n    def toc(self, average=True):\n        self.diff = time.time() - self.start_time\n        self.total_time += self.diff\n        self.calls += 1\n        self.average_time = self.total_time / self.calls\n        if average:\n            return self.average_time\n        else:\n            return self.diff\n\n    def clear(self):\n        self.total_time = 0.\n        self.calls = 0\n        self.start_time = 0.\n        self.diff = 0.\n        self.average_time = 0.\n"
  },
  {
    "path": "extract_init_states/FaceBoxes/weights/.gitignore",
    "content": "*.onnx\n"
  },
  {
    "path": "extract_init_states/FaceBoxes/weights/readme.md",
    "content": "The pre-trained model `FaceBoxesProd.pth` is downloaded from [Google Drive](https://drive.google.com/file/d/1tRVwOlu0QtjvADQ2H7vqrRwsWEmaqioI).\n\nThe converted `FaceBoxesProd.onnx`: [Google Drive](https://drive.google.com/file/d/1pccQOvYqKh3iCEHc5tSWx2-1fhgxs6rh/view?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1TJS2wFRLSoWZPR4l9E7G7w) (Password: 9hph)\n"
  },
  {
    "path": "extract_init_states/TDDFA_ONNX.py",
    "content": "# coding: utf-8\n\n__author__ = 'cleardusk'\n\nimport os \nimport sys  \ncurrent_dir = os.path.dirname(os.path.abspath(__file__))\nif current_dir not in sys.path:\n    sys.path.append(current_dir)\n    print(current_dir)\n    \nimport os.path as osp\nimport numpy as np\nimport cv2\nimport onnxruntime\n\nfrom utils.onnx import convert_to_onnx\nfrom utils.io import _load\nfrom utils.functions import (\n    crop_img, parse_roi_box_from_bbox, parse_roi_box_from_landmark,\n)\nfrom utils.tddfa_util import _parse_param, similar_transform\nfrom bfm.bfm import BFMModel\nfrom bfm.bfm_onnx import convert_bfm_to_onnx\n\nmake_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn)\n\n\nclass TDDFA_ONNX(object):\n    \"\"\"TDDFA_ONNX: the ONNX version of Three-D Dense Face Alignment (TDDFA)\"\"\"\n\n    def __init__(self, **kvs):\n        # torch.set_grad_enabled(False)\n\n        # load onnx version of BFM\n        bfm_fp = make_abs_path(kvs.get('bfm_fp', 'configs/bfm_noneck_v3.pkl'))\n        bfm_onnx_fp = bfm_fp.replace('.pkl', '.onnx')\n        if not osp.exists(bfm_onnx_fp):\n            convert_bfm_to_onnx(\n                bfm_onnx_fp,\n                shape_dim=kvs.get('shape_dim', 40),\n                exp_dim=kvs.get('exp_dim', 10)\n            )\n        self.bfm_session = onnxruntime.InferenceSession(bfm_onnx_fp, providers=['CUDAExecutionProvider'])\n\n        # load for optimization\n        bfm = BFMModel(bfm_fp, shape_dim=kvs.get('shape_dim', 40), exp_dim=kvs.get('exp_dim', 10))\n        self.tri = bfm.tri\n        self.u_base, self.w_shp_base, self.w_exp_base = bfm.u_base, bfm.w_shp_base, bfm.w_exp_base\n\n        # config\n        self.gpu_mode = kvs.get('gpu_mode', True)\n        self.gpu_id = kvs.get('gpu_id', 0)\n        self.size = kvs.get('size', 120)\n\n        param_mean_std_fp = make_abs_path(kvs.get(\n            'param_mean_std_fp', f'configs/param_mean_std_62d_{self.size}x{self.size}.pkl')\n        )\n\n        onnx_fp =  make_abs_path(kvs.get('onnx_fp', kvs.get('checkpoint_fp').replace('.pth', '.onnx'))) \n\n        # convert to onnx online if not existed\n        if onnx_fp is None or not osp.exists(onnx_fp):\n            print(f'{onnx_fp} does not exist, try to convert the `.pth` version to `.onnx` online')\n            onnx_fp = convert_to_onnx(**kvs)\n\n        self.session = onnxruntime.InferenceSession(onnx_fp, providers=['CUDAExecutionProvider'])\n\n        # params normalization config\n        r = _load(param_mean_std_fp)\n        self.param_mean = r.get('mean')\n        self.param_std = r.get('std')\n\n    def __call__(self, img_ori, objs, **kvs):\n        # Crop image, forward to get the param\n        param_lst = []\n        roi_box_lst = []\n\n        crop_policy = kvs.get('crop_policy', 'box')\n        for obj in objs:\n            if crop_policy == 'box':\n                # by face box\n                roi_box = parse_roi_box_from_bbox(obj)\n            elif crop_policy == 'landmark':\n                # by landmarks\n                roi_box = parse_roi_box_from_landmark(obj)\n            else:\n                raise ValueError(f'Unknown crop policy {crop_policy}')\n\n            roi_box_lst.append(roi_box)\n            img = crop_img(img_ori, roi_box)\n            img = cv2.resize(img, dsize=(self.size, self.size), interpolation=cv2.INTER_LINEAR)\n            img = img.astype(np.float32).transpose(2, 0, 1)[np.newaxis, ...]\n            img = (img - 127.5) / 128.\n\n            inp_dct = {'input': img}\n\n            param = self.session.run(None, inp_dct)[0]\n            param = param.flatten().astype(np.float32)\n            param = param * self.param_std + self.param_mean  # re-scale\n            param_lst.append(param)\n\n        return param_lst, roi_box_lst\n\n    def recon_vers(self, param_lst, roi_box_lst, **kvs):\n        dense_flag = kvs.get('dense_flag', False)\n        size = self.size\n\n        ver_lst = []\n        for param, roi_box in zip(param_lst, roi_box_lst):\n            R, offset, alpha_shp, alpha_exp = _parse_param(param)\n            if dense_flag:\n                inp_dct = {\n                    'R': R, 'offset': offset, 'alpha_shp': alpha_shp, 'alpha_exp': alpha_exp\n                }\n                pts3d = self.bfm_session.run(None, inp_dct)[0]\n                pts3d = similar_transform(pts3d, roi_box, size)\n            else:\n                pts3d = R @ (self.u_base + self.w_shp_base @ alpha_shp + self.w_exp_base @ alpha_exp). \\\n                    reshape(3, -1, order='F') + offset\n                pts3d = similar_transform(pts3d, roi_box, size)\n\n            ver_lst.append(pts3d)\n\n        return ver_lst\n"
  },
  {
    "path": "extract_init_states/bfm/.gitignore",
    "content": "*.ply\n"
  },
  {
    "path": "extract_init_states/bfm/__init__.py",
    "content": "from .bfm import BFMModel"
  },
  {
    "path": "extract_init_states/bfm/bfm.py",
    "content": "# coding: utf-8\n\n__author__ = 'cleardusk'\n\nimport sys\n\nsys.path.append('..')\n\nimport os.path as osp\nimport numpy as np\nfrom utils.io import _load\n\nmake_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn)\n\n\ndef _to_ctype(arr):\n    if not arr.flags.c_contiguous:\n        return arr.copy(order='C')\n    return arr\n\n\nclass BFMModel(object):\n    def __init__(self, bfm_fp, shape_dim=40, exp_dim=10):\n        bfm = _load(bfm_fp)\n        self.u = bfm.get('u').astype(np.float32)  # fix bug\n        self.w_shp = bfm.get('w_shp').astype(np.float32)[..., :shape_dim]\n        self.w_exp = bfm.get('w_exp').astype(np.float32)[..., :exp_dim]\n        if osp.split(bfm_fp)[-1] == 'bfm_noneck_v3.pkl':\n            self.tri = _load(make_abs_path('../configs/tri.pkl'))  # this tri/face is re-built for bfm_noneck_v3\n        else:\n            self.tri = bfm.get('tri')\n\n        self.tri = _to_ctype(self.tri.T).astype(np.int32)\n        self.keypoints = bfm.get('keypoints').astype(np.int64)  # fix bug\n        w = np.concatenate((self.w_shp, self.w_exp), axis=1)\n        self.w_norm = np.linalg.norm(w, axis=0)\n\n        self.u_base = self.u[self.keypoints].reshape(-1, 1)\n        self.w_shp_base = self.w_shp[self.keypoints]\n        self.w_exp_base = self.w_exp[self.keypoints]\n"
  },
  {
    "path": "extract_init_states/bfm/bfm_onnx.py",
    "content": "# coding: utf-8\n\n__author__ = 'cleardusk'\n\nimport sys\n\nsys.path.append('..')\n\nimport os.path as osp\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom utils.io import _load, _numpy_to_cuda, _numpy_to_tensor\n\nmake_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn)\n\n\ndef _to_ctype(arr):\n    if not arr.flags.c_contiguous:\n        return arr.copy(order='C')\n    return arr\n\n\ndef _load_tri(bfm_fp):\n    if osp.split(bfm_fp)[-1] == 'bfm_noneck_v3.pkl':\n        tri = _load(make_abs_path('../configs/tri.pkl'))  # this tri/face is re-built for bfm_noneck_v3\n    else:\n        tri = _load(bfm_fp).get('tri')\n\n    tri = _to_ctype(tri.T).astype(np.int32)\n    return tri\n\n\nclass BFMModel_ONNX(nn.Module):\n    \"\"\"BFM serves as a decoder\"\"\"\n\n    def __init__(self, bfm_fp, shape_dim=40, exp_dim=10):\n        super(BFMModel_ONNX, self).__init__()\n\n        _to_tensor = _numpy_to_tensor\n\n        # load bfm\n        bfm = _load(bfm_fp)\n\n        u = _to_tensor(bfm.get('u').astype(np.float32))\n        self.u = u.view(-1, 3).transpose(1, 0)\n        w_shp = _to_tensor(bfm.get('w_shp').astype(np.float32)[..., :shape_dim])\n        w_exp = _to_tensor(bfm.get('w_exp').astype(np.float32)[..., :exp_dim])\n        w = torch.cat((w_shp, w_exp), dim=1)\n        self.w = w.view(-1, 3, w.shape[-1]).contiguous().permute(1, 0, 2)\n\n        # self.u = _to_tensor(bfm.get('u').astype(np.float32))  # fix bug\n        # w_shp = _to_tensor(bfm.get('w_shp').astype(np.float32)[..., :shape_dim])\n        # w_exp = _to_tensor(bfm.get('w_exp').astype(np.float32)[..., :exp_dim])\n        # self.w = torch.cat((w_shp, w_exp), dim=1)\n\n        # self.keypoints = bfm.get('keypoints').astype(np.long)  # fix bug\n        # self.u_base = self.u[self.keypoints].reshape(-1, 1)\n        # self.w_shp_base = self.w_shp[self.keypoints]\n        # self.w_exp_base = self.w_exp[self.keypoints]\n\n    def forward(self, *inps):\n        R, offset, alpha_shp, alpha_exp = inps\n        alpha = torch.cat((alpha_shp, alpha_exp))\n        # pts3d = R @ (self.u + self.w_shp.matmul(alpha_shp) + self.w_exp.matmul(alpha_exp)). \\\n        #     view(-1, 3).transpose(1, 0) + offset\n        # pts3d = R @ (self.u + self.w.matmul(alpha)).view(-1, 3).transpose(1, 0) + offset\n        pts3d = R @ (self.u + self.w.matmul(alpha).squeeze()) + offset\n        return pts3d\n\n\ndef convert_bfm_to_onnx(bfm_onnx_fp, shape_dim=40, exp_dim=10):\n    # print(shape_dim, exp_dim)\n    bfm_fp = bfm_onnx_fp.replace('.onnx', '.pkl')\n    bfm_decoder = BFMModel_ONNX(bfm_fp=bfm_fp, shape_dim=shape_dim, exp_dim=exp_dim)\n    bfm_decoder.eval()\n\n    # dummy_input = torch.randn(12 + shape_dim + exp_dim)\n    dummy_input = torch.randn(3, 3), torch.randn(3, 1), torch.randn(shape_dim, 1), torch.randn(exp_dim, 1)\n    R, offset, alpha_shp, alpha_exp = dummy_input\n    torch.onnx.export(\n        bfm_decoder,\n        (R, offset, alpha_shp, alpha_exp),\n        bfm_onnx_fp,\n        input_names=['R', 'offset', 'alpha_shp', 'alpha_exp'],\n        output_names=['output'],\n        dynamic_axes={\n            'alpha_shp': [0],\n            'alpha_exp': [0],\n        },\n        do_constant_folding=True\n    )\n    print(f'Convert {bfm_fp} to {bfm_onnx_fp} done.')\n\n\nif __name__ == '__main__':\n    convert_bfm_to_onnx('../configs/bfm_noneck_v3.onnx')\n"
  },
  {
    "path": "extract_init_states/bfm/readme.md",
    "content": "## Statement\n\nThe modified BFM2009 face model in `../configs/bfm_noneck_v3.pkl` is only for academic use.\nFor commercial use, you need to apply for the commercial license, some refs are below:\n\n[1] https://faces.dmi.unibas.ch/bfm/?nav=1-0&id=basel_face_model\n\n[2] https://faces.dmi.unibas.ch/bfm/bfm2019.html\n\nIf your work benefits from this repo, please cite\n\n    @PROCEEDINGS{bfm09,\n        title={A 3D Face Model for Pose and Illumination Invariant Face Recognition},\n        author={P. Paysan and R. Knothe and B. Amberg\n                and S. Romdhani and T. Vetter},\n        journal={Proceedings of the 6th IEEE International Conference on Advanced Video and Signal based Surveillance (AVSS)\n             for Security, Safety and Monitoring in Smart Environments},\n        organization={IEEE},\n        year={2009},\n        address     = {Genova, Italy},\n    }\n\n "
  },
  {
    "path": "extract_init_states/build.sh",
    "content": "cd FaceBoxes\nsh ./build_cpu_nms.sh\ncd ..\n\n# cd Sim3DR\n# sh ./build_sim3dr.sh\n# cd ..\n\ncd utils/asset\ngcc -shared -Wall -O3 render.c -o render.so -fPIC\ncd ../.."
  },
  {
    "path": "extract_init_states/configs/.gitignore",
    "content": "# *.pkl\n# *.yml\n# *.onnx"
  },
  {
    "path": "extract_init_states/configs/mb05_120x120.yml",
    "content": "arch: mobilenet # MobileNet V1\nwiden_factor: 0.5\ncheckpoint_fp: weights/mb05_120x120.pth\nbfm_fp: configs/bfm_noneck_v3.pkl # or configs/bfm_noneck_v3_slim.pkl\nsize: 120\nnum_params: 62"
  },
  {
    "path": "extract_init_states/configs/mb1_120x120.yml",
    "content": "arch: mobilenet # MobileNet V1\nwiden_factor: 1.0\ncheckpoint_fp: weights/mb1_120x120.pth\nbfm_fp: configs/bfm_noneck_v3.pkl # or configs/bfm_noneck_v3_slim.pkl\nsize: 120\nnum_params: 62\n"
  },
  {
    "path": "extract_init_states/configs/readme.md",
    "content": "## The simplified version of BFM\n\n`bfm_noneck_v3_slim.pkl`: [Google Drive](https://drive.google.com/file/d/1iK5lD49E_gCn9voUjWDPj2ItGKvM10GI/view?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1C_SzYBOG3swZA_EjxpXlAw) (Password: p803)"
  },
  {
    "path": "extract_init_states/configs/resnet_120x120.yml",
    "content": "# before using this config, go through readme.md to find the onnx links to download `resnet22.onnx`\narch: resnet22\ncheckpoint_fp: weights/resnet22.pth\nbfm_fp: configs/bfm_noneck_v3.pkl\nsize: 120\nnum_params: 62\n"
  },
  {
    "path": "extract_init_states/demo_pose_extract_2d_lmk_img.py",
    "content": "# coding: utf-8\n# based on 3DDFA\n__author__ = 'cleardusk'\n\nimport sys\nimport os\ncurrent_dir = os.path.dirname(os.path.abspath(__file__))\nprint(current_dir)\nif current_dir not in sys.path:\n    sys.path.append(current_dir)\n    print(current_dir)\nimport argparse\nimport cv2\nimport yaml\n\nimport time\nfrom yaml import safe_dump\nfrom FaceBoxes import FaceBoxes\nimport numpy as np\nfrom tqdm import tqdm\nimport copy\nimport time\nfrom utils.pose import viz_pose, get_pose\nfrom utils.serialization import ser_to_ply, ser_to_obj\nfrom utils.functions import draw_landmarks, get_suffix, calculate_eye, calculate_bbox\nfrom utils.tddfa_util import str2bool\nimport concurrent.futures\nfrom multiprocessing import Pool\n\ndef main(args,img, save_path, pose_path):\n #   begin = time.time()\n    \n        # face_boxes.eval()\n\n    # Given a still image path and load to BGR channel\n  #  img = cv2.imread(img_path) #args.img_fp\n\n    # Detect faces, get 3DMM params and roi boxes\n\n    # start_time = time.time()\n    boxes = face_boxes(img)\n    # end_time = time.time()\n    # execution_time = end_time - start_time\n    # print(f'box time: {execution_time}')\n    n = len(boxes)\n    if n == 0:\n        print(f'No face detected, exit')\n      #  sys.exit(-1)\n        return None\n    # print(f'Detect {n} faces')\n\n    # start_time = time.time()\n    param_lst, roi_box_lst = tddfa(img, boxes)\n    # end_time = time.time()\n    # execution_time = end_time - start_time\n    # print(f'tddfa time: {execution_time}')\n    #detection time\n  #  detect_time = time.time()-begin\n #   print('detection time: '+str(detect_time), file=open('/mnt/lustre/jixinya/Home/3DDFA_V2/pose.txt', 'a'))\n    # Visualization and serialization\n    dense_flag = args.opt in ('2d_dense', '3d', 'depth', 'pncc', 'uv_tex', 'ply', 'obj')\n  #  old_suffix = get_suffix(img_path)\n    old_suffix = 'png'\n    new_suffix = f'.{args.opt}' if args.opt in ('ply', 'obj') else '.jpg'\n\n    wfp = f'examples/results/{args.img_fp.split(\"/\")[-1].replace(old_suffix, \"\")}_{args.opt}' + new_suffix\n\n    # start_time = time.time()\n    ver_lst = tddfa.recon_vers(param_lst, roi_box_lst, dense_flag=dense_flag)\n    # end_time = time.time()\n    # execution_time = end_time - start_time\n    # print(f'tddfa.recon_vers time: {execution_time}')\n\n\n    # start_time = time.time()\n    all_pose = get_pose(img, param_lst, ver_lst, show_flag=args.show_flag, wfp=save_path, wnp = pose_path)\n    end_time = time.time()\n    \n\n    return all_pose, ver_lst\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='The demo of still image of 3DDFA_V2')\n    parser.add_argument('-c', '--config', type=str, default=f'{current_dir}/configs/mb1_120x120.yml')\n    parser.add_argument('-f', '--img_fp', type=str, default='/disk2/pfhu/DAWN-pytorch/images/image/anime_female2.jpeg')\n    parser.add_argument('-m', '--mode', type=str, default='gpu', help='gpu or cpu mode')\n    parser.add_argument('-o', '--opt', type=str, default='pose',\n                        choices=['2d_sparse', '2d_dense', '3d', 'depth', 'pncc', 'uv_tex', 'pose', 'ply', 'obj'])\n    parser.add_argument('--show_flag', type=str2bool, default='False', help='whether to show the visualization result')\n    parser.add_argument('--onnx', action='store_true', default=True)\n    parser.add_argument('-p', '--part',  type=int, default=1)\n    parser.add_argument('-a', '--all', type=int, default=1)\n\n    parser.add_argument('-i', '--input', type=str)\n    parser.add_argument('-t', '--output', type=str)\n\n    args = parser.parse_args()\n\n    part = args.part\n    all_part = args.all\n\n\n    \n    filepath = args.input\n    save_path = args.output\n\n    if not os.path.exists(save_path):\n        os.makedirs(save_path)\n    \n    start_point = 30 #int((part - 1) *duration)\n    \n    cfg = yaml.load(open(args.config), Loader=yaml.SafeLoader)\n\n    # Init FaceBoxes and TDDFA, recommend using onnx flag\n    if args.onnx:\n        import os\n        os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'\n        # os.environ['OMP_WAIT_POLICY'] = 'PASSIVE'\n        os.environ['OMP_NUM_THREADS'] = '8'\n\n        from FaceBoxes.FaceBoxes_ONNX import FaceBoxes_ONNX\n        from TDDFA_ONNX import TDDFA_ONNX\n\n        face_boxes = FaceBoxes_ONNX()\n        tddfa = TDDFA_ONNX(**cfg)\n    else:\n        gpu_mode = args.mode == 'gpu'\n        tddfa = TDDFA(gpu_mode=gpu_mode, **cfg)\n        # tddfa.eval()\n        face_boxes = FaceBoxes()\n\n\n    # save_path_pose = os.path.join(save_path, 'tmp.npy')\n    image= cv2.imread(filepath)\n    if image.shape[2] == 4:\n        image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)\n    pose, lmk = main(args,image, save_path = None, pose_path =  None)\n\n    lmk = lmk[0]\n    eye_bbox_result = np.zeros(8)\n    bbox = calculate_bbox(image, lmk)\n    left_ratio, right_ratio = calculate_eye(lmk)\n    eye_bbox_result[0] = left_ratio.sum()\n    eye_bbox_result[1] = right_ratio.sum()\n    eye_bbox_result[2:] = np.array(bbox)\n\n    pose = pose.reshape(1,7)\n    eye_bbox_result = eye_bbox_result.reshape(1, -1)\n    eye_bbox_path = os.path.join(save_path, 'init_eye_bbox.npy')\n    pose_path = os.path.join(save_path, 'init_pose.npy')\n\n    np.save(eye_bbox_path, eye_bbox_result)\n    np.save(pose_path, pose)\n\n\n            \n   "
  },
  {
    "path": "extract_init_states/functions.py",
    "content": "# coding: utf-8\n\n__author__ = 'cleardusk'\n\nimport numpy as np\nimport cv2\nfrom math import sqrt\nimport matplotlib.pyplot as plt\n\nRED = (0, 0, 255)\nGREEN = (0, 255, 0)\nBLUE = (255, 0, 0)\n\n\ndef get_suffix(filename):\n    \"\"\"a.jpg -> jpg\"\"\"\n    pos = filename.rfind('.')\n    if pos == -1:\n        return ''\n    return filename[pos:]\n\n\ndef crop_img(img, roi_box):\n    h, w = img.shape[:2]\n\n    sx, sy, ex, ey = [int(round(_)) for _ in roi_box]\n    dh, dw = ey - sy, ex - sx\n    if len(img.shape) == 3:\n        res = np.zeros((dh, dw, 3), dtype=np.uint8)\n    else:\n        res = np.zeros((dh, dw), dtype=np.uint8)\n    if sx < 0:\n        sx, dsx = 0, -sx\n    else:\n        dsx = 0\n\n    if ex > w:\n        ex, dex = w, dw - (ex - w)\n    else:\n        dex = dw\n\n    if sy < 0:\n        sy, dsy = 0, -sy\n    else:\n        dsy = 0\n\n    if ey > h:\n        ey, dey = h, dh - (ey - h)\n    else:\n        dey = dh\n\n    res[dsy:dey, dsx:dex] = img[sy:ey, sx:ex]\n    return res\n\n\ndef calc_hypotenuse(pts):\n    bbox = [min(pts[0, :]), min(pts[1, :]), max(pts[0, :]), max(pts[1, :])]\n    center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]\n    radius = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2\n    bbox = [center[0] - radius, center[1] - radius, center[0] + radius, center[1] + radius]\n    llength = sqrt((bbox[2] - bbox[0]) ** 2 + (bbox[3] - bbox[1]) ** 2)\n    return llength / 3\n\n\ndef parse_roi_box_from_landmark(pts):\n    \"\"\"calc roi box from landmark\"\"\"\n    bbox = [min(pts[0, :]), min(pts[1, :]), max(pts[0, :]), max(pts[1, :])]\n    center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]\n    radius = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2\n    bbox = [center[0] - radius, center[1] - radius, center[0] + radius, center[1] + radius]\n\n    llength = sqrt((bbox[2] - bbox[0]) ** 2 + (bbox[3] - bbox[1]) ** 2)\n    center_x = (bbox[2] + bbox[0]) / 2\n    center_y = (bbox[3] + bbox[1]) / 2\n\n    roi_box = [0] * 4\n    roi_box[0] = center_x - llength / 2\n    roi_box[1] = center_y - llength / 2\n    roi_box[2] = roi_box[0] + llength\n    roi_box[3] = roi_box[1] + llength\n\n    return roi_box\n\n\ndef parse_roi_box_from_bbox(bbox):\n    left, top, right, bottom = bbox[:4]\n    old_size = (right - left + bottom - top) / 2\n    center_x = right - (right - left) / 2.0\n    center_y = bottom - (bottom - top) / 2.0 + old_size * 0.14\n    size = int(old_size * 1.58)\n\n    roi_box = [0] * 4\n    roi_box[0] = center_x - size / 2\n    roi_box[1] = center_y - size / 2\n    roi_box[2] = roi_box[0] + size\n    roi_box[3] = roi_box[1] + size\n\n    return roi_box\n\n\ndef plot_image(img):\n    height, width = img.shape[:2]\n    plt.figure(figsize=(12, height / width * 12))\n\n    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)\n    plt.axis('off')\n\n    plt.imshow(img[..., ::-1])\n    plt.show()\n\n\ndef draw_landmarks(img, pts, style='fancy', wfp=None, show_flag=False, **kwargs):\n    \"\"\"Draw landmarks using matplotlib\"\"\"\n    height, width = img.shape[:2]\n    plt.figure(figsize=(12, height / width * 12))\n    plt.imshow(img[..., ::-1])\n    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)\n    plt.axis('off')\n\n    dense_flag = kwargs.get('dense_flag')\n\n    if not type(pts) in [tuple, list]:\n        pts = [pts]\n    for i in range(len(pts)):\n        if dense_flag:\n            plt.plot(pts[i][0, ::6], pts[i][1, ::6], 'o', markersize=0.4, color='c', alpha=0.7)\n        else:\n            alpha = 0.8\n            markersize = 4\n            lw = 1.5\n            color = kwargs.get('color', 'w')\n            markeredgecolor = kwargs.get('markeredgecolor', 'black')\n\n            nums = [0, 17, 22, 27, 31, 36, 42, 48, 60, 68]\n\n            # close eyes and mouths\n            plot_close = lambda i1, i2: plt.plot([pts[i][0, i1], pts[i][0, i2]], [pts[i][1, i1], pts[i][1, i2]],\n                                                 color=color, lw=lw, alpha=alpha - 0.1)\n            plot_close(41, 36)\n            plot_close(47, 42)\n            plot_close(59, 48)\n            plot_close(67, 60)\n\n            for ind in range(len(nums) - 1):\n                l, r = nums[ind], nums[ind + 1]\n                plt.plot(pts[i][0, l:r], pts[i][1, l:r], color=color, lw=lw, alpha=alpha - 0.1)\n\n                plt.plot(pts[i][0, l:r], pts[i][1, l:r], marker='o', linestyle='None', markersize=markersize,\n                         color=color,\n                         markeredgecolor=markeredgecolor, alpha=alpha)\n    if wfp is not None:\n        plt.savefig(wfp, dpi=150)\n        print(f'Save visualization result to {wfp}')\n\n    if show_flag:\n        plt.show()\n\n\ndef cv_draw_landmark(img_ori, pts, box=None, color=GREEN, size=1):\n    img = img_ori.copy()\n    n = pts.shape[1]\n    if n <= 106:\n        for i in range(n):\n            cv2.circle(img, (int(round(pts[0, i])), int(round(pts[1, i]))), size, color, -1)\n    else:\n        sep = 1\n        for i in range(0, n, sep):\n            cv2.circle(img, (int(round(pts[0, i])), int(round(pts[1, i]))), size, color, 1)\n\n    if box is not None:\n        left, top, right, bottom = np.round(box).astype(np.int32)\n        left_top = (left, top)\n        right_top = (right, top)\n        right_bottom = (right, bottom)\n        left_bottom = (left, bottom)\n        cv2.line(img, left_top, right_top, BLUE, 1, cv2.LINE_AA)\n        cv2.line(img, right_top, right_bottom, BLUE, 1, cv2.LINE_AA)\n        cv2.line(img, right_bottom, left_bottom, BLUE, 1, cv2.LINE_AA)\n        cv2.line(img, left_bottom, left_top, BLUE, 1, cv2.LINE_AA)\n\n    return img\n\ndef calculate_bbox(img, lmk):\n    lmk = lmk.transpose(1,0)\n    # point_3d_homo = np.hstack((lmk, np.ones([lmk.shape[0], 1])))  # n x 4\n    # point_2d = point_3d_homo.dot(P.T)[:, :2]\n\n    # point_2d[:, 1] = - point_2d[:, 1]\n    # point_2d[:, :2] = point_2d[:, :2] - np.mean(point_2d, 0) + np.mean(lmk[:27,:2], 0)  # lmk 0-27 \n    point_2d = lmk[:, :2]\n    point_2d = np.int32(point_2d.reshape(-1, 2))\n    H = img.shape[0]\n    W = img.shape[1]\n    x_min, x_max = point_2d[:, 0].min(), point_2d[:, 0].max()\n    y_min, y_max = point_2d[:, 1].min(), point_2d[:, 1].max()\n    # cv2.polylines(img, [point_2d], True, (40, 255, 0), 2, cv2.LINE_AA)\n    # points_list = [(p[0], p[1]) for p in point_2d]\n    # for p in points_list:\n    #     cv2.circle(img, p, 1, (40, 255, 0), -1)\n    # return img\n\n    return [x_min, x_max, y_min, y_max, H, W]\n    \ndef calculate_eye(lmk):\n    '''\n    left right obj\n    '''\n    lmk = lmk.transpose(1,0)\n    leye_upper = lmk[43]\n    leye_lower = lmk[47]\n    leye_left = lmk[45]\n    leye_right = lmk[42]\n    reye_upper = lmk[37]\n    reye_lower = lmk[41]\n    reye_left = lmk[39]\n    reye_right = lmk[36]\n\n    left_ratio = np.linalg.norm(leye_upper - leye_lower, 2) / np.linalg.norm(leye_left - leye_right, 2)\n    right_ratio = np.linalg.norm(reye_upper - reye_lower, 2) / np.linalg.norm(reye_left - reye_right, 2)\n\n    return left_ratio, right_ratio"
  },
  {
    "path": "extract_init_states/models/__init__.py",
    "content": "from .mobilenet_v1 import *\nfrom .mobilenet_v3 import *\nfrom .resnet import *"
  },
  {
    "path": "extract_init_states/models/mobilenet_v1.py",
    "content": "# coding: utf-8\n\nfrom __future__ import division\n\n\"\"\" \nCreates a MobileNet Model as defined in:\nAndrew G. Howard Menglong Zhu Bo Chen, et.al. (2017). \nMobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications. \nCopyright (c) Yang Lu, 2017\n\nModified By cleardusk\n\"\"\"\nimport math\nimport torch.nn as nn\n\n__all__ = ['MobileNet', 'mobilenet']\n\n\n# __all__ = ['mobilenet_2', 'mobilenet_1', 'mobilenet_075', 'mobilenet_05', 'mobilenet_025']\n\n\nclass DepthWiseBlock(nn.Module):\n    def __init__(self, inplanes, planes, stride=1, prelu=False):\n        super(DepthWiseBlock, self).__init__()\n        inplanes, planes = int(inplanes), int(planes)\n        self.conv_dw = nn.Conv2d(inplanes, inplanes, kernel_size=3, padding=1, stride=stride, groups=inplanes,\n                                 bias=False)\n        self.bn_dw = nn.BatchNorm2d(inplanes)\n        self.conv_sep = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False)\n        self.bn_sep = nn.BatchNorm2d(planes)\n        if prelu:\n            self.relu = nn.PReLU()\n        else:\n            self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, x):\n        out = self.conv_dw(x)\n        out = self.bn_dw(out)\n        out = self.relu(out)\n\n        out = self.conv_sep(out)\n        out = self.bn_sep(out)\n        out = self.relu(out)\n\n        return out\n\n\nclass MobileNet(nn.Module):\n    def __init__(self, widen_factor=1.0, num_classes=1000, prelu=False, input_channel=3):\n        \"\"\" Constructor\n        Args:\n            widen_factor: config of widen_factor\n            num_classes: number of classes\n        \"\"\"\n        super(MobileNet, self).__init__()\n\n        block = DepthWiseBlock\n        self.conv1 = nn.Conv2d(input_channel, int(32 * widen_factor), kernel_size=3, stride=2, padding=1,\n                               bias=False)\n\n        self.bn1 = nn.BatchNorm2d(int(32 * widen_factor))\n        if prelu:\n            self.relu = nn.PReLU()\n        else:\n            self.relu = nn.ReLU(inplace=True)\n\n        self.dw2_1 = block(32 * widen_factor, 64 * widen_factor, prelu=prelu)\n        self.dw2_2 = block(64 * widen_factor, 128 * widen_factor, stride=2, prelu=prelu)\n\n        self.dw3_1 = block(128 * widen_factor, 128 * widen_factor, prelu=prelu)\n        self.dw3_2 = block(128 * widen_factor, 256 * widen_factor, stride=2, prelu=prelu)\n\n        self.dw4_1 = block(256 * widen_factor, 256 * widen_factor, prelu=prelu)\n        self.dw4_2 = block(256 * widen_factor, 512 * widen_factor, stride=2, prelu=prelu)\n\n        self.dw5_1 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu)\n        self.dw5_2 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu)\n        self.dw5_3 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu)\n        self.dw5_4 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu)\n        self.dw5_5 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu)\n        self.dw5_6 = block(512 * widen_factor, 1024 * widen_factor, stride=2, prelu=prelu)\n\n        self.dw6 = block(1024 * widen_factor, 1024 * widen_factor, prelu=prelu)\n\n        self.avgpool = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Linear(int(1024 * widen_factor), num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n                m.weight.data.normal_(0, math.sqrt(2. / n))\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n\n        x = self.dw2_1(x)\n        x = self.dw2_2(x)\n        x = self.dw3_1(x)\n        x = self.dw3_2(x)\n        x = self.dw4_1(x)\n        x = self.dw4_2(x)\n        x = self.dw5_1(x)\n        x = self.dw5_2(x)\n        x = self.dw5_3(x)\n        x = self.dw5_4(x)\n        x = self.dw5_5(x)\n        x = self.dw5_6(x)\n        x = self.dw6(x)\n\n        x = self.avgpool(x)\n        x = x.view(x.size(0), -1)\n        x = self.fc(x)\n\n        return x\n\n\ndef mobilenet(**kwargs):\n    \"\"\"\n    Construct MobileNet.\n    widen_factor=1.0  for mobilenet_1\n    widen_factor=0.75 for mobilenet_075\n    widen_factor=0.5  for mobilenet_05\n    widen_factor=0.25 for mobilenet_025\n    \"\"\"\n    # widen_factor = 1.0, num_classes = 1000\n    # model = MobileNet(widen_factor=widen_factor, num_classes=num_classes)\n    # return model\n\n    model = MobileNet(\n        widen_factor=kwargs.get('widen_factor', 1.0),\n        num_classes=kwargs.get('num_classes', 62)\n    )\n    return model\n\n\ndef mobilenet_2(num_classes=62, input_channel=3):\n    model = MobileNet(widen_factor=2.0, num_classes=num_classes, input_channel=input_channel)\n    return model\n\n\ndef mobilenet_1(num_classes=62, input_channel=3):\n    model = MobileNet(widen_factor=1.0, num_classes=num_classes, input_channel=input_channel)\n    return model\n\n\ndef mobilenet_075(num_classes=62, input_channel=3):\n    model = MobileNet(widen_factor=0.75, num_classes=num_classes, input_channel=input_channel)\n    return model\n\n\ndef mobilenet_05(num_classes=62, input_channel=3):\n    model = MobileNet(widen_factor=0.5, num_classes=num_classes, input_channel=input_channel)\n    return model\n\n\ndef mobilenet_025(num_classes=62, input_channel=3):\n    model = MobileNet(widen_factor=0.25, num_classes=num_classes, input_channel=input_channel)\n    return model\n"
  },
  {
    "path": "extract_init_states/models/mobilenet_v3.py",
    "content": "# coding: utf-8\n\n\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n__all__ = ['MobileNetV3', 'mobilenet_v3']\n\n\ndef conv_bn(inp, oup, stride, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, nlin_layer=nn.ReLU):\n    return nn.Sequential(\n        conv_layer(inp, oup, 3, stride, 1, bias=False),\n        norm_layer(oup),\n        nlin_layer(inplace=True)\n    )\n\n\ndef conv_1x1_bn(inp, oup, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, nlin_layer=nn.ReLU):\n    return nn.Sequential(\n        conv_layer(inp, oup, 1, 1, 0, bias=False),\n        norm_layer(oup),\n        nlin_layer(inplace=True)\n    )\n\n\nclass Hswish(nn.Module):\n    def __init__(self, inplace=True):\n        super(Hswish, self).__init__()\n        self.inplace = inplace\n\n    def forward(self, x):\n        return x * F.relu6(x + 3., inplace=self.inplace) / 6.\n\n\nclass Hsigmoid(nn.Module):\n    def __init__(self, inplace=True):\n        super(Hsigmoid, self).__init__()\n        self.inplace = inplace\n\n    def forward(self, x):\n        return F.relu6(x + 3., inplace=self.inplace) / 6.\n\n\nclass SEModule(nn.Module):\n    def __init__(self, channel, reduction=4):\n        super(SEModule, self).__init__()\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Sequential(\n            nn.Linear(channel, channel // reduction, bias=False),\n            nn.ReLU(inplace=True),\n            nn.Linear(channel // reduction, channel, bias=False),\n            Hsigmoid()\n            # nn.Sigmoid()\n        )\n\n    def forward(self, x):\n        b, c, _, _ = x.size()\n        y = self.avg_pool(x).view(b, c)\n        y = self.fc(y).view(b, c, 1, 1)\n        return x * y.expand_as(x)\n\n\nclass Identity(nn.Module):\n    def __init__(self, channel):\n        super(Identity, self).__init__()\n\n    def forward(self, x):\n        return x\n\n\ndef make_divisible(x, divisible_by=8):\n    import numpy as np\n    return int(np.ceil(x * 1. / divisible_by) * divisible_by)\n\n\nclass MobileBottleneck(nn.Module):\n    def __init__(self, inp, oup, kernel, stride, exp, se=False, nl='RE'):\n        super(MobileBottleneck, self).__init__()\n        assert stride in [1, 2]\n        assert kernel in [3, 5]\n        padding = (kernel - 1) // 2\n        self.use_res_connect = stride == 1 and inp == oup\n\n        conv_layer = nn.Conv2d\n        norm_layer = nn.BatchNorm2d\n        if nl == 'RE':\n            nlin_layer = nn.ReLU  # or ReLU6\n        elif nl == 'HS':\n            nlin_layer = Hswish\n        else:\n            raise NotImplementedError\n        if se:\n            SELayer = SEModule\n        else:\n            SELayer = Identity\n\n        self.conv = nn.Sequential(\n            # pw\n            conv_layer(inp, exp, 1, 1, 0, bias=False),\n            norm_layer(exp),\n            nlin_layer(inplace=True),\n            # dw\n            conv_layer(exp, exp, kernel, stride, padding, groups=exp, bias=False),\n            norm_layer(exp),\n            SELayer(exp),\n            nlin_layer(inplace=True),\n            # pw-linear\n            conv_layer(exp, oup, 1, 1, 0, bias=False),\n            norm_layer(oup),\n        )\n\n    def forward(self, x):\n        if self.use_res_connect:\n            return x + self.conv(x)\n        else:\n            return self.conv(x)\n\n\nclass MobileNetV3(nn.Module):\n    def __init__(self, widen_factor=1.0, num_classes=141, num_landmarks=136, input_size=120, mode='small'):\n        super(MobileNetV3, self).__init__()\n        input_channel = 16\n        last_channel = 1280\n        if mode == 'large':\n            # refer to Table 1 in paper\n            mobile_setting = [\n                # k, exp, c,  se,     nl,  s,\n                [3, 16, 16, False, 'RE', 1],\n                [3, 64, 24, False, 'RE', 2],\n                [3, 72, 24, False, 'RE', 1],\n                [5, 72, 40, True, 'RE', 2],\n                [5, 120, 40, True, 'RE', 1],\n                [5, 120, 40, True, 'RE', 1],\n                [3, 240, 80, False, 'HS', 2],\n                [3, 200, 80, False, 'HS', 1],\n                [3, 184, 80, False, 'HS', 1],\n                [3, 184, 80, False, 'HS', 1],\n                [3, 480, 112, True, 'HS', 1],\n                [3, 672, 112, True, 'HS', 1],\n                [5, 672, 160, True, 'HS', 2],\n                [5, 960, 160, True, 'HS', 1],\n                [5, 960, 160, True, 'HS', 1],\n            ]\n        elif mode == 'small':\n            # refer to Table 2 in paper\n            mobile_setting = [\n                # k, exp, c,  se,     nl,  s,\n                [3, 16, 16, True, 'RE', 2],\n                [3, 72, 24, False, 'RE', 2],\n                [3, 88, 24, False, 'RE', 1],\n                [5, 96, 40, True, 'HS', 2],\n                [5, 240, 40, True, 'HS', 1],\n                [5, 240, 40, True, 'HS', 1],\n                [5, 120, 48, True, 'HS', 1],\n                [5, 144, 48, True, 'HS', 1],\n                [5, 288, 96, True, 'HS', 2],\n                [5, 576, 96, True, 'HS', 1],\n                [5, 576, 96, True, 'HS', 1],\n            ]\n        else:\n            raise NotImplementedError\n\n        # building first layer\n        assert input_size % 32 == 0\n        last_channel = make_divisible(last_channel * widen_factor) if widen_factor > 1.0 else last_channel\n        self.features = [conv_bn(3, input_channel, 2, nlin_layer=Hswish)]\n        # self.classifier = []\n\n        # building mobile blocks\n        for k, exp, c, se, nl, s in mobile_setting:\n            output_channel = make_divisible(c * widen_factor)\n            exp_channel = make_divisible(exp * widen_factor)\n            self.features.append(MobileBottleneck(input_channel, output_channel, k, s, exp_channel, se, nl))\n            input_channel = output_channel\n\n        # building last several layers\n        if mode == 'large':\n            last_conv = make_divisible(960 * widen_factor)\n            self.features.append(conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish))\n            self.features.append(nn.AdaptiveAvgPool2d(1))\n            self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0))\n            self.features.append(Hswish(inplace=True))\n        elif mode == 'small':\n            last_conv = make_divisible(576 * widen_factor)\n            self.features.append(conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish))\n            # self.features.append(SEModule(last_conv))  # refer to paper Table2, but I think this is a mistake\n            self.features.append(nn.AdaptiveAvgPool2d(1))\n            self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0))\n            self.features.append(Hswish(inplace=True))\n        else:\n            raise NotImplementedError\n\n        # make it nn.Sequential\n        self.features = nn.Sequential(*self.features)\n\n        # self.fc_param = nn.Linear(int(last_channel), num_classes)\n        self.fc = nn.Linear(int(last_channel), num_classes)\n        # self.fc_lm = nn.Linear(int(last_channel), num_landmarks)\n\n        # building classifier\n        # self.classifier = nn.Sequential(\n        #     nn.Dropout(p=dropout),    # refer to paper section 6\n        #     nn.Linear(last_channel, n_class),\n        # )\n\n        self._initialize_weights()\n\n    def forward(self, x):\n        x = self.features(x)\n        x_share = x.mean(3).mean(2)\n\n        # x = self.classifier(x)\n        # print(x_share.shape)\n        # xp = self.fc_param(x_share)  # param\n        # xl = self.fc_lm(x_share)  # lm\n\n        xp = self.fc(x_share)  # param\n\n        return xp  # , xl\n\n    def _initialize_weights(self):\n        # weight initialization\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out')\n                if m.bias is not None:\n                    nn.init.zeros_(m.bias)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.ones_(m.weight)\n                nn.init.zeros_(m.bias)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, 0, 0.01)\n                if m.bias is not None:\n                    nn.init.zeros_(m.bias)\n\n\ndef mobilenet_v3(**kwargs):\n    model = MobileNetV3(\n        widen_factor=kwargs.get('widen_factor', 1.0),\n        num_classes=kwargs.get('num_classes', 62),\n        num_landmarks=kwargs.get('num_landmarks', 136),\n        input_size=kwargs.get('size', 128),\n        mode=kwargs.get('mode', 'small')\n    )\n\n    return model\n"
  },
  {
    "path": "extract_init_states/models/resnet.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport torch.nn as nn\n\n__all__ = ['ResNet', 'resnet22']\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"3x3 convolution with padding\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=1, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(BasicBlock, self).__init__()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass ResNet(nn.Module):\n    \"\"\"Another Strucutre used in caffe-resnet25\"\"\"\n\n    def __init__(self, block, layers, num_classes=62, num_landmarks=136, input_channel=3, fc_flg=False):\n        self.inplanes = 64\n        super(ResNet, self).__init__()\n        self.conv1 = nn.Conv2d(input_channel, 32, kernel_size=5, stride=2, padding=2, bias=False)\n        self.bn1 = nn.BatchNorm2d(32)  # 32 is input channels number\n        self.relu1 = nn.ReLU(inplace=True)\n\n        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(64)\n        self.relu2 = nn.ReLU(inplace=True)\n\n        # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n\n        self.layer1 = self._make_layer(block, 128, layers[0], stride=2)\n        self.layer2 = self._make_layer(block, 256, layers[1], stride=2)\n        self.layer3 = self._make_layer(block, 512, layers[2], stride=2)\n\n        self.conv_param = nn.Conv2d(512, num_classes, 1)\n        # self.conv_lm = nn.Conv2d(512, num_landmarks, 1)\n        self.avgpool = nn.AdaptiveAvgPool2d(1)\n        # self.fc = nn.Linear(512 * block.expansion, num_classes)\n        self.fc_flg = fc_flg\n\n        # parameter initialization\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                # 1.\n                # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n                # m.weight.data.normal_(0, math.sqrt(2. / n))\n\n                # 2. kaiming normal\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n    def _make_layer(self, block, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.inplanes, planes * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample))\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(self.inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu1(x)\n\n        x = self.conv2(x)\n        x = self.bn2(x)\n        x = self.relu2(x)\n\n        # x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n\n        # if self.fc_flg:\n        #     x = self.avgpool(x)\n        #     x = x.view(x.size(0), -1)\n        #     x = self.fc(x)\n        # else:\n        xp = self.conv_param(x)\n        xp = self.avgpool(xp)\n        xp = xp.view(xp.size(0), -1)\n\n        # xl = self.conv_lm(x)\n        # xl = self.avgpool(xl)\n        # xl = xl.view(xl.size(0), -1)\n\n        return xp  # , xl\n\n\ndef resnet22(**kwargs):\n    model = ResNet(\n        BasicBlock,\n        [3, 4, 3],\n        num_landmarks=kwargs.get('num_landmarks', 136),\n        input_channel=kwargs.get('input_channel', 3),\n        fc_flg=False\n    )\n    return model\n\n\ndef main():\n    pass\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "extract_init_states/pose.py",
    "content": "# coding: utf-8\n\n\"\"\"\nReference: https://github.com/YadiraF/PRNet/blob/master/utils/estimate_pose.py\n\nCalculating pose from the output 3DMM parameters, you can also try to use solvePnP to perform estimation\n\"\"\"\n\n__author__ = 'cleardusk'\n\nimport cv2\nimport numpy as np\nfrom math import cos, sin, atan2, asin, sqrt\n\nfrom .functions import calc_hypotenuse, plot_image\n\n\ndef P2sRt(P):\n    \"\"\" decompositing camera matrix P.\n    Args:\n        P: (3, 4). Affine Camera Matrix.\n    Returns:\n        s: scale factor.\n        R: (3, 3). rotation matrix.\n        t2d: (2,). 2d translation.\n    \"\"\"\n    t3d = P[:, 3] # \n    R1 = P[0:1, :3]\n    R2 = P[1:2, :3]\n    s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2.0  #\n    r1 = R1 / np.linalg.norm(R1)\n    r2 = R2 / np.linalg.norm(R2)\n    r3 = np.cross(r1, r2) # r1r2，r3\n\n    R = np.concatenate((r1, r2, r3), 0)  #  r 1-3R （）\n    return s, R, t3d\n\n\ndef matrix2angle(R):\n    \"\"\" compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf\n    refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv\n    todo: check and debug\n     Args:\n         R: (3,3). rotation matrix\n     Returns:\n         x: yaw\n         y: pitch\n         z: roll\n     \"\"\"\n    if R[2, 0] > 0.998:\n        z = 0\n        x = np.pi / 2\n        y = z + atan2(-R[0, 1], -R[0, 2])\n    elif R[2, 0] < -0.998:\n        z = 0\n        x = -np.pi / 2\n        y = -z + atan2(R[0, 1], R[0, 2])\n    else:\n        x = asin(R[2, 0])\n        y = atan2(R[2, 1] / cos(x), R[2, 2] / cos(x))\n        z = atan2(R[1, 0] / cos(x), R[0, 0] / cos(x))\n\n    return x, y, z\n\ndef angle2matrix(theta):\n    \"\"\" compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf\n    refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv\n    todo: check and debug\n     Args:\n         R: (3,3). rotation matrix\n     Returns:\n         x: yaw\n         y: pitch\n         z: roll\n     \"\"\"\n    R_x = np.array([[1,         0,                  0         ],\n\n                    [0,         cos(theta[1]), -sin(theta[1]) ],\n\n                    [0,         sin(theta[1]), cos(theta[1])  ]\n\n                    ])\n\n \n\n    R_y = np.array([[cos(theta[0]),    0,      sin(-theta[0])  ],\n\n                    [0,                     1,      0         ],\n\n                    [-sin(-theta[0]),   0,      cos(theta[0])  ]\n\n                    ])\n\n \n\n    R_z = np.array([[cos(theta[2]),    -sin(theta[2]),    0],\n\n                    [sin(theta[2]),    cos(theta[2]),     0],\n\n                    [0,                     0,            1]\n\n                    ])\n\n \n\n    R = np.dot(R_z, np.dot( R_y, R_x ))\n\n \n\n    return R\n\ndef angle2matrix_3ddfa(angles):\n    ''' get rotation matrix from three rotation angles(radian). The same as in 3DDFA.\n    Args:\n        angles: [3,]. x, y, z angles\n        x: pitch.\n        y: yaw. \n        z: roll. \n    Returns:\n        R: 3x3. rotation matrix.\n    '''\n    # x, y, z = np.deg2rad(angles[0]), np.deg2rad(angles[1]), np.deg2rad(angles[2])\n    x, y, z = angles[1], angles[0], angles[2]\n    \n    # x\n    Rx=np.array([[1,      0,       0],\n                 [0, cos(x),  sin(x)],\n                 [0, -sin(x),   cos(x)]])\n    # y\n    Ry=np.array([[ cos(y), 0, -sin(y)],\n                 [      0, 1,      0],\n                 [sin(y), 0, cos(y)]])\n    # z\n    Rz=np.array([[cos(z), sin(z), 0],\n                 [-sin(z),  cos(z), 0],\n                 [     0,       0, 1]])\n    R = Rx.dot(Ry).dot(Rz)\n    return R.astype(np.float32)\n\ndef calc_pose(param):\n    P = param[:12].reshape(3, -1)  # camera matrix\n    s, R, t3d = P2sRt(P)\n    P = np.concatenate((R, t3d.reshape(3, -1)), axis=1)  # without scale\n    pose = matrix2angle(R)\n    pose = [p * 180 / np.pi for p in pose]\n\n    return P, pose\n\n\ndef build_camera_box(rear_size=90):\n    point_3d = []\n    rear_depth = 0\n    point_3d.append((-rear_size, -rear_size, rear_depth))\n    point_3d.append((-rear_size, rear_size, rear_depth))\n    point_3d.append((rear_size, rear_size, rear_depth))\n    point_3d.append((rear_size, -rear_size, rear_depth))\n    point_3d.append((-rear_size, -rear_size, rear_depth))\n\n    front_size = int(4 / 3 * rear_size)\n    front_depth = int(4 / 3 * rear_size)\n    point_3d.append((-front_size, -front_size, front_depth))\n    point_3d.append((-front_size, front_size, front_depth))\n    point_3d.append((front_size, front_size, front_depth))\n    point_3d.append((front_size, -front_size, front_depth))\n    point_3d.append((-front_size, -front_size, front_depth))\n    point_3d = np.array(point_3d, dtype=np.float32).reshape(-1, 3)\n\n    return point_3d\n\n\ndef plot_pose_box(img, P, ver, color=(40, 255, 0), line_width=2):\n    \"\"\" Draw a 3D box as annotation of pose.\n    Ref:https://github.com/yinguobing/head-pose-estimation/blob/master/pose_estimator.py\n    Args:\n        img: the input image\n        P: (3, 4). Affine Camera Matrix.\n        kpt: (2, 68) or (3, 68)\n    \"\"\"\n    llength = calc_hypotenuse(ver)\n    point_3d = build_camera_box(llength)\n    # Map to 2d image points\n    point_3d_homo = np.hstack((point_3d, np.ones([point_3d.shape[0], 1])))  # n x 4\n    point_2d = point_3d_homo.dot(P.T)[:, :2]\n\n    point_2d[:, 1] = - point_2d[:, 1]\n    point_2d[:, :2] = point_2d[:, :2] - np.mean(point_2d[:4, :2], 0) + np.mean(ver[:2, :27], 1)  # lmk 0-27 \n    point_2d = np.int32(point_2d.reshape(-1, 2))\n\n    # Draw all the lines\n    cv2.polylines(img, [point_2d], True, color, line_width, cv2.LINE_AA)\n    cv2.line(img, tuple(point_2d[1]), tuple(\n        point_2d[6]), color, line_width, cv2.LINE_AA)\n    cv2.line(img, tuple(point_2d[2]), tuple(\n        point_2d[7]), color, line_width, cv2.LINE_AA)\n    cv2.line(img, tuple(point_2d[3]), tuple(\n        point_2d[8]), color, line_width, cv2.LINE_AA)\n\n    return img\n\n\ndef viz_pose(img, param_lst, ver_lst, show_flag=False, wfp=None):\n    for param, ver in zip(param_lst, ver_lst):\n        P, pose = calc_pose(param)\n        img = plot_pose_box(img, P, ver)\n        # print(P[:, :3])\n        # print(f'yaw: {pose[0]:.1f}, pitch: {pose[1]:.1f}, roll: {pose[2]:.1f}')\n\n    if wfp is not None:\n        cv2.imwrite(wfp, img)\n        print(f'Save visualization result to {wfp}')\n\n    if show_flag:\n        plot_image(img)\n\n    return img\n\ndef pose_6(param):\n    P = param[:12].reshape(3, -1)  # camera matrix\n    s, R, t3d = P2sRt(P)\n    P = np.concatenate((R, t3d.reshape(3, -1)), axis=1)  # without scale\n    pose = matrix2angle(R)  # R，pose\n    # print(t3d)\n    R1 = angle2matrix(pose)\n    # print(R)\n    # print(R1)\n    pose = [p * 180 / np.pi for p in pose]\n    \n    return s, pose, t3d, P  # s()、R()、t3d()\n\n\ndef smooth_pose(img, param_lst, ver_lst, pose_new, show_flag=False, wfp=None, wnp = None):\n    for param, ver in zip(param_lst, ver_lst):\n        t3d = np.array([pose_new[4],pose_new[5],pose_new[6]])\n        \n        theta = np.array([pose_new[0],pose_new[1],pose_new[2]])\n        theta = [p * np.pi / 180 for p in theta]\n        R = angle2matrix(theta)\n        P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) \n        img = plot_pose_box(img, P, ver)\n    #    print(P,P.shape,t3d)\n        # print(P,pose_new)\n        # print(f'yaw: {theta[0]:.1f}, pitch: {theta[1]:.1f}, roll: {theta[2]:.1f}')\n        all_pose = [0]\n        all_pose = np.array(all_pose)\n\n    if wfp is not None:\n        cv2.imwrite(wfp, img)\n        print(f'Save visualization result to {wfp}')\n        \n    if wnp is not None:\n        np.save(wnp, all_pose)\n        print(f'Save visualization result to {wfp}')\n        \n    if show_flag:\n        plot_image(img)\n\n    return img\n\n    \n    \n    \n\ndef get_pose(img, param_lst, ver_lst, show_flag=False, wfp=None, wnp = None):\n    for param, ver in zip(param_lst, ver_lst):  # \n        s, pose, t3d, P = pose_6(param)\n        img_1 = plot_pose_box(img.copy(), P, ver)\n    #    print(P,P.shape,t3d)\n        # print(f'yaw: {pose[0]:.1f}, pitch: {pose[1]:.1f}, roll: {pose[2]:.1f}')\n        all_pose = [pose[0],pose[1],pose[2],s,t3d[0],t3d[1],t3d[2]]\n        all_pose = np.array(all_pose)\n\n    # if wfp is not None:\n    #     cv2.imwrite(wfp, img_1)\n    #     print(f'Save visualization result to {wfp}')\n        \n    # if wnp is not None:\n    #     np.save(wnp, all_pose)\n    #     print(f'Save visualization result to {wfp}')\n        \n    if show_flag:\n        plot_image(img)\n\n    return all_pose\n"
  },
  {
    "path": "extract_init_states/readme.md",
    "content": "# README\n\nThe `extract_init_state` is mainly from [3DDFA_v2](https://github.com/cleardusk/3DDFA_V2) with minor revision. We remove the `Sim3DR` in original repo.\nWe add or revise the file of `extract_init_states\\demo_pose_extract_2d_lmk_img.py`, `extract_init_states\\utils\\pose.py`.\n\n## Linux\nLinux user can follow the installation process on [3DDFA_v2](https://github.com/cleardusk/3DDFA_V2)\n\n## Win\nFor Windows user, be aware to these tips:\n1. Installing gcc\n2. In `extract_init_states\\FaceBoxes\\utils\\build.py`, you need comment line 47\n3. Revise the `extract_init_states\\FaceBoxes\\utils\\nms\\cpu_nms.pyx` following [comment](https://github.com/cleardusk/3DDFA_V2/issues/12#issuecomment-697479173).\n4. Run the command in sh script line by line manually\n\n\n"
  },
  {
    "path": "extract_init_states/utils/__init__.py",
    "content": ""
  },
  {
    "path": "extract_init_states/utils/asset/.gitignore",
    "content": "*.so\n"
  },
  {
    "path": "extract_init_states/utils/asset/build_render_ctypes.sh",
    "content": "gcc -shared -Wall -O3 render.c -o render.so -fPIC"
  },
  {
    "path": "extract_init_states/utils/asset/render.c",
    "content": "#include <math.h>\n#include <stdlib.h>\n#include <stdio.h>\n\n#define max(x, y) (((x) > (y)) ? (x) : (y))\n#define min(x, y) (((x) < (y)) ? (x) : (y))\n#define clip(_x, _min, _max) min(max(_x, _min), _max)\n\nstruct Tuple3D\n{\n    float x;\n    float y;\n    float z;\n};\n\nvoid _render(const int *triangles,\n             const int ntri,\n             const float *light,\n             const float *directional,\n             const float *ambient,\n             const float *vertices,\n             const int nver,\n             unsigned char *image,\n             const int h, const int w)\n{\n    int tri_p0_ind, tri_p1_ind, tri_p2_ind;\n    int color_index;\n    float dot00, dot01, dot11, dot02, dot12;\n    float cos_sum, det;\n\n    struct Tuple3D p0, p1, p2;\n    struct Tuple3D v0, v1, v2;\n    struct Tuple3D p, start, end;\n\n    struct Tuple3D ver_max = {-1.0e8, -1.0e8, -1.0e8};\n    struct Tuple3D ver_min = {1.0e8, 1.0e8, 1.0e8};\n    struct Tuple3D ver_mean = {0.0, 0.0, 0.0};\n\n    float *ver_normal = (float *)calloc(3 * nver, sizeof(float));\n    float *colors = (float *)malloc(3 * nver * sizeof(float));\n    float *depth_buffer = (float *)calloc(h * w, sizeof(float));\n\n    for (int i = 0; i < ntri; i++)\n    {\n        tri_p0_ind = triangles[3 * i];\n        tri_p1_ind = triangles[3 * i + 1];\n        tri_p2_ind = triangles[3 * i + 2];\n\n        // counter clockwise order\n        start.x = vertices[tri_p1_ind] - vertices[tri_p0_ind];\n        start.y = vertices[tri_p1_ind + 1] - vertices[tri_p0_ind + 1];\n        start.z = vertices[tri_p1_ind + 2] - vertices[tri_p0_ind + 2];\n\n        end.x = vertices[tri_p2_ind] - vertices[tri_p0_ind];\n        end.y = vertices[tri_p2_ind + 1] - vertices[tri_p0_ind + 1];\n        end.z = vertices[tri_p2_ind + 2] - vertices[tri_p0_ind + 2];\n\n        p.x = start.y * end.z - start.z * end.y;\n        p.y = start.z * end.x - start.x * end.z;\n        p.z = start.x * end.y - start.y * end.x;\n\n        ver_normal[tri_p0_ind] += p.x;\n        ver_normal[tri_p1_ind] += p.x;\n        ver_normal[tri_p2_ind] += p.x;\n\n        ver_normal[tri_p0_ind + 1] += p.y;\n        ver_normal[tri_p1_ind + 1] += p.y;\n        ver_normal[tri_p2_ind + 1] += p.y;\n\n        ver_normal[tri_p0_ind + 2] += p.z;\n        ver_normal[tri_p1_ind + 2] += p.z;\n        ver_normal[tri_p2_ind + 2] += p.z;\n    }\n\n    for (int i = 0; i < nver; ++i)\n    {\n        p.x = ver_normal[3 * i];\n        p.y = ver_normal[3 * i + 1];\n        p.z = ver_normal[3 * i + 2];\n\n        det = sqrt(p.x * p.x + p.y * p.y + p.z * p.z);\n        if (det <= 0)\n            det = 1e-6;\n\n        ver_normal[3 * i] /= det;\n        ver_normal[3 * i + 1] /= det;\n        ver_normal[3 * i + 2] /= det;\n\n        ver_mean.x += p.x;\n        ver_mean.y += p.y;\n        ver_mean.z += p.z;\n\n        ver_max.x = max(ver_max.x, p.x);\n        ver_max.y = max(ver_max.y, p.y);\n        ver_max.z = max(ver_max.z, p.z);\n\n        ver_min.x = min(ver_min.x, p.x);\n        ver_min.y = min(ver_min.y, p.y);\n        ver_min.z = min(ver_min.z, p.z);\n    }\n\n    ver_mean.x /= nver;\n    ver_mean.y /= nver;\n    ver_mean.z /= nver;\n\n    for (int i = 0; i < nver; ++i)\n    {\n        colors[3 * i] = vertices[3 * i];\n        colors[3 * i + 1] = vertices[3 * i + 1];\n        colors[3 * i + 2] = vertices[3 * i + 2];\n\n        colors[3 * i] -= ver_mean.x;\n        colors[3 * i] /= ver_max.x - ver_min.x;\n\n        colors[3 * i + 1] -= ver_mean.y;\n        colors[3 * i + 1] /= ver_max.y - ver_min.y;\n\n        colors[3 * i + 2] -= ver_mean.z;\n        colors[3 * i + 2] /= ver_max.z - ver_min.z;\n\n        p.x = light[0] - colors[3 * i];\n        p.y = light[1] - colors[3 * i + 1];\n        p.z = light[2] - colors[3 * i + 2];\n\n        det = sqrt(p.x * p.x + p.y * p.y + p.z * p.z);\n        if (det <= 0)\n            det = 1e-6;\n\n        colors[3 * i] = p.x / det;\n        colors[3 * i + 1] = p.y / det;\n        colors[3 * i + 2] = p.z / det;\n\n        colors[3 * i] *= ver_normal[3 * i];\n        colors[3 * i + 1] *= ver_normal[3 * i + 1];\n        colors[3 * i + 2] *= ver_normal[3 * i + 2];\n\n        cos_sum = colors[3 * i] + colors[3 * i + 1] + colors[3 * i + 2];\n\n        colors[3 * i] = clip(cos_sum * directional[0] + ambient[0], 0, 1);\n        colors[3 * i + 1] = clip(cos_sum * directional[1] + ambient[1], 0, 1);\n        colors[3 * i + 2] = clip(cos_sum * directional[2] + ambient[2], 0, 1);\n    }\n\n    for (int i = 0; i < ntri; ++i)\n    {\n        tri_p0_ind = triangles[3 * i];\n        tri_p1_ind = triangles[3 * i + 1];\n        tri_p2_ind = triangles[3 * i + 2];\n\n        p0.x = vertices[tri_p0_ind];\n        p0.y = vertices[tri_p0_ind + 1];\n        p0.z = vertices[tri_p0_ind + 2];\n\n        p1.x = vertices[tri_p1_ind];\n        p1.y = vertices[tri_p1_ind + 1];\n        p1.z = vertices[tri_p1_ind + 2];\n\n        p2.x = vertices[tri_p2_ind];\n        p2.y = vertices[tri_p2_ind + 1];\n        p2.z = vertices[tri_p2_ind + 2];\n\n        start.x = max(ceil(min(p0.x, min(p1.x, p2.x))), 0);\n        end.x = min(floor(max(p0.x, max(p1.x, p2.x))), w - 1);\n\n        start.y = max(ceil(min(p0.y, min(p1.y, p2.y))), 0);\n        end.y = min(floor(max(p0.y, max(p1.y, p2.y))), h - 1);\n\n        if (end.x < start.x || end.y < start.y)\n            continue;\n\n        v0.x = p2.x - p0.x;\n        v0.y = p2.y - p0.y;\n        v1.x = p1.x - p0.x;\n        v1.y = p1.y - p0.y;\n\n        // dot products np.dot(v0.T, v0)\n        dot00 = v0.x * v0.x + v0.y * v0.y;\n        dot01 = v0.x * v1.x + v0.y * v1.y;\n        dot11 = v1.x * v1.x + v1.y * v1.y;\n\n        // barycentric coordinates\n        start.z = dot00 * dot11 - dot01 * dot01;\n        if (start.z != 0)\n            start.z = 1 / start.z;\n\n        for (p.y = start.y; p.y <= end.y; p.y += 1.0)\n        {\n            for (p.x = start.x; p.x <= end.x; p.x += 1.0)\n            {\n                v2.x = p.x - p0.x;\n                v2.y = p.y - p0.y;\n\n                dot02 = v0.x * v2.x + v0.y * v2.y;\n                dot12 = v1.x * v2.x + v1.y * v2.y;\n\n                v2.z = (dot11 * dot02 - dot01 * dot12) * start.z;\n                v1.z = (dot00 * dot12 - dot01 * dot02) * start.z;\n                v0.z = 1 - v2.z - v1.z;\n\n                // judge is_point_in_tri by below line of code\n                if (v2.z > 0 && v1.z > 0 && v0.z > 0)\n                {\n                    p.z = v0.z * p0.z + v1.z * p1.z + v2.z * p2.z;\n                    color_index = p.y * w + p.x;\n\n                    if (p.z > depth_buffer[color_index])\n                    {\n                        end.z = v0.z * colors[tri_p0_ind];\n                        end.z += v1.z * colors[tri_p1_ind];\n                        end.z += v2.z * colors[tri_p2_ind];\n                        image[3 * color_index] = end.z * 255;\n\n                        end.z = v0.z * colors[tri_p0_ind + 1];\n                        end.z += v1.z * colors[tri_p1_ind + 1];\n                        end.z += v2.z * colors[tri_p2_ind + 1];\n                        image[3 * color_index + 1] = end.z * 255;\n\n                        end.z = v0.z * colors[tri_p0_ind + 2];\n                        end.z += v1.z * colors[tri_p1_ind + 2];\n                        end.z += v2.z * colors[tri_p2_ind + 2];\n                        image[3 * color_index + 2] = end.z * 255;\n\n                        depth_buffer[color_index] = p.z;\n                    }\n                }\n            }\n        }\n    }\n\n    free(depth_buffer);\n    free(colors);\n    free(ver_normal);\n}\n"
  },
  {
    "path": "extract_init_states/utils/depth.py",
    "content": "# coding: utf-8\n\n__author__ = 'cleardusk'\n\nimport sys\n\nsys.path.append('..')\n\nimport cv2\nimport numpy as np\n\nfrom Sim3DR import rasterize\nfrom utils.functions import plot_image\nfrom .tddfa_util import _to_ctype\n\n\ndef depth(img, ver_lst, tri, show_flag=False, wfp=None, with_bg_flag=True):\n    if with_bg_flag:\n        overlap = img.copy()\n    else:\n        overlap = np.zeros_like(img)\n\n    for ver_ in ver_lst:\n        ver = _to_ctype(ver_.T)  # transpose\n\n        z = ver[:, 2]\n        z_min, z_max = min(z), max(z)\n\n        z = (z - z_min) / (z_max - z_min)\n\n        # expand\n        z = np.repeat(z[:, np.newaxis], 3, axis=1)\n\n        overlap = rasterize(ver, tri, z, bg=overlap)\n\n    if wfp is not None:\n        cv2.imwrite(wfp, overlap)\n        print(f'Save visualization result to {wfp}')\n\n    if show_flag:\n        plot_image(overlap)\n\n    return overlap\n"
  },
  {
    "path": "extract_init_states/utils/functions.py",
    "content": "# coding: utf-8\n\n__author__ = 'cleardusk'\n\nimport numpy as np\nimport cv2\nfrom math import sqrt\nimport matplotlib.pyplot as plt\n\nRED = (0, 0, 255)\nGREEN = (0, 255, 0)\nBLUE = (255, 0, 0)\n\n\ndef get_suffix(filename):\n    \"\"\"a.jpg -> jpg\"\"\"\n    pos = filename.rfind('.')\n    if pos == -1:\n        return ''\n    return filename[pos:]\n\n\ndef crop_img(img, roi_box):\n    h, w = img.shape[:2]\n\n    sx, sy, ex, ey = [int(round(_)) for _ in roi_box]\n    dh, dw = ey - sy, ex - sx\n    if len(img.shape) == 3:\n        res = np.zeros((dh, dw, 3), dtype=np.uint8)\n    else:\n        res = np.zeros((dh, dw), dtype=np.uint8)\n    if sx < 0:\n        sx, dsx = 0, -sx\n    else:\n        dsx = 0\n\n    if ex > w:\n        ex, dex = w, dw - (ex - w)\n    else:\n        dex = dw\n\n    if sy < 0:\n        sy, dsy = 0, -sy\n    else:\n        dsy = 0\n\n    if ey > h:\n        ey, dey = h, dh - (ey - h)\n    else:\n        dey = dh\n\n    res[dsy:dey, dsx:dex] = img[sy:ey, sx:ex]\n    return res\n\n\ndef calc_hypotenuse(pts):\n    bbox = [min(pts[0, :]), min(pts[1, :]), max(pts[0, :]), max(pts[1, :])]\n    center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]\n    radius = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2\n    bbox = [center[0] - radius, center[1] - radius, center[0] + radius, center[1] + radius]\n    llength = sqrt((bbox[2] - bbox[0]) ** 2 + (bbox[3] - bbox[1]) ** 2)\n    return llength / 3\n\n\ndef parse_roi_box_from_landmark(pts):\n    \"\"\"calc roi box from landmark\"\"\"\n    bbox = [min(pts[0, :]), min(pts[1, :]), max(pts[0, :]), max(pts[1, :])]\n    center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]\n    radius = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2\n    bbox = [center[0] - radius, center[1] - radius, center[0] + radius, center[1] + radius]\n\n    llength = sqrt((bbox[2] - bbox[0]) ** 2 + (bbox[3] - bbox[1]) ** 2)\n    center_x = (bbox[2] + bbox[0]) / 2\n    center_y = (bbox[3] + bbox[1]) / 2\n\n    roi_box = [0] * 4\n    roi_box[0] = center_x - llength / 2\n    roi_box[1] = center_y - llength / 2\n    roi_box[2] = roi_box[0] + llength\n    roi_box[3] = roi_box[1] + llength\n\n    return roi_box\n\n\ndef parse_roi_box_from_bbox(bbox):\n    left, top, right, bottom = bbox[:4]\n    old_size = (right - left + bottom - top) / 2\n    center_x = right - (right - left) / 2.0\n    center_y = bottom - (bottom - top) / 2.0 + old_size * 0.14\n    size = int(old_size * 1.58)\n\n    roi_box = [0] * 4\n    roi_box[0] = center_x - size / 2\n    roi_box[1] = center_y - size / 2\n    roi_box[2] = roi_box[0] + size\n    roi_box[3] = roi_box[1] + size\n\n    return roi_box\n\n\ndef plot_image(img):\n    height, width = img.shape[:2]\n    plt.figure(figsize=(12, height / width * 12))\n\n    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)\n    plt.axis('off')\n\n    plt.imshow(img[..., ::-1])\n    plt.show()\n\n\ndef draw_landmarks(img, pts, style='fancy', wfp=None, show_flag=False, **kwargs):\n    \"\"\"Draw landmarks using matplotlib\"\"\"\n    height, width = img.shape[:2]\n    plt.figure(figsize=(12, height / width * 12))\n    plt.imshow(img[..., ::-1])\n    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)\n    plt.axis('off')\n\n    dense_flag = kwargs.get('dense_flag')\n\n    if not type(pts) in [tuple, list]:\n        pts = [pts]\n    for i in range(len(pts)):\n        if dense_flag:\n            plt.plot(pts[i][0, ::6], pts[i][1, ::6], 'o', markersize=0.4, color='c', alpha=0.7)\n        else:\n            alpha = 0.8\n            markersize = 4\n            lw = 1.5\n            color = kwargs.get('color', 'w')\n            markeredgecolor = kwargs.get('markeredgecolor', 'black')\n\n            nums = [0, 17, 22, 27, 31, 36, 42, 48, 60, 68]\n\n            # close eyes and mouths\n            plot_close = lambda i1, i2: plt.plot([pts[i][0, i1], pts[i][0, i2]], [pts[i][1, i1], pts[i][1, i2]],\n                                                 color=color, lw=lw, alpha=alpha - 0.1)\n            plot_close(41, 36)\n            plot_close(47, 42)\n            plot_close(59, 48)\n            plot_close(67, 60)\n\n            for ind in range(len(nums) - 1):\n                l, r = nums[ind], nums[ind + 1]\n                plt.plot(pts[i][0, l:r], pts[i][1, l:r], color=color, lw=lw, alpha=alpha - 0.1)\n\n                plt.plot(pts[i][0, l:r], pts[i][1, l:r], marker='o', linestyle='None', markersize=markersize,\n                         color=color,\n                         markeredgecolor=markeredgecolor, alpha=alpha)\n    if wfp is not None:\n        plt.savefig(wfp, dpi=150)\n        print(f'Save visualization result to {wfp}')\n\n    if show_flag:\n        plt.show()\n\n\ndef cv_draw_landmark(img_ori, pts, box=None, color=GREEN, size=1):\n    img = img_ori.copy()\n    n = pts.shape[1]\n    if n <= 106:\n        for i in range(n):\n            cv2.circle(img, (int(round(pts[0, i])), int(round(pts[1, i]))), size, color, -1)\n    else:\n        sep = 1\n        for i in range(0, n, sep):\n            cv2.circle(img, (int(round(pts[0, i])), int(round(pts[1, i]))), size, color, 1)\n\n    if box is not None:\n        left, top, right, bottom = np.round(box).astype(np.int32)\n        left_top = (left, top)\n        right_top = (right, top)\n        right_bottom = (right, bottom)\n        left_bottom = (left, bottom)\n        cv2.line(img, left_top, right_top, BLUE, 1, cv2.LINE_AA)\n        cv2.line(img, right_top, right_bottom, BLUE, 1, cv2.LINE_AA)\n        cv2.line(img, right_bottom, left_bottom, BLUE, 1, cv2.LINE_AA)\n        cv2.line(img, left_bottom, left_top, BLUE, 1, cv2.LINE_AA)\n\n    return img\n\ndef calculate_bbox(img, lmk):\n    lmk = lmk.transpose(1,0)\n    # point_3d_homo = np.hstack((lmk, np.ones([lmk.shape[0], 1])))  # n x 4\n    # point_2d = point_3d_homo.dot(P.T)[:, :2]\n\n    # point_2d[:, 1] = - point_2d[:, 1]\n    # point_2d[:, :2] = point_2d[:, :2] - np.mean(point_2d, 0) + np.mean(lmk[:27,:2], 0)  # lmk 0-27 face contour\n    point_2d = lmk[:, :2]\n    point_2d = np.int32(point_2d.reshape(-1, 2))\n    H = img.shape[0]\n    W = img.shape[1]\n    x_min, x_max = point_2d[:, 0].min(), point_2d[:, 0].max()\n    y_min, y_max = point_2d[:, 1].min(), point_2d[:, 1].max()\n    # cv2.polylines(img, [point_2d], True, (40, 255, 0), 2, cv2.LINE_AA)\n    # points_list = [(p[0], p[1]) for p in point_2d]\n    # for p in points_list:\n    #     cv2.circle(img, p, 1, (40, 255, 0), -1)\n    # return img\n\n    return [x_min, x_max, y_min, y_max, H, W]\n    \ndef calculate_eye(lmk):\n    lmk = lmk.transpose(1,0)\n    leye_upper = lmk[43]\n    leye_lower = lmk[47]\n    leye_left = lmk[45]\n    leye_right = lmk[42]\n    reye_upper = lmk[37]\n    reye_lower = lmk[41]\n    reye_left = lmk[39]\n    reye_right = lmk[36]\n\n    left_ratio = np.linalg.norm(leye_upper - leye_lower, 2) / np.linalg.norm(leye_left - leye_right, 2)\n    right_ratio = np.linalg.norm(reye_upper - reye_lower, 2) / np.linalg.norm(reye_left - reye_right, 2)\n\n    return left_ratio, right_ratio"
  },
  {
    "path": "extract_init_states/utils/io.py",
    "content": "# coding: utf-8\n\n__author__ = 'cleardusk'\n\nimport os\nimport numpy as np\nimport torch\nimport pickle\n\n\ndef mkdir(d):\n    os.makedirs(d, exist_ok=True)\n\n\ndef _get_suffix(filename):\n    \"\"\"a.jpg -> jpg\"\"\"\n    pos = filename.rfind('.')\n    if pos == -1:\n        return ''\n    return filename[pos + 1:]\n\n\ndef _load(fp):\n    suffix = _get_suffix(fp)\n    if suffix == 'npy':\n        return np.load(fp)\n    elif suffix == 'pkl':\n        return pickle.load(open(fp, 'rb'))\n\n\ndef _dump(wfp, obj):\n    suffix = _get_suffix(wfp)\n    if suffix == 'npy':\n        np.save(wfp, obj)\n    elif suffix == 'pkl':\n        pickle.dump(obj, open(wfp, 'wb'))\n    else:\n        raise Exception('Unknown Type: {}'.format(suffix))\n\n\ndef _load_tensor(fp, mode='cpu'):\n    if mode.lower() == 'cpu':\n        return torch.from_numpy(_load(fp))\n    elif mode.lower() == 'gpu':\n        return torch.from_numpy(_load(fp)).cuda()\n\n\ndef _tensor_to_cuda(x):\n    if x.is_cuda:\n        return x\n    else:\n        return x.cuda()\n\n\ndef _load_gpu(fp):\n    return torch.from_numpy(_load(fp)).cuda()\n\n\n_load_cpu = _load\n_numpy_to_tensor = lambda x: torch.from_numpy(x)\n_tensor_to_numpy = lambda x: x.numpy()\n_numpy_to_cuda = lambda x: _tensor_to_cuda(torch.from_numpy(x))\n_cuda_to_tensor = lambda x: x.cpu()\n_cuda_to_numpy = lambda x: x.cpu().numpy()\n"
  },
  {
    "path": "extract_init_states/utils/onnx.py",
    "content": "# coding: utf-8\n\n__author__ = 'cleardusk'\n\nimport sys\n\nsys.path.append('..')\n\nimport torch\nimport models\nfrom utils.tddfa_util import load_model\n\n\ndef convert_to_onnx(**kvs):\n    # 1. load model\n    size = kvs.get('size', 120)\n    model = getattr(models, kvs.get('arch'))(\n        num_classes=kvs.get('num_params', 62),\n        widen_factor=kvs.get('widen_factor', 1),\n        size=size,\n        mode=kvs.get('mode', 'small')\n    )\n    checkpoint_fp = kvs.get('checkpoint_fp')\n    model = load_model(model, checkpoint_fp)\n    model.eval()\n\n    # 2. convert\n    batch_size = 1\n    dummy_input = torch.randn(batch_size, 3, size, size)\n    wfp = checkpoint_fp.replace('.pth', '.onnx')\n    torch.onnx.export(\n        model,\n        (dummy_input, ),\n        wfp,\n        input_names=['input'],\n        output_names=['output'],\n        do_constant_folding=True\n    )\n    print(f'Convert {checkpoint_fp} to {wfp} done.')\n    return wfp\n"
  },
  {
    "path": "extract_init_states/utils/pncc.py",
    "content": "# coding: utf-8\n\n__author__ = 'cleardusk'\n\nimport sys\n\nsys.path.append('..')\n\nimport cv2\nimport numpy as np\nimport os.path as osp\n\nfrom Sim3DR import rasterize\nfrom utils.functions import plot_image\nfrom utils.io import _load, _dump\nfrom utils.tddfa_util import _to_ctype\n\nmake_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn)\n\n\ndef calc_ncc_code():\n    from bfm import bfm\n\n    # formula: ncc_d = ( u_d - min(u_d) ) / ( max(u_d) - min(u_d) ), d = {r, g, b}\n    u = bfm.u\n    u = u.reshape(3, -1, order='F')\n\n    for i in range(3):\n        u[i] = (u[i] - u[i].min()) / (u[i].max() - u[i].min())\n\n    _dump('../configs/ncc_code.npy', u)\n\n\ndef pncc(img, ver_lst, tri, show_flag=False, wfp=None, with_bg_flag=True):\n    ncc_code = _load(make_abs_path('../configs/ncc_code.npy'))\n\n    if with_bg_flag:\n        overlap = img.copy()\n    else:\n        overlap = np.zeros_like(img)\n\n    # rendering pncc\n    for ver_ in ver_lst:\n        ver = _to_ctype(ver_.T)  # transpose\n        overlap = rasterize(ver, tri, ncc_code.T, bg=overlap)  # m x 3\n\n    if wfp is not None:\n        cv2.imwrite(wfp, overlap)\n        print(f'Save visualization result to {wfp}')\n\n    if show_flag:\n        plot_image(overlap)\n\n    return overlap\n\n\ndef main():\n    # `configs/ncc_code.npy` is generated by `calc_nnc_code` function\n    # calc_ncc_code()\n    pass\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "extract_init_states/utils/pose.py",
    "content": "# coding: utf-8\n\n\"\"\"\nReference: https://github.com/YadiraF/PRNet/blob/master/utils/estimate_pose.py\n\nCalculating pose from the output 3DMM parameters, you can also try to use solvePnP to perform estimation\n\"\"\"\n\n__author__ = 'cleardusk'\n\nimport cv2\nimport numpy as np\nfrom math import cos, sin, atan2, asin, sqrt\n\nfrom .functions import calc_hypotenuse, plot_image\n\n\ndef P2sRt(P):\n    \"\"\" decompositing camera matrix P.\n    Args:\n        P: (3, 4). Affine Camera Matrix.\n    Returns:\n        s: scale factor.\n        R: (3, 3). rotation matrix.\n        t2d: (2,). 2d translation.\n    \"\"\"\n    t3d = P[:, 3] # shift\n    R1 = P[0:1, :3]\n    R2 = P[1:2, :3]\n    s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2.0  #\n    r1 = R1 / np.linalg.norm(R1)\n    r2 = R2 / np.linalg.norm(R2)\n    r3 = np.cross(r1, r2) # r1r2，r3\n\n    R = np.concatenate((r1, r2, r3), 0)  #  r 1-3R （）\n    return s, R, t3d\n\n\ndef matrix2angle(R):\n    \"\"\" compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf\n    refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv\n    todo: check and debug\n     Args:\n         R: (3,3). rotation matrix\n     Returns:\n         x: yaw\n         y: pitch\n         z: roll\n     \"\"\"\n    if R[2, 0] > 0.998:\n        z = 0\n        x = np.pi / 2\n        y = z + atan2(-R[0, 1], -R[0, 2])\n    elif R[2, 0] < -0.998:\n        z = 0\n        x = -np.pi / 2\n        y = -z + atan2(R[0, 1], R[0, 2])\n    else:\n        x = asin(R[2, 0])\n        y = atan2(R[2, 1] / cos(x), R[2, 2] / cos(x))\n        z = atan2(R[1, 0] / cos(x), R[0, 0] / cos(x))\n\n    return x, y, z\n\ndef angle2matrix(theta):\n    \"\"\" compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf\n    refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv\n    todo: check and debug\n     Args:\n         R: (3,3). rotation matrix\n     Returns:\n         x: yaw\n         y: pitch\n         z: roll\n     \"\"\"\n    R_x = np.array([[1,         0,                  0         ],\n\n                    [0,         cos(theta[1]), -sin(theta[1]) ],\n\n                    [0,         sin(theta[1]), cos(theta[1])  ]\n\n                    ])\n\n \n\n    R_y = np.array([[cos(theta[0]),    0,      sin(-theta[0])  ],\n\n                    [0,                     1,      0         ],\n\n                    [-sin(-theta[0]),   0,      cos(theta[0])  ]\n\n                    ])\n\n \n\n    R_z = np.array([[cos(theta[2]),    -sin(theta[2]),    0],\n\n                    [sin(theta[2]),    cos(theta[2]),     0],\n\n                    [0,                     0,            1]\n\n                    ])\n\n \n\n    R = np.dot(R_z, np.dot( R_y, R_x ))\n\n \n\n    return R\n\ndef angle2matrix_3ddfa(angles):\n    ''' get rotation matrix from three rotation angles(radian). The same as in 3DDFA.\n    Args:\n        angles: [3,]. x, y, z angles\n        x: pitch.\n        y: yaw. \n        z: roll. \n    Returns:\n        R: 3x3. rotation matrix.\n    '''\n    # x, y, z = np.deg2rad(angles[0]), np.deg2rad(angles[1]), np.deg2rad(angles[2])\n    x, y, z = angles[1], angles[0], angles[2]\n    \n    # x\n    Rx=np.array([[1,      0,       0],\n                 [0, cos(x),  sin(x)],\n                 [0, -sin(x),   cos(x)]])\n    # y\n    Ry=np.array([[ cos(y), 0, -sin(y)],\n                 [      0, 1,      0],\n                 [sin(y), 0, cos(y)]])\n    # z\n    Rz=np.array([[cos(z), sin(z), 0],\n                 [-sin(z),  cos(z), 0],\n                 [     0,       0, 1]])\n    R = Rx.dot(Ry).dot(Rz)\n    return R.astype(np.float32)\n\ndef calc_pose(param):\n    P = param[:12].reshape(3, -1)  # camera matrix\n    s, R, t3d = P2sRt(P)\n    P = np.concatenate((R, t3d.reshape(3, -1)), axis=1)  # without scale\n    pose = matrix2angle(R)\n    pose = [p * 180 / np.pi for p in pose]\n\n    return P, pose\n\n\ndef build_camera_box(rear_size=90):\n    point_3d = []\n    rear_depth = 0\n    point_3d.append((-rear_size, -rear_size, rear_depth))\n    point_3d.append((-rear_size, rear_size, rear_depth))\n    point_3d.append((rear_size, rear_size, rear_depth))\n    point_3d.append((rear_size, -rear_size, rear_depth))\n    point_3d.append((-rear_size, -rear_size, rear_depth))\n\n    front_size = int(4 / 3 * rear_size)\n    front_depth = int(4 / 3 * rear_size)\n    point_3d.append((-front_size, -front_size, front_depth))\n    point_3d.append((-front_size, front_size, front_depth))\n    point_3d.append((front_size, front_size, front_depth))\n    point_3d.append((front_size, -front_size, front_depth))\n    point_3d.append((-front_size, -front_size, front_depth))\n    point_3d = np.array(point_3d, dtype=np.float32).reshape(-1, 3)\n\n    return point_3d\n\n\ndef plot_pose_box(img, P, ver, color=(40, 255, 0), line_width=2):\n    \"\"\" Draw a 3D box as annotation of pose.\n    Ref:https://github.com/yinguobing/head-pose-estimation/blob/master/pose_estimator.py\n    Args:\n        img: the input image\n        P: (3, 4). Affine Camera Matrix.\n        kpt: (2, 68) or (3, 68)\n    \"\"\"\n    llength = calc_hypotenuse(ver)\n    point_3d = build_camera_box(llength)\n    # Map to 2d image points\n    point_3d_homo = np.hstack((point_3d, np.ones([point_3d.shape[0], 1])))  # n x 4\n    point_2d = point_3d_homo.dot(P.T)[:, :2]\n\n    point_2d[:, 1] = - point_2d[:, 1]\n    point_2d[:, :2] = point_2d[:, :2] - np.mean(point_2d[:4, :2], 0) + np.mean(ver[:2, :27], 1)  # lmk 0-27 \n    point_2d = np.int32(point_2d.reshape(-1, 2))\n\n    # Draw all the lines\n    cv2.polylines(img, [point_2d], True, color, line_width, cv2.LINE_AA)\n    cv2.line(img, tuple(point_2d[1]), tuple(\n        point_2d[6]), color, line_width, cv2.LINE_AA)\n    cv2.line(img, tuple(point_2d[2]), tuple(\n        point_2d[7]), color, line_width, cv2.LINE_AA)\n    cv2.line(img, tuple(point_2d[3]), tuple(\n        point_2d[8]), color, line_width, cv2.LINE_AA)\n\n    return img\n\n\ndef viz_pose(img, param_lst, ver_lst, show_flag=False, wfp=None):\n    for param, ver in zip(param_lst, ver_lst):\n        P, pose = calc_pose(param)\n        img = plot_pose_box(img, P, ver)\n        # print(P[:, :3])\n        # print(f'yaw: {pose[0]:.1f}, pitch: {pose[1]:.1f}, roll: {pose[2]:.1f}')\n\n    if wfp is not None:\n        cv2.imwrite(wfp, img)\n        print(f'Save visualization result to {wfp}')\n\n    if show_flag:\n        plot_image(img)\n\n    return img\n\ndef pose_6(param):\n    P = param[:12].reshape(3, -1)  # camera matrix\n    s, R, t3d = P2sRt(P)\n    P = np.concatenate((R, t3d.reshape(3, -1)), axis=1)  # without scale\n    pose = matrix2angle(R)  # Convert the rotation matrix R to Euler angle form to obtain the pose.\n    # print(t3d)\n    R1 = angle2matrix(pose)\n    # print(R)\n    # print(R1)\n    pose = [p * 180 / np.pi for p in pose]\n    \n    return s, pose, t3d, P  # s(scale)、R(roate)、t3d(shift)\n\n\ndef smooth_pose(img, param_lst, ver_lst, pose_new, show_flag=False, wfp=None, wnp = None):\n    for param, ver in zip(param_lst, ver_lst):\n        t3d = np.array([pose_new[4],pose_new[5],pose_new[6]])\n        \n        theta = np.array([pose_new[0],pose_new[1],pose_new[2]])\n        theta = [p * np.pi / 180 for p in theta]\n        R = angle2matrix(theta)\n        P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) \n        img = plot_pose_box(img, P, ver)\n    #    print(P,P.shape,t3d)\n        # print(P,pose_new)\n        # print(f'yaw: {theta[0]:.1f}, pitch: {theta[1]:.1f}, roll: {theta[2]:.1f}')\n        all_pose = [0]\n        all_pose = np.array(all_pose)\n\n    if wfp is not None:\n        cv2.imwrite(wfp, img)\n        print(f'Save visualization result to {wfp}')\n        \n    if wnp is not None:\n        np.save(wnp, all_pose)\n        print(f'Save visualization result to {wfp}')\n        \n    if show_flag:\n        plot_image(img)\n\n    return img\n\n    \n    \n    \n\ndef get_pose(img, param_lst, ver_lst, show_flag=False, wfp=None, wnp = None):\n    for param, ver in zip(param_lst, ver_lst):  # only one loop\n        s, pose, t3d, P = pose_6(param)\n        img_1 = plot_pose_box(img.copy(), P, ver)\n    #    print(P,P.shape,t3d)\n        # print(f'yaw: {pose[0]:.1f}, pitch: {pose[1]:.1f}, roll: {pose[2]:.1f}')\n        all_pose = [pose[0],pose[1],pose[2],s,t3d[0],t3d[1],t3d[2]]\n        all_pose = np.array(all_pose)\n\n    # if wfp is not None:\n    #     cv2.imwrite(wfp, img_1)\n    #     print(f'Save visualization result to {wfp}')\n        \n    # if wnp is not None:\n    #     np.save(wnp, all_pose)\n    #     print(f'Save visualization result to {wfp}')\n        \n    if show_flag:\n        plot_image(img)\n\n    return all_pose\n"
  },
  {
    "path": "extract_init_states/utils/render.py",
    "content": "# coding: utf-8\n\n__author__ = 'cleardusk'\n\nimport sys\n\nsys.path.append('..')\n\nimport cv2\nimport numpy as np\n\nfrom Sim3DR import RenderPipeline\nfrom utils.functions import plot_image\nfrom .tddfa_util import _to_ctype\n\ncfg = {\n    'intensity_ambient': 0.3,\n    'color_ambient': (1, 1, 1),\n    'intensity_directional': 0.6,\n    'color_directional': (1, 1, 1),\n    'intensity_specular': 0.1,\n    'specular_exp': 5,\n    'light_pos': (0, 0, 5),\n    'view_pos': (0, 0, 5)\n}\n\nrender_app = RenderPipeline(**cfg)\n\n\ndef render(img, ver_lst, tri, alpha=0.6, show_flag=False, wfp=None, with_bg_flag=True):\n    if with_bg_flag:\n        overlap = img.copy()\n    else:\n        overlap = np.zeros_like(img)\n\n    for ver_ in ver_lst:\n        ver = _to_ctype(ver_.T)  # transpose\n        overlap = render_app(ver, tri, overlap)\n\n    if with_bg_flag:\n        res = cv2.addWeighted(img, 1 - alpha, overlap, alpha, 0)\n    else:\n        res = overlap\n\n    if wfp is not None:\n        cv2.imwrite(wfp, res)\n        print(f'Save visualization result to {wfp}')\n\n    if show_flag:\n        plot_image(res)\n\n    return res\n"
  },
  {
    "path": "extract_init_states/utils/render_ctypes.py",
    "content": "# coding: utf-8\n\n\"\"\"\nBorrowed from https://github.com/1996scarlet/Dense-Head-Pose-Estimation/blob/main/service/CtypesMeshRender.py\n\nTo use this render, you should build the clib first:\n```\ncd utils/asset\ngcc -shared -Wall -O3 render.c -o render.so -fPIC\ncd ../..\n```\n\"\"\"\n\nimport sys\n\nsys.path.append('..')\n\nimport os.path as osp\nimport cv2\nimport numpy as np\nimport ctypes\nfrom utils.functions import plot_image\n\nmake_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn)\n\n\nclass TrianglesMeshRender(object):\n    def __init__(\n            self,\n            clibs,\n            light=(0, 0, 5),\n            direction=(0.6, 0.6, 0.6),\n            ambient=(0.3, 0.3, 0.3)\n    ):\n        if not osp.exists(clibs):\n            raise Exception(f'{clibs} not found, please build it first, by run '\n                            f'\"gcc -shared -Wall -O3 render.c -o render.so -fPIC\" in utils/asset directory')\n\n        self._clibs = ctypes.CDLL(clibs)\n\n        self._light = np.array(light, dtype=np.float32)\n        self._light = np.ctypeslib.as_ctypes(self._light)\n\n        self._direction = np.array(direction, dtype=np.float32)\n        self._direction = np.ctypeslib.as_ctypes(self._direction)\n\n        self._ambient = np.array(ambient, dtype=np.float32)\n        self._ambient = np.ctypeslib.as_ctypes(self._ambient)\n\n    def __call__(self, vertices, triangles, bg):\n        self.triangles = np.ctypeslib.as_ctypes(3 * triangles)  # Attention\n        self.tri_nums = triangles.shape[0]\n\n        self._clibs._render(\n            self.triangles, self.tri_nums,\n            self._light, self._direction, self._ambient,\n            np.ctypeslib.as_ctypes(vertices),\n            vertices.shape[0],\n            np.ctypeslib.as_ctypes(bg),\n            bg.shape[0], bg.shape[1]\n        )\n\n\nrender_app = TrianglesMeshRender(clibs=make_abs_path('asset/render.so'))\n\n\ndef render(img, ver_lst, tri, alpha=0.6, show_flag=False, wfp=None, with_bg_flag=True):\n    if with_bg_flag:\n        overlap = img.copy()\n    else:\n        overlap = np.zeros_like(img)\n\n    for ver_ in ver_lst:\n        ver = np.ascontiguousarray(ver_.T)  # transpose\n        render_app(ver, tri, bg=overlap)\n\n    if with_bg_flag:\n        res = cv2.addWeighted(img, 1 - alpha, overlap, alpha, 0)\n    else:\n        res = overlap\n\n    if wfp is not None:\n        cv2.imwrite(wfp, res)\n        print(f'Save visualization result to {wfp}')\n\n    if show_flag:\n        plot_image(res)\n\n    return res\n"
  },
  {
    "path": "extract_init_states/utils/serialization.py",
    "content": "# coding: utf-8\n\n__author__ = 'cleardusk'\n\nimport numpy as np\n\nfrom .tddfa_util import _to_ctype\nfrom .functions import get_suffix\n\nheader_temp = \"\"\"ply\nformat ascii 1.0\nelement vertex {}\nproperty float x\nproperty float y\nproperty float z\nelement face {}\nproperty list uchar int vertex_indices\nend_header\n\"\"\"\n\n\ndef ser_to_ply_single(ver_lst, tri, height, wfp, reverse=True):\n    suffix = get_suffix(wfp)\n\n    for i, ver in enumerate(ver_lst):\n        wfp_new = wfp.replace(suffix, f'_{i + 1}{suffix}')\n\n        n_vertex = ver.shape[1]\n        n_face = tri.shape[0]\n        header = header_temp.format(n_vertex, n_face)\n\n        with open(wfp_new, 'w') as f:\n            f.write(header + '\\n')\n            for i in range(n_vertex):\n                x, y, z = ver[:, i]\n                if reverse:\n                    f.write(f'{x:.2f} {height-y:.2f} {z:.2f}\\n')\n                else:\n                    f.write(f'{x:.2f} {y:.2f} {z:.2f}\\n')\n            for i in range(n_face):\n                idx1, idx2, idx3 = tri[i]  # m x 3\n                if reverse:\n                    f.write(f'3 {idx3} {idx2} {idx1}\\n')\n                else:\n                    f.write(f'3 {idx1} {idx2} {idx3}\\n')\n\n        print(f'Dump tp {wfp_new}')\n\n\ndef ser_to_ply_multiple(ver_lst, tri, height, wfp, reverse=True):\n    n_ply = len(ver_lst)  # count ply\n\n    if n_ply <= 0:\n        return\n\n    n_vertex = ver_lst[0].shape[1]\n    n_face = tri.shape[0]\n    header = header_temp.format(n_vertex * n_ply, n_face * n_ply)\n\n    with open(wfp, 'w') as f:\n        f.write(header + '\\n')\n\n        for i in range(n_ply):\n            ver = ver_lst[i]\n            for j in range(n_vertex):\n                x, y, z = ver[:, j]\n                if reverse:\n                    f.write(f'{x:.2f} {height - y:.2f} {z:.2f}\\n')\n                else:\n                    f.write(f'{x:.2f} {y:.2f} {z:.2f}\\n')\n\n        for i in range(n_ply):\n            offset = i * n_vertex\n            for j in range(n_face):\n                idx1, idx2, idx3 = tri[j]  # m x 3\n                if reverse:\n                    f.write(f'3 {idx3 + offset} {idx2 + offset} {idx1 + offset}\\n')\n                else:\n                    f.write(f'3 {idx1 + offset} {idx2 + offset} {idx3 + offset}\\n')\n\n    print(f'Dump tp {wfp}')\n\n\ndef get_colors(img, ver):\n    h, w, _ = img.shape\n    ver[0, :] = np.minimum(np.maximum(ver[0, :], 0), w - 1)  # x\n    ver[1, :] = np.minimum(np.maximum(ver[1, :], 0), h - 1)  # y\n    ind = np.round(ver).astype(np.int32)\n    colors = img[ind[1, :], ind[0, :], :] / 255.  # n x 3\n\n    return colors.copy()\n\n\ndef ser_to_obj_single(img, ver_lst, tri, height, wfp):\n    suffix = get_suffix(wfp)\n\n    n_face = tri.shape[0]\n    for i, ver in enumerate(ver_lst):\n        colors = get_colors(img, ver)\n\n        n_vertex = ver.shape[1]\n\n        wfp_new = wfp.replace(suffix, f'_{i + 1}{suffix}')\n\n        with open(wfp_new, 'w') as f:\n            for i in range(n_vertex):\n                x, y, z = ver[:, i]\n                f.write(\n                    f'v {x:.2f} {height - y:.2f} {z:.2f} {colors[i, 2]:.2f} {colors[i, 1]:.2f} {colors[i, 0]:.2f}\\n')\n            for i in range(n_face):\n                idx1, idx2, idx3 = tri[i]  # m x 3\n                f.write(f'f {idx3 + 1} {idx2 + 1} {idx1 + 1}\\n')\n\n        print(f'Dump tp {wfp_new}')\n\n\ndef ser_to_obj_multiple(img, ver_lst, tri, height, wfp):\n    n_obj = len(ver_lst)  # count obj\n\n    if n_obj <= 0:\n        return\n\n    n_vertex = ver_lst[0].shape[1]\n    n_face = tri.shape[0]\n\n    with open(wfp, 'w') as f:\n        for i in range(n_obj):\n            ver = ver_lst[i]\n            colors = get_colors(img, ver)\n\n            for j in range(n_vertex):\n                x, y, z = ver[:, j]\n                f.write(\n                    f'v {x:.2f} {height - y:.2f} {z:.2f} {colors[j, 2]:.2f} {colors[j, 1]:.2f} {colors[j, 0]:.2f}\\n')\n\n        for i in range(n_obj):\n            offset = i * n_vertex\n            for j in range(n_face):\n                idx1, idx2, idx3 = tri[j]  # m x 3\n                f.write(f'f {idx3 + 1 + offset} {idx2 + 1 + offset} {idx1 + 1 + offset}\\n')\n\n    print(f'Dump tp {wfp}')\n\n\nser_to_ply = ser_to_ply_multiple  # ser_to_ply_single\nser_to_obj = ser_to_obj_multiple  # ser_to_obj_multiple\n"
  },
  {
    "path": "extract_init_states/utils/tddfa_util.py",
    "content": "# coding: utf-8\n\n__author__ = 'cleardusk'\n\nimport sys\n\nsys.path.append('..')\n\nimport argparse\nimport numpy as np\nimport torch\n\n\ndef _to_ctype(arr):\n    if not arr.flags.c_contiguous:\n        return arr.copy(order='C')\n    return arr\n\n\ndef str2bool(v):\n    if v.lower() in ('yes', 'true', 't', 'y', '1'):\n        return True\n    elif v.lower() in ('no', 'false', 'f', 'n', '0'):\n        return False\n    else:\n        raise argparse.ArgumentTypeError('Boolean value expected')\n\n\ndef load_model(model, checkpoint_fp):\n    checkpoint = torch.load(checkpoint_fp, map_location=lambda storage, loc: storage)['state_dict']\n    model_dict = model.state_dict()\n    # because the model is trained by multiple gpus, prefix module should be removed\n    for k in checkpoint.keys():\n        kc = k.replace('module.', '')\n        if kc in model_dict.keys():\n            model_dict[kc] = checkpoint[k]\n        if kc in ['fc_param.bias', 'fc_param.weight']:\n            model_dict[kc.replace('_param', '')] = checkpoint[k]\n\n    model.load_state_dict(model_dict)\n    return model\n\n\nclass ToTensorGjz(object):\n    def __call__(self, pic):\n        if isinstance(pic, np.ndarray):\n            img = torch.from_numpy(pic.transpose((2, 0, 1)))\n            return img.float()\n\n    def __repr__(self):\n        return self.__class__.__name__ + '()'\n\n\nclass NormalizeGjz(object):\n    def __init__(self, mean, std):\n        self.mean = mean\n        self.std = std\n\n    def __call__(self, tensor):\n        tensor.sub_(self.mean).div_(self.std)\n        return tensor\n\n\ndef similar_transform(pts3d, roi_box, size):\n    pts3d[0, :] -= 1  # for Python compatibility\n    pts3d[2, :] -= 1\n    pts3d[1, :] = size - pts3d[1, :]\n\n    sx, sy, ex, ey = roi_box\n    scale_x = (ex - sx) / size\n    scale_y = (ey - sy) / size\n    pts3d[0, :] = pts3d[0, :] * scale_x + sx\n    pts3d[1, :] = pts3d[1, :] * scale_y + sy\n    s = (scale_x + scale_y) / 2\n    pts3d[2, :] *= s\n    pts3d[2, :] -= np.min(pts3d[2, :])\n    return np.array(pts3d, dtype=np.float32)\n\n\ndef _parse_param(param):\n    \"\"\"matrix pose form\n    param: shape=(trans_dim+shape_dim+exp_dim,), i.e., 62 = 12 + 40 + 10\n    \"\"\"\n\n    # pre-defined templates for parameter\n    n = param.shape[0]\n    if n == 62:\n        trans_dim, shape_dim, exp_dim = 12, 40, 10\n    elif n == 72:\n        trans_dim, shape_dim, exp_dim = 12, 40, 20\n    elif n == 141:\n        trans_dim, shape_dim, exp_dim = 12, 100, 29\n    else:\n        raise Exception(f'Undefined templated param parsing rule')\n\n    R_ = param[:trans_dim].reshape(3, -1)\n    R = R_[:, :3]\n    offset = R_[:, -1].reshape(3, 1)\n    alpha_shp = param[trans_dim:trans_dim + shape_dim].reshape(-1, 1)\n    alpha_exp = param[trans_dim + shape_dim:].reshape(-1, 1)\n\n    return R, offset, alpha_shp, alpha_exp\n"
  },
  {
    "path": "extract_init_states/utils/uv.py",
    "content": "# coding: utf-8\n\n__author__ = 'cleardusk'\n\nimport sys\n\nsys.path.append('..')\n\nimport cv2\nimport numpy as np\nimport os.path as osp\nimport scipy.io as sio\n\nfrom Sim3DR import rasterize\nfrom utils.functions import plot_image\nfrom utils.io import _load\nfrom utils.tddfa_util import _to_ctype\n\nmake_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn)\n\n\ndef load_uv_coords(fp):\n    C = sio.loadmat(fp)\n    uv_coords = C['UV'].copy(order='C').astype(np.float32)\n    return uv_coords\n\n\ndef process_uv(uv_coords, uv_h=256, uv_w=256):\n    uv_coords[:, 0] = uv_coords[:, 0] * (uv_w - 1)\n    uv_coords[:, 1] = uv_coords[:, 1] * (uv_h - 1)\n    uv_coords[:, 1] = uv_h - uv_coords[:, 1] - 1\n    uv_coords = np.hstack((uv_coords, np.zeros((uv_coords.shape[0], 1), dtype=np.float32)))  # add z\n    return uv_coords\n\n\ng_uv_coords = load_uv_coords(make_abs_path('../configs/BFM_UV.mat'))\nindices = _load(make_abs_path('../configs/indices.npy'))  # todo: handle bfm_slim\ng_uv_coords = g_uv_coords[indices, :]\n\n\ndef get_colors(img, ver):\n    # nearest-neighbor sampling\n    [h, w, _] = img.shape\n    ver[0, :] = np.minimum(np.maximum(ver[0, :], 0), w - 1)  # x\n    ver[1, :] = np.minimum(np.maximum(ver[1, :], 0), h - 1)  # y\n    ind = np.round(ver).astype(np.int32)\n    colors = img[ind[1, :], ind[0, :], :]  # n x 3\n\n    return colors\n\n\ndef bilinear_interpolate(img, x, y):\n    \"\"\"\n    https://stackoverflow.com/questions/12729228/simple-efficient-bilinear-interpolation-of-images-in-numpy-and-python\n    \"\"\"\n    x0 = np.floor(x).astype(np.int32)\n    x1 = x0 + 1\n    y0 = np.floor(y).astype(np.int32)\n    y1 = y0 + 1\n\n    x0 = np.clip(x0, 0, img.shape[1] - 1)\n    x1 = np.clip(x1, 0, img.shape[1] - 1)\n    y0 = np.clip(y0, 0, img.shape[0] - 1)\n    y1 = np.clip(y1, 0, img.shape[0] - 1)\n\n    i_a = img[y0, x0]\n    i_b = img[y1, x0]\n    i_c = img[y0, x1]\n    i_d = img[y1, x1]\n\n    wa = (x1 - x) * (y1 - y)\n    wb = (x1 - x) * (y - y0)\n    wc = (x - x0) * (y1 - y)\n    wd = (x - x0) * (y - y0)\n\n    return wa[..., np.newaxis] * i_a + wb[..., np.newaxis] * i_b + wc[..., np.newaxis] * i_c + wd[..., np.newaxis] * i_d\n\n\ndef uv_tex(img, ver_lst, tri, uv_h=256, uv_w=256, uv_c=3, show_flag=False, wfp=None):\n    uv_coords = process_uv(g_uv_coords.copy(), uv_h=uv_h, uv_w=uv_w)\n\n    res_lst = []\n    for ver_ in ver_lst:\n        ver = _to_ctype(ver_.T)  # transpose to m x 3\n        colors = bilinear_interpolate(img, ver[:, 0], ver[:, 1]) / 255.\n        # `rasterize` here serves as texture sampling, may need to optimization\n        res = rasterize(uv_coords, tri, colors, height=uv_h, width=uv_w, channel=uv_c)\n        res_lst.append(res)\n\n    # concat if there more than one image\n    res = np.concatenate(res_lst, axis=1) if len(res_lst) > 1 else res_lst[0]\n\n    if wfp is not None:\n        cv2.imwrite(wfp, res)\n        print(f'Save visualization result to {wfp}')\n\n    if show_flag:\n        plot_image(res)\n\n    return res\n"
  },
  {
    "path": "extract_init_states/weights/.gitignore",
    "content": "# checkpoints/\n# *.pth\n# *.onnx"
  },
  {
    "path": "extract_init_states/weights/readme.md",
    "content": "## Pre-converted onnx model\n\n| Model | Link |\n| :-: | :-: |\n| `mb1_120x120.onnx` | [Google Drive](https://drive.google.com/file/d/1YpO1KfXvJHRmCBkErNa62dHm-CUjsoIk/view?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1qpQBd5KOS0-5lD6jZKXZ-Q) (Password: cqbx) |\n| `mb05_120x120.onnx` | [Google Drive](https://drive.google.com/file/d/1orJFiZPshmp7jmCx_D0tvIEtPYtnFvHS/view?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1sRaBOA5wHu6PFS1Qd-TBFA) (Password: 8qst) |\n| `resnet22.onnx` | [Google Drive](https://drive.google.com/file/d/1rRyrd7Ar-QYTi1hRHOYHspT8PTyXQ5ds/view?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1Nzkw7Ie_5trKvi1JYxymJA) (Password: 1op6) |\n| `resnet22.pth` | [Google Drive](https://drive.google.com/file/d/1dh7JZgkj1IaO4ZcSuBOBZl2suT9EPedV/view?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1IS7ncVxhw0f955ySg67Y4A) (Password: lv1a) |"
  },
  {
    "path": "filter_fourier.py",
    "content": "import torch\nimport torch.fft\nimport torchvision.transforms as transforms\nimport cv2\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport cv2\nimport numpy as np\n\n# Filtering function: Input optical flow field to filter out high-frequency noise.\ndef gaussian_pdf(x, mean, std):\n        return (1 / (std * torch.sqrt(2 * torch.tensor(3.141592653589793))) *\n                torch.exp(-((x - mean) ** 2) / (2 * std ** 2)))\n\ndef gaussian_density(length = 20, amplitude = 2, mean = 19, sigma = 3):\n    x = torch.arange(0, length, 1.0)\n    gaussian = amplitude * torch.exp(-(x - mean)**2 / (2 * sigma**2))\n    gaussian = torch.clip(gaussian, max = 1, min = 0)\n    return gaussian.cuda()\n\ndef fourier_filter(fea):\n    L, C , H , W = fea.shape\n    mean = 0\n    std = 3\n    _x = torch.linspace(-10, 10, H)  # Define 128 values within the range of -5 to 5.\n    X, Y = torch.meshgrid(_x, _x)  # Generate grid coordinates.\n    gaussian_map = (gaussian_pdf(X, mean, std).cuda()) * (gaussian_pdf(Y, mean, std).cuda())\n    gaussian_map = gaussian_map.unsqueeze(0).repeat(1, C, 1, 1)\n\n    gaussian_map = torch.clip((gaussian_map)/gaussian_map.max() * 3 , min = 0, max = 1)\n\n    # lowpass_filter = torch.zeros(H,H).cuda()\n    # for i in range(H):\n    #     for j in range(H):\n    #         if np.sqrt((i - H//2)**2 + (j - H//2)**2) <= 10:\n    #             lowpass_filter[i, j] = 1\n\n    x = torch.fft.fft2(fea, dim=(-2, -1))\n    x_shifted = torch.fft.fftshift(x)  # 1,3,128,128\n\n    x_shifted = x_shifted * gaussian_map# lowpass_filter #  * gaussian_map\n\n    reconstructed_x = torch.fft.ifftshift(x_shifted)\n    reconstructed_x = torch.fft.ifft2(reconstructed_x, dim=(-2, -1))\n    reconstructed_x = torch.real(reconstructed_x)\n\n\n    return reconstructed_x\n\ndef fourier_filter_1D(fea, dim):\n    # idex = freq * L / 25\n    L, C , H , W = fea.shape\n    mean = 0\n    std = 3\n    fft_result = torch.fft.rfft(fea, dim=dim)\n\n    # 低通滤波\n    cutoff_freq = 10  # 保留前 10 个频率\n    # mask = gaussian_density(length = L, mean = 0, sigma = 5, amplitude = 2)[:, None, None, None]\n    # fft_result = mask * fft_result\n    fft_result[L//4:] = 0  # 设置高频部分为 0\n\n    # 对 H 维度进行逆傅里叶变换\n    filtered_tensor = torch.fft.irfft(fft_result,n= L, dim=dim)\n    filtered_tensor = torch.real(filtered_tensor)\n\n    return filtered_tensor\n\ndef hf_loss(fea, mask, dim):\n    mask = 1- mask # gaussian_density(length = L, mean = 0, sigma = 12, amplitude = 2)\n    fft_result = torch.fft.rfft(fea, dim=dim)\n    fft_result = fft_result * mask \n    fft_result = fft_result.abs()\n\n    return fft_result\n\ndef hf_loss_2(fea_x, fea_y, dim):\n    '''\n    与GT计算频域损失\n    '''\n    fft_result_x = torch.fft.rfft(fea_x, dim=dim)\n    fft_result_y = torch.fft.rfft(fea_y, dim=dim)\n    # fft_result = fft_result.abs()\n    loss = (fft_result_y - fft_result_x).abs()\n\n    return loss\n\n    \n\nclass KalmanFilter1D:\n    def __init__(self, A, H, Q, R, x_init, P_init):\n        self.A = torch.tensor(A, requires_grad=False)\n        self.H = torch.tensor(H, requires_grad=False)\n        self.Q = torch.tensor(Q, requires_grad=False)\n        self.R = torch.tensor(R, requires_grad=False)\n        self.x = torch.tensor(x_init, requires_grad=True)\n        self.P = torch.tensor(P_init, requires_grad=True)\n\n    def update(self, z):\n        # 预测步骤\n        x_pred = self.A * self.x\n        P_pred = self.A * self.P * self.A + self.Q\n\n        # 更新步骤\n        K = P_pred * self.H / (self.H * P_pred * self.H + self.R)\n        self.x = x_pred + K * (z - self.H * x_pred)\n        self.P = (1 - K * self.H) * P_pred\n\n        return self.x\n\ndef kalman_filter(observations, dim):\n    kf = KalmanFilter1D(A=1., H=1., Q=0.01, R=0.1, x_init=0., P_init=1.)\n    filtered_values = torch.zeros_like(observations)\n\n    for idx in range(observations.size(dim)):\n        obs_slice = tuple(slice(None) if i != dim else idx for i in range(len(observations.size())))\n        obs = observations[obs_slice]\n        filtered_value = kf.update(obs)\n        filtered_values[obs_slice] = filtered_value\n\n    return filtered_values\n\ndef naive_filter(fea):\n    L, C , H , W = fea.shape\n    fea_mask = fea.abs()>(1/64)\n    fea = fea*fea_mask\n    return fea\n# def fourier_filter(x):\n#     L, C , H , W = x.shape\n#     mean = 0\n#     std = 3\n#     _x = torch.linspace(-5, 5, H)  # 定义一个范围为-5到5的128个值\n#     X, Y = torch.meshgrid(_x, _x)  # Generate grid coordinates.\n#     gaussian_map = (gaussian_pdf(X, mean, std).cuda()) * (gaussian_pdf(Y, mean, std).cuda())\n#     gaussian_map = gaussian_map.unsqueeze(0).repeat(1, C, 1, 1)\n\n#     gaussian_map = (gaussian_map)/gaussian_map.max()\n\n#     x = torch.fft.fft2(x, dim=(-2, -1))\n#     x_shifted = torch.fft.fftshift(x)  # 1,3,128,128\n\n#     x_shifted = x_shifted # * gaussian_map\n\n#     reconstructed_x = torch.fft.ifftshift(x_shifted)\n#     reconstructed_x = torch.fft.ifft2(reconstructed_x, dim=(-2, -1))\n#     reconstructed_x = torch.abs(reconstructed_x)\n\n\n#     return reconstructed_x\n\n\n\nif __name__ == '__main__':\n    # 读取视频\n    gd = gaussian_density(length = 20, mean = 0, sigma = 5, amplitude = 2)\n    print(gd)\n    print(gd[:10])\n    # cap = cv2.VideoCapture('your_path/demo/s2_20w_newae_crema_s1_10_s2_11-j-sl-vr-of-tr-rmm-ddim0200_1.00/7_s76_1076_ITH_FEA_XX.mp4')\n\n\n    \n\n    # # 生成均值为0，标准差为3的高斯概率密度分布张量\n    # mean = 0\n    # std = 3\n    # x = torch.linspace(-5, 5, 128)  # 定义一个范围为-5到5的128个值\n    # X, Y = torch.meshgrid(x, x)  # Generate grid coordinates.\n    # gaussian_map = gaussian_pdf(X, mean, std) * gaussian_pdf(Y, mean, std)\n    # gaussian_map = gaussian_map.unsqueeze(0).repeat(1, 3, 1, 1)\n\n    # gaussian_map = ( gaussian_map)/gaussian_map.max()\n\n    # # 输入数据，假设frames是一个包含L帧RGB图像的numpy数组，形状为(L, 3, H, W)\n    # frames = np.random.randint(0, 255, (100, 3, 256, 256)).astype(np.uint8)\n\n    # # 设置输出视频的名称、帧率和分辨率\n\n\n    # def generate_video(frames):\n    #     video_name = 'output_video.avi'\n    #     fps = 25\n    #     resolution = (128, 128)\n\n    #     # 创建视频写入对象\n    #     fourcc = cv2.VideoWriter_fourcc(*'XVID')\n    #     video = cv2.VideoWriter(video_name, fourcc, fps, resolution)\n\n    #     # 逐帧将图像写入视频\n    #     for i in range(frames.shape[0]):\n    #         frame = frames[i][:,:,:].transpose(1, 2, 0).astype(np.uint8)  # 调整通道顺序(H, W, 3)\n    #         video.write(frame)\n\n    #     # 释放资源并保存视频\n    #     video.release()\n    # # 存储还原后的图像帧\n    # reconstructed_frames = []\n\n    # # 循环遍历视频的每一帧\n    # while(cap.isOpened()):\n    #     ret, frame = cap.read()\n\n    #     if not ret:\n    #         break\n\n    #     # 将当前帧转换为 PyTorch 张量\n    #     frame = torch.tensor(frame)\n    #     frame = frame.permute(2, 0, 1).unsqueeze(0).float()\n\n    #     # 对当前帧进行 2D 傅里叶变换\n    #     fft_frame = torch.fft.fft2(frame, dim=(-2, -1))\n    #     fft_frame_shifted = torch.fft.fftshift(fft_frame)  # 1,3,128,128\n\n    #     # 将频域展开形式还原回图像\n    #     # fft_frame_shifted = fft_frame_shifted * gaussian_map\n\n    #     reconstructed_frame = torch.fft.ifftshift(fft_frame_shifted)\n    #     reconstructed_frame = torch.fft.ifft2(reconstructed_frame, dim=(-2, -1))\n    #     reconstructed_frame = torch.abs(reconstructed_frame)\n\n    #     # 将还原后的图像帧添加到列表中\n    #     reconstructed_frames.append(reconstructed_frame)\n\n    # # 将还原后的图像帧转换为数组\n    # reconstructed_frames = torch.cat(reconstructed_frames, dim=0)\n\n    # # 将还原后的图像帧转换为 numpy 数组\n    # reconstructed_frames = (reconstructed_frames).to(torch.int32)\n    # reconstructed_frames = reconstructed_frames.squeeze(1).numpy()\n\n    # # 显示还原后的视频\n\n    # generate_video(reconstructed_frames)\n"
  },
  {
    "path": "hubert_extract/data_gen/process_lrs3/binarizer.py",
    "content": "import os\nimport numpy as np\nfrom scipy.misc import face\nimport torch\nfrom tqdm import trange\nimport pickle\nfrom copy import deepcopy\n\nfrom data_util.face3d_helper import Face3DHelper\nfrom utils.commons.indexed_datasets import IndexedDataset, IndexedDatasetBuilder\n\n\ndef load_video_npy(fn):\n    assert fn.endswith(\".npy\")\n    ret_dict = np.load(fn,allow_pickle=True).item()\n    video_dict = {\n        'coeff': ret_dict['coeff'], # [T, h]\n        'lm68': ret_dict['lm68'], # [T, 68, 2]  \n        'lm5': ret_dict['lm5'], # [T, 5, 2]\n    }\n    return video_dict\n\ndef cal_lm3d_in_video_dict(video_dict, face3d_helper):\n    coeff = torch.from_numpy(video_dict['coeff']).float()\n    identity = coeff[:, 0:80]\n    exp = coeff[:, 80:144]\n    idexp_lm3d = face3d_helper.reconstruct_idexp_lm3d(identity, exp).cpu().numpy()\n    video_dict['idexp_lm3d'] = idexp_lm3d\n\ndef load_audio_npy(fn):\n    assert fn.endswith(\".npy\")\n    ret_dict = np.load(fn,allow_pickle=True).item()\n    audio_dict = {\n        \"mel\": ret_dict['mel'], # [T, 80]\n        \"f0\": ret_dict['f0'], # [T,1]\n    }\n    return audio_dict\n\n\nif __name__ == '__main__':\n    face3d_helper = Face3DHelper(use_gpu=False)\n    \n    import glob,tqdm\n    prefixs = ['val', 'train']\n    binarized_ds_path = \"data/binary/lrs3\"\n    os.makedirs(binarized_ds_path, exist_ok=True)\n    for prefix in prefixs:\n        databuilder = IndexedDatasetBuilder(os.path.join(binarized_ds_path, prefix), gzip=False)\n        raw_base_dir =  '/home/yezhenhui/datasets/raw/lrs3_raw'\n        spk_ids = sorted([dir_name.split(\"/\")[-1] for dir_name in glob.glob(raw_base_dir + \"/*\")])\n        spk_id2spk_idx = {spk_id : i for i,spk_id in enumerate(spk_ids) }\n        np.save(os.path.join(binarized_ds_path, \"spk_id2spk_idx.npy\"), spk_id2spk_idx, allow_pickle=True)\n        mp4_names = glob.glob(raw_base_dir + \"/*/*.mp4\")\n        cnt = 0\n        for i, mp4_name in tqdm.tqdm(enumerate(mp4_names), total=len(mp4_names)):\n            if prefix == 'train':\n                if i % 100 == 0:\n                    continue\n            else:\n                if i % 100 != 0:\n                    continue\n            lst = mp4_name.split(\"/\")\n            spk_id = lst[-2]\n            clip_id = lst[-1][:-4]\n            audio_npy_name = os.path.join(raw_base_dir, spk_id, clip_id+\"_audio.npy\")\n            hubert_npy_name = os.path.join(raw_base_dir, spk_id, clip_id+\"_hubert.npy\")\n            video_npy_name = os.path.join(raw_base_dir, spk_id, clip_id+\"_coeff_pt.npy\")\n            if (not os.path.exists(audio_npy_name)) or (not os.path.exists(video_npy_name)):\n                print(f\"Skip item for not found.\")\n                continue\n            if (not os.path.exists(hubert_npy_name)):\n                print(f\"Skip item for hubert_npy not found.\")\n                continue\n            audio_dict = load_audio_npy(audio_npy_name)\n            hubert = np.load(hubert_npy_name)\n            video_dict = load_video_npy(video_npy_name)\n            cal_lm3d_in_video_dict(video_dict, face3d_helper)\n            mel = audio_dict['mel']\n            if mel.shape[0] < 64: # the video is shorter than 0.6s\n                print(f\"Skip item for too short.\")\n                continue\n            audio_dict.update(video_dict)\n            audio_dict['spk_id'] = spk_id\n            audio_dict['spk_idx'] = spk_id2spk_idx[spk_id]\n            audio_dict['item_id'] = spk_id + \"_\" + clip_id\n            \n            audio_dict['hubert'] = hubert # [T_x, hid=1024]\n            databuilder.add_item(audio_dict)\n            cnt += 1\n        databuilder.finalize()\n        print(f\"{prefix} set has {cnt} samples!\")"
  },
  {
    "path": "hubert_extract/data_gen/process_lrs3/process_audio_hubert.py",
    "content": "from genericpath import exists\nfrom transformers import Wav2Vec2Processor, HubertModel\nimport soundfile as sf\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\nimport fairseq\nprint(\"Loading the Wav2Vec2 Processor...\")\n# wav2vec2_processor = Wav2Vec2Processor.from_pretrained(\"facebook/hubert-large-ls960-ft\")\nwav2vec2_processor = Wav2Vec2Processor.from_pretrained(\"/train20/intern/permanent/hbcheng2/lmlin2/Flow/GeneFace-main/checkpoints/hubert_ckp\")\nprint(\"Loading the HuBERT Model...\")\nhubert_model = HubertModel.from_pretrained(\"/train20/intern/permanent/hbcheng2/lmlin2/Flow/GeneFace-main/checkpoints/hubert_ckp\", from_tf = True)\n\ndef get_hubert_from_16k_wav(wav_16k_name):\n    speech_16k, _ = sf.read(wav_16k_name)\n    hubert = get_hubert_from_16k_speech(speech_16k)\n    return hubert\n\n@torch.no_grad()\ndef get_hubert_from_16k_speech(speech, device=\"cuda:1\"):\n    global hubert_model\n    hubert_model = hubert_model.to(device)\n    if speech.ndim ==2:\n        speech = speech[:, 0] # [T, 2] ==> [T,]\n    input_values_all = wav2vec2_processor(speech, return_tensors=\"pt\", sampling_rate=16000).input_values # [1, T]\n    input_values_all = input_values_all.to(device)\n    # For long audio sequence, due to the memory limitation, we cannot process them in one run\n    # HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320\n    # Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step.\n    # So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320\n    # We have the equation to calculate out time step: T = floor((t-k)/s)\n    # To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip\n    # The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N\n    kernel = 400\n    stride = 320\n    clip_length = stride * 1000\n    num_iter = input_values_all.shape[1] // clip_length\n    expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride\n    res_lst = []\n    for i in range(num_iter):\n        if i == 0:\n            start_idx = 0\n            end_idx = clip_length - stride + kernel\n        else:\n            start_idx = clip_length * i\n            end_idx = start_idx + (clip_length - stride + kernel)\n        input_values = input_values_all[:, start_idx: end_idx]\n        hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]\n        res_lst.append(hidden_states[0])\n    if num_iter > 0:\n        input_values = input_values_all[:, clip_length * num_iter:]\n    else:\n        input_values = input_values_all\n    # if input_values.shape[1] != 0:\n    if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it            \n        hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]\n        res_lst.append(hidden_states[0])\n    ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024]\n    # assert ret.shape[0] == expected_T\n    assert abs(ret.shape[0] - expected_T) <= 1\n    if ret.shape[0] < expected_T:\n        ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0]))\n    else:\n        ret = ret[:expected_T]\n    return ret\n\n\nif __name__ == '__main__':\n    ## Process Single Long Audio for NeRF dataset\n    # person_id = 'May'\n    import os\n\n    # demo test\n\n    # wav_16k_name = f\"/train20/intern/permanent/lmlin2/data/audio_lesson_01-j-w.wav\"\n    # npy_name = 'demo_test-j-w'\n    # demo_npy_name = f\"/train20/intern/permanent/lmlin2/data/{npy_name}.npy\"\n    # speech_16k, _ = sf.read(wav_16k_name)\n    # hubert_hidden = get_hubert_from_16k_speech(speech_16k)\n    # np.save(demo_npy_name, hubert_hidden.detach().numpy())\n\n    # hdtf dataset\n    image_path = '/train20/intern/permanent/lmlin2/data/hdtf_image_50hz'\n    image_path = '/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz'\n    wav_path = '/yrfs2/cv2/pcxia/audiovisual/hdtf/image_audio'\n\n    for wavfile in os.listdir(wav_path):\n        frames = os.listdir(os.path.join(image_path, wavfile[0:-4]))\n        frames.sort()\n        num_frames = len(frames)\n        wav_16k_name = os.path.join(wav_path, wavfile)\n        # wav_16k_name = f\"/yrfs2/cv2/pcxia/audiovisual/hdtf/image_audio/RD_Radio1_000.wav\"           #(3749, 1024)\n        # wav_16k_name = f\"data/processed/videos/{person_id}/aud.wav\"          \n        # wav_16k_name = f\"/train20/intern/permanent/lmlin2/Flow/GeneFace-main/data/raw/val_wavs/zozo.wav\"  # 543 1024\n        # wav_16k_name = f\"/train20/intern/permanent/lmlin2/data/audio_lesson_01.wav\"\n        npy_name = wavfile[0:-4]\n        hubert_npy_name = f\"/train20/intern/permanent/lmlin2/data/hdtf_wav_hubert/{npy_name}.npy\"\n        speech_16k, _ = sf.read(wav_16k_name)\n        hubert_hidden = get_hubert_from_16k_speech(speech_16k)\n        print(f'fnum{num_frames},hubersize{hubert_hidden.shape[0]}')\n        np.save(hubert_npy_name, hubert_hidden.detach().numpy())\n\n\n    # crema dataset\n    # image_path = '/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images'\n    # wav_path = '/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/audio'\n    # save_path = '/train20/intern/permanent/lmlin2/data/crema_wav_hubert'\n    # for id_name in tqdm(os.listdir(wav_path)):\n    #     for wavfile in os.listdir(os.path.join(wav_path, id_name)):\n    #         frame_dir = os.path.join(image_path, id_name, wavfile[0:-4])\n    #         if not exists(frame_dir):\n    #             print(f'{frame_dir} does not exist!')\n    #             continue\n    #         frames= os.listdir(frame_dir)\n    #         frames.sort()\n    #         num_frames = len(frames)\n    #         wav_16k_name = os.path.join(wav_path, id_name, wavfile)\n    #         npy_name = wavfile[0:-4]\n    #         save_dir = os.path.join(save_path,id_name)\n    #         hubert_npy_name = os.path.join(save_dir,npy_name+'.npy') \n    #         if exists(hubert_npy_name):\n    #             print(f'{hubert_npy_name} exists!')\n    #             continue\n    #         if not exists(save_dir):\n    #             os.makedirs(save_dir)\n            \n    #         speech_16k, _ = sf.read(wav_16k_name)\n    #         hubert_hidden = get_hubert_from_16k_speech(speech_16k)\n    #         print(f'fnum{num_frames},hubersize{hubert_hidden.shape[0]}')\n    #         np.save(hubert_npy_name, hubert_hidden.detach().numpy())\n\n    # ## Process short audio clips for LRS3 dataset\n    # import glob, os, tqdm\n    # lrs3_dir = '/home/yezhenhui/datasets/raw/lrs3_raw/'\n    # wav_16k_names = glob.glob(os.path.join(lrs3_dir, '*/*.wav'))\n    # for wav_16k_name in tqdm.tqdm(wav_16k_names, total=len(wav_16k_names)):\n    #     spk_id = wav_16k_name.split(\"/\")[-2]\n    #     clip_id = wav_16k_name.split(\"/\")[-1][:-4]\n    #     out_name = os.path.join(lrs3_dir, spk_id, clip_id+'_hubert.npy')\n    #     if os.path.exists(out_name):\n    #         continue\n    #     speech_16k, _ = sf.read(wav_16k_name)\n    #     hubert_hidden = get_hubert_from_16k_speech(speech_16k)\n    #     np.save(out_name, hubert_hidden.detach().numpy())"
  },
  {
    "path": "hubert_extract/data_gen/process_lrs3/process_audio_hubert_interpolate.py",
    "content": "from genericpath import exists\nfrom transformers import Wav2Vec2Processor, HubertModel\nimport soundfile as sf\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\nimport fairseq\nimport decord\nfrom scipy.interpolate import interp1d\n\n\nprint(\"Loading the Wav2Vec2 Processor...\")\n# wav2vec2_processor = Wav2Vec2Processor.from_pretrained(\"facebook/hubert-large-ls960-ft\")\nwav2vec2_processor = Wav2Vec2Processor.from_pretrained(\"/train20/intern/permanent/hbcheng2/lmlin2/Flow/GeneFace-main/checkpoints/hubert_ckp\")\nprint(\"Loading the HuBERT Model...\")\nhubert_model = HubertModel.from_pretrained(\"/train20/intern/permanent/hbcheng2/lmlin2/Flow/GeneFace-main/checkpoints/hubert_ckp\", from_tf = True)\n\ndef get_hubert_from_16k_wav(wav_16k_name):\n    speech_16k, _ = sf.read(wav_16k_name)\n    hubert = get_hubert_from_16k_speech(speech_16k)\n    return hubert\n\n@torch.no_grad()\ndef get_hubert_from_16k_speech(speech, device=\"cuda:1\"):\n    global hubert_model\n    hubert_model = hubert_model.to(device)\n    if speech.ndim ==2:\n        speech = speech[:, 0] # [T, 2] ==> [T,]\n    input_values_all = wav2vec2_processor(speech, return_tensors=\"pt\", sampling_rate=16000).input_values # [1, T]\n    input_values_all = input_values_all.to(device)\n    # For long audio sequence, due to the memory limitation, we cannot process them in one run\n    # HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320\n    # Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step.\n    # So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320\n    # We have the equation to calculate out time step: T = floor((t-k)/s)\n    # To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip\n    # The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N\n    kernel = 400\n    stride = 320\n    clip_length = stride * 1000\n    num_iter = input_values_all.shape[1] // clip_length\n    expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride\n    res_lst = []\n    for i in range(num_iter):\n        if i == 0:\n            start_idx = 0\n            end_idx = clip_length - stride + kernel\n        else:\n            start_idx = clip_length * i\n            end_idx = start_idx + (clip_length - stride + kernel)\n        input_values = input_values_all[:, start_idx: end_idx]\n        hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]\n        res_lst.append(hidden_states[0])\n    if num_iter > 0:\n        input_values = input_values_all[:, clip_length * num_iter:]\n    else:\n        input_values = input_values_all\n    # if input_values.shape[1] != 0:\n    if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it            \n        hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]\n        res_lst.append(hidden_states[0])\n    ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024]\n    # assert ret.shape[0] == expected_T\n    assert abs(ret.shape[0] - expected_T) <= 1\n    if ret.shape[0] < expected_T:\n        ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0]))\n    else:\n        ret = ret[:expected_T]\n    return ret\n\n# decord.bridge.set_bridge('torch')\nfrom tqdm import tqdm\nif __name__ == '__main__':\n    ## Process Single Long Audio for NeRF dataset\n    # person_id = 'May'\n    import os\n\n    # demo test\n\n    # wav_16k_name = f\"/train20/intern/permanent/lmlin2/data/audio_lesson_01-j-w.wav\"\n    # npy_name = 'demo_test-j-w'\n    # demo_npy_name = f\"/train20/intern/permanent/lmlin2/data/{npy_name}.npy\"\n    # speech_16k, _ = sf.read(wav_16k_name)\n    # hubert_hidden = get_hubert_from_16k_speech(speech_16k)\n    # np.save(demo_npy_name, hubert_hidden.detach().numpy())\n\n    # hdtf dataset\n    # image_path = '/train20/intern/permanent/lmlin2/data/hdtf_image_50hz'\n    # video_path = '/train20/intern/permanent/hbcheng2/data/HDTF/video_25hz'\n    # wav_path = '/yrfs2/cv2/pcxia/audiovisual/hdtf/image_audio'\n    # save_path = '/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert'\n    # interpolate_path = \"/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate\"\n\n    # if not os.path.exists(interpolate_path):\n    #     os.makedirs(interpolate_path)\n\n    # if not os.path.exists(save_path):\n    #     os.makedirs(save_path)\n\n    # for wavfile in tqdm(os.listdir(wav_path)):\n    #     video = os.path.join(video_path, wavfile[0:-4] + '.mp4')\n    #     vr = decord.VideoReader(video)\n    #     num_frames = len(vr)\n    #     wav_16k_name = os.path.join(wav_path, wavfile)\n    #     # wav_16k_name = f\"/yrfs2/cv2/pcxia/audiovisual/hdtf/image_audio/RD_Radio1_000.wav\"           #(3749, 1024)\n    #     # wav_16k_name = f\"data/processed/videos/{person_id}/aud.wav\"          \n    #     # wav_16k_name = f\"/train20/intern/permanent/lmlin2/Flow/GeneFace-main/data/raw/val_wavs/zozo.wav\"  # 543 1024\n    #     # wav_16k_name = f\"/train20/intern/permanent/lmlin2/data/audio_lesson_01.wav\"\n    #     npy_name = wavfile[0:-4]\n    #     hubert_npy_name = f\"{save_path}/{npy_name}.npy\"\n    #     hubert_npy_name_interpolate = f\"{interpolate_path}/{npy_name}.npy\"\n    #     if os.path.exists(hubert_npy_name) and os.path.exists(hubert_npy_name_interpolate):\n    #         continue\n    #     speech_16k, _ = sf.read(wav_16k_name)\n    #     hubert_hidden = get_hubert_from_16k_speech(speech_16k)\n    #     print(f'fnum{num_frames},hubersize{hubert_hidden.shape[0]}')\n\n    #     hubert_hidden = hubert_hidden.detach().numpy()\n    #     interp_func = interp1d(np.arange(hubert_hidden.shape[0]), hubert_hidden, kind='linear', axis=0)\n    #     hubert_feature_interpolated = interp_func(np.linspace(0, hubert_hidden.shape[0] - 1, num_frames)).astype(np.float32)\n    #     # torch.nn.functional.interpolate(hubert_hidden.unsqueeze(0).permute(0,2,1).cuda(), size=num_frames, mode='linear', align_corners=False).squeeze(0).permute(1, 0).cpu()\n    #     np.save(hubert_npy_name, hubert_hidden)\n    #     np.save(hubert_npy_name_interpolate, hubert_feature_interpolated)\n\n\n    # crema dataset\n    image_path = '/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images'\n    wav_path = '/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/audio'\n    save_path = '/train20/intern/permanent/hbcheng2/data/crema/hubert_25hz'\n    for id_name in tqdm(os.listdir(wav_path)):\n        for wavfile in os.listdir(os.path.join(wav_path, id_name)):\n            frame_dir = os.path.join(image_path, id_name, wavfile[0:-4])\n            if not exists(frame_dir):\n                print(f'{frame_dir} does not exist!')\n                continue\n            frames= os.listdir(frame_dir)\n            frames.sort()\n            num_frames = len(frames)\n            wav_16k_name = os.path.join(wav_path, id_name, wavfile)\n            npy_name = wavfile[0:-4]\n            save_dir = os.path.join(save_path,id_name)\n            hubert_npy_name = os.path.join(save_dir,npy_name+'.npy') \n            if exists(hubert_npy_name):\n                print(f'{hubert_npy_name} exists!')\n                continue\n            if not exists(save_dir):\n                os.makedirs(save_dir)\n            \n            speech_16k, _ = sf.read(wav_16k_name)\n            hubert_hidden = get_hubert_from_16k_speech(speech_16k)\n            hubert_hidden = hubert_hidden.detach().numpy()\n            interp_func = interp1d(np.arange(hubert_hidden.shape[0]), hubert_hidden, kind='linear', axis=0)\n            hubert_feature_interpolated = interp_func(np.linspace(0, hubert_hidden.shape[0] - 1, num_frames)).astype(np.float32)\n            print(f'fnum{num_frames},hubersize{hubert_hidden.shape[0]}')\n            np.save(hubert_npy_name, hubert_feature_interpolated)\n\n    # ## Process short audio clips for LRS3 dataset\n    # import glob, os, tqdm\n    # lrs3_dir = '/home/yezhenhui/datasets/raw/lrs3_raw/'\n    # wav_16k_names = glob.glob(os.path.join(lrs3_dir, '*/*.wav'))\n    # for wav_16k_name in tqdm.tqdm(wav_16k_names, total=len(wav_16k_names)):\n    #     spk_id = wav_16k_name.split(\"/\")[-2]\n    #     clip_id = wav_16k_name.split(\"/\")[-1][:-4]\n    #     out_name = os.path.join(lrs3_dir, spk_id, clip_id+'_hubert.npy')\n    #     if os.path.exists(out_name):\n    #         continue\n    #     speech_16k, _ = sf.read(wav_16k_name)\n    #     hubert_hidden = get_hubert_from_16k_speech(speech_16k)\n    #     np.save(out_name, hubert_hidden.detach().numpy())"
  },
  {
    "path": "hubert_extract/data_gen/process_lrs3/process_audio_hubert_interpolate_batch.py",
    "content": "from genericpath import exists\nfrom transformers import Wav2Vec2Processor, HubertModel\nimport soundfile as sf\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\nimport fairseq\nimport decord\nfrom scipy.interpolate import interp1d\n\n\nprint(\"Loading the Wav2Vec2 Processor...\")\n# wav2vec2_processor = Wav2Vec2Processor.from_pretrained(\"facebook/hubert-large-ls960-ft\")\nwav2vec2_processor = Wav2Vec2Processor.from_pretrained(\"/train20/intern/permanent/hbcheng2/lmlin2/Flow/GeneFace-main/checkpoints/hubert_ckp\")\nprint(\"Loading the HuBERT Model...\")\nhubert_model = HubertModel.from_pretrained(\"/train20/intern/permanent/hbcheng2/lmlin2/Flow/GeneFace-main/checkpoints/hubert_ckp\", from_tf = True)\n\ndef get_hubert_from_16k_wav(wav_16k_name):\n    speech_16k, _ = sf.read(wav_16k_name)\n    hubert = get_hubert_from_16k_speech(speech_16k)\n    return hubert\n\n@torch.no_grad()\ndef get_hubert_from_16k_speech(speech, device=\"cuda:3\"):\n    global hubert_model\n    hubert_model = hubert_model.to(device)\n    if speech.ndim ==2:\n        speech = speech[:, 0] # [T, 2] ==> [T,]\n    input_values_all = wav2vec2_processor(speech, return_tensors=\"pt\", sampling_rate=16000).input_values # [1, T]\n    input_values_all = input_values_all.to(device)\n    # For long audio sequence, due to the memory limitation, we cannot process them in one run\n    # HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320\n    # Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step.\n    # So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320\n    # We have the equation to calculate out time step: T = floor((t-k)/s)\n    # To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip\n    # The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N\n    kernel = 400\n    stride = 320\n    clip_length = stride * 1000\n    num_iter = input_values_all.shape[1] // clip_length\n    expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride\n    res_lst = []\n    for i in range(num_iter):\n        if i == 0:\n            start_idx = 0\n            end_idx = clip_length - stride + kernel\n        else:\n            start_idx = clip_length * i\n            end_idx = start_idx + (clip_length - stride + kernel)\n        input_values = input_values_all[:, start_idx: end_idx]\n        hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]\n        res_lst.append(hidden_states[0])\n    if num_iter > 0:\n        input_values = input_values_all[:, clip_length * num_iter:]\n    else:\n        input_values = input_values_all\n    # if input_values.shape[1] != 0:\n    if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it            \n        hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]\n        res_lst.append(hidden_states[0])\n    ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024]\n    # assert ret.shape[0] == expected_T\n    assert abs(ret.shape[0] - expected_T) <= 1\n    if ret.shape[0] < expected_T:\n        ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0]))\n    else:\n        ret = ret[:expected_T]\n    return ret\n\n# decord.bridge.set_bridge('torch')\nfrom tqdm import tqdm\nif __name__ == '__main__':\n    ## Process Single Long Audio for NeRF dataset\n    # person_id = 'May'\n    import os\n\n    # demo test\n\n    # wav_16k_name = f\"/train20/intern/permanent/lmlin2/data/audio_lesson_01-j-w.wav\"\n    # npy_name = 'demo_test-j-w'\n    # demo_npy_name = f\"/train20/intern/permanent/lmlin2/data/{npy_name}.npy\"\n    # speech_16k, _ = sf.read(wav_16k_name)\n    # hubert_hidden = get_hubert_from_16k_speech(speech_16k)\n    # np.save(demo_npy_name, hubert_hidden.detach().numpy())\n\n    # hdtf dataset\n    # image_path = '/train20/intern/permanent/lmlin2/data/hdtf_image_50hz'\n    # video_path = '/train20/intern/permanent/hbcheng2/data/HDTF/video_25hz'\n    # wav_path = '/yrfs2/cv2/pcxia/audiovisual/hdtf/image_audio'\n    # save_path = '/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert'\n    # interpolate_path = \"/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate\"\n\n    # if not os.path.exists(interpolate_path):\n    #     os.makedirs(interpolate_path)\n\n    # if not os.path.exists(save_path):\n    #     os.makedirs(save_path)\n\n    # for wavfile in tqdm(os.listdir(wav_path)):\n    #     video = os.path.join(video_path, wavfile[0:-4] + '.mp4')\n    #     vr = decord.VideoReader(video)\n    #     num_frames = len(vr)\n    #     wav_16k_name = os.path.join(wav_path, wavfile)\n    #     # wav_16k_name = f\"/yrfs2/cv2/pcxia/audiovisual/hdtf/image_audio/RD_Radio1_000.wav\"           #(3749, 1024)\n    #     # wav_16k_name = f\"data/processed/videos/{person_id}/aud.wav\"          \n    #     # wav_16k_name = f\"/train20/intern/permanent/lmlin2/Flow/GeneFace-main/data/raw/val_wavs/zozo.wav\"  # 543 1024\n    #     # wav_16k_name = f\"/train20/intern/permanent/lmlin2/data/audio_lesson_01.wav\"\n    #     npy_name = wavfile[0:-4]\n    #     hubert_npy_name = f\"{save_path}/{npy_name}.npy\"\n    #     hubert_npy_name_interpolate = f\"{interpolate_path}/{npy_name}.npy\"\n    #     if os.path.exists(hubert_npy_name) and os.path.exists(hubert_npy_name_interpolate):\n    #         continue\n    #     speech_16k, _ = sf.read(wav_16k_name)\n    #     hubert_hidden = get_hubert_from_16k_speech(speech_16k)\n    #     print(f'fnum{num_frames},hubersize{hubert_hidden.shape[0]}')\n\n    #     hubert_hidden = hubert_hidden.detach().numpy()\n    #     interp_func = interp1d(np.arange(hubert_hidden.shape[0]), hubert_hidden, kind='linear', axis=0)\n    #     hubert_feature_interpolated = interp_func(np.linspace(0, hubert_hidden.shape[0] - 1, num_frames)).astype(np.float32)\n    #     # torch.nn.functional.interpolate(hubert_hidden.unsqueeze(0).permute(0,2,1).cuda(), size=num_frames, mode='linear', align_corners=False).squeeze(0).permute(1, 0).cpu()\n    #     np.save(hubert_npy_name, hubert_hidden)\n    #     np.save(hubert_npy_name_interpolate, hubert_feature_interpolated)\n\n\n    # crema dataset\n    # image_path = '/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images'\n    # wav_path = '/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/audio'\n    # save_path = '/train20/intern/permanent/hbcheng2/data/crema/hubert_25hz'\n    wav_path = '/train20/intern/permanent/hbcheng2/data/ood_video/audio_clip_2'\n    save_path = '/train20/intern/permanent/hbcheng2/data/ood_video/audio_clip_hubert'\n    # for id_name in tqdm(os.listdir(wav_path)):\n    for wavfile in os.listdir(wav_path):\n            # frame_dir = os.path.join(image_path, id_name, wavfile[0:-4])\n            # if not exists(frame_dir):\n            #     print(f'{frame_dir} does not exist!')\n            #     continue\n            # frames= os.listdir(frame_dir)\n            # frames.sort()\n            # num_frames = len(frames)\n            wav_16k_name = os.path.join(wav_path, wavfile)\n            npy_name = wavfile[0:-4]\n            save_dir = os.path.join(save_path)\n            hubert_npy_name = os.path.join(save_dir,npy_name+'.npy') \n            # if exists(hubert_npy_name):\n            #     print(f'{hubert_npy_name} exists!')\n            #     continue\n            if not exists(save_dir):\n                os.makedirs(save_dir)\n            \n            speech_16k, _ = sf.read(wav_16k_name)\n            num_frames = int((speech_16k.shape[0] / 16000) * 25)\n            hubert_hidden = get_hubert_from_16k_speech(speech_16k)\n            hubert_hidden = hubert_hidden.detach().numpy()\n            interp_func = interp1d(np.arange(hubert_hidden.shape[0]), hubert_hidden, kind='linear', axis=0)\n            hubert_feature_interpolated = interp_func(np.linspace(0, hubert_hidden.shape[0] - 1, num_frames)).astype(np.float32)\n            # print(f'fnum{num_frames},hubersize{hubert_hidden.shape[0]}')\n            np.save(hubert_npy_name, hubert_feature_interpolated)\n\n    # ## Process short audio clips for LRS3 dataset\n    # import glob, os, tqdm\n    # lrs3_dir = '/home/yezhenhui/datasets/raw/lrs3_raw/'\n    # wav_16k_names = glob.glob(os.path.join(lrs3_dir, '*/*.wav'))\n    # for wav_16k_name in tqdm.tqdm(wav_16k_names, total=len(wav_16k_names)):\n    #     spk_id = wav_16k_name.split(\"/\")[-2]\n    #     clip_id = wav_16k_name.split(\"/\")[-1][:-4]\n    #     out_name = os.path.join(lrs3_dir, spk_id, clip_id+'_hubert.npy')\n    #     if os.path.exists(out_name):\n    #         continue\n    #     speech_16k, _ = sf.read(wav_16k_name)\n    #     hubert_hidden = get_hubert_from_16k_speech(speech_16k)\n    #     np.save(out_name, hubert_hidden.detach().numpy())"
  },
  {
    "path": "hubert_extract/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py",
    "content": "import os\nimport sys\nos.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"max_split_size_mb:128\"\n# adding path of PBnet\ncurrent_dir = os.path.dirname(os.path.abspath(__file__))\nparent_dir = os.path.dirname(os.path.dirname(os.path.dirname(current_dir)))\nif parent_dir not in sys.path:\n    sys.path.append(parent_dir)\n    print(parent_dir)\n\nfrom genericpath import exists\nfrom transformers import Wav2Vec2Processor, HubertModel\nimport soundfile as sf\nimport numpy as np\nimport torch\nfrom scipy.interpolate import interp1d\nimport subprocess\nimport os\nfrom tqdm import tqdm\nimport tempfile\n\nprint(\"Loading the Wav2Vec2 Processor...\")\n# wav2vec2_processor = Wav2Vec2Processor.from_pretrained(\"facebook/hubert-large-ls960-ft\")\nwav2vec2_processor = Wav2Vec2Processor.from_pretrained(\"./pretrain_models/hubert_ckp\")\nprint(\"Loading the HuBERT Model...\")\nhubert_model = HubertModel.from_pretrained(\"./pretrain_models/hubert_ckp\", from_tf = True)\n\ndef get_hubert_from_16k_wav(wav_16k_name):\n    speech_16k, _ = sf.read(wav_16k_name)\n    hubert = get_hubert_from_16k_speech(speech_16k)\n    return hubert\n\n@torch.no_grad()\ndef get_hubert_from_16k_speech(speech, device=\"cuda:0\"):\n    global hubert_model\n    print(f\"当前显存占用: {torch.cuda.memory_allocated()} 字节\")\n    print(f\"显存缓存占用: {torch.cuda.memory_reserved()} 字节\")\n    torch.cuda.empty_cache()\n    # 强制重置 PyTorch 的 CUDA 分配器\n    torch.cuda.empty_cache()\n    torch.cuda.reset_peak_memory_stats()\n    \n    # 可选：手动设置较大的初始缓存大小\n    torch.cuda.set_per_process_memory_fraction(0.9)  # 允许使用90%的显存\n    # 在加载模型前先检查显存状态\n    print(f\"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**2:.2f} MB\")\n    print(f\"Allocated memory: {torch.cuda.memory_allocated() / 1024**2:.2f} MB\")\n    print(f\"Cached memory: {torch.cuda.memory_reserved() / 1024**2:.2f} MB\")\n    print(torch.cuda.memory_summary())\n    \n    hubert_model = hubert_model.to(device)\n    if speech.ndim ==2:\n        speech = speech[:, 0] # [T, 2] ==> [T,]\n    input_values_all = wav2vec2_processor(speech, return_tensors=\"pt\", sampling_rate=16000).input_values # [1, T]\n    input_values_all = input_values_all.to(device)\n    # For long audio sequence, due to the memory limitation, we cannot process them in one run\n    # HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320\n    # Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step.\n    # So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320\n    # We have the equation to calculate out time step: T = floor((t-k)/s)\n    # To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip\n    # The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N\n    kernel = 400\n    stride = 320\n    clip_length = stride * 1000\n    num_iter = input_values_all.shape[1] // clip_length\n    expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride\n    res_lst = []\n    for i in range(num_iter):\n        if i == 0:\n            start_idx = 0\n            end_idx = clip_length - stride + kernel\n        else:\n            start_idx = clip_length * i\n            end_idx = start_idx + (clip_length - stride + kernel)\n        input_values = input_values_all[:, start_idx: end_idx]\n        hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]\n        res_lst.append(hidden_states[0])\n    if num_iter > 0:\n        input_values = input_values_all[:, clip_length * num_iter:]\n    else:\n        input_values = input_values_all\n    # if input_values.shape[1] != 0:\n    if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it            \n        hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]\n        res_lst.append(hidden_states[0])\n    ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024]\n    # assert ret.shape[0] == expected_T\n    assert abs(ret.shape[0] - expected_T) <= 1\n    if ret.shape[0] < expected_T:\n        ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0]))\n    else:\n        ret = ret[:expected_T]\n    return ret\n\nimport argparse\ndef get_arguments():\n    \"\"\"Parse all the arguments provided from the CLI.\n\n    Returns:\n      A list of parsed arguments.\n    \"\"\"\n    parser = argparse.ArgumentParser(description=\"Flow Diffusion\")\n    parser.add_argument(\"--src_audio_path\", default='/train20/intern/permanent/hbcheng2/data/test_speed/target_audio.wav')\n    parser.add_argument(\"--save_path\", default='your/path/DAWN-pytorch/ood_data',\n                        help=\"\")\n    \n    return parser.parse_args()\n\n\n\ndef convert_wav_to_16k(input_file, output_file):\n    command = [\n        'ffmpeg',\n        '-i', input_file,\n        '-ar', '16000',\n        output_file\n    ]\n    subprocess.run(command)\n\ndef delete_file(file_path):\n    try:\n        os.remove(file_path)\n        print(f\"File {file_path} has been deleted successfully.\")\n    except FileNotFoundError:\n        print(f\"File {file_path} not found.\")\n    except PermissionError:\n        print(f\"Permission denied: Unable to delete {file_path}.\")\n    except Exception as e:\n        print(f\"Error occurred while deleting {file_path}: {e}\")\n\n\nif __name__ == '__main__':\n\n    args = get_arguments()\n    wav_path = args.src_audio_path\n    wav_16k_name = wav_path\n    npy_name = args.save_path\n\n    output_wav_path = tempfile.NamedTemporaryFile('w', suffix='.wav', dir='./')\n    convert_wav_to_16k(wav_path, output_wav_path.name)\n\n    speech_16k, _ = sf.read(output_wav_path.name)\n    delete_file(output_wav_path.name)\n\n    # speech_16k, _ = sf.read(wav_path)\n\n    num_frames = int((speech_16k.shape[0] / 16000) * 25)\n    hubert_hidden = get_hubert_from_16k_speech(speech_16k, device = 'cuda:0')\n    hubert_hidden = hubert_hidden.detach().numpy()\n    interp_func = interp1d(np.arange(hubert_hidden.shape[0]), hubert_hidden, kind='linear', axis=0)\n    hubert_feature_interpolated = interp_func(np.linspace(0, hubert_hidden.shape[0] - 1, num_frames)).astype(np.float32)\n    print(f'fnum{num_frames},hubersize{hubert_hidden.shape[0]}')\n    np.save(npy_name, hubert_feature_interpolated)\n"
  },
  {
    "path": "hubert_extract/data_gen/process_lrs3/process_audio_hubert_interpolate_single.py",
    "content": "from genericpath import exists\nfrom transformers import Wav2Vec2Processor, HubertModel\nimport soundfile as sf\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\nimport fairseq\nimport decord\nfrom scipy.interpolate import interp1d\n\n\nprint(\"Loading the Wav2Vec2 Processor...\")\n# wav2vec2_processor = Wav2Vec2Processor.from_pretrained(\"facebook/hubert-large-ls960-ft\")\nwav2vec2_processor = Wav2Vec2Processor.from_pretrained(\"/train20/intern/permanent/hbcheng2/lmlin2/Flow/GeneFace-main/checkpoints/hubert_ckp\")\nprint(\"Loading the HuBERT Model...\")\nhubert_model = HubertModel.from_pretrained(\"/train20/intern/permanent/hbcheng2/lmlin2/Flow/GeneFace-main/checkpoints/hubert_ckp\", from_tf = True)\n\ndef get_hubert_from_16k_wav(wav_16k_name):\n    speech_16k, _ = sf.read(wav_16k_name)\n    hubert = get_hubert_from_16k_speech(speech_16k)\n    return hubert\n\n@torch.no_grad()\ndef get_hubert_from_16k_speech(speech, device=\"cuda:1\"):\n    global hubert_model\n    hubert_model = hubert_model.to(device)\n    if speech.ndim ==2:\n        speech = speech[:, 0] # [T, 2] ==> [T,]\n    input_values_all = wav2vec2_processor(speech, return_tensors=\"pt\", sampling_rate=16000).input_values # [1, T]\n    input_values_all = input_values_all.to(device)\n    # For long audio sequence, due to the memory limitation, we cannot process them in one run\n    # HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320\n    # Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step.\n    # So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320\n    # We have the equation to calculate out time step: T = floor((t-k)/s)\n    # To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip\n    # The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N\n    kernel = 400\n    stride = 320\n    clip_length = stride * 1000\n    num_iter = input_values_all.shape[1] // clip_length\n    expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride\n    res_lst = []\n    for i in range(num_iter):\n        if i == 0:\n            start_idx = 0\n            end_idx = clip_length - stride + kernel\n        else:\n            start_idx = clip_length * i\n            end_idx = start_idx + (clip_length - stride + kernel)\n        input_values = input_values_all[:, start_idx: end_idx]\n        hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]\n        res_lst.append(hidden_states[0])\n    if num_iter > 0:\n        input_values = input_values_all[:, clip_length * num_iter:]\n    else:\n        input_values = input_values_all\n    # if input_values.shape[1] != 0:\n    if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it            \n        hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]\n        res_lst.append(hidden_states[0])\n    ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024]\n    # assert ret.shape[0] == expected_T\n    assert abs(ret.shape[0] - expected_T) <= 1\n    if ret.shape[0] < expected_T:\n        ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0]))\n    else:\n        ret = ret[:expected_T]\n    return ret\n\nfrom tqdm import tqdm\nif __name__ == '__main__':\n\n    wav_path = '/train20/intern/permanent/hbcheng2/data/ood_video/audio_clip/minanyu-xuelingsun_2.wav'\n    wav_16k_name = wav_path\n    hubert_npy_name = '/train20/intern/permanent/hbcheng2/data/ood_video/audio_clip_hubert/minanyu-xuelingsun_2.npy'# os.path.join(wav_path,npy_name+'.npy') \n\n    \n    speech_16k, _ = sf.read(wav_16k_name)\n    num_frames = int((speech_16k.shape[0] / 16000) * 25)\n    hubert_hidden = get_hubert_from_16k_speech(speech_16k)\n    hubert_hidden = hubert_hidden.detach().numpy()\n    interp_func = interp1d(np.arange(hubert_hidden.shape[0]), hubert_hidden, kind='linear', axis=0)\n    hubert_feature_interpolated = interp_func(np.linspace(0, hubert_hidden.shape[0] - 1, num_frames)).astype(np.float32)\n    print(f'fnum{num_frames},hubersize{hubert_hidden.shape[0]}')\n    np.save(hubert_npy_name, hubert_feature_interpolated)"
  },
  {
    "path": "hubert_extract/data_gen/process_lrs3/process_audio_mel_f0.py",
    "content": "import numpy as np\nimport torch\nimport glob\nimport os\nimport tqdm\nimport librosa\nimport parselmouth\nfrom utils.commons.pitch_utils import f0_to_coarse\nfrom utils.commons.multiprocess_utils import multiprocess_run_tqdm\n\n\ndef librosa_pad_lr(x, fsize, fshift, pad_sides=1):\n    '''compute right padding (final frame) or both sides padding (first and final frames)\n    '''\n    assert pad_sides in (1, 2)\n    # return int(fsize // 2)\n    pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0]\n    if pad_sides == 1:\n        return 0, pad\n    else:\n        return pad // 2, pad // 2 + pad % 2\n\ndef extract_mel_from_fname(wav_path,\n                      fft_size=512,\n                      hop_size=320,\n                      win_length=512,\n                      window=\"hann\",\n                      num_mels=80,\n                      fmin=80,\n                      fmax=7600,\n                      eps=1e-6,\n                      sample_rate=16000,\n                      min_level_db=-100):\n    if isinstance(wav_path, str):\n        wav, _ = librosa.core.load(wav_path, sr=sample_rate)\n    else:\n        wav = wav_path\n\n    # get amplitude spectrogram\n    x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size,\n                          win_length=win_length, window=window, center=False)\n    spc = np.abs(x_stft)  # (n_bins, T)\n\n    # get mel basis\n    fmin = 0 if fmin == -1 else fmin\n    fmax = sample_rate / 2 if fmax == -1 else fmax\n    mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)\n    mel = mel_basis @ spc\n\n    mel = np.log10(np.maximum(eps, mel))  # (n_mel_bins, T)\n    mel = mel.T\n\n    l_pad, r_pad = librosa_pad_lr(wav, fft_size, hop_size, 1)\n    wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0)\n\n    return wav.T, mel\n\ndef extract_f0_from_wav_and_mel(wav, mel,\n                        hop_size=320,\n                        audio_sample_rate=16000,\n                        ):\n    time_step = hop_size / audio_sample_rate * 1000\n    f0_min = 80\n    f0_max = 750\n    f0 = parselmouth.Sound(wav, audio_sample_rate).to_pitch_ac(\n        time_step=time_step / 1000, voicing_threshold=0.6,\n        pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']\n\n    delta_l = len(mel) - len(f0)\n    assert np.abs(delta_l) <= 8\n    if delta_l > 0:\n        f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0)\n    f0 = f0[:len(mel)]\n    pitch_coarse = f0_to_coarse(f0)\n    return f0, pitch_coarse\n\ndef extract_mel_f0_from_fname(fname, out_name=None):\n    assert fname.endswith(\".wav\")\n    if out_name is None:\n        out_name = fname[:-4] + '_audio.npy'\n\n    wav, mel = extract_mel_from_fname(fname)\n    f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)\n    out_dict = {\n        \"mel\": mel, # [T, 80]\n        \"f0\": f0,\n    }\n    np.save(out_name, out_dict)\n    return True\n\nif __name__ == '__main__':\n    import os, glob\n    lrs3_dir = \"/home/yezhenhui/datasets/raw/lrs3_raw\"\n    wav_name_pattern = os.path.join(lrs3_dir, \"*/*.wav\")\n    wav_names = glob.glob(wav_name_pattern)\n    wav_names = sorted(wav_names)\n    for _ in multiprocess_run_tqdm(extract_mel_f0_from_fname, args=wav_names, num_workers=32,desc='extracting Mel and f0'):\n        pass"
  },
  {
    "path": "misc.py",
    "content": "import cv2\nimport os\n\nimport requests\nimport torch\nimport torch.nn.functional as F\nimport torch.distributed as dist\nimport sys\nimport matplotlib.pyplot as plt\nfrom matplotlib.collections import LineCollection\nimport numpy as np\nimport flow_vis\nimport cv2\n\n\ndef fig2data(fig):\n    \"\"\"\n    @brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it\n    @param fig a matplotlib figure\n    @return a numpy 3D array of RGBA values\n    \"\"\"\n    # draw the renderer\n    fig.canvas.draw()\n\n    # Get the RGBA buffer from the figure\n    w, h = fig.canvas.get_width_height()\n    buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8)\n    buf.shape = (w, h, 4)\n\n    # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode\n    buf = np.roll(buf, 3, axis=2)\n    return buf\n\n\ndef plot_grid(x, y, ax=None, **kwargs):\n    ax = ax or plt.gca()\n    segs1 = np.stack((x, y), axis=2)\n    segs2 = segs1.transpose(1, 0, 2)\n    ax.add_collection(LineCollection(segs1, **kwargs))\n    ax.add_collection(LineCollection(segs2, **kwargs))\n    ax.autoscale()\n\n\ndef grid2fig(warped_grid, grid_size=32, img_size=256):\n    dpi = 1000\n    # plt.ioff()\n    h_range = torch.linspace(-1, 1, grid_size)\n    w_range = torch.linspace(-1, 1, grid_size)\n    grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).flip(2)\n    flow_uv = grid.cpu().data.numpy()\n    fig, ax = plt.subplots()\n    grid_x, grid_y = warped_grid[..., 0], warped_grid[..., 1]\n    plot_grid(flow_uv[..., 0], flow_uv[..., 1], ax=ax, color=\"lightgrey\")\n    plot_grid(grid_x, grid_y, ax=ax, color=\"C0\")\n    plt.axis(\"off\")\n    plt.tight_layout(pad=0)\n    fig.set_size_inches(img_size/100, img_size/100)\n    fig.set_dpi(100)\n    out = fig2data(fig)[:, :, :3]\n    out = np.flipud(out)\n    out = np.fliplr(out)\n    plt.close()\n    plt.cla()\n    plt.clf()\n    return out\n\n\ndef flow2fig(warped_grid, id_grid, grid_size=32, img_size=128):\n    # h_range = torch.linspace(-1, 1, grid_size)\n    # w_range = torch.linspace(-1, 1, grid_size)\n    # id_grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).flip(2)\n    warped_flow = warped_grid - id_grid\n    img = flow_vis.flow_to_color(warped_flow)\n    img = cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_AREA)\n\n    return img\n\n\ndef conf2fig(conf, img_size=128):\n    conf = F.interpolate(conf.unsqueeze(dim=0), size=img_size).data.cpu().numpy()\n    conf = np.transpose(conf, [0, 2, 3, 1])\n    conf = np.array(conf[0, :, :, 0]*255, dtype=np.uint8)\n    return conf\n\n\nclass Logger(object):\n    def __init__(self, filename='default.log', stream=sys.stdout):\n        self.terminal = stream\n        self.log = open(filename, 'w')\n\n    def write(self, message):\n        self.terminal.write(message)\n        self.log.write(message)\n\n    def flush(self):\n        pass\n\n\ndef resize(im, desired_size, interpolation):\n    old_size = im.shape[:2]\n    ratio = float(desired_size)/max(old_size)\n    new_size = tuple(int(x*ratio) for x in old_size)\n\n    im = cv2.resize(im, (new_size[1], new_size[0]), interpolation=interpolation)\n    delta_w = desired_size - new_size[1]\n    delta_h = desired_size - new_size[0]\n    top, bottom = delta_h//2, delta_h-(delta_h//2)\n    left, right = delta_w//2, delta_w-(delta_w//2)\n\n    color = [0, 0, 0]\n    new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)\n\n    return new_im\n\n\ndef resample(image, flow):\n    r\"\"\"Resamples an image using the provided flow.\n\n    Args:\n        image (NxCxHxW tensor) : Image to resample.\n        flow (Nx2xHxW tensor) : Optical flow to resample the image.\n    Returns:\n        output (NxCxHxW tensor) : Resampled image.\n    \"\"\"\n    assert flow.shape[1] == 2\n    b, c, h, w = image.size()\n    grid = get_grid(b, (h, w))\n    flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0),\n                      flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1)\n    final_grid = (grid + flow).permute(0, 2, 3, 1)\n    try:\n        output = F.grid_sample(image, final_grid, mode='bilinear',\n                               padding_mode='border', align_corners=True)\n    except Exception:\n        output = F.grid_sample(image, final_grid, mode='bilinear',\n                               padding_mode='border')\n    return output\n\n\ndef get_grid(batchsize, size, minval=-1.0, maxval=1.0):\n    r\"\"\"Get a grid ranging [-1, 1] of 2D/3D coordinates.\n\n    Args:\n        batchsize (int) : Batch size.\n        size (tuple) : (height, width) or (depth, height, width).\n        minval (float) : minimum value in returned grid.\n        maxval (float) : maximum value in returned grid.\n    Returns:\n        t_grid (4D tensor) : Grid of coordinates.\n    \"\"\"\n    if len(size) == 2:\n        rows, cols = size\n    elif len(size) == 3:\n        deps, rows, cols = size\n    else:\n        raise ValueError('Dimension can only be 2 or 3.')\n    x = torch.linspace(minval, maxval, cols)\n    x = x.view(1, 1, 1, cols)\n    x = x.expand(batchsize, 1, rows, cols)\n\n    y = torch.linspace(minval, maxval, rows)\n    y = y.view(1, 1, rows, 1)\n    y = y.expand(batchsize, 1, rows, cols)\n\n    t_grid = torch.cat([x, y], dim=1)\n\n    if len(size) == 3:\n        z = torch.linspace(minval, maxval, deps)\n        z = z.view(1, 1, deps, 1, 1)\n        z = z.expand(batchsize, 1, deps, rows, cols)\n\n        t_grid = t_grid.unsqueeze(2).expand(batchsize, 2, deps, rows, cols)\n        t_grid = torch.cat([t_grid, z], dim=1)\n\n    t_grid.requires_grad = False\n    return t_grid.to('cuda')\n\n\ndef get_checkpoint(checkpoint_path, url=''):\n    r\"\"\"Get the checkpoint path. If it does not exist yet, download it from\n    the url.\n\n    Args:\n        checkpoint_path (str): Checkpoint path.\n        url (str): URL to download checkpoint.\n    Returns:\n        (str): Full checkpoint path.\n    \"\"\"\n    if 'TORCH_HOME' not in os.environ:\n        os.environ['TORCH_HOME'] = os.getcwd()\n    save_dir = os.path.join(os.environ['TORCH_HOME'], 'checkpoints')\n    os.makedirs(save_dir, exist_ok=True)\n    full_checkpoint_path = os.path.join(save_dir, checkpoint_path)\n    if not os.path.exists(full_checkpoint_path):\n        os.makedirs(os.path.dirname(full_checkpoint_path), exist_ok=True)\n        if is_master():\n            print('Download {}'.format(url))\n            download_file_from_google_drive(url, full_checkpoint_path)\n    if dist.is_available() and dist.is_initialized():\n        dist.barrier()\n    return full_checkpoint_path\n\n\ndef download_file_from_google_drive(file_id, destination):\n    r\"\"\"Download a file from the google drive by using the file ID.\n\n    Args:\n        file_id: Google drive file ID\n        destination: Path to save the file.\n\n    Returns:\n\n    \"\"\"\n    URL = \"https://docs.google.com/uc?export=download\"\n    session = requests.Session()\n    response = session.get(URL, params={'id': file_id}, stream=True)\n    token = get_confirm_token(response)\n    if token:\n        params = {'id': file_id, 'confirm': token}\n        response = session.get(URL, params=params, stream=True)\n    save_response_content(response, destination)\n\n\ndef get_confirm_token(response):\n    r\"\"\"Get confirm token\n\n    Args:\n        response: Check if the file exists.\n\n    Returns:\n\n    \"\"\"\n    for key, value in response.cookies.items():\n        if key.startswith('download_warning'):\n            return value\n    return None\n\n\ndef save_response_content(response, destination):\n    r\"\"\"Save response content\n\n    Args:\n        response:\n        destination: Path to save the file.\n\n    Returns:\n\n    \"\"\"\n    chunk_size = 32768\n    with open(destination, \"wb\") as f:\n        for chunk in response.iter_content(chunk_size):\n            if chunk:\n                f.write(chunk)\n\n\ndef get_rank():\n    r\"\"\"Get rank of the thread.\"\"\"\n    rank = 0\n    if dist.is_available():\n        if dist.is_initialized():\n            rank = dist.get_rank()\n    return rank\n\n\ndef is_master():\n    r\"\"\"check if current process is the master\"\"\"\n    return get_rank() == 0\n\n"
  },
  {
    "path": "requirements.txt",
    "content": "absl-py==2.0.0\naccelerate==1.0.1\naiofiles==23.2.1\nalbumentations==1.3.1\nannotated-types==0.7.0\nantlr4-python3-runtime==4.8\nanyio==4.5.2\nastunparse==1.6.3\naudioread==3.0.1\nav==11.0.0\nbeautifulsoup4==4.12.3\nbitarray==2.8.2\nboto3==1.28.78\nbotocore==1.31.78\ncachetools==4.2.4\ncertifi==2023.7.22\ncffi==1.16.0\ncharset-normalizer==3.2.0\nclick==8.1.7\ncmake==3.30.1\ncolorama==0.4.6\ncoloredlogs==15.0.1\ncontourpy==1.1.1\ncycler==0.12.1\nCython==3.0.5\ndecorator==4.4.2\ndecord==0.6.0\neinops==0.7.0\neinops-exts==0.0.4\nexceptiongroup==1.2.2\nfairseq==0.12.2\nfastapi==0.115.8\nffmpeg==1.4\nffmpeg-python==0.2.0\nffmpy==0.5.0\nfilelock==3.13.1\nflatbuffers==23.5.26\nflow-vis==0.1\nfonttools==4.44.0\nfsspec==2023.10.0\nfuture==1.0.0\ngast==0.4.0\ngdown==5.1.0\ngoogle-auth==2.32.0\ngoogle-auth-oauthlib==0.4.6\ngoogle-pasta==0.2.0\ngradio==4.44.1\ngradio_client==1.3.0\ngrpcio==1.59.2\nh11==0.14.0\nh5py==3.10.0\nhttpcore==1.0.7\nhttpx==0.28.1\nhuggingface-hub==0.28.1\nhumanfriendly==10.0\nhydra-core==1.0.7\nidna==3.4\nimageio==2.31.5\nimageio-ffmpeg==0.4.9\nimportlib-metadata==6.8.0\nimportlib-resources==6.1.0\nJinja2==3.1.4\njmespath==1.0.1\njoblib==1.3.2\njson-tricks==3.17.3\nkeras==2.11.0\nkiwisolver==1.4.5\nlazy_loader==0.3\nlibclang==16.0.6\nlibrosa==0.7.1\nlit==18.1.8\nllvmlite==0.41.0\nlpips==0.1.4\nlxml==4.9.3\nMarkdown==3.5.1\nmarkdown-it-py==3.0.0\nMarkupSafe==2.1.3\nmatplotlib==3.7.3\nmdurl==0.1.2\nmoviepy==1.0.3\nmpmath==1.3.0\nmsgpack==1.0.7\nnatsort==8.4.0\nnetworkx==3.1\nnumba==0.58.0\nnumpy==1.24.3\nnvidia-cublas-cu11==11.10.3.66\nnvidia-cublas-cu12==12.1.3.1\nnvidia-cuda-nvrtc-cu11==11.7.99\nnvidia-cuda-runtime-cu11==11.7.99\nnvidia-cudnn-cu11==8.5.0.96\nnvidia-nvjitlink-cu12==12.8.61\noauthlib==3.2.2\nomegaconf==2.0.5\nonnx==1.17.0\nonnxruntime==1.19.2\nopencv-contrib-python==4.8.0.76\nopencv-python==4.7.0.72\nopencv-python-headless==4.8.1.78\nopt-einsum==3.3.0\norjson==3.10.15\npackaging==23.2\npandas==2.0.3\nPillow==10.0.1\nplatformdirs==3.11.0\npooch==1.7.0\nportalocker==2.8.2\nproglog==0.1.10\nprotobuf==3.20.2\npsutil==6.1.1\npyasn1==0.5.0\npyasn1-modules==0.3.0\npycparser==2.21\npydantic==2.10.6\npydantic_core==2.27.2\npydub==0.25.1\nPygments==2.19.1\npyparsing==3.1.1\nPySocks==1.7.1\npyspng==0.1.1\npython-dateutil==2.8.2\npython-multipart==0.0.20\npython_speech_features==0.6\npytz==2023.3.post1\nPyWavelets==1.4.1\nPyYAML==6.0.1\nqudida==0.0.4\nregex==2023.10.3\nrequests==2.31.0\nrequests-oauthlib==1.3.1\nresampy==0.4.2\nrich==13.9.4\nrotary-embedding-torch==0.3.5\nrsa==4.9\nruff==0.9.6\ns3transfer==0.7.0\nsacrebleu==2.3.1\nsacremoses==0.1.1\nsafetensors==0.5.2\nscenedetect==0.5.1\nscikit-image==0.21.0\nscikit-learn==1.3.1\nscipy==1.9.1\nsemantic-version==2.10.0\nsentencepiece==0.1.99\nshellingham==1.5.4\nsix==1.16.0\nsniffio==1.3.1\nsoundfile==0.12.1\nsoupsieve==2.5\nsoxr==0.3.7\nstarlette==0.44.0\nsympy==1.13.1\nsync-batchnorm==0.0.1\ntabulate==0.9.0\ntensorboard==2.11.2\ntensorboard-data-server==0.6.1\ntensorboard-plugin-wit==1.8.1\ntensorboardX==2.6.2.2\ntensorflow==2.11.1\ntensorflow-estimator==2.11.0\ntensorflow-io-gcs-filesystem==0.34.0\ntermcolor==2.3.0\nthreadpoolctl==3.2.0\ntifffile==2023.7.10\ntokenizers==0.20.3\ntomlkit==0.12.0\ntorch==1.13.0\ntorchaudio==0.13.0\ntorchvision==0.14.0\ntqdm==4.66.1\ntransformers==4.46.3\ntriton==2.0.0\ntyper==0.15.1\ntyping_extensions==4.12.2\ntzdata==2023.3\nurllib3==2.2.3\nuvicorn==0.33.0\nvisualize==0.5.1\nwebsockets==12.0\nWerkzeug==3.0.1\nwrapt==1.15.0\nzipp==3.17.0\n"
  },
  {
    "path": "run_ood_test/run_DM_v0_df_test_128_both_pose_blink.sh",
    "content": "\ntest_name=ood_test_1009 # $(date +\"%Y-%m-%d_%H-%M-%S\")\ntime_tag=tmp1009 # $(date +\"%Y-%m-%d_%H-%M-%S\")\naudio_path=WRA_MarcoRubio_000.wav \nimage_path=real_female_1.jpeg\ncache_path=cache/$time_tag\naudio_emb_path=cache/target_audio.npy\nvideo_output_path=cache/\n\nconda activate 3DDFA\ncd extract_init_states\npython demo_pose_extract_2d_lmk_img.py \\\n    --input $image_path \\\n    --output $cache_path\n\ncd ..\nconda activate DAWN\n\npython ./hubert_extract/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \\\n    --src_audio_path $audio_path \\\n    --save_path $audio_emb_path\n\npython ./PBnet/src/evaluate/tvae_eval_single_both_eye_pose.py \\\n    --audio_path  $audio_emb_path \\\n    --init_pose_blink $cache_path \\\n    --ckpt './pretrain_models/pbnet_both/checkpoint_100000.pth.tar' \\\n    --output $cache_path\n\npython ./DM_3/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_128_2.py --gpu 0  \\\n    --source_img_path $image_path \\\n    --init_state_path $cache_path \\\n    --drive_blink_path $cache_path/dri_blink.npy \\\n    --drive_pose_path $cache_path/dri_pose.npy \\\n    --audio_emb_path $audio_emb_path \\\n    --save_path $video_output_path/$test_name \\\n    --src_audio_path $audio_path\n"
  },
  {
    "path": "run_ood_test/run_DM_v0_df_test_128_separate_pose_blink.sh",
    "content": "\ntest_name=ood_test_1009 # $(date +\"%Y-%m-%d_%H-%M-%S\")\ntime_tag=tmp1009 # $(date +\"%Y-%m-%d_%H-%M-%S\")\naudio_path=WRA_MarcoRubio_000.wav \nimage_path=real_female_1.jpeg\ncache_path=cache/$time_tag\naudio_emb_path=cache/target_audio.npy\nvideo_output_path=cache/\n\nsource activate\nconda activate 3DDFA\ncd extract_init_states\npython demo_pose_extract_2d_lmk_img.py \\\n    --input ../$image_path \\\n    --output ../$cache_path\n\ncd ..\nconda activate DAWN\n\npython ./hubert_extract/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \\\n    --src_audio_path $audio_path \\\n    --save_path $audio_emb_path\n\n# conda activate LFDM_a40\npython ./PBnet/src/evaluate/tvae_eval_single.py \\\n    --audio_path  $audio_emb_path \\\n    --init_pose_blink $cache_path \\\n    --output $cache_path \\\n    --ckpt_pose ./pretrain_models/pbnet_seperate/pose/checkpoint_40000.pth.tar \\\n    --ckpt_blink ./pretrain_models/pbnet_seperate/blink/checkpoint_95000.pth.tar \n\n\npython your_path/DM_3/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_128_2.py --gpu 0  \\\n    --source_img_path $image_path \\\n    --init_state_path $cache_path \\\n    --drive_blink_path $cache_path/dri_blink.npy \\\n    --drive_pose_path $cache_path/dri_pose.npy \\\n    --audio_emb_path $audio_emb_path \\\n    --save_path $video_output_path/$test_name \\\n    --src_audio_path $audio_path\n"
  },
  {
    "path": "run_ood_test/run_DM_v0_df_test_256.sh",
    "content": "\n\nsource /home4/intern/hbcheng2/.bashrc\n\n\ntest_name=ood_test_1006 # $(date +\"%Y-%m-%d_%H-%M-%S\")\ntime_tag=tmp #$(date +\"%Y-%m-%d_%H-%M-%S\")\naudio_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_clip_vocal_origin/Taylor-Swift-You-Belong-With-Me-vocal_clip7.wav\nimage_path=your/path/DAWN-pytorch/ood_data/ood_select_3/test4.jpeg\ncache_path=your/path/DAWN-pytorch/ood_data_3/$time_tag\naudio_emb_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_embedding_vocal/Taylor-Swift-You-Belong-With-Me-vocal_clip7.npy\n\nconda activate 3DDFA\ncd /train20/intern/permanent/hbcheng2/AIGC_related/3DDFA_V2-master\npython /train20/intern/permanent/hbcheng2/AIGC_related/3DDFA_V2-master/demo_pose_extract_2d_lmk_img.py \\\n    --input $image_path \\\n    --output $cache_path\n\n# conda activate LFDM_chb\n# cd /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main\n# python /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \\\n#     --src_audio_path $audio_path \\\n#     --save_path $audio_emb_path\n\nconda activate LFDM_chb\ncd /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master\npython /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/evaluate/tvae_eval_signal.py \\\n    --audio_path  $audio_emb_path \\\n    --init_pose_blink $cache_path \\\n    --output $cache_path\n\ncd your/path/DAWN-pytorch\n# source /home4/intern/hbcheng2/.bashrc\n\n# echo 'finish extracting init state'\npython your/path/DAWN-pytorch/DM_1/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_256_2.py --gpu 0  \\\n    --source_img_path $image_path \\\n    --init_state_path $cache_path \\\n    --drive_blink_path $cache_path/dri_blink.npy \\\n    --drive_pose_path $cache_path/dri_pose.npy \\\n    --audio_emb_path $audio_emb_path \\\n    --save_path /train20/intern/permanent/hbcheng2/data/ood_test_3/$test_name \\\n    --src_audio_path $audio_path\n\n# audio_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_clip_vocal_origin/Taylor-Swift-You-Belong-With-Me-vocal_clip1.wav\n# # image_path=your/path/DAWN-pytorch/ood_data/ood_select/images/draw_female_test1.png\n# # cache_path=your/path/DAWN-pytorch/ood_data_3/$test_name\n# audio_emb_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_embedding_vocal/Taylor-Swift-You-Belong-With-Me-vocal_clip1.npy\n\n\n# # conda activate LFDM_chb\n# # cd /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main\n# # python /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \\\n# #     --src_audio_path $audio_path \\\n# #     --save_path $audio_emb_path\n\n\n# cd /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master\n# python /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/evaluate/tvae_eval_signal.py \\\n#     --audio_path  $audio_emb_path \\\n#     --init_pose_blink $cache_path \\\n#     --output $cache_path\n\n# cd your/path/DAWN-pytorch\n# # source /home4/intern/hbcheng2/.bashrc\n# # conda activate LFDM_a40\n# # echo 'finish extracting init state'\n\n# python your/path/DAWN-pytorch/DM_1/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_256_2.py --gpu 0  \\\n#     --source_img_path $image_path \\\n#     --init_state_path $cache_path \\\n#     --drive_blink_path $cache_path/dri_blink.npy \\\n#     --drive_pose_path $cache_path/dri_pose.npy \\\n#     --audio_emb_path $audio_emb_path \\\n#     --save_path /train20/intern/permanent/hbcheng2/data/ood_test_3/$test_name \\\n#     --src_audio_path $audio_path\n\n# audio_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_clip_vocal_origin/Taylor-Swift-You-Belong-With-Me-vocal_clip2.wav\n# # image_path=your/path/DAWN-pytorch/ood_data/ood_select/images/draw_female_test1.png\n# # cache_path=your/path/DAWN-pytorch/ood_data_3/$test_name\n# audio_emb_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_embedding_vocal/Taylor-Swift-You-Belong-With-Me-vocal_clip2.npy\n\n# # conda activate LFDM_chb\n# # cd /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main\n# # python /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \\\n# #     --src_audio_path $audio_path \\\n# #     --save_path $audio_emb_path\n\n\n\n# cd /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master\n# python /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/evaluate/tvae_eval_signal.py \\\n#     --audio_path  $audio_emb_path \\\n#     --init_pose_blink $cache_path \\\n#     --output $cache_path\n\n# cd your/path/DAWN-pytorch\n# # source /home4/intern/hbcheng2/.bashrc\n# # conda activate LFDM_a40\n# # echo 'finish extracting init state'\n\n# python your/path/DAWN-pytorch/DM_1/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_256_2.py --gpu 0  \\\n#     --source_img_path $image_path \\\n#     --init_state_path $cache_path \\\n#     --drive_blink_path $cache_path/dri_blink.npy \\\n#     --drive_pose_path $cache_path/dri_pose.npy \\\n#     --audio_emb_path $audio_emb_path \\\n#     --save_path /train20/intern/permanent/hbcheng2/data/ood_test_3/$test_name \\\n#     --src_audio_path $audio_path\n\n# audio_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_clip_vocal_origin/Taylor-Swift-You-Belong-With-Me-vocal_clip3.wav\n# # image_path=your/path/DAWN-pytorch/ood_data/ood_select/images/draw_female_test1.png\n# # cache_path=your/path/DAWN-pytorch/ood_data_3/$test_name\n# audio_emb_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_embedding_vocal/Taylor-Swift-You-Belong-With-Me-vocal_clip3.npy\n\n\n# # conda activate LFDM_chb\n# # cd /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main\n# # python /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \\\n# #     --src_audio_path $audio_path \\\n# #     --save_path $audio_emb_path\n\n\n# cd /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master\n# python /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/evaluate/tvae_eval_signal.py \\\n#     --audio_path  $audio_emb_path \\\n#     --init_pose_blink $cache_path \\\n#     --output $cache_path\n\n# cd your/path/DAWN-pytorch\n# # source /home4/intern/hbcheng2/.bashrc\n# # conda activate LFDM_a40\n# # echo 'finish extracting init state'\n# python your/path/DAWN-pytorch/DM_1/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_256_2.py --gpu 0  \\\n#     --source_img_path $image_path \\\n#     --init_state_path $cache_path \\\n#     --drive_blink_path $cache_path/dri_blink.npy \\\n#     --drive_pose_path $cache_path/dri_pose.npy \\\n#     --audio_emb_path $audio_emb_path \\\n#     --save_path /train20/intern/permanent/hbcheng2/data/ood_test_3/$test_name \\\n#     --src_audio_path $audio_path\n\n# audio_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_clip_vocal_origin/Taylor-Swift-You-Belong-With-Me-vocal_clip4.wav\n# # image_path=your/path/DAWN-pytorch/ood_data/ood_select/images/draw_female_test1.png\n# # cache_path=your/path/DAWN-pytorch/ood_data_3/$test_name\n# audio_emb_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_embedding_vocal/Taylor-Swift-You-Belong-With-Me-vocal_clip4.npy\n\n\n# # conda activate LFDM_chb\n# # cd /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main\n# # python /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \\\n# #     --src_audio_path $audio_path \\\n# #     --save_path $audio_emb_path\n\n\n# cd /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master\n# python /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/evaluate/tvae_eval_signal.py \\\n#     --audio_path  $audio_emb_path \\\n#     --init_pose_blink $cache_path \\\n#     --output $cache_path\n\n# cd your/path/DAWN-pytorch\n# # source /home4/intern/hbcheng2/.bashrc\n# # conda activate LFDM_a40\n# # echo 'finish extracting init state'\n# python your/path/DAWN-pytorch/DM_1/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_256_2.py --gpu 0  \\\n#     --source_img_path $image_path \\\n#     --init_state_path $cache_path \\\n#     --drive_blink_path $cache_path/dri_blink.npy \\\n#     --drive_pose_path $cache_path/dri_pose.npy \\\n#     --audio_emb_path $audio_emb_path \\\n#     --save_path /train20/intern/permanent/hbcheng2/data/ood_test_3/$test_name \\\n#     --src_audio_path $audio_path\n\n# audio_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_clip_vocal_origin/Taylor-Swift-You-Belong-With-Me-vocal_clip5.wav\n# # image_path=your/path/DAWN-pytorch/ood_data/ood_select/images/draw_female_test1.png\n# # cache_path=your/path/DAWN-pytorch/ood_data_3/$test_name\n# audio_emb_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_embedding_vocal/Taylor-Swift-You-Belong-With-Me-vocal_clip5.npy\n\n\n# # conda activate LFDM_chb\n# # cd /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main\n# # python /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \\\n# #     --src_audio_path $audio_path \\\n# #     --save_path $audio_emb_path\n\n\n# cd /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master\n# python /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/evaluate/tvae_eval_signal.py \\\n#     --audio_path  $audio_emb_path \\\n#     --init_pose_blink $cache_path \\\n#     --output $cache_path\n\n# cd your/path/DAWN-pytorch\n# # source /home4/intern/hbcheng2/.bashrc\n# # conda activate LFDM_a40\n# # echo 'finish extracting init state'\n\n# python your/path/DAWN-pytorch/DM_1/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_256_2.py --gpu 0  \\\n#     --source_img_path $image_path \\\n#     --init_state_path $cache_path \\\n#     --drive_blink_path $cache_path/dri_blink.npy \\\n#     --drive_pose_path $cache_path/dri_pose.npy \\\n#     --audio_emb_path $audio_emb_path \\\n#     --save_path /train20/intern/permanent/hbcheng2/data/ood_test_3/$test_name \\\n#     --src_audio_path $audio_path\n\n# audio_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_clip_vocal_origin/Taylor-Swift-You-Belong-With-Me-vocal_clip6.wav\n# # image_path=your/path/DAWN-pytorch/ood_data/ood_select/images/draw_female_test1.png\n# # cache_path=your/path/DAWN-pytorch/ood_data_3/$test_name\n# audio_emb_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_embedding_vocal/Taylor-Swift-You-Belong-With-Me-vocal_clip6.npy\n\n\n\n# # conda activate LFDM_chb\n# # cd /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main\n# # python /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \\\n# #     --src_audio_path $audio_path \\\n# #     --save_path $audio_emb_path\n\n\n# cd /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master\n# python /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/evaluate/tvae_eval_signal.py \\\n#     --audio_path  $audio_emb_path \\\n#     --init_pose_blink $cache_path \\\n#     --output $cache_path\n\n# cd your/path/DAWN-pytorch\n# # source /home4/intern/hbcheng2/.bashrc\n# # conda activate LFDM_a40\n# # echo 'finish extracting init state'\n# python your/path/DAWN-pytorch/DM_1/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_256_2.py --gpu 0  \\\n#     --source_img_path $image_path \\\n#     --init_state_path $cache_path \\\n#     --drive_blink_path $cache_path/dri_blink.npy \\\n#     --drive_pose_path $cache_path/dri_pose.npy \\\n#     --audio_emb_path $audio_emb_path \\\n#     --save_path /train20/intern/permanent/hbcheng2/data/ood_test_3/$test_name \\\n#     --src_audio_path $audio_path\n\n\n\n# audio_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_clip_vocal_origin/Taylor-Swift-You-Belong-With-Me-vocal_clip0.wav\n# # image_path=your/path/DAWN-pytorch/ood_data/ood_select/images/draw_female_test1.png\n# # cache_path=your/path/DAWN-pytorch/ood_data_3/$test_name\n# audio_emb_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_embedding_vocal/Taylor-Swift-You-Belong-With-Me-vocal_clip0.npy\n\n\n\n# # conda activate LFDM_chb\n# # cd /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main\n# # python /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \\\n# #     --src_audio_path $audio_path \\\n# #     --save_path $audio_emb_path\n\n\n# cd /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master\n# python /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/evaluate/tvae_eval_signal.py \\\n#     --audio_path  $audio_emb_path \\\n#     --init_pose_blink $cache_path \\\n#     --output $cache_path\n\n# cd your/path/DAWN-pytorch\n# # source /home4/intern/hbcheng2/.bashrc\n# # conda activate LFDM_a40\n# # echo 'finish extracting init state'\n\n# python your/path/DAWN-pytorch/DM_1/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_256_2.py --gpu 0  \\\n#     --source_img_path $image_path \\\n#     --init_state_path $cache_path \\\n#     --drive_blink_path $cache_path/dri_blink.npy \\\n#     --drive_pose_path $cache_path/dri_pose.npy \\\n#     --audio_emb_path $audio_emb_path \\\n#     --save_path /train20/intern/permanent/hbcheng2/data/ood_test_3/$test_name \\\n#     --src_audio_path $audio_path\n\n\n\n"
  },
  {
    "path": "run_ood_test/run_DM_v0_df_test_256_1.sh",
    "content": "test_name=ood_test_1009 # $(date +\"%Y-%m-%d_%H-%M-%S\")\ntime_tag=tmp1009 # $(date +\"%Y-%m-%d_%H-%M-%S\")\naudio_path=WRA_MarcoRubio_000.wav \nimage_path=real_female_1.jpeg\ncache_path=cache/$time_tag\naudio_emb_path=cache/target_audio.npy\nvideo_output_path=cache/\n\nconda activate 3DDFA\ncd extract_init_states\npython demo_pose_extract_2d_lmk_img.py \\\n    --input ../$image_path \\\n    --output ../$cache_path\n\ncd ..\nconda activate DAWN\n\npython ./hubert_extract/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \\\n    --src_audio_path $audio_path \\\n    --save_path $audio_emb_path\n\npython ./PBnet/src/evaluate/tvae_eval_single_both_eye_pose.py \\\n    --audio_path  $audio_emb_path \\\n    --init_pose_blink $cache_path \\\n    --ckpt './pretrain_models/pbnet_both/checkpoint_100000.pth.tar' \\\n    --output $cache_path\n\npython ./DM_3/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_256_2.py --gpu 0  \\\n    --source_img_path $image_path \\\n    --init_state_path $cache_path \\\n    --drive_blink_path $cache_path/dri_blink.npy \\\n    --drive_pose_path $cache_path/dri_pose.npy \\\n    --audio_emb_path $audio_emb_path \\\n    --save_path $video_output_path/$test_name \\\n    --src_audio_path $audio_path\n\n"
  },
  {
    "path": "run_ood_test/run_DM_v0_df_test_256_1_separate_pose_blink.sh",
    "content": "test_name=ood_test_1009 # $(date +\"%Y-%m-%d_%H-%M-%S\")\ntime_tag=tmp1009 # $(date +\"%Y-%m-%d_%H-%M-%S\")\naudio_path=WRA_MarcoRubio_000.wav \nimage_path=real_female_1.jpeg\ncache_path=cache/$time_tag\naudio_emb_path=cache/target_audio.npy\nvideo_output_path=cache/\n\nsource activate\n# conda activate 3DDFA\n# cd extract_init_states\n# python demo_pose_extract_2d_lmk_img.py \\\n#     --input ../$image_path \\\n#     --output ../$cache_path\n\n# cd ..\nconda activate DAWN\n\npython ./hubert_extract/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \\\n    --src_audio_path $audio_path \\\n    --save_path $audio_emb_path\n\n# python ./PBnet/src/evaluate/tvae_eval_single.py \\\n#     --audio_path  $audio_emb_path \\\n#     --init_pose_blink $cache_path \\\n#     --output $cache_path \\\n#     --ckpt_pose ./pretrain_models/pbnet_seperate/pose/checkpoint_40000.pth.tar \\\n#     --ckpt_blink ./pretrain_models/pbnet_seperate/blink/checkpoint_95000.pth.tar \n\n\n\n# python ./DM_3/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_256_2.py --gpu 0  \\\n#     --source_img_path $image_path \\\n#     --init_state_path $cache_path \\\n#     --drive_blink_path $cache_path/dri_blink.npy \\\n#     --drive_pose_path $cache_path/dri_pose.npy \\\n#     --audio_emb_path $audio_emb_path \\\n#     --save_path $video_output_path/$test_name \\\n#     --src_audio_path $audio_path\n\n"
  },
  {
    "path": "sync_batchnorm/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : __init__.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nfrom .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d\nfrom .replicate import DataParallelWithCallback, patch_replication_callback\n"
  },
  {
    "path": "sync_batchnorm/batchnorm.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : batchnorm.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport collections\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch.nn.modules.batchnorm import _BatchNorm\nfrom torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast\n\nfrom .comm import SyncMaster\n\n__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']\n\n\ndef _sum_ft(tensor):\n    \"\"\"sum over the first and last dimention\"\"\"\n    return tensor.sum(dim=0).sum(dim=-1)\n\n\ndef _unsqueeze_ft(tensor):\n    \"\"\"add new dementions at the front and the tail\"\"\"\n    return tensor.unsqueeze(0).unsqueeze(-1)\n\n\n_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])\n_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])\n\n\nclass _SynchronizedBatchNorm(_BatchNorm):\n    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):\n        super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)\n\n        self._sync_master = SyncMaster(self._data_parallel_master)\n\n        self._is_parallel = False\n        self._parallel_id = None\n        self._slave_pipe = None\n\n    def forward(self, input):\n        # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.\n        if not (self._is_parallel and self.training):\n            return F.batch_norm(\n                input, self.running_mean, self.running_var, self.weight, self.bias,\n                self.training, self.momentum, self.eps)\n\n        # Resize the input to (B, C, -1).\n        input_shape = input.size()\n        input = input.view(input.size(0), self.num_features, -1)\n\n        # Compute the sum and square-sum.\n        sum_size = input.size(0) * input.size(2)\n        input_sum = _sum_ft(input)\n        input_ssum = _sum_ft(input ** 2)\n\n        # Reduce-and-broadcast the statistics.\n        if self._parallel_id == 0:\n            mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))\n        else:\n            mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))\n\n        # Compute the output.\n        if self.affine:\n            # MJY:: Fuse the multiplication for speed.\n            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)\n        else:\n            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)\n\n        # Reshape it.\n        return output.view(input_shape)\n\n    def __data_parallel_replicate__(self, ctx, copy_id):\n        self._is_parallel = True\n        self._parallel_id = copy_id\n\n        # parallel_id == 0 means master device.\n        if self._parallel_id == 0:\n            ctx.sync_master = self._sync_master\n        else:\n            self._slave_pipe = ctx.sync_master.register_slave(copy_id)\n\n    def _data_parallel_master(self, intermediates):\n        \"\"\"Reduce the sum and square-sum, compute the statistics, and broadcast it.\"\"\"\n\n        # Always using same \"device order\" makes the ReduceAdd operation faster.\n        # Thanks to:: Tete Xiao (http://tetexiao.com/)\n        intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())\n\n        to_reduce = [i[1][:2] for i in intermediates]\n        to_reduce = [j for i in to_reduce for j in i]  # flatten\n        target_gpus = [i[1].sum.get_device() for i in intermediates]\n\n        sum_size = sum([i[1].sum_size for i in intermediates])\n        sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)\n        mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)\n\n        broadcasted = Broadcast.apply(target_gpus, mean, inv_std)\n\n        outputs = []\n        for i, rec in enumerate(intermediates):\n            outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))\n\n        return outputs\n\n    def _compute_mean_std(self, sum_, ssum, size):\n        \"\"\"Compute the mean and standard-deviation with sum and square-sum. This method\n        also maintains the moving average on the master device.\"\"\"\n        assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'\n        mean = sum_ / size\n        sumvar = ssum - sum_ * mean\n        unbias_var = sumvar / (size - 1)\n        bias_var = sumvar / size\n\n        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data\n        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data\n\n        return mean, bias_var.clamp(self.eps) ** -0.5\n\n\nclass SynchronizedBatchNorm1d(_SynchronizedBatchNorm):\n    r\"\"\"Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a\n    mini-batch.\n\n    .. math::\n\n        y = \\frac{x - mean[x]}{ \\sqrt{Var[x] + \\epsilon}} * gamma + beta\n\n    This module differs from the built-in PyTorch BatchNorm1d as the mean and\n    standard-deviation are reduced across all devices during training.\n\n    For example, when one uses `nn.DataParallel` to wrap the network during\n    training, PyTorch's implementation normalize the tensor on each device using\n    the statistics only on that device, which accelerated the computation and\n    is also easy to implement, but the statistics might be inaccurate.\n    Instead, in this synchronized version, the statistics will be computed\n    over all training samples distributed on multiple devices.\n    \n    Note that, for one-GPU or CPU-only case, this module behaves exactly same\n    as the built-in PyTorch implementation.\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and gamma and beta are learnable parameter vectors\n    of size C (where C is the input size).\n\n    During training, this layer keeps a running estimate of its computed mean\n    and variance. The running sum is kept with a default momentum of 0.1.\n\n    During evaluation, this running mean/variance is used for normalization.\n\n    Because the BatchNorm is done over the `C` dimension, computing statistics\n    on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm\n\n    Args:\n        num_features: num_features from an expected input of size\n            `batch_size x num_features [x width]`\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Default: 0.1\n        affine: a boolean value that when set to ``True``, gives the layer learnable\n            affine parameters. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C)` or :math:`(N, C, L)`\n        - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)\n\n    Examples:\n        >>> # With Learnable Parameters\n        >>> m = SynchronizedBatchNorm1d(100)\n        >>> # Without Learnable Parameters\n        >>> m = SynchronizedBatchNorm1d(100, affine=False)\n        >>> input = torch.autograd.Variable(torch.randn(20, 100))\n        >>> output = m(input)\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() != 2 and input.dim() != 3:\n            raise ValueError('expected 2D or 3D input (got {}D input)'\n                             .format(input.dim()))\n        super(SynchronizedBatchNorm1d, self)._check_input_dim(input)\n\n\nclass SynchronizedBatchNorm2d(_SynchronizedBatchNorm):\n    r\"\"\"Applies Batch Normalization over a 4d input that is seen as a mini-batch\n    of 3d inputs\n\n    .. math::\n\n        y = \\frac{x - mean[x]}{ \\sqrt{Var[x] + \\epsilon}} * gamma + beta\n\n    This module differs from the built-in PyTorch BatchNorm2d as the mean and\n    standard-deviation are reduced across all devices during training.\n\n    For example, when one uses `nn.DataParallel` to wrap the network during\n    training, PyTorch's implementation normalize the tensor on each device using\n    the statistics only on that device, which accelerated the computation and\n    is also easy to implement, but the statistics might be inaccurate.\n    Instead, in this synchronized version, the statistics will be computed\n    over all training samples distributed on multiple devices.\n    \n    Note that, for one-GPU or CPU-only case, this module behaves exactly same\n    as the built-in PyTorch implementation.\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and gamma and beta are learnable parameter vectors\n    of size C (where C is the input size).\n\n    During training, this layer keeps a running estimate of its computed mean\n    and variance. The running sum is kept with a default momentum of 0.1.\n\n    During evaluation, this running mean/variance is used for normalization.\n\n    Because the BatchNorm is done over the `C` dimension, computing statistics\n    on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm\n\n    Args:\n        num_features: num_features from an expected input of\n            size batch_size x num_features x height x width\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Default: 0.1\n        affine: a boolean value that when set to ``True``, gives the layer learnable\n            affine parameters. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C, H, W)`\n        - Output: :math:`(N, C, H, W)` (same shape as input)\n\n    Examples:\n        >>> # With Learnable Parameters\n        >>> m = SynchronizedBatchNorm2d(100)\n        >>> # Without Learnable Parameters\n        >>> m = SynchronizedBatchNorm2d(100, affine=False)\n        >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))\n        >>> output = m(input)\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() != 4:\n            raise ValueError('expected 4D input (got {}D input)'\n                             .format(input.dim()))\n        super(SynchronizedBatchNorm2d, self)._check_input_dim(input)\n\n\nclass SynchronizedBatchNorm3d(_SynchronizedBatchNorm):\n    r\"\"\"Applies Batch Normalization over a 5d input that is seen as a mini-batch\n    of 4d inputs\n\n    .. math::\n\n        y = \\frac{x - mean[x]}{ \\sqrt{Var[x] + \\epsilon}} * gamma + beta\n\n    This module differs from the built-in PyTorch BatchNorm3d as the mean and\n    standard-deviation are reduced across all devices during training.\n\n    For example, when one uses `nn.DataParallel` to wrap the network during\n    training, PyTorch's implementation normalize the tensor on each device using\n    the statistics only on that device, which accelerated the computation and\n    is also easy to implement, but the statistics might be inaccurate.\n    Instead, in this synchronized version, the statistics will be computed\n    over all training samples distributed on multiple devices.\n    \n    Note that, for one-GPU or CPU-only case, this module behaves exactly same\n    as the built-in PyTorch implementation.\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and gamma and beta are learnable parameter vectors\n    of size C (where C is the input size).\n\n    During training, this layer keeps a running estimate of its computed mean\n    and variance. The running sum is kept with a default momentum of 0.1.\n\n    During evaluation, this running mean/variance is used for normalization.\n\n    Because the BatchNorm is done over the `C` dimension, computing statistics\n    on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm\n    or Spatio-temporal BatchNorm\n\n    Args:\n        num_features: num_features from an expected input of\n            size batch_size x num_features x depth x height x width\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Default: 0.1\n        affine: a boolean value that when set to ``True``, gives the layer learnable\n            affine parameters. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C, D, H, W)`\n        - Output: :math:`(N, C, D, H, W)` (same shape as input)\n\n    Examples:\n        >>> # With Learnable Parameters\n        >>> m = SynchronizedBatchNorm3d(100)\n        >>> # Without Learnable Parameters\n        >>> m = SynchronizedBatchNorm3d(100, affine=False)\n        >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))\n        >>> output = m(input)\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() != 5:\n            raise ValueError('expected 5D input (got {}D input)'\n                             .format(input.dim()))\n        super(SynchronizedBatchNorm3d, self)._check_input_dim(input)\n"
  },
  {
    "path": "sync_batchnorm/comm.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : comm.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport queue\nimport collections\nimport threading\n\n__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']\n\n\nclass FutureResult(object):\n    \"\"\"A thread-safe future implementation. Used only as one-to-one pipe.\"\"\"\n\n    def __init__(self):\n        self._result = None\n        self._lock = threading.Lock()\n        self._cond = threading.Condition(self._lock)\n\n    def put(self, result):\n        with self._lock:\n            assert self._result is None, 'Previous result has\\'t been fetched.'\n            self._result = result\n            self._cond.notify()\n\n    def get(self):\n        with self._lock:\n            if self._result is None:\n                self._cond.wait()\n\n            res = self._result\n            self._result = None\n            return res\n\n\n_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])\n_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])\n\n\nclass SlavePipe(_SlavePipeBase):\n    \"\"\"Pipe for master-slave communication.\"\"\"\n\n    def run_slave(self, msg):\n        self.queue.put((self.identifier, msg))\n        ret = self.result.get()\n        self.queue.put(True)\n        return ret\n\n\nclass SyncMaster(object):\n    \"\"\"An abstract `SyncMaster` object.\n\n    - During the replication, as the data parallel will trigger an callback of each module, all slave devices should\n    call `register(id)` and obtain an `SlavePipe` to communicate with the master.\n    - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,\n    and passed to a registered callback.\n    - After receiving the messages, the master device should gather the information and determine to message passed\n    back to each slave devices.\n    \"\"\"\n\n    def __init__(self, master_callback):\n        \"\"\"\n\n        Args:\n            master_callback: a callback to be invoked after having collected messages from slave devices.\n        \"\"\"\n        self._master_callback = master_callback\n        self._queue = queue.Queue()\n        self._registry = collections.OrderedDict()\n        self._activated = False\n\n    def __getstate__(self):\n        return {'master_callback': self._master_callback}\n\n    def __setstate__(self, state):\n        self.__init__(state['master_callback'])\n\n    def register_slave(self, identifier):\n        \"\"\"\n        Register an slave device.\n\n        Args:\n            identifier: an identifier, usually is the device id.\n\n        Returns: a `SlavePipe` object which can be used to communicate with the master device.\n\n        \"\"\"\n        if self._activated:\n            assert self._queue.empty(), 'Queue is not clean before next initialization.'\n            self._activated = False\n            self._registry.clear()\n        future = FutureResult()\n        self._registry[identifier] = _MasterRegistry(future)\n        return SlavePipe(identifier, self._queue, future)\n\n    def run_master(self, master_msg):\n        \"\"\"\n        Main entry for the master device in each forward pass.\n        The messages were first collected from each devices (including the master device), and then\n        an callback will be invoked to compute the message to be sent back to each devices\n        (including the master device).\n\n        Args:\n            master_msg: the message that the master want to send to itself. This will be placed as the first\n            message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.\n\n        Returns: the message to be sent back to the master device.\n\n        \"\"\"\n        self._activated = True\n\n        intermediates = [(0, master_msg)]\n        for i in range(self.nr_slaves):\n            intermediates.append(self._queue.get())\n\n        results = self._master_callback(intermediates)\n        assert results[0][0] == 0, 'The first result should belongs to the master.'\n\n        for i, res in results:\n            if i == 0:\n                continue\n            self._registry[i].result.put(res)\n\n        for i in range(self.nr_slaves):\n            assert self._queue.get() is True\n\n        return results[0][1]\n\n    @property\n    def nr_slaves(self):\n        return len(self._registry)\n"
  },
  {
    "path": "sync_batchnorm/replicate.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : replicate.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport functools\n\nfrom torch.nn.parallel.data_parallel import DataParallel\n\n__all__ = [\n    'CallbackContext',\n    'execute_replication_callbacks',\n    'DataParallelWithCallback',\n    'patch_replication_callback'\n]\n\n\nclass CallbackContext(object):\n    pass\n\n\ndef execute_replication_callbacks(modules):\n    \"\"\"\n    Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.\n\n    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`\n\n    Note that, as all modules are isomorphism, we assign each sub-module with a context\n    (shared among multiple copies of this module on different devices).\n    Through this context, different copies can share some information.\n\n    We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback\n    of any slave copies.\n    \"\"\"\n    master_copy = modules[0]\n    nr_modules = len(list(master_copy.modules()))\n    ctxs = [CallbackContext() for _ in range(nr_modules)]\n\n    for i, module in enumerate(modules):\n        for j, m in enumerate(module.modules()):\n            if hasattr(m, '__data_parallel_replicate__'):\n                m.__data_parallel_replicate__(ctxs[j], i)\n\n\nclass DataParallelWithCallback(DataParallel):\n    \"\"\"\n    Data Parallel with a replication callback.\n\n    An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by\n    original `replicate` function.\n    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`\n\n    Examples:\n        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])\n        # sync_bn.__data_parallel_replicate__ will be invoked.\n    \"\"\"\n\n    def replicate(self, module, device_ids):\n        modules = super(DataParallelWithCallback, self).replicate(module, device_ids)\n        execute_replication_callbacks(modules)\n        return modules\n\n    def update_num_frames(self, new_num_frames): \n        \n        self.unet.update_num_frames(new_num_frames)\n        self.gaussian_diffusion.update_num_frames(new_num_frames)\n\n\ndef patch_replication_callback(data_parallel):\n    \"\"\"\n    Monkey-patch an existing `DataParallel` object. Add the replication callback.\n    Useful when you have customized `DataParallel` implementation.\n\n    Examples:\n        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n        > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])\n        > patch_replication_callback(sync_bn)\n        # this is equivalent to\n        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])\n    \"\"\"\n\n    assert isinstance(data_parallel, DataParallel)\n\n    old_replicate = data_parallel.replicate\n\n    @functools.wraps(old_replicate)\n    def new_replicate(module, device_ids):\n        modules = old_replicate(module, device_ids)\n        execute_replication_callbacks(modules)\n        return modules\n\n    data_parallel.replicate = new_replicate\n"
  },
  {
    "path": "sync_batchnorm/replicate_ddp.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : replicate.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport functools\n\nfrom torch.nn.parallel.data_parallel import DataParallel\nfrom torch.nn.parallel import DistributedDataParallel\n\n__all__ = [\n    'CallbackContext',\n    'execute_replication_callbacks',\n    'DataParallelWithCallback_ddp',\n    'patch_replication_callback_ddp'\n]\n\n\nclass CallbackContext(object):\n    pass\n\n\ndef execute_replication_callbacks(modules):\n    \"\"\"\n    Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.\n\n    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`\n\n    Note that, as all modules are isomorphism, we assign each sub-module with a context\n    (shared among multiple copies of this module on different devices).\n    Through this context, different copies can share some information.\n\n    We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback\n    of any slave copies.\n    \"\"\"\n    master_copy = modules[0]\n    nr_modules = len(list(master_copy.modules()))\n    ctxs = [CallbackContext() for _ in range(nr_modules)]\n\n    for i, module in enumerate(modules):\n        for j, m in enumerate(module.modules()):\n            if hasattr(m, '__data_parallel_replicate__'):\n                m.__data_parallel_replicate__(ctxs[j], i)\n\n\nclass DataParallelWithCallback_ddp(DistributedDataParallel):\n    \"\"\"\n    Data Parallel with a replication callback.\n\n    An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by\n    original `replicate` function.\n    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`\n\n    Examples:\n        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])\n        # sync_bn.__data_parallel_replicate__ will be invoked.\n    \"\"\"\n\n    def replicate(self, module, device_ids):\n        modules = super(DataParallelWithCallback_ddp, self).replicate(module, device_ids)\n        execute_replication_callbacks(modules)\n        return modules\n\n    def update_num_frames(self, new_num_frames): \n        \n        self.unet.update_num_frames(new_num_frames)\n        self.gaussian_diffusion.update_num_frames(new_num_frames)\n\n\ndef patch_replication_callback_ddp(data_parallel):\n    \"\"\"\n    Monkey-patch an existing `DataParallel` object. Add the replication callback.\n    Useful when you have customized `DataParallel` implementation.\n\n    Examples:\n        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n        > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])\n        > patch_replication_callback(sync_bn)\n        # this is equivalent to\n        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])\n    \"\"\"\n\n    assert isinstance(data_parallel, DistributedDataParallel)\n\n    old_replicate = data_parallel.replicate\n\n    @functools.wraps(old_replicate)\n    def new_replicate(module, device_ids):\n        modules = old_replicate(module, device_ids)\n        execute_replication_callbacks(modules)\n        return modules\n\n    data_parallel.replicate = new_replicate\n"
  },
  {
    "path": "sync_batchnorm/unittest.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : unittest.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport unittest\n\nimport numpy as np\nfrom torch.autograd import Variable\n\n\ndef as_numpy(v):\n    if isinstance(v, Variable):\n        v = v.data\n    return v.cpu().numpy()\n\n\nclass TorchTestCase(unittest.TestCase):\n    def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):\n        npa, npb = as_numpy(a), as_numpy(b)\n        self.assertTrue(\n                np.allclose(npa, npb, atol=atol),\n                'Tensor close check failed\\n{}\\n{}\\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())\n        )\n"
  },
  {
    "path": "unified_video_generator.py",
    "content": "import os\r\nimport os.path as osp\r\nimport argparse\r\nfrom pathlib import Path\r\nimport subprocess\r\nimport sys\r\nsys.path.append('.')\r\n\r\nimport os\r\nimport cv2\r\nimport yaml\r\nimport tempfile\r\nimport numpy as np\r\nimport torch\r\n\r\nimport soundfile as sf\r\nfrom scipy.interpolate import interp1d\r\nfrom extract_init_states.FaceBoxes.FaceBoxes_ONNX import FaceBoxes_ONNX\r\nfrom extract_init_states.TDDFA_ONNX import TDDFA_ONNX\r\nfrom extract_init_states.utils.pose import get_pose\r\nfrom extract_init_states.utils.functions import calculate_eye, calculate_bbox\r\n\r\nfrom transformers import AutoProcessor, HubertModel\r\n\r\nfrom PBnet.src.models.get_model import get_model as get_gen_model\r\nfrom PIL import Image\r\nfrom torchvision import transforms\r\n\r\nfrom pydub import AudioSegment\r\n\r\ndef inv_transform(x, min_vals, max_vals):\r\n    return x * (max_vals - min_vals) + min_vals\r\n\r\ndef load_args(filename):\r\n    with open(filename, \"rb\") as optfile:\r\n        opt = yaml.load(optfile, Loader=yaml.Loader)\r\n    return opt\r\n\r\nclass VideoGenerator:\r\n    def __init__(self, args):\r\n        self.audio_path = args.audio_path\r\n        self.image_path = args.image_path\r\n        self.output_path = args.output_path\r\n        self.cache_path = args.cache_path\r\n\r\n        self.resolution = args.resolution\r\n        \r\n        # Ensure output directories exist\r\n        os.makedirs(self.cache_path, exist_ok=True)\r\n        os.makedirs(self.output_path, exist_ok=True)\r\n        \r\n        # Set intermediate file paths\r\n        self.audio_emb_path = os.path.join(self.cache_path, 'target_audio.npy')\r\n\r\n        # Set ONNX runtime environment for 3DDFA\r\n        os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'\r\n        os.environ['OMP_NUM_THREADS'] = '8'\r\n        \r\n        # Initialize configuration\r\n        self.config_path = './extract_init_states/configs/mb1_120x120.yml'\r\n        self.cfg = yaml.load(open(self.config_path), Loader=yaml.SafeLoader)\r\n        \r\n        # Initialize models\r\n        self.face_boxes = FaceBoxes_ONNX()\r\n        self.tddfa = TDDFA_ONNX(**self.cfg)\r\n\r\n        # HuBERT model configuration\r\n        print(\"Loading the Wav2Vec2 Processor...\")\r\n        self.wav2vec2_processor = AutoProcessor.from_pretrained(\"./pretrain_models/hubert-large-ls960-ft\")\r\n        print(\"Loading the HuBERT Model...\")\r\n        self.hubert_model = HubertModel.from_pretrained(\"./pretrain_models/hubert-large-ls960-ft\")\r\n        self.hubert_model.eval()\r\n        # PBnet related configuration\r\n        self.pbnet_pose_ckpt = './pretrain_models/pbnet_seperate/pose/checkpoint_40000.pth.tar'\r\n        self.pbnet_blink_ckpt = './pretrain_models/pbnet_seperate/blink/checkpoint_95000.pth.tar'\r\n        self.device = 'cuda:0'\r\n        \r\n        # PBnet model parameters\r\n        folder_p, _ = os.path.split(self.pbnet_pose_ckpt)\r\n        self.pose_params = load_args(os.path.join(folder_p, \"opt.yaml\"))\r\n        self.pose_params['device'] = self.device\r\n        self.pose_params['audio_dim'] = 1024\r\n        self.pose_params['pos_dim'] = 6\r\n        self.pose_params['eye_dim'] = 0\r\n        \r\n        \r\n        folder_b, _ = os.path.split(self.pbnet_blink_ckpt)\r\n        self.blink_params = load_args(os.path.join(folder_b, \"opt.yaml\"))\r\n        self.blink_params['device'] = self.device\r\n        self.blink_params['audio_dim'] = 1024\r\n        self.blink_params['pos_dim'] = 0\r\n        self.blink_params['eye_dim'] = 2\r\n\r\n        # Add normalization parameters\r\n        self.max_vals = torch.tensor([90, 90, 90,  1,\r\n            720,  1080]).to(torch.float32).reshape(1, 1, 6)\r\n        self.min_vals = torch.tensor([-90, -90, -90,  0,\r\n            0,  0]).to(torch.float32).reshape(1, 1, 6)\r\n\r\n        # Load models\r\n        model_p = get_gen_model(self.pose_params)\r\n        model_b = get_gen_model(self.blink_params)\r\n\r\n        # Load pretrained weights\r\n        state_dict_p = torch.load(self.pbnet_pose_ckpt, map_location=self.device)\r\n        state_dict_b = torch.load(self.pbnet_blink_ckpt, map_location=self.device)\r\n        model_p.load_state_dict(state_dict_p)\r\n        model_b.load_state_dict(state_dict_b)\r\n        model_p.eval()\r\n        model_b.eval()\r\n\r\n        self.model_p = model_p\r\n        self.model_b = model_b\r\n\r\n        # Add default video generation configuration\r\n        current_dir = osp.dirname(osp.abspath(__file__))\r\n        \r\n        # Load configuration file\r\n        config_path = osp.join(current_dir, 'config', f'DAWN_{int(self.resolution)}.yaml')\r\n        with open(config_path, 'r') as f:\r\n            self.video_config = yaml.safe_load(f)\r\n            \r\n        # Initialize video generation model as None for lazy loading\r\n        self.video_model = self._init_video_model(self.video_config['model_config'])\r\n\r\n    # def switch_conda_env(self, env_name):\r\n    #     \"\"\"切换 conda 环境的函数\"\"\"\r\n    #     # 这里需要使用 subprocess 来执行 conda 命令\r\n    #     subprocess.run(f\"conda activate {env_name}\", shell=True)\r\n        \r\n    def extract_pose(self):\r\n        \"\"\"Extract facial pose and landmark information from input image.\r\n        \r\n        This function uses 3DDFA-V2 model for face detection and pose estimation. Main steps include:\r\n        1. Load and initialize face detection and pose estimation models\r\n        2. Process input image\r\n        3. Extract facial pose and landmark information\r\n        4. Save results to specified paths\r\n\r\n        Output files:\r\n            - init_pose.npy: Numpy array file containing pose information\r\n            - init_eye_bbox.npy: Numpy array file containing eye and bounding box information\r\n        \"\"\"\r\n        \r\n        \r\n        # Set ONNX runtime environment\r\n        os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'\r\n        os.environ['OMP_NUM_THREADS'] = '8'\r\n        \r\n        # Initialize configuration\r\n        # config_path = 'configs/mb1_120x120.yml'  # Make sure path is correct\r\n        cfg = yaml.load(open(self.config_path), Loader=yaml.SafeLoader)\r\n        \r\n        # Initialize models\r\n        face_boxes = FaceBoxes_ONNX()\r\n        tddfa = TDDFA_ONNX(**cfg)\r\n        \r\n        # Read input image\r\n        image = cv2.imread(self.image_path)\r\n        if image.shape[2] == 4:  # Handle RGBA images\r\n            image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)\r\n        \r\n        # Face detection\r\n        boxes = face_boxes(image)\r\n        if len(boxes) == 0:\r\n            raise ValueError(f'No face detected in image: {self.image_path}')\r\n            return None\r\n        \r\n        # Get 3DMM parameters and ROI boxes\r\n        param_lst, roi_box_lst = tddfa(image, boxes)\r\n        \r\n        # Reconstruct vertices\r\n        dense_flag = True  # For generating dense landmarks\r\n        ver_lst = tddfa.recon_vers(param_lst, roi_box_lst, dense_flag=dense_flag)\r\n        \r\n        # Get pose information\r\n        pose = get_pose(image, param_lst, ver_lst, show_flag=False, wfp=None, wnp=None)\r\n        \r\n        # Calculate eye and bounding box information\r\n        lmk = ver_lst[0]\r\n        eye_bbox_result = np.zeros(8)\r\n        bbox = calculate_bbox(image, lmk)\r\n        left_ratio, right_ratio = calculate_eye(lmk)\r\n        \r\n        # Organize result data\r\n        eye_bbox_result[0] = left_ratio.sum()\r\n        eye_bbox_result[1] = right_ratio.sum()\r\n        eye_bbox_result[2:] = np.array(bbox)\r\n        \r\n        # Reshape arrays\r\n        pose = pose.reshape(1, 7)\r\n        eye_bbox_result = eye_bbox_result.reshape(1, -1)\r\n        \r\n        # Set save paths\r\n        eye_bbox_path = os.path.join(self.cache_path, 'init_eye_bbox.npy')\r\n        pose_path = os.path.join(self.cache_path, 'init_pose.npy')\r\n        \r\n        # Save results\r\n        np.save(eye_bbox_path, eye_bbox_result)\r\n        np.save(pose_path, pose)\r\n\r\n    def process_audio(self):\r\n        \"\"\"Process audio file and extract HuBERT features.\r\n        \r\n        This method performs the following steps:\r\n        1. Convert input audio to 16kHz sampling rate\r\n        2. Extract audio features using HuBERT model \r\n        3. Interpolate features to match video frame rate\r\n        4. Save processed features\r\n\r\n        Output files:\r\n            - target_audio.npy: Numpy array containing interpolated HuBERT features\r\n\r\n        Raises:\r\n            RuntimeError: If audio processing fails\r\n        \"\"\"\r\n        # self.switch_conda_env(\"DAWN\")\r\n        \r\n        try:\r\n            # Create temp file for 16kHz audio\r\n            with tempfile.NamedTemporaryFile('w', suffix='.wav', dir='./') as temp_wav:\r\n                # Convert audio sampling rate to 16kHz\r\n                self._convert_wav_to_16k(self.audio_path, temp_wav.name)\r\n                \r\n                # Read 16kHz audio\r\n                speech_16k, _ = sf.read(temp_wav.name)\r\n                \r\n                # Calculate target frame count (based on 25fps video)\r\n                num_frames = int((speech_16k.shape[0] / 16000) * 25)\r\n                \r\n                # Extract HuBERT features\r\n                hubert_hidden = self._get_hubert_from_16k_speech(speech_16k, device=self.device)\r\n                hubert_hidden = hubert_hidden.detach().numpy()\r\n                \r\n                # Linear interpolation of features\r\n                interp_func = interp1d(np.arange(hubert_hidden.shape[0]), \r\n                                    hubert_hidden, \r\n                                    kind='linear', \r\n                                    axis=0)\r\n                hubert_feature_interpolated = interp_func(\r\n                    np.linspace(0, hubert_hidden.shape[0] - 1, num_frames)\r\n                ).astype(np.float32)\r\n                \r\n                print(f'Frame count: {num_frames}, HuBERT size: {hubert_hidden.shape[0]}')\r\n                \r\n                # Save processed features\r\n                np.save(self.audio_emb_path, hubert_feature_interpolated)\r\n                \r\n        except Exception as e:\r\n            raise RuntimeError(f\"Audio processing failed: {str(e)}\")\r\n\r\n    def generate_pose_blink(self):\r\n        \"\"\"Generate pose and blink data.\r\n        \r\n        This function uses the PBnet model to generate driving pose and blink data. Main steps include:\r\n        1. Load pretrained pose and blink models\r\n        2. Process input data (audio features, initial pose, initial blink)\r\n        3. Generate driving data\r\n        4. Save results\r\n        \r\n        Output files:\r\n            - dri_pose.npy: Generated pose data\r\n            - dri_blink.npy: Generated blink data\r\n        \"\"\"\r\n        \r\n        # Set input paths\r\n        init_pose_path = os.path.join(self.cache_path, 'init_pose.npy')\r\n        init_blink_path = os.path.join(self.cache_path, 'init_eye_bbox.npy')\r\n        \r\n        try:\r\n            # Load input data\r\n            init_pose = torch.from_numpy(np.load(init_pose_path))[:,:self.pose_params['pos_dim']].unsqueeze(0).to(torch.float32)\r\n            init_blink = torch.from_numpy(np.load(init_blink_path))[:,:self.blink_params['eye_dim']].unsqueeze(0).to(torch.float32)\r\n            audio = torch.from_numpy(np.load(self.audio_emb_path)).unsqueeze(0).to(torch.float32)\r\n        except Exception:\r\n            # Use default values when 3DDFA extraction fails\r\n            init_pose = torch.from_numpy(np.array([[0, 0, 0, 4.79e-04, 5.65e+01, 6.49e+01,]]))[:,:self.pose_params['pos_dim']].unsqueeze(0).to(torch.float32)\r\n            init_blink = torch.from_numpy(np.array([[0.3,0.3]]))[:,:self.blink_params['eye_dim']].unsqueeze(0).to(torch.float32)\r\n            audio = torch.from_numpy(np.load(self.audio_emb_path)).unsqueeze(0).to(torch.float32)\r\n        \r\n        # normalize\r\n        init_pose = (init_pose - self.min_vals) / (self.max_vals - self.min_vals)\r\n        \r\n        with torch.no_grad():\r\n            # 生成驱动数据\r\n            gendurations_seg = torch.tensor([audio.shape[1] - 0])\r\n            batch_p = self.model_p.generate(init_pose, audio, gendurations_seg, fact=1)\r\n            batch_b = self.model_b.generate(init_blink, audio, gendurations_seg, fact=1)\r\n            \r\n            # process the output\r\n            output_p = batch_p['output'].detach().cpu()\r\n            output_b = batch_b['output'].detach().cpu()\r\n            \r\n            output_p = output_p + init_pose\r\n            output_p = inv_transform(output_p, self.min_vals, self.max_vals)\r\n            output_b = output_b + init_blink\r\n            \r\n            # save results\r\n            output_pose_path = os.path.join(self.cache_path, 'dri_pose.npy')\r\n            output_blink_path = os.path.join(self.cache_path, 'dri_blink.npy')\r\n            np.save(output_pose_path, output_p[0])\r\n            np.save(output_blink_path, output_b[0])\r\n\r\n    def generate_final_video(self):\r\n        \"\"\"Generate the final video.\r\n        \r\n        Args:\r\n        Raises:\r\n            RuntimeError: If an error occurs during video generation\r\n        \"\"\"\r\n        try:  \r\n            # prepare the output dir\r\n            directory_name = os.path.splitext(os.path.basename(self.image_path))[0]\r\n            video_dir = os.path.join(self.output_path, directory_name, 'video')\r\n            img_dir = os.path.join(self.output_path, directory_name, 'img')\r\n            os.makedirs(video_dir, exist_ok=True)\r\n            os.makedirs(img_dir, exist_ok=True)\r\n            \r\n            # prepare input\r\n            image = Image.open(self.image_path).convert(\"RGB\")\r\n            transform = transforms.Compose([\r\n                transforms.Resize((self.video_config['input_size'], self.video_config['input_size'])),\r\n                transforms.ToTensor()\r\n            ])\r\n            image_tensor = transform(image) * 255\r\n            \r\n            # load the audio emb and condition (pose blink)\r\n            hubert_npy = np.load(self.audio_emb_path)\r\n            max_frames = min(self.video_config['max_n_frames'], hubert_npy.shape[0])\r\n            ref_hubert = torch.from_numpy(hubert_npy[:max_frames]).to(torch.float32)\r\n            \r\n            drive_poses = torch.from_numpy(np.load(os.path.join(self.cache_path, 'dri_pose.npy'))[:max_frames]).to(torch.float32)\r\n            drive_blink = torch.from_numpy(np.load(os.path.join(self.cache_path, 'dri_blink.npy'))[:max_frames]).to(torch.float32)\r\n            \r\n            try:\r\n                real_poses = torch.from_numpy(np.load(os.path.join(self.cache_path, 'init_pose.npy'))).to(torch.float32)\r\n                real_blink_bbox = torch.from_numpy(np.load(os.path.join(self.cache_path, 'init_eye_bbox.npy'))).to(torch.float32)\r\n            except Exception:\r\n                # default value\r\n                real_poses = torch.zeros(1, 7)\r\n                real_blink_bbox = torch.tensor([[0.3, 0.3, 64, 64, 192, 192, 256, 256]]).reshape(1, -1).to(torch.float32)\r\n            \r\n            # prepare init state\r\n            init_pose = real_poses[0].unsqueeze(0)\r\n            init_blink = real_blink_bbox[0,:2].unsqueeze(0)\r\n            \r\n            # process\r\n            drive_poses = drive_poses.permute(1,0)\r\n            drive_blink = drive_blink.permute(1,0)\r\n            real_blink_bbox = real_blink_bbox.permute(1,0)\r\n            \r\n            # temp file\r\n            with tempfile.NamedTemporaryFile('w', suffix='.wav') as temp_wav, \\\r\n                tempfile.NamedTemporaryFile('w', suffix='.mp4') as temp_video:\r\n                \r\n                # extract the audio seg\r\n                self._extract_audio_segment(self.audio_path, 0, max_frames, 25, temp_wav.name)\r\n                \r\n                # video writer\r\n                fourcc = cv2.VideoWriter_fourcc(*'mp4v')\r\n                video_writer = cv2.VideoWriter(\r\n                    temp_video.name, \r\n                    fourcc, \r\n                    25, \r\n                    (self.video_config['input_size'], self.video_config['input_size'])\r\n                )\r\n                \r\n                # ddim generation\r\n                with torch.no_grad():\r\n                    self.video_model.update_num_frames(max_frames)\r\n                    sample_output = self.video_model.sample_one_video(\r\n                        sample_img=image_tensor.unsqueeze(dim=0).cuda()/255.,\r\n                        sample_audio_hubert=ref_hubert.unsqueeze(dim=0).cuda(),\r\n                        sample_pose=drive_poses.unsqueeze(0).cuda(),\r\n                        sample_eye=drive_blink[:2].unsqueeze(0).cuda(),\r\n                        sample_bbox=real_blink_bbox[2:].unsqueeze(0).cuda(),\r\n                        init_pose=init_pose.cuda(),\r\n                        init_eye=init_blink.cuda(),\r\n                        cond_scale=self.video_config['cond_scale']\r\n                    )\r\n                \r\n                # write the frame\r\n                for frame_idx in range(max_frames):\r\n                    frame = self._process_output_frame(\r\n                        sample_output[\"sample_out_vid\"][:, :, frame_idx],\r\n                        mean=self.video_config['mean']\r\n                    )\r\n                    video_writer.write(frame)\r\n                    # save frames as png\r\n                    frame_name = f\"{frame_idx:03d}.png\"\r\n                    frame_path = os.path.join(img_dir, frame_name)\r\n                    cv2.imwrite(frame_path, frame)\r\n                video_writer.release()\r\n                \r\n                # save the final video\r\n                output_video_path = os.path.join(video_dir, f\"{directory_name}.mp4\")\r\n                self._combine_video_audio(temp_wav.name, temp_video.name, output_video_path)\r\n            \r\n        except Exception as e:\r\n            raise RuntimeError(f\"! Video generation failed: {str(e)}\")\r\n\r\n    def run(self):\r\n        \"\"\"Execute the complete generation pipeline\"\"\"\r\n        print(\"1. Extracting pose information...\")\r\n        self.extract_pose()\r\n        \r\n        print(\"2. Processing audio...\")\r\n        self.process_audio()\r\n        \r\n        print(\"3. Generating pose and blink data...\")\r\n        self.generate_pose_blink()\r\n        \r\n        print(\"4. Generating final video...\")\r\n        self.generate_final_video()\r\n    \r\n    @staticmethod\r\n    def _convert_wav_to_16k(input_file, output_file):\r\n        \"\"\"Convert audio file to 16kHz sampling rate.\r\n\r\n        Args:\r\n            input_file (str): Path to input audio file\r\n            output_file (str): Path to output audio file\r\n        \"\"\"\r\n        command = [\r\n            'ffmpeg',\r\n            '-i', input_file,\r\n            '-ar', '16000',\r\n            '-y',  # Add -y parameter to automatically overwrite existing files\r\n            output_file\r\n        ]\r\n        subprocess.run(command)\r\n\r\n    @torch.no_grad()\r\n    def _get_hubert_from_16k_speech(self, speech, device=\"cuda:0\"):\r\n        \"\"\"Extract HuBERT features from 16kHz audio.\r\n\r\n        Args:\r\n            speech (numpy.ndarray): Input audio data\r\n            device (str): Computing device, defaults to \"cuda:0\"\r\n\r\n        Returns:\r\n            torch.Tensor: HuBERT feature tensor\r\n\r\n        Notes:\r\n            HuBERT model uses multi-layer CNN for processing:\r\n            - Total stride is 320 (5*2*2*2*2*2)\r\n            - Kernel size is 400\r\n            - Process long audio in segments to avoid memory issues\r\n        \"\"\"\r\n        self.hubert_model = self.hubert_model.to(device)\r\n        if speech.ndim == 2:\r\n            speech = speech[:, 0]  # [T, 2] ==> [T,]\r\n            \r\n        input_values_all = self.wav2vec2_processor(\r\n            speech, \r\n            return_tensors=\"pt\", \r\n            sampling_rate=16000\r\n        ).input_values.to(device)\r\n\r\n        # Set parameters for segment processing\r\n        kernel = 400\r\n        stride = 320\r\n        clip_length = stride * 1000\r\n        num_iter = input_values_all.shape[1] // clip_length\r\n        expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride\r\n        \r\n        # Process audio in segments\r\n        res_lst = []\r\n        for i in range(num_iter):\r\n            if i == 0:\r\n                start_idx = 0\r\n                end_idx = clip_length - stride + kernel\r\n            else:\r\n                start_idx = clip_length * i\r\n                end_idx = start_idx + (clip_length - stride + kernel)\r\n                \r\n            input_values = input_values_all[:, start_idx: end_idx]\r\n            hidden_states = self.hubert_model(input_values).last_hidden_state\r\n            res_lst.append(hidden_states[0])\r\n        \r\n        # the last seg\r\n        if num_iter > 0:\r\n            input_values = input_values_all[:, clip_length * num_iter:]\r\n        else:\r\n            input_values = input_values_all\r\n            \r\n        if input_values.shape[1] >= kernel:\r\n            hidden_states = self.hubert_model(input_values).last_hidden_state\r\n            res_lst.append(hidden_states[0])\r\n        \r\n        # concat the feature\r\n        ret = torch.cat(res_lst, dim=0).cpu()\r\n        \r\n        # check length\r\n        assert abs(ret.shape[0] - expected_T) <= 1\r\n        if ret.shape[0] < expected_T:\r\n            ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0]))\r\n        else:\r\n            ret = ret[:expected_T]\r\n            \r\n        return ret\r\n    \r\n\r\n    def _init_video_model(self, model_config):\r\n        \"\"\"Initialize the video generation model.\r\n        \r\n        Args:\r\n            model_config (dict): Model configuration dictionary\r\n        \r\n        Returns:\r\n            FlowDiffusion: Initialized video generation model\r\n        \"\"\"\r\n        from DM_3.modules.video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_test import FlowDiffusion\r\n        \r\n        model = FlowDiffusion(\r\n            is_train=model_config['is_train'],\r\n            sampling_timesteps=self.video_config['sampling_step'],\r\n            ddim_sampling_eta=self.video_config['ddim_sampling_eta'],\r\n            pose_dim=model_config['pose_dim'],\r\n            config_pth=model_config['config_pth'],\r\n            pretrained_pth=model_config['ae_pretrained_pth'],\r\n            win_width=self.video_config['win_width']\r\n        )\r\n        model.cuda()\r\n        \r\n        # load model\r\n        checkpoint = torch.load(model_config['diffusion_pretrained_pth'])\r\n        model.diffusion.load_state_dict(checkpoint['diffusion'])\r\n        model.eval()\r\n        \r\n        return model\r\n\r\n    def _process_output_frame(self, frame_batch, mean=(0.0, 0.0, 0.0), index=0):\r\n        \"\"\"Process the output frame data from the model.\r\n        \r\n        Args:\r\n            frame_batch (torch.Tensor): Batch of frame data\r\n            mean (tuple): Mean values\r\n            index (int): Batch index\r\n        \r\n        Returns:\r\n            numpy.ndarray: Processed frame in BGR format\r\n        \"\"\"\r\n        frame = frame_batch[index].permute(1, 2, 0).data.cpu().numpy().copy()\r\n        frame += np.array(mean)/255.0\r\n        frame = np.clip(frame, 0, 1)\r\n        frame = (frame * 255).astype(np.uint8)\r\n        return cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)\r\n    \r\n    def _extract_audio_segment(self, input_wav, start_frame, num_frames, fps, output_wav):\r\n        \"\"\"Extract audio segment.\r\n        \r\n        Args:\r\n            input_wav (str): Input audio path\r\n            start_frame (int): Start frame\r\n            num_frames (int): Number of frames\r\n            fps (int): Frames per second\r\n            output_wav (str): Output audio path\r\n        \"\"\"\r\n        \r\n        audio = AudioSegment.from_wav(input_wav)\r\n        frame_duration = 1000 / fps\r\n        start_time = start_frame * frame_duration\r\n        end_time = (start_frame + num_frames) * frame_duration\r\n        audio[start_time:end_time].export(output_wav, format=\"wav\")\r\n\r\n    def _combine_video_audio(self, audio_path, video_path, output_path):\r\n            \"\"\"Combine video and audio.\r\n            \r\n            Args:\r\n                audio_path (str): Path to audio file\r\n                video_path (str): Path to video file\r\n                output_path (str): Path to output file\r\n            \"\"\"\r\n            cmd = [\r\n                'ffmpeg', '-y',\r\n                '-i', audio_path,\r\n                '-i', video_path,\r\n                '-vcodec', 'copy',\r\n                '-ac', '2',\r\n                '-channel_layout', 'stereo',\r\n                '-pix_fmt', 'yuv420p',\r\n                output_path,\r\n                '-shortest'\r\n            ]\r\n            subprocess.run(cmd)\r\n\r\ndef parse_args():\r\n    parser = argparse.ArgumentParser()\r\n    parser.add_argument('--audio_path', type=str, default= 'WRA_MarcoRubio_000.wav', help='Input audio path')\r\n    parser.add_argument('--image_path', type=str, default= 'real_female_1.jpeg', help='Input image path')\r\n    parser.add_argument('--output_path', type=str, default= 'output', help='Output video path')\r\n    parser.add_argument('--cache_path', type=str, default='cache/tmp', help='Cache file path')\r\n    parser.add_argument('--resolution', type=int, default=128, help='resolution')\r\n    return parser.parse_args()\r\n\r\ndef main():\r\n    args = parse_args()\r\n    generator = VideoGenerator(args)\r\n    generator.run()\r\n\r\nif __name__ == \"__main__\":\r\n    main()"
  }
]