[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n#/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n.idea/\n\n# custom ignores\n.DS_Store\n_.*\n\n# models and outputs\n# models/\noutputs/\n\n# Replicate\n.cog"
  },
  {
    "path": "README.md",
    "content": "# [ICLR2025] DisPose: Disentangling Pose Guidance for Controllable Human Image Animation\nThis repository is the official implementation of [DisPose](https://arxiv.org/abs/2412.09349).\n\n[![arXiv](https://img.shields.io/badge/arXiv-2412.09349-b31b1b.svg)](https://arxiv.org/abs/2412.09349)\n[![Project Page](https://img.shields.io/badge/Project-Website-green)](https://lihxxx.github.io/DisPose/)\n\n## 🔥 News\n- **`2025/01/23`**: DisPose is accepted to ICLR 2025.\n- **`2024/12/13`**: We have released the inference code and the checkpoints for DisPose.\n  \n**📖 Table of Contents**\n- [DisPose: Disentangling Pose Guidance for Controllable Human Image Animation](#dispose-disentangling-pose-guidance-for-controllable-human-image-animation)\n  - [🎨 Gallery](#-gallery)\n  - [🧙 Method Overview](#-method-overview)\n  - [🔧 Preparations](#-preparations)\n    - [Setup repository and conda environment](#setup-repository-and-conda-environment)\n    - [Prepare model weights](#prepare-model-weights)\n  - [💫 Inference](#-inference)\n    - [Tips](#tips)\n  - [📣 Disclaimer](#-disclaimer)\n  - [💞 Acknowledgements](#-acknowledgements)\n  - [🔍 Citation](#-citation)\n\n## 🎨 Gallery\n<table class=\"center\">\n<tr>\n  <td><video src=\"https://github.com/user-attachments/assets/e2f5e263-3f86-4778-98b9-6d2d451b7516\" autoplay></td>\n  <td><video src=\"https://github.com/user-attachments/assets/f8e761e3-7a7a-4812-ad61-023b33034a42\" autoplay></td>\n  <td><video src=\"https://github.com/user-attachments/assets/9a6c7ea6-8c73-4a50-b594-f8eba239c405\" autoplay></td>\n  <td><video src=\"https://github.com/user-attachments/assets/a0f97ac4-429e-4ca9-a794-7c02b5dc5405\" autoplay></td>\n  <td><video src=\"https://github.com/user-attachments/assets/6e9d463c-f7c5-4de8-924b-1ad591e3a9a4\" autoplay></td>\n</tr>\n</table>\n\n## 🧙 Method Overview\nWe present **DisPose** to mine more generalizable and effective control signals without additional dense input, which disentangles the sparse skeleton pose in human image animation into motion field guidance and keypoint correspondence.\n<div align='center'>\n<img src=\"https://anonymous.4open.science/r/DisPose-AB1D/pipeline.png\" class=\"interpolation-image\" alt=\"comparison.\" height=\"80%\" width=\"80%\" />\n</div>\n\n\n## 🔧 Preparations\n### Setup repository and conda environment\nThe code requires `python>=3.10`, as well as `torch>=2.0.1` and `torchvision>=0.15.2`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. The demo has been tested on CUDA version of 12.4.\n```\nconda create -n dispose python==3.10\nconda activate dispose\npip install -r requirements.txt\n```\n\n### Prepare model weights\n1. Download the weights of  [DisPose](https://huggingface.co/lihxxx/DisPose) and put `DisPose.pth` into `./pretrained_weights/`.\n\n2. Download the weights of other components and put them into `./pretrained_weights/`:\n  - [stable-video-diffusion-img2vid-xt-1-1](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1/tree/main)\n  - [stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main)\n  - [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)\n  - [MimicMotion](https://huggingface.co/tencent/MimicMotion/tree/main)\n3. Download the weights of [CMP](https://huggingface.co/MyNiuuu/MOFA-Video-Hybrid/resolve/main/models/cmp/experiments/semiauto_annot/resnet50_vip%2Bmpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar) and put it into `./mimicmotion/modules/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints`\n\nFinally, these weights should be organized in `./pretrained_weights/`. as follows:\n\n\n```\n./pretrained_weights/\n|-- MimicMotion_1-1.pth\n|-- DisPose.pth\n|-- dwpose\n|   |-- dw-ll_ucoco_384.onnx\n|   └── yolox_l.onnx\n|-- stable-diffusion-v1-5\n|-- stable-video-diffusion-img2vid-xt-1-1\n```\n\n## 💫 Inference\n\nA sample configuration for testing is provided as `test.yaml`. You can also easily modify the various configurations according to your needs.\n\n```\nbash scripts/test.sh \n```\n\n### Tips\n- If your GPU memory is limited, try set `decode_chunk_size` in `test.yaml` to 1.\n- If you want to enhance the quality of the generated video, you could try some post-processing such as face swapping ([insightface](https://github.com/deepinsight/insightface)) and frame interpolation ([IFRNet](https://github.com/ltkong218/IFRNet)).\n\n## 📣 Disclaimer\nThis is official code of DisPose.\nAll the copyrights of the demo images and videos are from community users. \nFeel free to contact us if you would like to remove them.\n\n## 💞 Acknowledgements\nWe sincerely appreciate the code release of the following projects: [MimicMotion](https://github.com/Tencent/MimicMotion), [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone), [CMP](https://github.com/XiaohangZhan/conditional-motion-propagation).\n\n## 🔍 Citation\n\n```\n@inproceedings{\nli2025dispose,\ntitle={DisPose: Disentangling Pose Guidance for Controllable Human Image Animation},\nauthor={Hongxiang Li and Yaowei Li and Yuhang Yang and Junjie Cao and Zhihong Zhu and Xuxin Cheng and Long Chen},\nbooktitle={The Thirteenth International Conference on Learning Representations},\nyear={2025},\nurl={https://openreview.net/forum?id=AumOa10MKG}\n}\n```\n"
  },
  {
    "path": "configs/test.yaml",
    "content": "# base svd model path\nbase_model_path: ./pretrained_weights/stable-video-diffusion-img2vid-xt-1-1\n# base dift model path\ndift_model_path: ./pretrained_weights/stable-diffusion-v1-5\n\n# checkpoint path\nckpt_path: ./pretrained_weights/MimicMotion_1-1.pth\ncontrolnet_path: ./pretrained_weights/DisPose.pth\n\ntest_case:\n  - ref_video_path: ./assets/example_data/videos/video1.mp4\n    ref_image_path: ./assets/example_data/images/ref1.png\n    num_frames: 16\n    resolution: 576\n    frames_overlap: 6\n    num_inference_steps: 25\n    noise_aug_strength: 0\n    guidance_scale: 2.0\n    sample_stride: 2\n    decode_chunk_size: 8\n    fps: 15\n    seed: 42\n\n  - ref_video_path: ./assets/example_data/videos/video2.mp4\n    ref_image_path: ./assets/example_data/images/ref2.png\n    num_frames: 16\n    resolution: 576\n    frames_overlap: 6\n    num_inference_steps: 25\n    noise_aug_strength: 0\n    guidance_scale: 2.0\n    sample_stride: 2\n    decode_chunk_size: 8\n    fps: 15\n    seed: 42\n\n  - ref_video_path: ./assets/example_data/videos/video3.mp4\n    ref_image_path: ./assets/example_data/images/ref3.png\n    num_frames: 16\n    resolution: 576\n    frames_overlap: 6\n    num_inference_steps: 25\n    noise_aug_strength: 0\n    guidance_scale: 2.0\n    sample_stride: 2\n    decode_chunk_size: 8\n    fps: 15\n    seed: 42"
  },
  {
    "path": "constants.py",
    "content": "# w/h apsect ratio\nASPECT_RATIO = 9 / 16\n"
  },
  {
    "path": "inference_ctrl.py",
    "content": "import os\nimport argparse\nimport logging\nimport math\nfrom omegaconf import OmegaConf\nfrom datetime import datetime\nimport time\nfrom pathlib import Path\nimport PIL.Image\nimport numpy as np\nimport torch.jit\nfrom torchvision.datasets.folder import pil_loader\nfrom torchvision.transforms.functional import pil_to_tensor, resize, center_crop\nfrom torchvision.transforms.functional import to_pil_image\nfrom torchvision import transforms\nimport torch.nn.functional as F\nfrom torchvision.transforms import PILToTensor\nimport torchvision\n\nimport decord\nfrom einops import rearrange, repeat\nfrom mimicmotion.utils.dift_utils import SDFeaturizer\nfrom mimicmotion.utils.utils import points_to_flows, bivariate_Gaussian, sample_inputs_flow, get_cmp_flow, pose2track\nfrom  mimicmotion.utils.visualizer import Visualizer, vis_flow_to_video\nimport cv2\n\n\n\nfrom mimicmotion.utils.geglu_patch import patch_geglu_inplace\npatch_geglu_inplace()\n\nfrom constants import ASPECT_RATIO\nfrom mimicmotion.utils.loader import create_ctrl_pipeline\nfrom mimicmotion.utils.utils import save_to_mp4\nfrom mimicmotion.dwpose.preprocess import get_video_pose, get_image_pose\nfrom mimicmotion.modules.cmp_model import CMP\n\n\nimport pdb\nlogging.basicConfig(level=logging.INFO, format=\"%(asctime)s: [%(levelname)s] %(message)s\")\nlogger = logging.getLogger(__name__)\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef preprocess(video_path, image_path, dift_model_path, resolution=576, sample_stride=2):\n    \"\"\"preprocess ref image pose and video pose\n\n    Args:\n        video_path (str): input video pose path\n        image_path (str): reference image path\n        resolution (int, optional):  Defaults to 576.\n        sample_stride (int, optional): Defaults to 2.\n    \"\"\"\n    image_pixels = pil_loader(image_path)\n    image_pixels = pil_to_tensor(image_pixels) # (c, h, w)\n    h, w = image_pixels.shape[-2:]\n    ############################ compute target h/w according to original aspect ratio ###############################\n    if h>w:\n        w_target, h_target = resolution, int(resolution / ASPECT_RATIO // 64) * 64\n    elif h==w:\n        w_target, h_target = resolution, resolution\n    else:\n        w_target, h_target = int(resolution / ASPECT_RATIO // 64) * 64, resolution\n    h_w_ratio = float(h) / float(w)\n    if h_w_ratio < h_target / w_target:\n        h_resize, w_resize = h_target, math.ceil(h_target / h_w_ratio)\n    else:\n        h_resize, w_resize = math.ceil(w_target * h_w_ratio), w_target\n    image_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None)\n    image_pixels = center_crop(image_pixels, [h_target, w_target])\n    # h_target, w_target = image_pixels.shape[-2:]\n    image_pixels = image_pixels.permute((1, 2, 0)).numpy()\n    ##################################### get video flow #################################################\n    transform = transforms.Compose(\n        [\n        \n        transforms.Resize((h_target, w_target), antialias=None), \n        transforms.CenterCrop((h_target, w_target)), \n        transforms.ToTensor()\n        ]\n    )\n    \n    ref_img = transform(PIL.Image.fromarray(image_pixels))\n\n    ##################################### get image&video pose value #################################################\n    image_pose, ref_point = get_image_pose(image_pixels)\n    ref_point_body, ref_point_head = ref_point[\"bodies\"], ref_point[\"faces\"]\n    video_pose, body_point, face_point = get_video_pose(video_path, image_pixels, sample_stride=sample_stride)\n    body_point_list = [ref_point_body] + body_point\n    face_point_list = [ref_point_head] + face_point\n\n    pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose])\n    image_pixels = np.transpose(np.expand_dims(image_pixels, 0), (0, 3, 1, 2))\n    \n    dift_model = SDFeaturizer(sd_id = dift_model_path, weight_dtype=torch.float16)\n    category=\"human\"\n    prompt = f'photo of a {category}'\n    dift_ref_img = (image_pixels / 255.0 - 0.5) *2\n    dift_ref_img = torch.from_numpy(dift_ref_img).to(device, torch.float16)\n    dift_feats = dift_model.forward(dift_ref_img, prompt=prompt, t=[261,0], up_ft_index=[1,2], ensemble_size=8)\n\n\n    model_length = len(body_point_list)\n    traj_flow = points_to_flows(body_point_list, model_length, h_target, w_target)\n    blur_kernel = bivariate_Gaussian(kernel_size=199, sig_x=20, sig_y=20, theta=0, grid=None, isotropic=True)\n\n    for i in range(0, model_length-1):\n        traj_flow[i] = cv2.filter2D(traj_flow[i], -1, blur_kernel)\n\n    traj_flow = rearrange(traj_flow, \"f h w c -> f c h w\") \n    traj_flow = torch.from_numpy(traj_flow)\n    traj_flow = traj_flow.unsqueeze(0)\n\n    cmp = CMP(\n        './mimicmotion/modules/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/config.yaml',\n        42000\n    ).to(device)\n    cmp.requires_grad_(False)\n\n    pc, ph, pw = ref_img.shape\n    poses, poses_subset = pose2track(body_point_list, ph, pw)\n    poses = torch.from_numpy(poses).permute(1,0,2)\n    poses_subset = torch.from_numpy(poses_subset).permute(1,0,2)\n\n    # pdb.set_trace()\n    val_controlnet_image, val_sparse_optical_flow, \\\n    val_mask, val_first_frame_384, \\\n        val_sparse_optical_flow_384, val_mask_384 = sample_inputs_flow(ref_img.unsqueeze(0).float(), poses.unsqueeze(0), poses_subset.unsqueeze(0))\n\n    fb, fl, fc, fh, fw = val_sparse_optical_flow.shape\n\n    val_controlnet_flow = get_cmp_flow(\n        cmp, \n        val_first_frame_384.unsqueeze(0).repeat(1, fl, 1, 1, 1).to(device), \n        val_sparse_optical_flow_384.to(device), \n        val_mask_384.to(device)\n    )\n\n    if fh != 384 or fw != 384:\n        scales = [fh / 384, fw / 384]\n        val_controlnet_flow = F.interpolate(val_controlnet_flow.flatten(0, 1), (fh, fw), mode='nearest').reshape(fb, fl, 2, fh, fw)\n        val_controlnet_flow[:, :, 0] *= scales[1]\n        val_controlnet_flow[:, :, 1] *= scales[0]\n    \n    vis_flow = val_controlnet_flow[0]\n\n    return torch.from_numpy(pose_pixels.copy()) / 127.5 - 1, torch.from_numpy(image_pixels) / 127.5 - 1, val_controlnet_flow, val_controlnet_image, body_point_list, dift_feats, traj_flow\n\n\ndef run_pipeline(pipeline, image_pixels, pose_pixels,\n                controlnet_flow, controlnet_image, point_list, dift_feats, traj_flow,\n                device, task_config):\n    image_pixels = [to_pil_image(img.to(torch.uint8)) for img in (image_pixels + 1.0) * 127.5]\n    generator = torch.Generator(device=device)\n    generator.manual_seed(task_config.seed)\n    with torch.autocast(\"cuda\"):\n        frames = pipeline(\n            image_pixels, image_pose=pose_pixels, num_frames=pose_pixels.size(0),\n            tile_size=task_config.num_frames, tile_overlap=task_config.frames_overlap,\n            height=pose_pixels.shape[-2], width=pose_pixels.shape[-1], fps=7,\n            controlnet_flow=controlnet_flow, controlnet_image=controlnet_image, point_list=point_list, dift_feats=dift_feats, traj_flow=traj_flow,\n            noise_aug_strength=task_config.noise_aug_strength, num_inference_steps=task_config.num_inference_steps,\n            generator=generator, min_guidance_scale=task_config.guidance_scale, \n            max_guidance_scale=task_config.guidance_scale, decode_chunk_size=task_config.decode_chunk_size, output_type=\"pt\", device=device\n        ).frames.cpu()\n    video_frames = (frames * 255.0).to(torch.uint8)\n\n    for vid_idx in range(video_frames.shape[0]):\n        # deprecated first frame because of ref image\n        _video_frames = video_frames[vid_idx, 1:]\n\n    return _video_frames\n\n\n@torch.no_grad()\ndef main(args):\n    if not args.no_use_float16 :\n        torch.set_default_dtype(torch.float16)\n\n    infer_config = OmegaConf.load(args.inference_config)\n    pipeline = create_ctrl_pipeline(infer_config, device)\n\n    for task in infer_config.test_case:\n        ############################################## Pre-process data ##############################################\n        pose_pixels, image_pixels, controlnet_flow, controlnet_image, point_list, dift_feats, traj_flow = preprocess(\n            task.ref_video_path, task.ref_image_path, infer_config.dift_model_path, \n            resolution=task.resolution, sample_stride=task.sample_stride\n        )\n        ########################################### Run MimicMotion pipeline ###########################################\n        _video_frames = run_pipeline(\n            pipeline, \n            image_pixels, pose_pixels, controlnet_flow, controlnet_image, point_list, dift_feats, traj_flow,\n            device, task\n        )\n        ################################### save results to output folder. ###########################################\n        save_to_mp4(\n            _video_frames, \n            f\"{args.output_dir}/{datetime.now().strftime('%Y%m%d')}_{args.name}/{datetime.now().strftime('%H%M%S')}_{os.path.basename(task.ref_image_path).split('.')[0]}_to_{os.path.basename(task.ref_video_path).split('.')[0]}\" \\\n            f\"_CFG{task.guidance_scale}_{task.num_frames}_{task.fps}.mp4\",\n            fps=task.fps,\n        )\n\ndef set_logger(log_file=None, log_level=logging.INFO):\n    log_handler = logging.FileHandler(log_file, \"w\")\n    log_handler.setFormatter(\n        logging.Formatter(\"[%(asctime)s][%(name)s][%(levelname)s]: %(message)s\")\n    )\n    log_handler.setLevel(log_level)\n    logger.addHandler(log_handler)\n\n\nif __name__ == \"__main__\":    \n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--log_file\", type=str, default=None)\n    parser.add_argument(\"--inference_config\", type=str, default=\"configs/test.yaml\") #ToDo\n    parser.add_argument(\"--output_dir\", type=str, default=\"outputs/\", help=\"path to output\")\n    parser.add_argument(\"--name\", type=str, default=\"\")\n    parser.add_argument(\"--no_use_float16\",\n                        action=\"store_true\",\n                        help=\"Whether use float16 to speed up inference\",\n    )\n    args = parser.parse_args()\n\n    Path(args.output_dir).mkdir(parents=True, exist_ok=True)\n    main(args)\n    logger.info(f\"--- Finished ---\")\n\n"
  },
  {
    "path": "mimicmotion/__init__.py",
    "content": ""
  },
  {
    "path": "mimicmotion/dwpose/.gitignore",
    "content": "*.pyc\n"
  },
  {
    "path": "mimicmotion/dwpose/__init__.py",
    "content": ""
  },
  {
    "path": "mimicmotion/dwpose/dwpose_detector.py",
    "content": "import os\n\nimport numpy as np\nimport torch\n\nfrom .wholebody import Wholebody\n\nos.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\"\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\nclass DWposeDetector:\n    \"\"\"\n    A pose detect method for image-like data.\n\n    Parameters:\n        model_det: (str) serialized ONNX format model path, \n                    such as https://huggingface.co/yzd-v/DWPose/blob/main/yolox_l.onnx\n        model_pose: (str) serialized ONNX format model path, \n                    such as https://huggingface.co/yzd-v/DWPose/blob/main/dw-ll_ucoco_384.onnx\n        device: (str) 'cpu' or 'cuda:{device_id}'\n    \"\"\"\n    def __init__(self, model_det, model_pose, device='cpu'):\n        self.args = model_det, model_pose, device\n\n    def release_memory(self):\n        if hasattr(self, 'pose_estimation'):\n            del self.pose_estimation\n            import gc; gc.collect()\n\n    def __call__(self, oriImg):\n        if not hasattr(self, 'pose_estimation'):\n            self.pose_estimation = Wholebody(*self.args)\n\n        oriImg = oriImg.copy()\n        H, W, C = oriImg.shape\n        with torch.no_grad():\n            candidate, score = self.pose_estimation(oriImg)\n            nums, _, locs = candidate.shape\n            candidate[..., 0] /= float(W)\n            candidate[..., 1] /= float(H)\n            body = candidate[:, :18].copy()\n            body = body.reshape(nums * 18, locs)\n            subset = score[:, :18].copy()\n            for i in range(len(subset)):\n                for j in range(len(subset[i])):\n                    if subset[i][j] > 0.3:\n                        subset[i][j] = int(18 * i + j)\n                    else:\n                        subset[i][j] = -1\n\n            # un_visible = subset < 0.3\n            # candidate[un_visible] = -1\n\n            # foot = candidate[:, 18:24]\n\n            faces = candidate[:, 24:92]\n\n            hands = candidate[:, 92:113]\n            hands = np.vstack([hands, candidate[:, 113:]])\n\n            faces_score = score[:, 24:92]\n            hands_score = np.vstack([score[:, 92:113], score[:, 113:]])\n\n            bodies = dict(candidate=body, subset=subset, score=score[:, :18])\n            pose = dict(bodies=bodies, hands=hands, hands_score=hands_score, faces=faces, faces_score=faces_score)\n\n            return pose\n\ndwpose_detector = DWposeDetector(\n    model_det=\"./pretrained_weights/DWPose/yolox_l.onnx\",\n    model_pose=\"./pretrained_weights/DWPose/dw-ll_ucoco_384.onnx\",\n    device=device)\n"
  },
  {
    "path": "mimicmotion/dwpose/onnxdet.py",
    "content": "import cv2\nimport numpy as np\n\n\ndef nms(boxes, scores, nms_thr):\n    \"\"\"Single class NMS implemented in Numpy.\n\n    Args:\n        boxes (np.ndarray): shape=(N,4); N is number of boxes\n        scores (np.ndarray): the score of bboxes\n        nms_thr (float): the threshold in NMS \n\n    Returns:\n        List[int]: output bbox ids\n    \"\"\"\n    x1 = boxes[:, 0]\n    y1 = boxes[:, 1]\n    x2 = boxes[:, 2]\n    y2 = boxes[:, 3]\n\n    areas = (x2 - x1 + 1) * (y2 - y1 + 1)\n    order = scores.argsort()[::-1]\n\n    keep = []\n    while order.size > 0:\n        i = order[0]\n        keep.append(i)\n        xx1 = np.maximum(x1[i], x1[order[1:]])\n        yy1 = np.maximum(y1[i], y1[order[1:]])\n        xx2 = np.minimum(x2[i], x2[order[1:]])\n        yy2 = np.minimum(y2[i], y2[order[1:]])\n\n        w = np.maximum(0.0, xx2 - xx1 + 1)\n        h = np.maximum(0.0, yy2 - yy1 + 1)\n        inter = w * h\n        ovr = inter / (areas[i] + areas[order[1:]] - inter)\n\n        inds = np.where(ovr <= nms_thr)[0]\n        order = order[inds + 1]\n\n    return keep\n\ndef multiclass_nms(boxes, scores, nms_thr, score_thr):\n    \"\"\"Multiclass NMS implemented in Numpy. Class-aware version.\n\n    Args:\n        boxes (np.ndarray): shape=(N,4); N is number of boxes\n        scores (np.ndarray): the score of bboxes\n        nms_thr (float): the threshold in NMS \n        score_thr (float): the threshold of cls score\n\n    Returns:\n        np.ndarray: outputs bboxes coordinate\n    \"\"\"\n    final_dets = []\n    num_classes = scores.shape[1]\n    for cls_ind in range(num_classes):\n        cls_scores = scores[:, cls_ind]\n        valid_score_mask = cls_scores > score_thr\n        if valid_score_mask.sum() == 0:\n            continue\n        else:\n            valid_scores = cls_scores[valid_score_mask]\n            valid_boxes = boxes[valid_score_mask]\n            keep = nms(valid_boxes, valid_scores, nms_thr)\n            if len(keep) > 0:\n                cls_inds = np.ones((len(keep), 1)) * cls_ind\n                dets = np.concatenate(\n                    [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1\n                )\n                final_dets.append(dets)\n    if len(final_dets) == 0:\n        return None\n    return np.concatenate(final_dets, 0)\n\ndef demo_postprocess(outputs, img_size, p6=False):\n    grids = []\n    expanded_strides = []\n    strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]\n\n    hsizes = [img_size[0] // stride for stride in strides]\n    wsizes = [img_size[1] // stride for stride in strides]\n\n    for hsize, wsize, stride in zip(hsizes, wsizes, strides):\n        xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))\n        grid = np.stack((xv, yv), 2).reshape(1, -1, 2)\n        grids.append(grid)\n        shape = grid.shape[:2]\n        expanded_strides.append(np.full((*shape, 1), stride))\n\n    grids = np.concatenate(grids, 1)\n    expanded_strides = np.concatenate(expanded_strides, 1)\n    outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides\n    outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides\n\n    return outputs\n\ndef preprocess(img, input_size, swap=(2, 0, 1)):\n    if len(img.shape) == 3:\n        padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114\n    else:\n        padded_img = np.ones(input_size, dtype=np.uint8) * 114\n\n    r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])\n    resized_img = cv2.resize(\n        img,\n        (int(img.shape[1] * r), int(img.shape[0] * r)),\n        interpolation=cv2.INTER_LINEAR,\n    ).astype(np.uint8)\n    padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img\n\n    padded_img = padded_img.transpose(swap)\n    padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)\n    return padded_img, r\n\ndef inference_detector(session, oriImg):\n    \"\"\"run human detect \n    \"\"\"\n    input_shape = (640,640)\n    img, ratio = preprocess(oriImg, input_shape)\n\n    ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}\n    output = session.run(None, ort_inputs)\n    predictions = demo_postprocess(output[0], input_shape)[0]\n\n    boxes = predictions[:, :4]\n    scores = predictions[:, 4:5] * predictions[:, 5:]\n\n    boxes_xyxy = np.ones_like(boxes)\n    boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.\n    boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.\n    boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.\n    boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.\n    boxes_xyxy /= ratio\n    dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)\n    if dets is not None:\n        final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]\n        isscore = final_scores>0.3\n        iscat = final_cls_inds == 0\n        isbbox = [ i and j for (i, j) in zip(isscore, iscat)]\n        final_boxes = final_boxes[isbbox]\n    else:\n        final_boxes = np.array([])\n\n    return final_boxes\n"
  },
  {
    "path": "mimicmotion/dwpose/onnxpose.py",
    "content": "from typing import List, Tuple\n\nimport cv2\nimport numpy as np\nimport onnxruntime as ort\n\ndef preprocess(\n    img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)\n) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:\n    \"\"\"Do preprocessing for RTMPose model inference.\n\n    Args:\n        img (np.ndarray): Input image in shape.\n        input_size (tuple): Input image size in shape (w, h).\n\n    Returns:\n        tuple:\n        - resized_img (np.ndarray): Preprocessed image.\n        - center (np.ndarray): Center of image.\n        - scale (np.ndarray): Scale of image.\n    \"\"\"\n    # get shape of image\n    img_shape = img.shape[:2]\n    out_img, out_center, out_scale = [], [], []\n    if len(out_bbox) == 0:\n        out_bbox = [[0, 0, img_shape[1], img_shape[0]]]\n    for i in range(len(out_bbox)):\n        x0 = out_bbox[i][0]\n        y0 = out_bbox[i][1]\n        x1 = out_bbox[i][2]\n        y1 = out_bbox[i][3]\n        bbox = np.array([x0, y0, x1, y1])\n\n        # get center and scale\n        center, scale = bbox_xyxy2cs(bbox, padding=1.25)\n\n        # do affine transformation\n        resized_img, scale = top_down_affine(input_size, scale, center, img)\n\n        # normalize image\n        mean = np.array([123.675, 116.28, 103.53])\n        std = np.array([58.395, 57.12, 57.375])\n        resized_img = (resized_img - mean) / std\n\n        out_img.append(resized_img)\n        out_center.append(center)\n        out_scale.append(scale)\n\n    return out_img, out_center, out_scale\n\n\ndef inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:\n    \"\"\"Inference RTMPose model.\n\n    Args:\n        sess (ort.InferenceSession): ONNXRuntime session.\n        img (np.ndarray): Input image in shape.\n\n    Returns:\n        outputs (np.ndarray): Output of RTMPose model.\n    \"\"\"\n    all_out = []\n    # build input\n    for i in range(len(img)):\n        input = [img[i].transpose(2, 0, 1)]\n\n        # build output\n        sess_input = {sess.get_inputs()[0].name: input}\n        sess_output = []\n        for out in sess.get_outputs():\n            sess_output.append(out.name)\n\n        # run model\n        outputs = sess.run(sess_output, sess_input)\n        all_out.append(outputs)\n\n    return all_out\n\n\ndef postprocess(outputs: List[np.ndarray],\n                model_input_size: Tuple[int, int],\n                center: Tuple[int, int],\n                scale: Tuple[int, int],\n                simcc_split_ratio: float = 2.0\n                ) -> Tuple[np.ndarray, np.ndarray]:\n    \"\"\"Postprocess for RTMPose model output.\n\n    Args:\n        outputs (np.ndarray): Output of RTMPose model.\n        model_input_size (tuple): RTMPose model Input image size.\n        center (tuple): Center of bbox in shape (x, y).\n        scale (tuple): Scale of bbox in shape (w, h).\n        simcc_split_ratio (float): Split ratio of simcc.\n\n    Returns:\n        tuple:\n        - keypoints (np.ndarray): Rescaled keypoints.\n        - scores (np.ndarray): Model predict scores.\n    \"\"\"\n    all_key = []\n    all_score = []\n    for i in range(len(outputs)):\n        # use simcc to decode\n        simcc_x, simcc_y = outputs[i]\n        keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)\n\n        # rescale keypoints\n        keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2\n        all_key.append(keypoints[0])\n        all_score.append(scores[0])\n\n    return np.array(all_key), np.array(all_score)\n\n\ndef bbox_xyxy2cs(bbox: np.ndarray,\n                 padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:\n    \"\"\"Transform the bbox format from (x,y,w,h) into (center, scale)\n\n    Args:\n        bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted\n            as (left, top, right, bottom)\n        padding (float): BBox padding factor that will be multilied to scale.\n            Default: 1.0\n\n    Returns:\n        tuple: A tuple containing center and scale.\n        - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or\n            (n, 2)\n        - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or\n            (n, 2)\n    \"\"\"\n    # convert single bbox from (4, ) to (1, 4)\n    dim = bbox.ndim\n    if dim == 1:\n        bbox = bbox[None, :]\n\n    # get bbox center and scale\n    x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])\n    center = np.hstack([x1 + x2, y1 + y2]) * 0.5\n    scale = np.hstack([x2 - x1, y2 - y1]) * padding\n\n    if dim == 1:\n        center = center[0]\n        scale = scale[0]\n\n    return center, scale\n\n\ndef _fix_aspect_ratio(bbox_scale: np.ndarray,\n                      aspect_ratio: float) -> np.ndarray:\n    \"\"\"Extend the scale to match the given aspect ratio.\n\n    Args:\n        scale (np.ndarray): The image scale (w, h) in shape (2, )\n        aspect_ratio (float): The ratio of ``w/h``\n\n    Returns:\n        np.ndarray: The reshaped image scale in (2, )\n    \"\"\"\n    w, h = np.hsplit(bbox_scale, [1])\n    bbox_scale = np.where(w > h * aspect_ratio,\n                          np.hstack([w, w / aspect_ratio]),\n                          np.hstack([h * aspect_ratio, h]))\n    return bbox_scale\n\n\ndef _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:\n    \"\"\"Rotate a point by an angle.\n\n    Args:\n        pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )\n        angle_rad (float): rotation angle in radian\n\n    Returns:\n        np.ndarray: Rotated point in shape (2, )\n    \"\"\"\n    sn, cs = np.sin(angle_rad), np.cos(angle_rad)\n    rot_mat = np.array([[cs, -sn], [sn, cs]])\n    return rot_mat @ pt\n\n\ndef _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:\n    \"\"\"To calculate the affine matrix, three pairs of points are required. This\n    function is used to get the 3rd point, given 2D points a & b.\n\n    The 3rd point is defined by rotating vector `a - b` by 90 degrees\n    anticlockwise, using b as the rotation center.\n\n    Args:\n        a (np.ndarray): The 1st point (x,y) in shape (2, )\n        b (np.ndarray): The 2nd point (x,y) in shape (2, )\n\n    Returns:\n        np.ndarray: The 3rd point.\n    \"\"\"\n    direction = a - b\n    c = b + np.r_[-direction[1], direction[0]]\n    return c\n\n\ndef get_warp_matrix(center: np.ndarray,\n                    scale: np.ndarray,\n                    rot: float,\n                    output_size: Tuple[int, int],\n                    shift: Tuple[float, float] = (0., 0.),\n                    inv: bool = False) -> np.ndarray:\n    \"\"\"Calculate the affine transformation matrix that can warp the bbox area\n    in the input image to the output size.\n\n    Args:\n        center (np.ndarray[2, ]): Center of the bounding box (x, y).\n        scale (np.ndarray[2, ]): Scale of the bounding box\n            wrt [width, height].\n        rot (float): Rotation angle (degree).\n        output_size (np.ndarray[2, ] | list(2,)): Size of the\n            destination heatmaps.\n        shift (0-100%): Shift translation ratio wrt the width/height.\n            Default (0., 0.).\n        inv (bool): Option to inverse the affine transform direction.\n            (inv=False: src->dst or inv=True: dst->src)\n\n    Returns:\n        np.ndarray: A 2x3 transformation matrix\n    \"\"\"\n    shift = np.array(shift)\n    src_w = scale[0]\n    dst_w = output_size[0]\n    dst_h = output_size[1]\n\n    # compute transformation matrix\n    rot_rad = np.deg2rad(rot)\n    src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)\n    dst_dir = np.array([0., dst_w * -0.5])\n\n    # get four corners of the src rectangle in the original image\n    src = np.zeros((3, 2), dtype=np.float32)\n    src[0, :] = center + scale * shift\n    src[1, :] = center + src_dir + scale * shift\n    src[2, :] = _get_3rd_point(src[0, :], src[1, :])\n\n    # get four corners of the dst rectangle in the input image\n    dst = np.zeros((3, 2), dtype=np.float32)\n    dst[0, :] = [dst_w * 0.5, dst_h * 0.5]\n    dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir\n    dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])\n\n    if inv:\n        warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))\n    else:\n        warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))\n\n    return warp_mat\n\n\ndef top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,\n                    img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:\n    \"\"\"Get the bbox image as the model input by affine transform.\n\n    Args:\n        input_size (dict): The input size of the model.\n        bbox_scale (dict): The bbox scale of the img.\n        bbox_center (dict): The bbox center of the img.\n        img (np.ndarray): The original image.\n\n    Returns:\n        tuple: A tuple containing center and scale.\n        - np.ndarray[float32]: img after affine transform.\n        - np.ndarray[float32]: bbox scale after affine transform.\n    \"\"\"\n    w, h = input_size\n    warp_size = (int(w), int(h))\n\n    # reshape bbox to fixed aspect ratio\n    bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)\n\n    # get the affine matrix\n    center = bbox_center\n    scale = bbox_scale\n    rot = 0\n    warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))\n\n    # do affine transform\n    img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)\n\n    return img, bbox_scale\n\n\ndef get_simcc_maximum(simcc_x: np.ndarray,\n                      simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:\n    \"\"\"Get maximum response location and value from simcc representations.\n\n    Note:\n        instance number: N\n        num_keypoints: K\n        heatmap height: H\n        heatmap width: W\n\n    Args:\n        simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)\n        simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)\n\n    Returns:\n        tuple:\n        - locs (np.ndarray): locations of maximum heatmap responses in shape\n            (K, 2) or (N, K, 2)\n        - vals (np.ndarray): values of maximum heatmap responses in shape\n            (K,) or (N, K)\n    \"\"\"\n    N, K, Wx = simcc_x.shape\n    simcc_x = simcc_x.reshape(N * K, -1)\n    simcc_y = simcc_y.reshape(N * K, -1)\n\n    # get maximum value locations\n    x_locs = np.argmax(simcc_x, axis=1)\n    y_locs = np.argmax(simcc_y, axis=1)\n    locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)\n    max_val_x = np.amax(simcc_x, axis=1)\n    max_val_y = np.amax(simcc_y, axis=1)\n\n    # get maximum value across x and y axis\n    mask = max_val_x > max_val_y\n    max_val_x[mask] = max_val_y[mask]\n    vals = max_val_x\n    locs[vals <= 0.] = -1\n\n    # reshape\n    locs = locs.reshape(N, K, 2)\n    vals = vals.reshape(N, K)\n\n    return locs, vals\n\n\ndef decode(simcc_x: np.ndarray, simcc_y: np.ndarray,\n           simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:\n    \"\"\"Modulate simcc distribution with Gaussian.\n\n    Args:\n        simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.\n        simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.\n        simcc_split_ratio (int): The split ratio of simcc.\n\n    Returns:\n        tuple: A tuple containing center and scale.\n        - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)\n        - np.ndarray[float32]: scores in shape (K,) or (n, K)\n    \"\"\"\n    keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)\n    keypoints /= simcc_split_ratio\n\n    return keypoints, scores\n\n\ndef inference_pose(session, out_bbox, oriImg):\n    \"\"\"run pose detect \n\n    Args:\n        session (ort.InferenceSession): ONNXRuntime session.\n        out_bbox (np.ndarray): bbox list\n        oriImg (np.ndarray): Input image in shape.\n\n    Returns:\n        tuple:\n        - keypoints (np.ndarray): Rescaled keypoints.\n        - scores (np.ndarray): Model predict scores.\n    \"\"\"\n    h, w = session.get_inputs()[0].shape[2:]\n    model_input_size = (w, h)\n    # preprocess for rtm-pose model inference.\n    resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)\n    # run pose estimation for processed img\n    outputs = inference(session, resized_img)\n    # postprocess for rtm-pose model output.\n    keypoints, scores = postprocess(outputs, model_input_size, center, scale)\n\n    return keypoints, scores\n"
  },
  {
    "path": "mimicmotion/dwpose/preprocess.py",
    "content": "from tqdm import tqdm\nimport decord\nimport numpy as np\n\nfrom .util import draw_pose\nfrom .dwpose_detector import dwpose_detector as dwprocessor\n\ndef get_video_pose(\n        video_path: str, \n        ref_image: np.ndarray, \n        sample_stride: int=1):\n    \"\"\"preprocess ref image pose and video pose\n\n    Args:\n        video_path (str): video pose path\n        ref_image (np.ndarray): reference image \n        sample_stride (int, optional): Defaults to 1.\n\n    Returns:\n        np.ndarray: sequence of video pose\n    \"\"\"\n    # select ref-keypoint from reference pose for pose rescale\n    ref_pose = dwprocessor(ref_image)\n    ref_keypoint_id = [0, 1, 2, 5, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]\n    ref_keypoint_id = [i for i in ref_keypoint_id \\\n        if len(ref_pose['bodies']['subset']) > 0 and ref_pose['bodies']['subset'][0][i] >= .0]\n    ref_body = ref_pose['bodies']['candidate'][ref_keypoint_id]\n\n    height, width, _ = ref_image.shape\n\n    # read input video\n    vr = decord.VideoReader(video_path, ctx=decord.cpu(0))\n    sample_stride *= max(1, int(vr.get_avg_fps() / 24))\n\n    frames = vr.get_batch(list(range(0, len(vr), sample_stride))).asnumpy()\n    detected_poses = [dwprocessor(frm) for frm in tqdm(frames, desc=\"DWPose\")]\n    dwprocessor.release_memory()\n\n    detected_bodies = np.stack(\n        [p['bodies']['candidate'] for p in detected_poses if p['bodies']['candidate'].shape[0] == 18])[:,\n                      ref_keypoint_id]\n    # compute linear-rescale params\n    ay, by = np.polyfit(detected_bodies[:, :, 1].flatten(), np.tile(ref_body[:, 1], len(detected_bodies)), 1)\n    fh, fw, _ = vr[0].shape\n    ax = ay / (fh / fw / height * width)\n    bx = np.mean(np.tile(ref_body[:, 0], len(detected_bodies)) - detected_bodies[:, :, 0].flatten() * ax)\n    a = np.array([ax, ay])\n    b = np.array([bx, by])\n    output_pose = []\n    # pose rescale \n    body_point = []\n    face_point = []\n    for detected_pose in detected_poses:\n        detected_pose['bodies']['candidate'] = detected_pose['bodies']['candidate'] * a + b\n        detected_pose['faces'] = detected_pose['faces'] * a + b\n        detected_pose['hands'] = detected_pose['hands'] * a + b\n        im = draw_pose(detected_pose, height, width)\n        output_pose.append(np.array(im))\n        body_point.append(detected_pose['bodies'])\n        face_point.append(detected_pose['faces'])\n    return np.stack(output_pose), body_point, face_point\n\n\ndef get_image_pose(ref_image):\n    \"\"\"process image pose\n\n    Args:\n        ref_image (np.ndarray): reference image pixel value\n\n    Returns:\n        np.ndarray: pose visual image in RGB-mode\n    \"\"\"\n    height, width, _ = ref_image.shape\n    ref_pose = dwprocessor(ref_image)\n    pose_img = draw_pose(ref_pose, height, width)\n    return np.array(pose_img), ref_pose\n"
  },
  {
    "path": "mimicmotion/dwpose/util.py",
    "content": "import math\nimport numpy as np\nimport matplotlib\nimport cv2\nimport pdb\n\neps = 0.01\n\ndef alpha_blend_color(color, alpha):\n    \"\"\"blend color according to point conf\n    \"\"\"\n    return [int(c * alpha) for c in color]\n\ndef draw_bodypose(canvas, candidate, subset, score):\n    H, W, C = canvas.shape\n    candidate = np.array(candidate)\n    subset = np.array(subset)\n\n    stickwidth = 4\n\n    limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \\\n               [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \\\n               [1, 16], [16, 18], [3, 17], [6, 18]]\n\n    colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \\\n              [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \\\n              [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]\n\n    for i in range(17):\n        for n in range(len(subset)):\n            index = subset[n][np.array(limbSeq[i]) - 1]\n            conf = score[n][np.array(limbSeq[i]) - 1]\n            if conf[0] < 0.3 or conf[1] < 0.3:\n                continue\n            Y = candidate[index.astype(int), 0] * float(W)\n            X = candidate[index.astype(int), 1] * float(H)\n            mX = np.mean(X)\n            mY = np.mean(Y)\n            length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5\n            angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))\n            polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)\n            cv2.fillConvexPoly(canvas, polygon, alpha_blend_color(colors[i], conf[0] * conf[1]))\n\n    canvas = (canvas * 0.6).astype(np.uint8)\n\n    for i in range(18):\n        for n in range(len(subset)):\n            index = int(subset[n][i])\n            if index == -1:\n                continue\n            x, y = candidate[index][0:2]\n            conf = score[n][i]\n            x = int(x * W)\n            y = int(y * H)\n            cv2.circle(canvas, (int(x), int(y)), 4, alpha_blend_color(colors[i], conf), thickness=-1)\n\n    return canvas\n\ndef draw_handpose(canvas, all_hand_peaks, all_hand_scores):\n    H, W, C = canvas.shape\n\n    edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \\\n             [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]\n\n    for peaks, scores in zip(all_hand_peaks, all_hand_scores):\n\n        for ie, e in enumerate(edges):\n            x1, y1 = peaks[e[0]]\n            x2, y2 = peaks[e[1]]\n            x1 = int(x1 * W)\n            y1 = int(y1 * H)\n            x2 = int(x2 * W)\n            y2 = int(y2 * H)\n            score = int(scores[e[0]] * scores[e[1]] * 255)\n            if x1 > eps and y1 > eps and x2 > eps and y2 > eps:\n                cv2.line(canvas, (x1, y1), (x2, y2), \n                         matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * score, thickness=2)\n\n        for i, keyponit in enumerate(peaks):\n            x, y = keyponit\n            x = int(x * W)\n            y = int(y * H)\n            score = int(scores[i] * 255)\n            if x > eps and y > eps:\n                cv2.circle(canvas, (x, y), 4, (0, 0, score), thickness=-1)\n    return canvas\n\ndef draw_facepose(canvas, all_lmks, all_scores):\n    H, W, C = canvas.shape\n    for lmks, scores in zip(all_lmks, all_scores):\n        for lmk, score in zip(lmks, scores):\n            x, y = lmk\n            x = int(x * W)\n            y = int(y * H)\n            conf = int(score * 255)\n            if x > eps and y > eps:\n                cv2.circle(canvas, (x, y), 3, (conf, conf, conf), thickness=-1)\n    return canvas\n\ndef draw_pose(pose, H, W, ref_w=2160):\n    \"\"\"vis dwpose outputs\n\n    Args:\n        pose (List): DWposeDetector outputs in dwpose_detector.py\n        H (int): height\n        W (int): width\n        ref_w (int, optional) Defaults to 2160.\n\n    Returns:\n        np.ndarray: image pixel value in RGB mode\n    \"\"\"\n    bodies = pose['bodies']\n    faces = pose['faces']\n    hands = pose['hands']\n    candidate = bodies['candidate']\n    subset = bodies['subset']\n\n    sz = min(H, W)\n    sr = (ref_w / sz) if sz != ref_w else 1\n\n    ########################################## create zero canvas ##################################################\n    canvas = np.zeros(shape=(int(H*sr), int(W*sr), 3), dtype=np.uint8)\n\n    ########################################### draw body pose #####################################################\n    canvas = draw_bodypose(canvas, candidate, subset, score=bodies['score'])\n\n    ########################################### draw hand pose #####################################################\n    canvas = draw_handpose(canvas, hands, pose['hands_score'])\n\n    ########################################### draw face pose #####################################################\n    canvas = draw_facepose(canvas, faces, pose['faces_score'])\n\n    return cv2.cvtColor(cv2.resize(canvas, (W, H)), cv2.COLOR_BGR2RGB).transpose(2, 0, 1)\n"
  },
  {
    "path": "mimicmotion/dwpose/wholebody.py",
    "content": "import numpy as np\nimport onnxruntime as ort\n\nfrom .onnxdet import inference_detector\nfrom .onnxpose import inference_pose\n\n\nclass Wholebody:\n    \"\"\"detect human pose by dwpose\n    \"\"\"\n    def __init__(self, model_det, model_pose, device=\"cpu\"):\n        providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider']\n        provider_options = None if device == 'cpu' else [{'device_id': 0}]\n\n        self.session_det = ort.InferenceSession(\n            path_or_bytes=model_det, providers=providers,  provider_options=provider_options\n        )\n        self.session_pose = ort.InferenceSession(\n            path_or_bytes=model_pose, providers=providers, provider_options=provider_options\n        )\n    \n    def __call__(self, oriImg):\n        \"\"\"call to process dwpose-detect\n\n        Args:\n            oriImg (np.ndarray): detected image\n\n        \"\"\"\n        det_result = inference_detector(self.session_det, oriImg)\n        keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)\n\n        keypoints_info = np.concatenate(\n            (keypoints, scores[..., None]), axis=-1)\n        # compute neck joint\n        neck = np.mean(keypoints_info[:, [5, 6]], axis=1)\n        # neck score when visualizing pred\n        neck[:, 2:4] = np.logical_and(\n            keypoints_info[:, 5, 2:4] > 0.3,\n            keypoints_info[:, 6, 2:4] > 0.3).astype(int)\n        new_keypoints_info = np.insert(\n            keypoints_info, 17, neck, axis=1)\n        mmpose_idx = [\n            17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3\n        ]\n        openpose_idx = [\n            1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17\n        ]\n        new_keypoints_info[:, openpose_idx] = \\\n            new_keypoints_info[:, mmpose_idx]\n        keypoints_info = new_keypoints_info\n\n        keypoints, scores = keypoints_info[\n            ..., :2], keypoints_info[..., 2]\n        \n        return keypoints, scores\n\n\n"
  },
  {
    "path": "mimicmotion/modules/__init__.py",
    "content": ""
  },
  {
    "path": "mimicmotion/modules/attention.py",
    "content": "from dataclasses import dataclass\nfrom typing import Any, Dict, Optional\n\nimport torch\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.models.attention import BasicTransformerBlock, TemporalBasicTransformerBlock\nfrom diffusers.models.embeddings import TimestepEmbedding, Timesteps\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.models.resnet import AlphaBlender\nfrom diffusers.utils import BaseOutput\nfrom inspect import isfunction\nimport math\nimport torch.nn.functional as F\nfrom torch import nn, einsum\nfrom einops import rearrange, repeat\n\n@dataclass\nclass TransformerTemporalModelOutput(BaseOutput):\n    \"\"\"\n    The output of [`TransformerTemporalModel`].\n\n    Args:\n        sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):\n            The hidden states output conditioned on `encoder_hidden_states` input.\n    \"\"\"\n\n    sample: torch.FloatTensor\n\n\nclass TransformerTemporalModel(ModelMixin, ConfigMixin):\n    \"\"\"\n    A Transformer model for video-like data.\n\n    Parameters:\n        num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.\n        attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.\n        in_channels (`int`, *optional*):\n            The number of channels in the input and output (specify if the input is **continuous**).\n        num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.\n        cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.\n        attention_bias (`bool`, *optional*):\n            Configure if the `TransformerBlock` attention should contain a bias parameter.\n        sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).\n            This is fixed during training since it is used to learn a number of position embeddings.\n        activation_fn (`str`, *optional*, defaults to `\"geglu\"`):\n            Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported\n            activation functions.\n        norm_elementwise_affine (`bool`, *optional*):\n            Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.\n        double_self_attention (`bool`, *optional*):\n            Configure if each `TransformerBlock` should contain two self-attention layers.\n        positional_embeddings: (`str`, *optional*):\n            The type of positional embeddings to apply to the sequence input before passing use.\n        num_positional_embeddings: (`int`, *optional*):\n            The maximum length of the sequence over which to apply positional embeddings.\n    \"\"\"\n\n    @register_to_config\n    def __init__(\n            self,\n            num_attention_heads: int = 16,\n            attention_head_dim: int = 88,\n            in_channels: Optional[int] = None,\n            out_channels: Optional[int] = None,\n            num_layers: int = 1,\n            dropout: float = 0.0,\n            norm_num_groups: int = 32,\n            cross_attention_dim: Optional[int] = None,\n            attention_bias: bool = False,\n            sample_size: Optional[int] = None,\n            activation_fn: str = \"geglu\",\n            norm_elementwise_affine: bool = True,\n            double_self_attention: bool = True,\n            positional_embeddings: Optional[str] = None,\n            num_positional_embeddings: Optional[int] = None,\n    ):\n        super().__init__()\n        self.num_attention_heads = num_attention_heads\n        self.attention_head_dim = attention_head_dim\n        inner_dim = num_attention_heads * attention_head_dim\n\n        self.in_channels = in_channels\n\n        self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)\n        self.proj_in = nn.Linear(in_channels, inner_dim)\n\n        # 3. Define transformers blocks\n        self.transformer_blocks = nn.ModuleList(\n            [\n                BasicTransformerBlock(\n                    inner_dim,\n                    num_attention_heads,\n                    attention_head_dim,\n                    dropout=dropout,\n                    cross_attention_dim=cross_attention_dim,\n                    activation_fn=activation_fn,\n                    attention_bias=attention_bias,\n                    double_self_attention=double_self_attention,\n                    norm_elementwise_affine=norm_elementwise_affine,\n                    positional_embeddings=positional_embeddings,\n                    num_positional_embeddings=num_positional_embeddings,\n                )\n                for d in range(num_layers)\n            ]\n        )\n\n        self.proj_out = nn.Linear(inner_dim, in_channels)\n\n    def forward(\n            self,\n            hidden_states: torch.FloatTensor,\n            encoder_hidden_states: Optional[torch.LongTensor] = None,\n            timestep: Optional[torch.LongTensor] = None,\n            class_labels: torch.LongTensor = None,\n            num_frames: int = 1,\n            cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n            return_dict: bool = True,\n    ) -> TransformerTemporalModelOutput:\n        \"\"\"\n        The [`TransformerTemporal`] forward method.\n\n        Args:\n            hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, \n                `torch.FloatTensor` of shape `(batch size, channel, height, width)`if continuous): Input hidden_states.\n            encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):\n                Conditional embeddings for cross attention layer. If not given, cross-attention defaults to\n                self-attention.\n            timestep ( `torch.LongTensor`, *optional*):\n                Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.\n            class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):\n                Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in\n                `AdaLayerZeroNorm`.\n            num_frames (`int`, *optional*, defaults to 1):\n                The number of frames to be processed per batch. This is used to reshape the hidden states.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in [diffusers.models.attention_processor](\n                https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain\n                tuple.\n\n        Returns:\n            [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:\n                If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is\n                returned, otherwise a `tuple` where the first element is the sample tensor.\n        \"\"\"\n        # 1. Input\n        batch_frames, channel, height, width = hidden_states.shape\n        batch_size = batch_frames // num_frames\n\n        residual = hidden_states\n\n        hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)\n        hidden_states = hidden_states.permute(0, 2, 1, 3, 4)\n\n        hidden_states = self.norm(hidden_states)\n        hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)\n\n        hidden_states = self.proj_in(hidden_states)\n\n        # 2. Blocks\n        for block in self.transformer_blocks:\n            hidden_states = block(\n                hidden_states,\n                encoder_hidden_states=encoder_hidden_states,\n                timestep=timestep,\n                cross_attention_kwargs=cross_attention_kwargs,\n                class_labels=class_labels,\n            )\n\n        # 3. Output\n        hidden_states = self.proj_out(hidden_states)\n        hidden_states = (\n            hidden_states[None, None, :]\n            .reshape(batch_size, height, width, num_frames, channel)\n            .permute(0, 3, 4, 1, 2)\n            .contiguous()\n        )\n        hidden_states = hidden_states.reshape(batch_frames, channel, height, width)\n\n        output = hidden_states + residual\n\n        if not return_dict:\n            return (output,)\n\n        return TransformerTemporalModelOutput(sample=output)\n\n\nclass TransformerSpatioTemporalModel(nn.Module):\n    \"\"\"\n    A Transformer model for video-like data.\n\n    Parameters:\n        num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.\n        attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.\n        in_channels (`int`, *optional*):\n            The number of channels in the input and output (specify if the input is **continuous**).\n        out_channels (`int`, *optional*):\n            The number of channels in the output (specify if the input is **continuous**).\n        num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.\n        cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.\n    \"\"\"\n\n    def __init__(\n            self,\n            num_attention_heads: int = 16,\n            attention_head_dim: int = 88,\n            in_channels: int = 320,\n            out_channels: Optional[int] = None,\n            num_layers: int = 1,\n            cross_attention_dim: Optional[int] = None,\n    ):\n        super().__init__()\n        self.num_attention_heads = num_attention_heads\n        self.attention_head_dim = attention_head_dim\n\n        inner_dim = num_attention_heads * attention_head_dim\n        self.inner_dim = inner_dim\n\n        # 2. Define input layers\n        self.in_channels = in_channels\n        self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)\n        self.proj_in = nn.Linear(in_channels, inner_dim)\n\n        # 3. Define transformers blocks\n        self.transformer_blocks = nn.ModuleList(\n            [\n                BasicTransformerBlock(\n                    inner_dim,\n                    num_attention_heads,\n                    attention_head_dim,\n                    cross_attention_dim=cross_attention_dim,\n                )\n                for d in range(num_layers)\n            ]\n        )\n\n        time_mix_inner_dim = inner_dim\n        self.temporal_transformer_blocks = nn.ModuleList(\n            [\n                TemporalBasicTransformerBlock(\n                    inner_dim,\n                    time_mix_inner_dim,\n                    num_attention_heads,\n                    attention_head_dim,\n                    cross_attention_dim=cross_attention_dim,\n                )\n                for _ in range(num_layers)\n            ]\n        )\n\n        time_embed_dim = in_channels * 4\n        self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)\n        self.time_proj = Timesteps(in_channels, True, 0)\n        self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy=\"learned_with_images\")\n\n        # 4. Define output layers\n        self.out_channels = in_channels if out_channels is None else out_channels\n        # TODO: should use out_channels for continuous projections\n        self.proj_out = nn.Linear(inner_dim, in_channels)\n\n        self.gradient_checkpointing = False\n\n    def forward(\n            self,\n            hidden_states: torch.Tensor,\n            encoder_hidden_states: Optional[torch.Tensor] = None,\n            image_only_indicator: Optional[torch.Tensor] = None,\n            return_dict: bool = True,\n    ):\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):\n                Input hidden_states.\n            num_frames (`int`):\n                The number of frames to be processed per batch. This is used to reshape the hidden states.\n            encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):\n                Conditional embeddings for cross attention layer. If not given, cross-attention defaults to\n                self-attention.\n            image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):\n                A tensor indicating whether the input contains only images. 1 indicates that the input contains only\n                images, 0 indicates that the input contains video frames.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] \n                instead of a plain tuple.\n\n        Returns:\n            [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:\n                If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is\n                returned, otherwise a `tuple` where the first element is the sample tensor.\n        \"\"\"\n        # 1. Input\n        batch_frames, _, height, width = hidden_states.shape\n        num_frames = image_only_indicator.shape[-1]\n        batch_size = batch_frames // num_frames\n\n        time_context = encoder_hidden_states\n        time_context_first_timestep = time_context[None, :].reshape(\n            batch_size, num_frames, -1, time_context.shape[-1]\n        )[:, 0]\n        time_context = time_context_first_timestep[None, :].broadcast_to(\n            height * width, batch_size, 1, time_context.shape[-1]\n        )\n        time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])\n\n        residual = hidden_states\n\n        hidden_states = self.norm(hidden_states)\n        inner_dim = hidden_states.shape[1]\n        hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)\n        hidden_states = torch.utils.checkpoint.checkpoint(self.proj_in, hidden_states)\n\n        num_frames_emb = torch.arange(num_frames, device=hidden_states.device)\n        num_frames_emb = num_frames_emb.repeat(batch_size, 1)\n        num_frames_emb = num_frames_emb.reshape(-1)\n        t_emb = self.time_proj(num_frames_emb)\n\n        # `Timesteps` does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=hidden_states.dtype)\n\n        emb = self.time_pos_embed(t_emb)\n        emb = emb[:, None, :]\n\n        # 2. Blocks\n        for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):\n            if self.gradient_checkpointing:\n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    block,\n                    hidden_states,\n                    None,\n                    encoder_hidden_states,\n                    None,\n                    use_reentrant=False,\n                )\n            else:\n                hidden_states = block(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                )\n\n            hidden_states_mix = hidden_states\n            hidden_states_mix = hidden_states_mix + emb\n\n            if self.gradient_checkpointing:\n                hidden_states_mix = torch.utils.checkpoint.checkpoint(\n                    temporal_block,\n                    hidden_states_mix,\n                    num_frames,\n                    time_context,\n                )\n                hidden_states = self.time_mixer(\n                    x_spatial=hidden_states,\n                    x_temporal=hidden_states_mix,\n                    image_only_indicator=image_only_indicator,\n                )\n            else:\n                hidden_states_mix = temporal_block(\n                    hidden_states_mix,\n                    num_frames=num_frames,\n                    encoder_hidden_states=time_context,\n                )\n                hidden_states = self.time_mixer(\n                    x_spatial=hidden_states,\n                    x_temporal=hidden_states_mix,\n                    image_only_indicator=image_only_indicator,\n                )\n\n        # 3. Output\n        hidden_states = torch.utils.checkpoint.checkpoint(self.proj_out, hidden_states)\n        hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()\n\n        output = hidden_states + residual\n\n        if not return_dict:\n            return (output,)\n\n        return TransformerTemporalModelOutput(sample=output)\n\n\n# from ldm.modules.diffusionmodules.util import checkpoint\n\n\ndef exists(val):\n    return val is not None\n\n\ndef uniq(arr):\n    return{el: True for el in arr}.keys()\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\ndef max_neg_value(t):\n    return -torch.finfo(t.dtype).max\n\n\ndef init_(tensor):\n    dim = tensor.shape[-1]\n    std = 1 / math.sqrt(dim)\n    tensor.uniform_(-std, std)\n    return tensor\n\n\n# feedforward\nclass GEGLU(nn.Module):\n    def __init__(self, dim_in, dim_out):\n        super().__init__()\n        self.proj = nn.Linear(dim_in, dim_out * 2)\n\n    def forward(self, x):\n        x, gate = self.proj(x).chunk(2, dim=-1)\n        return x * F.gelu(gate)\n\n\nclass FeedForward(nn.Module):\n    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):\n        super().__init__()\n        inner_dim = int(dim * mult)\n        dim_out = default(dim_out, dim)\n        project_in = nn.Sequential(\n            nn.Linear(dim, inner_dim),\n            nn.GELU()\n        ) if not glu else GEGLU(dim, inner_dim)\n\n        self.net = nn.Sequential(\n            project_in,\n            nn.Dropout(dropout),\n            nn.Linear(inner_dim, dim_out)\n        )\n\n    def forward(self, x):\n        return self.net(x)\n\n\ndef zero_module(module):\n    \"\"\"\n    Zero out the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().zero_()\n    return module\n\n\ndef Normalize(in_channels):\n    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)\n\n\nclass LinearAttention(nn.Module):\n    def __init__(self, dim, heads=4, dim_head=32):\n        super().__init__()\n        self.heads = heads\n        hidden_dim = dim_head * heads\n        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)\n        self.to_out = nn.Conv2d(hidden_dim, dim, 1)\n\n    def forward(self, x):\n        b, c, h, w = x.shape\n        qkv = self.to_qkv(x)\n        q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)\n        k = k.softmax(dim=-1)  \n        context = torch.einsum('bhdn,bhen->bhde', k, v)\n        out = torch.einsum('bhde,bhdn->bhen', context, q)\n        out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)\n        return self.to_out(out)\n\n\nclass SpatialSelfAttention(nn.Module):\n    def __init__(self, in_channels):\n        super().__init__()\n        self.in_channels = in_channels\n\n        self.norm = Normalize(in_channels)\n        self.q = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.k = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.v = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.proj_out = torch.nn.Conv2d(in_channels,\n                                        in_channels,\n                                        kernel_size=1,\n                                        stride=1,\n                                        padding=0)\n\n    def forward(self, x):\n        h_ = x\n        h_ = self.norm(h_)\n        q = self.q(h_)\n        k = self.k(h_)\n        v = self.v(h_)\n\n        # compute attention\n        b,c,h,w = q.shape\n        q = rearrange(q, 'b c h w -> b (h w) c')\n        k = rearrange(k, 'b c h w -> b c (h w)')\n        w_ = torch.einsum('bij,bjk->bik', q, k)\n\n        w_ = w_ * (int(c)**(-0.5))\n        w_ = torch.nn.functional.softmax(w_, dim=2)\n\n        # attend to values\n        v = rearrange(v, 'b c h w -> b c (h w)')\n        w_ = rearrange(w_, 'b i j -> b j i')\n        h_ = torch.einsum('bij,bjk->bik', v, w_)\n        h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)\n        h_ = self.proj_out(h_)\n\n        return x+h_\n\n\nclass CrossAttention(nn.Module):\n    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):\n        super().__init__()\n        inner_dim = dim_head * heads\n        context_dim = default(context_dim, query_dim)\n\n        self.scale = dim_head ** -0.5\n        self.heads = heads\n\n        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)\n        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)\n        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)\n\n        self.to_out = nn.Sequential(\n            nn.Linear(inner_dim, query_dim),\n            nn.Dropout(dropout)\n        )\n\n    def forward(self, x, context=None, mask=None):\n        h = self.heads\n\n        q = self.to_q(x)\n        context = default(context, x)\n        k = self.to_k(context)\n        v = self.to_v(context)\n\n        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))\n\n        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale\n\n        if exists(mask):\n            mask = rearrange(mask, 'b ... -> b (...)')\n            max_neg_value = -torch.finfo(sim.dtype).max\n            mask = repeat(mask, 'b j -> (b h) () j', h=h)\n            sim.masked_fill_(~mask, max_neg_value)\n\n        # attention, what we cannot get enough of\n        attn = sim.softmax(dim=-1)\n\n        out = einsum('b i j, b j d -> b i d', attn, v)\n        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)\n        return self.to_out(out)\n\n\nclass BasicTransformerBlock(nn.Module):\n    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True):\n        super().__init__()\n        self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout)  # is a self-attention\n        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)\n        self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,\n                                    heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none\n        self.norm1 = nn.LayerNorm(dim)\n        self.norm2 = nn.LayerNorm(dim)\n        self.norm3 = nn.LayerNorm(dim)\n\n    def forward(self, x, context=None):\n        x = self.attn1(self.norm1(x)) + x\n        x = self.attn2(self.norm2(x), context=context) + x\n        x = self.ff(self.norm3(x)) + x\n        return x\n\n\nclass SpatialTransformer(nn.Module):\n    \"\"\"\n    Transformer block for image-like data.\n    First, project the input (aka embedding)\n    and reshape to b, t, d.\n    Then apply standard transformer action.\n    Finally, reshape to image\n    \"\"\"\n    def __init__(self, in_channels, n_heads=8, d_head=64,\n                 depth=1, dropout=0., context_dim=None):\n        super().__init__()\n        self.in_channels = in_channels\n        inner_dim = n_heads * d_head\n        self.norm = Normalize(in_channels)\n\n        self.proj_in = nn.Conv2d(in_channels,\n                                 inner_dim,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n\n        self.transformer_blocks = nn.ModuleList(\n            [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)\n                for d in range(depth)]\n        )\n\n        self.proj_out = zero_module(nn.Conv2d(inner_dim,\n                                              in_channels,\n                                              kernel_size=1,\n                                              stride=1,\n                                              padding=0))\n\n    def forward(self, x, context=None):\n        # note: if no context is given, cross-attention defaults to self-attention\n        b, c, h, w = x.shape\n        x_in = x\n        x = self.norm(x)\n        x = self.proj_in(x)\n        x = rearrange(x, 'b c h w -> b (h w) c')\n        for block in self.transformer_blocks:\n            x = block(x, context=context)\n        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)\n        x = self.proj_out(x)\n        return x + x_in, x"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/config.yaml",
    "content": "model:\n    arch: CMP\n    total_iter: 140000\n    lr_steps: [80000, 120000]\n    lr_mults: [0.1, 0.1]\n    lr: 0.1\n    optim: SGD\n    warmup_lr: []\n    warmup_steps: []\n    module:\n        arch: CMP\n        image_encoder: alexnet_fcn_32x\n        sparse_encoder: shallownet32x\n        flow_decoder: MotionDecoderPlain\n        skip_layer: False\n        img_enc_dim: 256\n        sparse_enc_dim: 16\n        output_dim: 198\n        decoder_combo: [1,2,4]\n        pretrained_image_encoder: False\n        flow_criterion: \"DiscreteLoss\"\n        nbins: 99\n        fmax: 50\ndata:\n    workers: 2\n    batch_size: 12\n    batch_size_test: 1\n    data_mean: [123.675, 116.28, 103.53] # RGB\n    data_div: [58.395, 57.12, 57.375]\n    short_size: 416\n    crop_size: [384, 384]\n    sample_strategy: ['grid', 'watershed']\n    sample_bg_ratio: 0.000025\n    nms_ks: 81\n    max_num_guide: 150\n\n    flow_file_type: \"jpg\"\n    image_flow_aug:\n        flip: False\n    flow_aug:\n        reverse: False\n        scale: False\n        rotate: False\n    train_source:\n        - data/yfcc/lists/train.txt\n        - data/youtube9000/lists/train.txt\n    val_source:\n        - data/yfcc/lists/val.txt\n    memcached: False\ntrainer:\n    initial_val: True\n    print_freq: 100\n    val_freq: 10000\n    save_freq: 10000\n    val_iter: -1\n    val_disp_start_iter: 0\n    val_disp_end_iter: 16\n    loss_record: ['loss_flow']\n    tensorboard: False\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/resume.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npython -m torch.distributed.launch --nproc_per_node=8 \\\n    --nnodes=2 --node_rank=$1 \\\n    --master_addr=\"192.168.1.1\" main.py \\\n    --config $work_path/config.yaml --launcher pytorch \\\n    --load-iter 10000 \\\n    --resume\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/resume_slurm.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npartition=$1\nGLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \\\n    --gres=gpu:8 --ntasks-per-node=8 \\\n    python -u main.py \\\n        --config $work_path/config.yaml --launcher slurm \\\n        --load-iter 10000 \\\n        --resume\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/train.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npython -m torch.distributed.launch --nproc_per_node=8 \\\n    --nnodes=2 --node_rank=$1 \\\n    --master_addr=\"192.168.1.1\" main.py \\\n    --config $work_path/config.yaml --launcher pytorch\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/train_slurm.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npartition=$1\nGLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \\\n    --gres=gpu:8 --ntasks-per-node=8 \\\n    python -u main.py \\\n        --config $work_path/config.yaml --launcher slurm\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/validate.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npython -m torch.distributed.launch --nproc_per_node=8 main.py \\\n    --config $work_path/config.yaml --launcher pytorch \\\n    --load-iter 70000 \\\n    --validate\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/validate_slurm.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npartition=$1\nGLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \\\n    --gres=gpu:8 --ntasks-per-node=8 \\\n    python -u main.py --config $work_path/config.yaml --launcher slurm \\\n        --load-iter 70000 \\\n        --validate\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/config.yaml",
    "content": "model:\n    arch: CMP\n    total_iter: 70000\n    lr_steps: [40000, 60000]\n    lr_mults: [0.1, 0.1]\n    lr: 0.1\n    optim: SGD\n    warmup_lr: []\n    warmup_steps: []\n    module:\n        arch: CMP\n        image_encoder: alexnet_fcn_32x\n        sparse_encoder: shallownet32x\n        flow_decoder: MotionDecoderPlain\n        skip_layer: False\n        img_enc_dim: 256\n        sparse_enc_dim: 16\n        output_dim: 198\n        decoder_combo: [1,2,4]\n        pretrained_image_encoder: False\n        flow_criterion: \"DiscreteLoss\"\n        nbins: 99\n        fmax: 50\ndata:\n    workers: 2\n    batch_size: 12\n    batch_size_test: 1\n    data_mean: [123.675, 116.28, 103.53] # RGB\n    data_div: [58.395, 57.12, 57.375]\n    short_size: 416\n    crop_size: [384, 384]\n    sample_strategy: ['grid', 'watershed']\n    sample_bg_ratio: 0.00015625\n    nms_ks: 41\n    max_num_guide: 150\n\n    flow_file_type: \"jpg\"\n    image_flow_aug:\n        flip: False\n    flow_aug:\n        reverse: False\n        scale: False\n        rotate: False\n    train_source:\n        - data/yfcc/lists/train.txt\n    val_source:\n        - data/yfcc/lists/val.txt\n    memcached: False\ntrainer:\n    initial_val: True\n    print_freq: 100\n    val_freq: 10000\n    save_freq: 10000\n    val_iter: -1\n    val_disp_start_iter: 0\n    val_disp_end_iter: 16\n    loss_record: ['loss_flow']\n    tensorboard: False\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/resume.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npython -m torch.distributed.launch --nproc_per_node=8 \\\n    --nnodes=2 --node_rank=$1 \\\n    --master_addr=\"192.168.1.1\" main.py \\\n    --config $work_path/config.yaml --launcher pytorch \\\n    --load-iter 10000 \\\n    --resume\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/resume_slurm.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npartition=$1\nGLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \\\n    --gres=gpu:8 --ntasks-per-node=8 \\\n    python -u main.py \\\n        --config $work_path/config.yaml --launcher slurm \\\n        --load-iter 10000 \\\n        --resume\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/train.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npython -m torch.distributed.launch --nproc_per_node=8 \\\n    --nnodes=2 --node_rank=$1 \\\n    --master_addr=\"192.168.1.1\" main.py \\\n    --config $work_path/config.yaml --launcher pytorch\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/train_slurm.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npartition=$1\nGLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \\\n    --gres=gpu:8 --ntasks-per-node=8 \\\n    python -u main.py \\\n        --config $work_path/config.yaml --launcher slurm\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/validate.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npython -m torch.distributed.launch --nproc_per_node=8 main.py \\\n    --config $work_path/config.yaml --launcher pytorch \\\n    --load-iter 70000 \\\n    --validate\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/validate_slurm.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npartition=$1\nGLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \\\n    --gres=gpu:8 --ntasks-per-node=8 \\\n    python -u main.py --config $work_path/config.yaml --launcher slurm \\\n        --load-iter 70000 \\\n        --validate\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/config.yaml",
    "content": "model:\n    arch: CMP\n    total_iter: 140000\n    lr_steps: [80000, 120000]\n    lr_mults: [0.1, 0.1]\n    lr: 0.1\n    optim: SGD\n    warmup_lr: []\n    warmup_steps: []\n    module:\n        arch: CMP\n        image_encoder: alexnet_fcn_32x\n        sparse_encoder: shallownet32x\n        flow_decoder: MotionDecoderPlain\n        skip_layer: False\n        img_enc_dim: 256\n        sparse_enc_dim: 16\n        output_dim: 198\n        decoder_combo: [1,2,4]\n        pretrained_image_encoder: False\n        flow_criterion: \"DiscreteLoss\"\n        nbins: 99\n        fmax: 50\ndata:\n    workers: 2\n    batch_size: 12\n    batch_size_test: 1\n    data_mean: [123.675, 116.28, 103.53] # RGB\n    data_div: [58.395, 57.12, 57.375]\n    short_size: 416\n    crop_size: [384, 384]\n    sample_strategy: ['grid', 'watershed']\n    sample_bg_ratio: 0.00015625\n    nms_ks: 41\n    max_num_guide: 150\n\n    flow_file_type: \"jpg\"\n    image_flow_aug:\n        flip: False\n    flow_aug:\n        reverse: False\n        scale: False\n        rotate: False\n    train_source:\n        - data/yfcc/lists/train.txt\n    val_source:\n        - data/yfcc/lists/val.txt\n    memcached: False\ntrainer:\n    initial_val: True\n    print_freq: 100\n    val_freq: 10000\n    save_freq: 10000\n    val_iter: -1\n    val_disp_start_iter: 0\n    val_disp_end_iter: 16\n    loss_record: ['loss_flow']\n    tensorboard: False\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/resume.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npython -m torch.distributed.launch --nproc_per_node=8 main.py \\\n    --config $work_path/config.yaml --launcher pytorch \\\n    --load-iter 10000 \\\n    --resume\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/resume_slurm.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npartition=$1\nGLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \\\n    --gres=gpu:8 --ntasks-per-node=8 \\\n    python -u main.py \\\n        --config $work_path/config.yaml --launcher slurm \\\n        --load-iter 10000 \\\n        --resume\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/train.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npython -m torch.distributed.launch --nproc_per_node=8 main.py \\\n    --config $work_path/config.yaml --launcher pytorch\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/train_slurm.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npartition=$1\nGLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \\\n    --gres=gpu:8 --ntasks-per-node=8 \\\n    python -u main.py \\\n        --config $work_path/config.yaml --launcher slurm\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/validate.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npython -m torch.distributed.launch --nproc_per_node=8 main.py \\\n    --config $work_path/config.yaml --launcher pytorch \\\n    --load-iter 70000 \\\n    --validate\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/validate_slurm.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npartition=$1\nGLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \\\n    --gres=gpu:8 --ntasks-per-node=8 \\\n    python -u main.py --config $work_path/config.yaml --launcher slurm \\\n        --load-iter 70000 \\\n        --validate\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/config.yaml",
    "content": "model:\n    arch: CMP\n    total_iter: 70000\n    lr_steps: [40000, 60000]\n    lr_mults: [0.1, 0.1]\n    lr: 0.1\n    optim: SGD\n    warmup_lr: []\n    warmup_steps: []\n    module:\n        arch: CMP\n        image_encoder: resnet50\n        sparse_encoder: shallownet8x\n        flow_decoder: MotionDecoderPlain\n        skip_layer: False\n        img_enc_dim: 256\n        sparse_enc_dim: 16\n        output_dim: 198\n        decoder_combo: [1,2,4]\n        pretrained_image_encoder: False\n        flow_criterion: \"DiscreteLoss\"\n        nbins: 99\n        fmax: 50\ndata:\n    workers: 2\n    batch_size: 10\n    batch_size_test: 1\n    data_mean: [123.675, 116.28, 103.53] # RGB\n    data_div: [58.395, 57.12, 57.375]\n    short_size: 416\n    crop_size: [320, 320]\n    sample_strategy: ['grid', 'watershed']\n    sample_bg_ratio: 0.00015625\n    nms_ks: 15\n    max_num_guide: -1\n\n    flow_file_type: \"jpg\"\n    image_flow_aug:\n        flip: False\n    flow_aug:\n        reverse: False\n        scale: False\n        rotate: False\n    train_source:\n        - data/yfcc/lists/train.txt\n        - data/youtube9000/lists/train.txt\n        - data/VIP/lists/train.txt\n        - data/MPII/lists/train.txt\n    val_source:\n        - data/yfcc/lists/val.txt\n    memcached: False\ntrainer:\n    initial_val: True\n    print_freq: 100\n    val_freq: 10000\n    save_freq: 10000\n    val_iter: -1\n    val_disp_start_iter: 0\n    val_disp_end_iter: 16\n    loss_record: ['loss_flow']\n    tensorboard: False\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/resume.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npython -m torch.distributed.launch --nproc_per_node=8 \\\n    --nnodes=2 --node_rank=$1 \\\n    --master_addr=\"192.168.1.1\" main.py \\\n    --config $work_path/config.yaml --launcher pytorch \\\n    --load-iter 10000 \\\n    --resume\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/resume_slurm.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npartition=$1\nGLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \\\n    --gres=gpu:8 --ntasks-per-node=8 \\\n    python -u main.py \\\n        --config $work_path/config.yaml --launcher slurm \\\n        --load-iter 10000 \\\n        --resume\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/train.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npython -m torch.distributed.launch --nproc_per_node=8 \\\n    --nnodes=2 --node_rank=$1 \\\n    --master_addr=\"192.168.1.1\" main.py \\\n    --config $work_path/config.yaml --launcher pytorch\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/train_slurm.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npartition=$1\nGLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \\\n    --gres=gpu:8 --ntasks-per-node=8 \\\n    python -u main.py \\\n        --config $work_path/config.yaml --launcher slurm\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/validate.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npython -m torch.distributed.launch --nproc_per_node=8 main.py \\\n    --config $work_path/config.yaml --launcher pytorch \\\n    --load-iter 70000 \\\n    --validate\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/validate_slurm.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npartition=$1\nGLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \\\n    --gres=gpu:8 --ntasks-per-node=8 \\\n    python -u main.py --config $work_path/config.yaml --launcher slurm \\\n        --load-iter 70000 \\\n        --validate\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/config.yaml",
    "content": "model:\n    arch: CMP\n    total_iter: 42000\n    lr_steps: [24000, 36000]\n    lr_mults: [0.1, 0.1]\n    lr: 0.1\n    optim: SGD\n    warmup_lr: []\n    warmup_steps: []\n    module:\n        arch: CMP\n        image_encoder: resnet50\n        sparse_encoder: shallownet8x\n        flow_decoder: MotionDecoderPlain\n        skip_layer: False\n        img_enc_dim: 256\n        sparse_enc_dim: 16\n        output_dim: 198\n        decoder_combo: [1,2,4]\n        pretrained_image_encoder: False\n        flow_criterion: \"DiscreteLoss\"\n        nbins: 99\n        fmax: 50\ndata:\n    workers: 2\n    batch_size: 16\n    batch_size_test: 1\n    data_mean: [123.675, 116.28, 103.53] # RGB\n    data_div: [58.395, 57.12, 57.375]\n    short_size: 333\n    crop_size: [256, 256]\n    sample_strategy: ['grid', 'watershed']\n    sample_bg_ratio: 0.00005632\n    nms_ks: 49\n    max_num_guide: -1\n\n    flow_file_type: \"jpg\"\n    image_flow_aug:\n        flip: False\n    flow_aug:\n        reverse: False\n        scale: False\n        rotate: False\n    train_source:\n        - data/yfcc/lists/train.txt\n    val_source:\n        - data/yfcc/lists/val.txt\n    memcached: False\ntrainer:\n    initial_val: True\n    print_freq: 100\n    val_freq: 10000\n    save_freq: 10000\n    val_iter: -1\n    val_disp_start_iter: 0\n    val_disp_end_iter: 16\n    loss_record: ['loss_flow']\n    tensorboard: False\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/resume.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npython -m torch.distributed.launch --nproc_per_node=8 \\\n    --nnodes=2 --node_rank=$1 \\\n    --master_addr=\"192.168.1.1\" main.py \\\n    --config $work_path/config.yaml --launcher pytorch \\\n    --load-iter 10000 \\\n    --resume\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/resume_slurm.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npartition=$1\nGLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \\\n    --gres=gpu:8 --ntasks-per-node=8 \\\n    python -u main.py \\\n        --config $work_path/config.yaml --launcher slurm \\\n        --load-iter 10000 \\\n        --resume\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/train.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npython -m torch.distributed.launch --nproc_per_node=8 \\\n    --nnodes=2 --node_rank=$1 \\\n    --master_addr=\"192.168.1.1\" main.py \\\n    --config $work_path/config.yaml --launcher pytorch\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/train_slurm.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npartition=$1\nGLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \\\n    --gres=gpu:8 --ntasks-per-node=8 \\\n    python -u main.py \\\n        --config $work_path/config.yaml --launcher slurm\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/validate.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npython -m torch.distributed.launch --nproc_per_node=8 main.py \\\n    --config $work_path/config.yaml --launcher pytorch \\\n    --load-iter 70000 \\\n    --validate\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/validate_slurm.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npartition=$1\nGLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \\\n    --gres=gpu:8 --ntasks-per-node=8 \\\n    python -u main.py --config $work_path/config.yaml --launcher slurm \\\n        --load-iter 70000 \\\n        --validate\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/config.yaml",
    "content": "model:\n    arch: CMP\n    total_iter: 42000\n    lr_steps: [24000, 36000]\n    lr_mults: [0.1, 0.1]\n    lr: 0.1\n    optim: SGD\n    warmup_lr: []\n    warmup_steps: []\n    module:\n        arch: CMP\n        image_encoder: resnet50\n        sparse_encoder: shallownet8x\n        flow_decoder: MotionDecoderPlain\n        skip_layer: False\n        img_enc_dim: 256\n        sparse_enc_dim: 16\n        output_dim: 198\n        decoder_combo: [1,2,4]\n        pretrained_image_encoder: False\n        flow_criterion: \"DiscreteLoss\"\n        nbins: 99\n        fmax: 50\ndata:\n    workers: 2\n    batch_size: 10\n    batch_size_test: 1\n    data_mean: [123.675, 116.28, 103.53] # RGB\n    data_div: [58.395, 57.12, 57.375]\n    short_size: 416\n    crop_size: [320, 320]\n    sample_strategy: ['grid', 'watershed']\n    sample_bg_ratio: 0.00003629\n    nms_ks: 67\n    max_num_guide: -1\n\n    flow_file_type: \"jpg\"\n    image_flow_aug:\n        flip: False\n    flow_aug:\n        reverse: False\n        scale: False\n        rotate: False\n    train_source:\n        - data/yfcc/lists/train.txt\n    val_source:\n        - data/yfcc/lists/val.txt\n    memcached: False\ntrainer:\n    initial_val: True\n    print_freq: 100\n    val_freq: 10000\n    save_freq: 10000\n    val_iter: -1\n    val_disp_start_iter: 0\n    val_disp_end_iter: 16\n    loss_record: ['loss_flow']\n    tensorboard: False\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/resume.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npython -m torch.distributed.launch --nproc_per_node=8 \\\n    --nnodes=2 --node_rank=$1 \\\n    --master_addr=\"192.168.1.1\" main.py \\\n    --config $work_path/config.yaml --launcher pytorch \\\n    --load-iter 10000 \\\n    --resume\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/resume_slurm.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npartition=$1\nGLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \\\n    --gres=gpu:8 --ntasks-per-node=8 \\\n    python -u main.py \\\n        --config $work_path/config.yaml --launcher slurm \\\n        --load-iter 10000 \\\n        --resume\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/train.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npython -m torch.distributed.launch --nproc_per_node=8 \\\n    --nnodes=2 --node_rank=$1 \\\n    --master_addr=\"192.168.1.1\" main.py \\\n    --config $work_path/config.yaml --launcher pytorch\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/train_slurm.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npartition=$1\nGLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \\\n    --gres=gpu:8 --ntasks-per-node=8 \\\n    python -u main.py \\\n        --config $work_path/config.yaml --launcher slurm\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/validate.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npython -m torch.distributed.launch --nproc_per_node=8 main.py \\\n    --config $work_path/config.yaml --launcher pytorch \\\n    --load-iter 70000 \\\n    --validate\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/validate_slurm.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npartition=$1\nGLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \\\n    --gres=gpu:8 --ntasks-per-node=8 \\\n    python -u main.py --config $work_path/config.yaml --launcher slurm \\\n        --load-iter 70000 \\\n        --validate\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/config.yaml",
    "content": "model:\n    arch: CMP\n    total_iter: 42000\n    lr_steps: [24000, 36000]\n    lr_mults: [0.1, 0.1]\n    lr: 0.1\n    optim: SGD\n    warmup_lr: []\n    warmup_steps: []\n    module:\n        arch: CMP\n        image_encoder: resnet50\n        sparse_encoder: shallownet8x\n        flow_decoder: MotionDecoderSkipLayer\n        skip_layer: True\n        img_enc_dim: 256\n        sparse_enc_dim: 16\n        output_dim: 198\n        decoder_combo: [1,2,4]\n        pretrained_image_encoder: False\n        flow_criterion: \"DiscreteLoss\"\n        nbins: 99\n        fmax: 50\ndata:\n    workers: 2\n    batch_size: 8\n    batch_size_test: 1\n    data_mean: [123.675, 116.28, 103.53] # RGB\n    data_div: [58.395, 57.12, 57.375]\n    short_size: 416\n    crop_size: [384, 384]\n    sample_strategy: ['grid', 'watershed']\n    sample_bg_ratio: 5.74e-5\n    nms_ks: 41\n    max_num_guide: -1\n\n    flow_file_type: \"jpg\"\n    image_flow_aug:\n        flip: False\n    flow_aug:\n        reverse: False\n        scale: False\n        rotate: False\n    train_source:\n        - data/VIP/lists/train.txt\n        - data/MPII/lists/train.txt\n    val_source:\n        - data/VIP/lists/randval.txt\n    memcached: False\ntrainer:\n    initial_val: True\n    print_freq: 100\n    val_freq: 5000\n    save_freq: 5000\n    val_iter: -1\n    val_disp_start_iter: 0\n    val_disp_end_iter: 16\n    loss_record: ['loss_flow']\n    tensorboard: True\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/resume.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npython -m torch.distributed.launch --nproc_per_node=8 main.py \\\n    --config $work_path/config.yaml --launcher pytorch \\\n    --load-iter 10000 \\\n    --resume\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/resume_slurm.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npartition=$1\nGLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \\\n    --gres=gpu:8 --ntasks-per-node=8 \\\n    python -u main.py \\\n        --config $work_path/config.yaml --launcher slurm \\\n        --load-iter 10000 \\\n        --resume\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/train.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npython -m torch.distributed.launch --nproc_per_node=8 main.py \\\n    --config $work_path/config.yaml --launcher pytorch\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/train_slurm.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npartition=$1\nGLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \\\n    --gres=gpu:8 --ntasks-per-node=8 \\\n    python -u main.py \\\n        --config $work_path/config.yaml --launcher slurm\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/validate.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npython -m torch.distributed.launch --nproc_per_node=8 main.py \\\n    --config $work_path/config.yaml --launcher pytorch \\\n    --load-iter 70000 \\\n    --validate\n"
  },
  {
    "path": "mimicmotion/modules/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/validate_slurm.sh",
    "content": "#!/bin/bash\nwork_path=$(dirname $0)\npartition=$1\nGLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition \\\n    -n8 --gres=gpu:8 --ntasks-per-node=8 \\\n    python -u main.py --config $work_path/config.yaml --launcher slurm \\\n        --load-iter 70000 \\\n        --validate\n"
  },
  {
    "path": "mimicmotion/modules/cmp/losses.py",
    "content": "import torch\nimport numpy as np\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\nimport random\nimport math\n\ndef MultiChannelSoftBinaryCrossEntropy(input, target, reduction='mean'):\n    '''\n    input: N x 38 x H x W --> 19N x 2 x H x W\n    target: N x 19 x H x W --> 19N x 1 x H x W\n    '''\n    input = input.view(-1, 2, input.size(2), input.size(3))\n    target = target.view(-1, 1, input.size(2), input.size(3))\n\n    logsoftmax = nn.LogSoftmax(dim=1)\n    if reduction == 'mean':\n        return torch.mean(torch.sum(-target * logsoftmax(input), dim=1))\n    else:\n        return torch.sum(torch.sum(-target * logsoftmax(input), dim=1))\n\nclass EdgeAwareLoss():\n    def __init__(self, nc=2, loss_type=\"L1\", reduction='mean'):\n        assert loss_type in ['L1', 'BCE'], \"Undefined loss type: {}\".format(loss_type)\n        self.nc = nc\n        self.loss_type = loss_type\n        self.kernelx = Variable(torch.Tensor([[1,0,-1],[2,0,-2],[1,0,-1]]).cuda())\n        self.kernelx = self.kernelx.repeat(nc,1,1,1)\n        self.kernely = Variable(torch.Tensor([[1,2,1],[0,0,0],[-1,-2,-1]]).cuda())\n        self.kernely = self.kernely.repeat(nc,1,1,1)\n        self.bias = Variable(torch.zeros(nc).cuda())\n        self.reduction = reduction\n        if loss_type == 'L1':\n            self.loss = nn.SmoothL1Loss(reduction=reduction)\n        elif loss_type == 'BCE':\n            self.loss = self.bce2d\n\n    def bce2d(self, input, target):\n        assert not target.requires_grad\n        beta = 1 - torch.mean(target)\n        weights = 1 - beta + (2 * beta - 1)  * target\n        loss = nn.functional.binary_cross_entropy(input, target, weights, reduction=self.reduction)\n        return loss\n\n    def get_edge(self, var):\n        assert var.size(1) == self.nc, \\\n            \"input size at dim 1 should be consistent with nc, {} vs {}\".format(var.size(1), self.nc)\n        outputx = nn.functional.conv2d(var, self.kernelx, bias=self.bias, padding=1, groups=self.nc)\n        outputy = nn.functional.conv2d(var, self.kernely, bias=self.bias, padding=1, groups=self.nc)\n        eps=1e-05\n        return torch.sqrt(outputx.pow(2) + outputy.pow(2) + eps).mean(dim=1, keepdim=True)\n\n    def __call__(self, input, target):\n        size = target.shape[2:4]\n        input = nn.functional.interpolate(input, size=size, mode=\"bilinear\", align_corners=True)\n        target_edge = self.get_edge(target)\n        if self.loss_type == 'L1':\n            return self.loss(self.get_edge(input), target_edge)\n        elif self.loss_type == 'BCE':\n            raise NotImplemented\n            #target_edge = torch.sign(target_edge - 0.1)\n            #pred = self.get_edge(nn.functional.sigmoid(input))\n            #return self.loss(pred, target_edge)\n\ndef KLD(mean, logvar):\n    return -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())\n\nclass DiscreteLoss(nn.Module):\n    def __init__(self, nbins, fmax):\n        super().__init__()\n        self.loss = nn.CrossEntropyLoss()\n        assert nbins % 2 == 1, \"nbins should be odd\"\n        self.nbins = nbins\n        self.fmax = fmax\n        self.step = 2 * fmax / float(nbins)\n\n    def tobin(self, target):\n        target = torch.clamp(target, -self.fmax + 1e-3, self.fmax - 1e-3)\n        quantized_target = torch.floor((target + self.fmax) / self.step)\n        return quantized_target.type(torch.cuda.LongTensor)\n\n    def __call__(self, input, target):\n        size = target.shape[2:4]\n        if input.shape[2] != size[0] or input.shape[3] != size[1]:\n            input = nn.functional.interpolate(input, size=size, mode=\"bilinear\", align_corners=True)\n        target = self.tobin(target)\n        assert input.size(1) == self.nbins * 2\n        # print(target.shape)\n        # print(input.shape)\n        # print(torch.max(target))\n        target[target>=99]=98  # odd bugs of the training loss. We have [0 ~ 99] in GT flow, but nbins = 99\n        return self.loss(input[:,:self.nbins,...], target[:,0,...]) + self.loss(input[:,self.nbins:,...], target[:,1,...])\n\nclass MultiDiscreteLoss():\n    def __init__(self, nbins=19, fmax=47.5, reduction='mean', xy_weight=(1., 1.), quantize_strategy='linear'):\n        self.loss = nn.CrossEntropyLoss(reduction=reduction)\n        assert nbins % 2 == 1, \"nbins should be odd\"\n        self.nbins = nbins\n        self.fmax = fmax\n        self.step = 2 * fmax / float(nbins)\n        self.x_weight, self.y_weight = xy_weight\n        self.quantize_strategy = quantize_strategy\n\n    def tobin(self, target):\n        target = torch.clamp(target, -self.fmax + 1e-3, self.fmax - 1e-3)\n        if self.quantize_strategy == \"linear\":\n            quantized_target = torch.floor((target + self.fmax) / self.step)\n        elif self.quantize_strategy == \"quadratic\":\n            ind = target.data > 0\n            quantized_target = target.clone()\n            quantized_target[ind] = torch.floor(self.nbins * torch.sqrt(target[ind] / (4 * self.fmax)) + self.nbins / 2.)\n            quantized_target[~ind] = torch.floor(-self.nbins * torch.sqrt(-target[~ind] / (4 * self.fmax)) + self.nbins / 2.)\n        return quantized_target.type(torch.cuda.LongTensor)\n\n    def __call__(self, input, target):\n        size = target.shape[2:4]\n        target = self.tobin(target)\n        if isinstance(input, list):\n            input = [nn.functional.interpolate(ip, size=size, mode=\"bilinear\", align_corners=True) for ip in input]\n            return sum([self.x_weight * self.loss(input[k][:,:self.nbins,...], target[:,0,...]) + self.y_weight * self.loss(input[k][:,self.nbins:,...], target[:,1,...]) for k in range(len(input))]) / float(len(input))\n        else:\n            input = nn.functional.interpolate(input, size=size, mode=\"bilinear\", align_corners=True)\n            return self.x_weight * self.loss(input[:,:self.nbins,...], target[:,0,...]) + self.y_weight * self.loss(input[:,self.nbins:,...], target[:,1,...])\n\nclass MultiL1Loss():\n    def __init__(self, reduction='mean'):\n        self.loss = nn.SmoothL1Loss(reduction=reduction)\n\n    def __call__(self, input, target):\n        size = target.shape[2:4]\n        if isinstance(input, list):\n            input = [nn.functional.interpolate(ip, size=size, mode=\"bilinear\", align_corners=True) for ip in input]\n            return sum([self.loss(input[k], target) for k in range(len(input))]) / float(len(input))\n        else:\n            input = nn.functional.interpolate(input, size=size, mode=\"bilinear\", align_corners=True)\n            return self.loss(input, target)\n\nclass MultiMSELoss():\n    def __init__(self):\n        self.loss = nn.MSELoss()\n    \n    def __call__(self, predicts, targets):\n        loss = 0\n        for predict, target in zip(predicts, targets):\n            loss += self.loss(predict, target)\n        return loss\n        \nclass JointDiscreteLoss():\n    def __init__(self, nbins=19, fmax=47.5, reduction='mean', quantize_strategy='linear'):\n        self.loss = nn.CrossEntropyLoss(reduction=reduction)\n        assert nbins % 2 == 1, \"nbins should be odd\"\n        self.nbins = nbins\n        self.fmax = fmax\n        self.step = 2 * fmax / float(nbins)\n        self.quantize_strategy = quantize_strategy\n        \n    def tobin(self, target):\n        target = torch.clamp(target, -self.fmax + 1e-3, self.fmax - 1e-3)\n        if self.quantize_strategy == \"linear\":\n            quantized_target = torch.floor((target + self.fmax) / self.step)\n        elif self.quantize_strategy == \"quadratic\":\n            ind = target.data > 0\n            quantized_target = target.clone()\n            quantized_target[ind] = torch.floor(self.nbins * torch.sqrt(target[ind] / (4 * self.fmax)) + self.nbins / 2.)\n            quantized_target[~ind] = torch.floor(-self.nbins * torch.sqrt(-target[~ind] / (4 * self.fmax)) + self.nbins / 2.)\n        else:\n            raise Exception(\"No such quantize strategy: {}\".format(self.quantize_strategy))\n        joint_target = quantized_target[:,0,:,:] * self.nbins + quantized_target[:,1,:,:]\n        return joint_target.type(torch.cuda.LongTensor)\n\n    def __call__(self, input, target):\n        target = self.tobin(target)\n        assert input.size(1) == self.nbins ** 2\n        return self.loss(input, target)\n\nclass PolarDiscreteLoss():\n    def __init__(self, abins=30, rbins=20, fmax=50., reduction='mean', ar_weight=(1., 1.), quantize_strategy='linear'):\n        self.loss = nn.CrossEntropyLoss(reduction=reduction)\n        self.fmax = fmax\n        self.rbins = rbins\n        self.abins = abins\n        self.a_weight, self.r_weight = ar_weight\n        self.quantize_strategy = quantize_strategy\n\n    def tobin(self, target):\n        indxneg = target.data[:,0,:,:] < 0\n        eps = torch.zeros(target.data[:,0,:,:].size()).cuda()\n        epsind = target.data[:,0,:,:] == 0\n        eps[epsind] += 1e-5\n        angle = torch.atan(target.data[:,1,:,:] / (target.data[:,0,:,:] + eps))\n        angle[indxneg] += np.pi\n        angle += np.pi / 2 # 0 to 2pi\n        angle = torch.clamp(angle, 0, 2 * np.pi - 1e-3)\n        radius = torch.sqrt(target.data[:,0,:,:] ** 2 + target.data[:,1,:,:] ** 2)\n        radius = torch.clamp(radius, 0, self.fmax - 1e-3)\n        quantized_angle = torch.floor(self.abins * angle / (2 * np.pi))\n        if self.quantize_strategy == 'linear':\n            quantized_radius = torch.floor(self.rbins * radius / self.fmax)\n        elif self.quantize_strategy == 'quadratic':\n            quantized_radius = torch.floor(self.rbins * torch.sqrt(radius / self.fmax))\n        else:\n            raise Exception(\"No such quantize strategy: {}\".format(self.quantize_strategy))\n        quantized_target = torch.autograd.Variable(torch.cat([torch.unsqueeze(quantized_angle, 1), torch.unsqueeze(quantized_radius, 1)], dim=1))\n        return quantized_target.type(torch.cuda.LongTensor)\n\n    def __call__(self, input, target):\n        target = self.tobin(target)\n        assert (target >= 0).all() and (target[:,0,:,:] < self.abins).all() and (target[:,1,:,:] < self.rbins).all()\n        return self.a_weight * self.loss(input[:,:self.abins,...], target[:,0,...]) + self.r_weight * self.loss(input[:,self.abins:,...], target[:,1,...])\n\nclass WeightedDiscreteLoss():\n    def __init__(self, nbins=19, fmax=47.5, reduction='mean'):\n        self.loss = CrossEntropy2d(reduction=reduction)\n        assert nbins % 2 == 1, \"nbins should be odd\"\n        self.nbins = nbins\n        self.fmax = fmax\n        self.step = 2 * fmax / float(nbins)\n        self.weight = np.ones((nbins), dtype=np.float32)\n        self.weight[int(self.fmax / self.step)] = 0.01\n        self.weight = torch.from_numpy(self.weight).cuda()\n\n    def tobin(self, target):\n        target = torch.clamp(target, -self.fmax + 1e-3, self.fmax - 1e-3)\n        return torch.floor((target + self.fmax) / self.step).type(torch.cuda.LongTensor)\n\n    def __call__(self, input, target):\n        target = self.tobin(target)\n        assert (target >= 0).all() and (target < self.nbins).all()\n        return self.loss(input[:,:self.nbins,...], target[:,0,...]) + self.loss(input[:,self.nbins:,...], target[:,1,...], self.weight)\n\n\nclass CrossEntropy2d(nn.Module):\n    def __init__(self, reduction='mean', ignore_label=-1):\n        super(CrossEntropy2d, self).__init__()\n        self.ignore_label = ignore_label\n        self.reduction = reduction\n\n    def forward(self, predict, target, weight=None):\n        \"\"\"\n            Args:\n                predict:(n, c, h, w)\n                target:(n, h, w)\n                weight (Tensor, optional): a manual rescaling weight given to each class.\n                                           If given, has to be a Tensor of size \"nclasses\"\n        \"\"\"\n        assert not target.requires_grad\n        assert predict.dim() == 4\n        assert target.dim() == 3\n        assert predict.size(0) == target.size(0), \"{0} vs {1} \".format(predict.size(0), target.size(0))\n        assert predict.size(2) == target.size(1), \"{0} vs {1} \".format(predict.size(2), target.size(1))\n        assert predict.size(3) == target.size(2), \"{0} vs {1} \".format(predict.size(3), target.size(3))\n        n, c, h, w = predict.size()\n        target_mask = (target >= 0) * (target != self.ignore_label)\n        target = target[target_mask]\n        predict = predict.transpose(1, 2).transpose(2, 3).contiguous()\n        predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c)\n        loss = F.cross_entropy(predict, target, weight=weight, reduction=self.reduction)\n        return loss\n\n#class CrossPixelSimilarityLoss():\n#    '''\n#        Modified from: https://github.com/lppllppl920/Challenge2018/blob/master/loss.py\n#    '''\n#    def __init__(self, sigma=0.0036, sampling_size=512):\n#        self.sigma = sigma\n#        self.sampling_size = sampling_size\n#        self.epsilon = 1.0e-15\n#        self.embed_norm = True # loss does not decrease no matter it is true or false.\n#\n#    def __call__(self, embeddings, flows):\n#        '''\n#            embedding: Variable Nx256xHxW (not hyper-column)\n#            flows: Variable Nx2xHxW\n#        '''\n#        assert flows.size(1) == 2\n#\n#        # flow normalization\n#        positive_mask = (flows > 0)\n#        flows = -torch.clamp(torch.log(torch.abs(flows) + 1) / math.log(50. + 1), max=1.)\n#        flows[positive_mask] = -flows[positive_mask]\n#\n#        # embedding normalization\n#        if self.embed_norm:\n#            embeddings /= torch.norm(embeddings, p=2, dim=1, keepdim=True)\n#\n#        # Spatially random sampling (512 samples)\n#        flows_flatten = flows.view(flows.shape[0], 2, -1)\n#        random_locations = Variable(torch.from_numpy(np.array(random.sample(range(flows_flatten.shape[2]), self.sampling_size))).long().cuda())\n#        flows_sample = torch.index_select(flows_flatten, 2, random_locations)\n#\n#        # K_f\n#        k_f = self.epsilon + torch.norm(torch.unsqueeze(flows_sample, dim=-1).permute(0, 3, 2, 1) -\n#                                        torch.unsqueeze(flows_sample, dim=-1).permute(0, 2, 3, 1), p=2, dim=3,\n#                                        keepdim=False) ** 2\n#        exp_k_f = torch.exp(-k_f / 2. / self.sigma)\n#\n#        \n#        # mask\n#        eye = Variable(torch.unsqueeze(torch.eye(k_f.shape[1]), dim=0).cuda())\n#        mask = torch.ones_like(exp_k_f) - eye\n#\n#        # S_f\n#        masked_exp_k_f = torch.mul(mask, exp_k_f) + eye\n#        s_f = masked_exp_k_f / torch.sum(masked_exp_k_f, dim=1, keepdim=True)\n#\n#        # K_theta\n#        embeddings_flatten = embeddings.view(embeddings.shape[0], embeddings.shape[1], -1)\n#        embeddings_sample = torch.index_select(embeddings_flatten, 2, random_locations)\n#        embeddings_sample_norm = torch.norm(embeddings_sample, p=2, dim=1, keepdim=True)\n#        k_theta = 0.25 * (torch.matmul(embeddings_sample.permute(0, 2, 1), embeddings_sample)) / (self.epsilon + torch.matmul(embeddings_sample_norm.permute(0, 2, 1), embeddings_sample_norm))\n#        exp_k_theta = torch.exp(k_theta)\n#\n#        # S_theta\n#        masked_exp_k_theta = torch.mul(mask, exp_k_theta) + math.exp(-0.75) * eye\n#        s_theta = masked_exp_k_theta / torch.sum(masked_exp_k_theta, dim=1, keepdim=True)\n#\n#        # loss\n#        loss = -torch.mean(torch.mul(s_f, torch.log(s_theta)))\n#\n#        return loss\n\nclass CrossPixelSimilarityLoss():\n    '''\n        Modified from: https://github.com/lppllppl920/Challenge2018/blob/master/loss.py\n    '''\n    def __init__(self, sigma=0.01, sampling_size=512):\n        self.sigma = sigma\n        self.sampling_size = sampling_size\n        self.epsilon = 1.0e-15\n        self.embed_norm = True # loss does not decrease no matter it is true or false.\n\n    def __call__(self, embeddings, flows):\n        '''\n            embedding: Variable Nx256xHxW (not hyper-column)\n            flows: Variable Nx2xHxW\n        '''\n        assert flows.size(1) == 2\n\n        # flow normalization\n        positive_mask = (flows > 0)\n        flows = -torch.clamp(torch.log(torch.abs(flows) + 1) / math.log(50. + 1), max=1.)\n        flows[positive_mask] = -flows[positive_mask]\n\n        # embedding normalization\n        if self.embed_norm:\n            embeddings /= torch.norm(embeddings, p=2, dim=1, keepdim=True)\n\n        # Spatially random sampling (512 samples)\n        flows_flatten = flows.view(flows.shape[0], 2, -1)\n        random_locations = Variable(torch.from_numpy(np.array(random.sample(range(flows_flatten.shape[2]), self.sampling_size))).long().cuda())\n        flows_sample = torch.index_select(flows_flatten, 2, random_locations)\n\n        # K_f\n        k_f = self.epsilon + torch.norm(torch.unsqueeze(flows_sample, dim=-1).permute(0, 3, 2, 1) -\n                                        torch.unsqueeze(flows_sample, dim=-1).permute(0, 2, 3, 1), p=2, dim=3,\n                                        keepdim=False) ** 2\n        exp_k_f = torch.exp(-k_f / 2. / self.sigma)\n\n        \n        # mask\n        eye = Variable(torch.unsqueeze(torch.eye(k_f.shape[1]), dim=0).cuda())\n        mask = torch.ones_like(exp_k_f) - eye\n\n        # S_f\n        masked_exp_k_f = torch.mul(mask, exp_k_f) + eye\n        s_f = masked_exp_k_f / torch.sum(masked_exp_k_f, dim=1, keepdim=True)\n\n        # K_theta\n        embeddings_flatten = embeddings.view(embeddings.shape[0], embeddings.shape[1], -1)\n        embeddings_sample = torch.index_select(embeddings_flatten, 2, random_locations)\n        embeddings_sample_norm = torch.norm(embeddings_sample, p=2, dim=1, keepdim=True)\n        k_theta = 0.25 * (torch.matmul(embeddings_sample.permute(0, 2, 1), embeddings_sample)) / (self.epsilon + torch.matmul(embeddings_sample_norm.permute(0, 2, 1), embeddings_sample_norm))\n        exp_k_theta = torch.exp(k_theta)\n\n        # S_theta\n        masked_exp_k_theta = torch.mul(mask, exp_k_theta) + eye\n        s_theta = masked_exp_k_theta / torch.sum(masked_exp_k_theta, dim=1, keepdim=True)\n\n        # loss\n        loss = -torch.mean(torch.mul(s_f, torch.log(s_theta)))\n\n        return loss\n\n\nclass CrossPixelSimilarityFullLoss():\n    '''\n        Modified from: https://github.com/lppllppl920/Challenge2018/blob/master/loss.py\n    '''\n    def __init__(self, sigma=0.01):\n        self.sigma = sigma\n        self.epsilon = 1.0e-15\n        self.embed_norm = True # loss does not decrease no matter it is true or false.\n\n    def __call__(self, embeddings, flows):\n        '''\n            embedding: Variable Nx256xHxW (not hyper-column)\n            flows: Variable Nx2xHxW\n        '''\n        assert flows.size(1) == 2\n\n        # downsample flow\n        factor = flows.shape[2] // embeddings.shape[2]\n        flows = nn.functional.avg_pool2d(flows, factor, factor)\n        assert flows.shape[2] == embeddings.shape[2]\n\n        # flow normalization\n        positive_mask = (flows > 0)\n        flows = -torch.clamp(torch.log(torch.abs(flows) + 1) / math.log(50. + 1), max=1.)\n        flows[positive_mask] = -flows[positive_mask]\n\n        # embedding normalization\n        if self.embed_norm:\n            embeddings /= torch.norm(embeddings, p=2, dim=1, keepdim=True)\n\n        # Spatially random sampling (512 samples)\n        flows_flatten = flows.view(flows.shape[0], 2, -1)\n        #random_locations = Variable(torch.from_numpy(np.array(random.sample(range(flows_flatten.shape[2]), self.sampling_size))).long().cuda())\n        #flows_sample = torch.index_select(flows_flatten, 2, random_locations)\n\n        # K_f\n        k_f = self.epsilon + torch.norm(torch.unsqueeze(flows_flatten, dim=-1).permute(0, 3, 2, 1) -\n                                        torch.unsqueeze(flows_flatten, dim=-1).permute(0, 2, 3, 1), p=2, dim=3,\n                                        keepdim=False) ** 2\n        exp_k_f = torch.exp(-k_f / 2. / self.sigma)\n\n        \n        # mask\n        eye = Variable(torch.unsqueeze(torch.eye(k_f.shape[1]), dim=0).cuda())\n        mask = torch.ones_like(exp_k_f) - eye\n\n        # S_f\n        masked_exp_k_f = torch.mul(mask, exp_k_f) + eye\n        s_f = masked_exp_k_f / torch.sum(masked_exp_k_f, dim=1, keepdim=True)\n\n        # K_theta\n        embeddings_flatten = embeddings.view(embeddings.shape[0], embeddings.shape[1], -1)\n        #embeddings_sample = torch.index_select(embeddings_flatten, 2, random_locations)\n        embeddings_flatten_norm = torch.norm(embeddings_flatten, p=2, dim=1, keepdim=True)\n        k_theta = 0.25 * (torch.matmul(embeddings_flatten.permute(0, 2, 1), embeddings_flatten)) / (self.epsilon + torch.matmul(embeddings_flatten_norm.permute(0, 2, 1), embeddings_flatten_norm))\n        exp_k_theta = torch.exp(k_theta)\n\n        # S_theta\n        masked_exp_k_theta = torch.mul(mask, exp_k_theta) + eye\n        s_theta = masked_exp_k_theta / torch.sum(masked_exp_k_theta, dim=1, keepdim=True)\n\n        # loss\n        loss = -torch.mean(torch.mul(s_f, torch.log(s_theta)))\n\n        return loss\n\n\ndef get_column(embeddings, index, full_size):\n    col = []\n    for embd in embeddings:\n        ind = (index.float() / full_size * embd.size(2)).long()\n        col.append(torch.index_select(embd.view(embd.shape[0], embd.shape[1], -1), 2, ind))\n    return torch.cat(col, dim=1) # N x coldim x sparsenum\n\nclass CrossPixelSimilarityColumnLoss(nn.Module):\n    '''\n        Modified from: https://github.com/lppllppl920/Challenge2018/blob/master/loss.py\n    '''\n    def __init__(self, sigma=0.0036, sampling_size=512):\n        super(CrossPixelSimilarityColumnLoss, self).__init__()\n        self.sigma = sigma\n        self.sampling_size = sampling_size\n        self.epsilon = 1.0e-15\n        self.embed_norm = True # loss does not decrease no matter it is true or false.\n        self.mlp = nn.Sequential(\n            nn.Linear(96 + 96 + 384 + 256 + 4096, 256),\n            nn.ReLU(inplace=True),\n            nn.Linear(256, 16))\n\n    def forward(self, feats, flows):\n        '''\n            embedding: Variable Nx256xHxW (not hyper-column)\n            flows: Variable Nx2xHxW\n        '''\n        assert flows.size(1) == 2\n\n        # flow normalization\n        positive_mask = (flows > 0)\n        flows = -torch.clamp(torch.log(torch.abs(flows) + 1) / math.log(50. + 1), max=1.)\n        flows[positive_mask] = -flows[positive_mask]\n\n        # Spatially random sampling (512 samples)\n        flows_flatten = flows.view(flows.shape[0], 2, -1)\n        random_locations = Variable(torch.from_numpy(np.array(random.sample(range(flows_flatten.shape[2]), self.sampling_size))).long().cuda())\n        flows_sample = torch.index_select(flows_flatten, 2, random_locations)\n\n        # K_f\n        k_f = self.epsilon + torch.norm(torch.unsqueeze(flows_sample, dim=-1).permute(0, 3, 2, 1) -\n                                        torch.unsqueeze(flows_sample, dim=-1).permute(0, 2, 3, 1), p=2, dim=3,\n                                        keepdim=False) ** 2\n        exp_k_f = torch.exp(-k_f / 2. / self.sigma)\n\n        \n        # mask\n        eye = Variable(torch.unsqueeze(torch.eye(k_f.shape[1]), dim=0).cuda())\n        mask = torch.ones_like(exp_k_f) - eye\n\n        # S_f\n        masked_exp_k_f = torch.mul(mask, exp_k_f) + eye\n        s_f = masked_exp_k_f / torch.sum(masked_exp_k_f, dim=1, keepdim=True)\n\n\n        # column\n        column = get_column(feats, random_locations, flows.shape[2])\n        embedding = self.mlp(column)\n        # K_theta\n        embedding_norm = torch.norm(embedding, p=2, dim=1, keepdim=True)\n        k_theta = 0.25 * (torch.matmul(embedding.permute(0, 2, 1), embedding)) / (self.epsilon + torch.matmul(embedding_norm.permute(0, 2, 1), embedding_norm))\n        exp_k_theta = torch.exp(k_theta)\n\n        # S_theta\n        masked_exp_k_theta = torch.mul(mask, exp_k_theta) + math.exp(-0.75) * eye\n        s_theta = masked_exp_k_theta / torch.sum(masked_exp_k_theta, dim=1, keepdim=True)\n\n        # loss\n        loss = -torch.mean(torch.mul(s_f, torch.log(s_theta)))\n\n        return loss\n\n\ndef print_info(name, var):\n    print(name, var.size(), torch.max(var).data.cpu()[0], torch.min(var).data.cpu()[0], torch.mean(var).data.cpu()[0])\n\n\ndef MaskL1Loss(input, target, mask):\n    input_size = input.size()\n    res = torch.sum(torch.abs(input * mask - target * mask))\n    total = torch.sum(mask).item()\n    if total > 0:\n        res = res / (total * input_size[1])\n    return res\n"
  },
  {
    "path": "mimicmotion/modules/cmp/models/__init__.py",
    "content": "from .single_stage_model import *\nfrom .cmp import *\nfrom . import modules\nfrom . import backbone\n"
  },
  {
    "path": "mimicmotion/modules/cmp/models/backbone/__init__.py",
    "content": "from .resnet import *\nfrom .alexnet import *\n"
  },
  {
    "path": "mimicmotion/modules/cmp/models/backbone/alexnet.py",
    "content": "import torch.nn as nn\nimport math\n\nclass AlexNetBN_FCN(nn.Module):\n\n    def __init__(self, output_dim=256, stride=[4, 2, 2, 2], dilation=[1, 1], padding=[1, 1]):\n        super(AlexNetBN_FCN, self).__init__()\n        BN = nn.BatchNorm2d\n\n        self.conv1 = nn.Sequential(\n            nn.Conv2d(3, 96, kernel_size=11, stride=stride[0], padding=5),\n            BN(96),\n            nn.ReLU(inplace=True))\n        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=stride[1], padding=1)\n        self.conv2 = nn.Sequential(\n            nn.Conv2d(96, 256, kernel_size=5, padding=2),\n            BN(256),\n            nn.ReLU(inplace=True))\n        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=stride[2], padding=1)\n        self.conv3 = nn.Sequential(\n            nn.Conv2d(256, 384, kernel_size=3, padding=1),\n            BN(384),\n            nn.ReLU(inplace=True))\n        self.conv4 = nn.Sequential(\n            nn.Conv2d(384, 384, kernel_size=3, padding=padding[0], dilation=dilation[0]),\n            BN(384),\n            nn.ReLU(inplace=True))\n        self.conv5 = nn.Sequential(\n            nn.Conv2d(384, 256, kernel_size=3, padding=padding[1], dilation=dilation[1]),\n            BN(256),\n            nn.ReLU(inplace=True))\n        self.pool5 = nn.MaxPool2d(kernel_size=3, stride=stride[3], padding=1)\n\n        self.fc6 = nn.Sequential(\n            nn.Conv2d(256, 4096, kernel_size=3, stride=1, padding=1),\n            BN(4096),\n            nn.ReLU(inplace=True))\n        self.drop6 = nn.Dropout(0.5)\n        self.fc7 = nn.Sequential(\n            nn.Conv2d(4096, 4096, kernel_size=1, stride=1, padding=0),\n            BN(4096),\n            nn.ReLU(inplace=True))\n        self.drop7 = nn.Dropout(0.5)\n        self.conv8 = nn.Conv2d(4096, output_dim, kernel_size=1)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                fan_in = m.out_channels * m.kernel_size[0] * m.kernel_size[1]\n                scale = math.sqrt(2. / fan_in)\n                m.weight.data.uniform_(-scale, scale)\n                if m.bias is not None:\n                    m.bias.data.zero_()\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n    def forward(self, x, ret_feat=False):\n        if ret_feat:\n            raise NotImplemented\n        x = self.conv1(x)\n        x = self.pool1(x)\n        x = self.conv2(x)\n        x = self.pool2(x)\n        x = self.conv3(x)\n        x = self.conv4(x)\n        x = self.conv5(x)\n        x = self.pool5(x)\n        x = self.fc6(x)\n        x = self.drop6(x)\n        x = self.fc7(x)\n        x = self.drop7(x)\n        x = self.conv8(x)\n        return x\n\ndef alexnet_fcn_32x(output_dim, pretrained=False, **kwargs):\n    assert pretrained == False\n    model = AlexNetBN_FCN(output_dim=output_dim, **kwargs)\n    return model\n\ndef alexnet_fcn_8x(output_dim, use_ppm=False, pretrained=False, **kwargs):\n    assert pretrained == False\n    model = AlexNetBN_FCN(output_dim=output_dim, stride=[2, 2, 2, 1], **kwargs)\n    return model\n"
  },
  {
    "path": "mimicmotion/modules/cmp/models/backbone/resnet.py",
    "content": "import torch.nn as nn\nimport math\nimport torch.utils.model_zoo as model_zoo\n\nBN = None\n\n\nmodel_urls = {\n    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',\n    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',\n    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',\n    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',\n    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',\n}\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"3x3 convolution with padding\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=1, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(BasicBlock, self).__init__()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = BN(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = BN(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n        self.bn1 = BN(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n                               padding=1, bias=False)\n        self.bn2 = BN(planes)\n        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)\n        self.bn3 = BN(planes * 4)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass ResNet(nn.Module):\n\n    def __init__(self, output_dim, block, layers):\n        \n        global BN\n\n        BN = nn.BatchNorm2d\n\n        self.inplanes = 64\n        super(ResNet, self).__init__()\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,\n                               bias=False)\n        self.bn1 = BN(64)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n        \n        self.conv5 = nn.Conv2d(2048, output_dim, kernel_size=1)\n\n        ## dilation\n        for n, m in self.layer3.named_modules():\n            if 'conv2' in n:\n                m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)\n            elif 'downsample.0' in n:\n                m.stride = (1, 1)\n        for n, m in self.layer4.named_modules():\n            if 'conv2' in n:\n                m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)\n            elif 'downsample.0' in n:\n                m.stride = (1, 1)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n                m.weight.data.normal_(0, math.sqrt(2. / n))\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n    def _make_layer(self, block, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.inplanes, planes * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                BN(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample))\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(self.inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, img, ret_feat=False):\n        x = self.conv1(img) # 1/2\n        x = self.bn1(x)\n        conv1 = self.relu(x) # 1/2\n        pool1 = self.maxpool(conv1) # 1/4\n\n        layer1 = self.layer1(pool1) # 1/4\n        layer2 = self.layer2(layer1) # 1/8\n        layer3 = self.layer3(layer2) # 1/8\n        layer4 = self.layer4(layer3) # 1/8\n        out = self.conv5(layer4)\n\n        if ret_feat:\n            return out, [img, conv1, layer1] # 3, 64, 256\n        else:\n            return out\n\ndef resnet18(output_dim, pretrained=False):\n    model = ResNet(output_dim, BasicBlock, [2, 2, 2, 2])\n    if pretrained:\n        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))\n    return model\n\n\ndef resnet34(output_dim, pretrained=False):\n    model = ResNet(output_dim, BasicBlock, [3, 4, 6, 3])\n    if pretrained:\n        model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))\n    return model\n\n\ndef resnet50(output_dim, pretrained=False):\n    model = ResNet(output_dim, Bottleneck, [3, 4, 6, 3])\n    if pretrained:\n        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)\n    return model\n\ndef resnet101(output_dim, pretrained=False):\n    model = ResNet(output_dim, Bottleneck, [3, 4, 23, 3])\n    if pretrained:\n        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']), strict=False)\n    return model\n\n\ndef resnet152(output_dim, pretrained=False):\n    model = ResNet(output_dim, Bottleneck, [3, 8, 36, 3])\n    if pretrained:\n        model.load_state_dict(model_zoo.load_url(model_urls['resnet152']), strict=False)\n    return model\n"
  },
  {
    "path": "mimicmotion/modules/cmp/models/cmp.py",
    "content": "import torch\nimport torch.nn as nn\n\nimport mimicmotion.modules.cmp.losses as losses\nimport mimicmotion.modules.cmp.utils as utils\n\nfrom . import SingleStageModel\n\nclass CMP(SingleStageModel):\n\n    def __init__(self, params, dist_model=False):\n        super(CMP, self).__init__(params, dist_model)\n        model_params = params['module']\n\n        # define loss\n        if model_params['flow_criterion'] == 'L1':\n            self.flow_criterion = nn.SmoothL1Loss()\n        elif model_params['flow_criterion'] == 'L2':\n            self.flow_criterion = nn.MSELoss()\n        elif model_params['flow_criterion'] == 'DiscreteLoss':\n            self.flow_criterion = losses.DiscreteLoss(\n                nbins=model_params['nbins'], fmax=model_params['fmax'])\n        else:\n            raise Exception(\"No such flow loss: {}\".format(model_params['flow_criterion']))\n\n        self.fuser = utils.Fuser(nbins=model_params['nbins'],\n                                 fmax=model_params['fmax'])\n        self.model_params = model_params\n\n    def eval(self, ret_loss=True):\n        with torch.no_grad():\n            cmp_output = self.model(self.image_input, self.sparse_input)\n        if self.model_params['flow_criterion'] == \"DiscreteLoss\":\n            self.flow = self.fuser.convert_flow(cmp_output)\n        else:\n            self.flow = cmp_output\n        if self.flow.shape[2] != self.image_input.shape[2]:\n            self.flow = nn.functional.interpolate(\n                self.flow, size=self.image_input.shape[2:4],\n                mode=\"bilinear\", align_corners=True)\n\n        ret_tensors = {\n            'flow_tensors': [self.flow, self.flow_target],\n            'common_tensors': [],\n            'rgb_tensors': []} # except for image_input\n\n        if ret_loss:\n            if cmp_output.shape[2] != self.flow_target.shape[2]:\n                cmp_output = nn.functional.interpolate(\n                    cmp_output, size=self.flow_target.shape[2:4],\n                    mode=\"bilinear\", align_corners=True)\n            loss_flow = self.flow_criterion(cmp_output, self.flow_target) / self.world_size\n            return ret_tensors, {'loss_flow': loss_flow}\n        else:   \n            return ret_tensors\n\n    def step(self):\n        cmp_output = self.model(self.image_input, self.sparse_input)\n        loss_flow = self.flow_criterion(cmp_output, self.flow_target) / self.world_size\n        self.optim.zero_grad()\n        loss_flow.backward()\n        utils.average_gradients(self.model)\n        self.optim.step()\n        return {'loss_flow': loss_flow}\n"
  },
  {
    "path": "mimicmotion/modules/cmp/models/modules/__init__.py",
    "content": "from .warp import *\nfrom .others import *\nfrom .shallownet import *\nfrom .decoder import *\nfrom .cmp import *\n\n"
  },
  {
    "path": "mimicmotion/modules/cmp/models/modules/cmp.py",
    "content": "import torch\nimport torch.nn as nn\nimport mimicmotion.modules.cmp.models as models\n\n\nclass CMP(nn.Module):\n\n    def __init__(self, params):\n        super(CMP, self).__init__()\n        img_enc_dim = params['img_enc_dim']\n        sparse_enc_dim = params['sparse_enc_dim']\n        output_dim = params['output_dim']\n        pretrained = params['pretrained_image_encoder']\n        decoder_combo = params['decoder_combo']\n        self.skip_layer = params['skip_layer']\n        if self.skip_layer:\n            assert params['flow_decoder'] == \"MotionDecoderSkipLayer\"\n\n        self.image_encoder = models.backbone.__dict__[params['image_encoder']](\n            img_enc_dim, pretrained)\n        self.flow_encoder = models.modules.__dict__[params['sparse_encoder']](\n            sparse_enc_dim)\n        self.flow_decoder = models.modules.__dict__[params['flow_decoder']](\n            input_dim=img_enc_dim+sparse_enc_dim,\n            output_dim=output_dim, combo=decoder_combo)\n\n    def forward(self, image, sparse):\n        sparse_enc = self.flow_encoder(sparse)\n        if self.skip_layer:\n            img_enc, skip_feat = self.image_encoder(image, ret_feat=True)\n            flow_dec = self.flow_decoder(torch.cat((img_enc, sparse_enc), dim=1), skip_feat)\n        else:\n            img_enc = self.image_encoder(image)\n            flow_dec = self.flow_decoder(torch.cat((img_enc, sparse_enc), dim=1))\n        return flow_dec\n\n\n"
  },
  {
    "path": "mimicmotion/modules/cmp/models/modules/decoder.py",
    "content": "import torch\nimport torch.nn as nn\nimport math\n\nclass MotionDecoderPlain(nn.Module):\n\n    def __init__(self, input_dim=512, output_dim=2, combo=[1,2,4]):\n        super(MotionDecoderPlain, self).__init__()\n        BN = nn.BatchNorm2d\n\n        self.combo = combo\n        for c in combo:\n            assert c in [1,2,4,8], \"invalid combo: {}\".format(combo)\n\n        if 1 in combo:\n            self.decoder1 = nn.Sequential(\n                nn.Conv2d(input_dim, 128, kernel_size=3, padding=1),\n                BN(128),\n                nn.ReLU(inplace=True),\n                nn.Conv2d(128, 128, kernel_size=3, padding=1),\n                BN(128),\n                nn.ReLU(inplace=True))\n\n        if 2 in combo:\n            self.decoder2 = nn.Sequential(\n                nn.MaxPool2d(kernel_size=2, stride=2),\n                nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),\n                BN(128),\n                nn.ReLU(inplace=True),\n                nn.Conv2d(128, 128, kernel_size=3, padding=1),\n                BN(128),\n                nn.ReLU(inplace=True))\n\n        if 4 in combo:\n            self.decoder4 = nn.Sequential(\n                nn.MaxPool2d(kernel_size=4, stride=4),\n                nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),\n                BN(128),\n                nn.ReLU(inplace=True),\n                nn.Conv2d(128, 128, kernel_size=3, padding=1),\n                BN(128),\n                nn.ReLU(inplace=True))\n\n        if 8 in combo:\n            self.decoder8 = nn.Sequential(\n                nn.MaxPool2d(kernel_size=8, stride=8),\n                nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),\n                BN(128),\n                nn.ReLU(inplace=True),\n                nn.Conv2d(128, 128, kernel_size=3, padding=1),\n                BN(128),\n                nn.ReLU(inplace=True))\n\n        self.head = nn.Conv2d(128 * len(self.combo), output_dim, kernel_size=1, padding=0)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                fan_in = m.out_channels * m.kernel_size[0] * m.kernel_size[1]\n                scale = math.sqrt(2. / fan_in)\n                m.weight.data.uniform_(-scale, scale)\n                if m.bias is not None:\n                    m.bias.data.zero_()\n            elif isinstance(m, nn.BatchNorm2d):\n                if not m.weight is None:\n                    m.weight.data.fill_(1)\n                if not m.bias is None:\n                    m.bias.data.zero_()\n\n    def forward(self, x):\n        \n        cat_list = []\n        if 1 in self.combo:\n            x1 = self.decoder1(x)\n            cat_list.append(x1)\n        if 2 in self.combo:\n            x2 = nn.functional.interpolate(\n                self.decoder2(x), size=(x.size(2), x.size(3)),\n                mode=\"bilinear\", align_corners=True)\n            cat_list.append(x2)\n        if 4 in self.combo:\n            x4 = nn.functional.interpolate(\n                self.decoder4(x), size=(x.size(2), x.size(3)),\n                mode=\"bilinear\", align_corners=True)\n            cat_list.append(x4)\n        if 8 in self.combo:\n            x8 = nn.functional.interpolate(\n                self.decoder8(x), size=(x.size(2), x.size(3)),\n                mode=\"bilinear\", align_corners=True)\n            cat_list.append(x8)\n           \n        cat = torch.cat(cat_list, dim=1)\n        flow = self.head(cat)\n        return flow\n\n\nclass MotionDecoderSkipLayer(nn.Module):\n\n    def __init__(self, input_dim=512, output_dim=2, combo=[1,2,4,8]):\n        super(MotionDecoderSkipLayer, self).__init__()\n\n        BN = nn.BatchNorm2d\n\n        self.decoder1 = nn.Sequential(\n            nn.Conv2d(input_dim, 128, kernel_size=3, padding=1),\n            BN(128),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(128, 128, kernel_size=3, padding=1),\n            BN(128),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(128, 128, kernel_size=3, padding=1),\n            BN(128),\n            nn.ReLU(inplace=True))\n\n        self.decoder2 = nn.Sequential(\n            nn.MaxPool2d(kernel_size=2, stride=2),\n            nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),\n            BN(128),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(128, 128, kernel_size=3, padding=1),\n            BN(128),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(128, 128, kernel_size=3, padding=1),\n            BN(128),\n            nn.ReLU(inplace=True))\n\n        self.decoder4 = nn.Sequential(\n            nn.MaxPool2d(kernel_size=4, stride=4),\n            nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),\n            BN(128),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(128, 128, kernel_size=3, padding=1),\n            BN(128),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(128, 128, kernel_size=3, padding=1),\n            BN(128),\n            nn.ReLU(inplace=True))\n\n        self.decoder8 = nn.Sequential(\n            nn.MaxPool2d(kernel_size=8, stride=8),\n            nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),\n            BN(128),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(128, 128, kernel_size=3, padding=1),\n            BN(128),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(128, 128, kernel_size=3, padding=1),\n            BN(128),\n            nn.ReLU(inplace=True))\n\n        self.fusion8 = nn.Sequential(\n            nn.Conv2d(512, 256, kernel_size=3, padding=1),\n            BN(256),\n            nn.ReLU(inplace=True))\n\n        self.skipconv4 = nn.Sequential(\n            nn.Conv2d(256, 128, kernel_size=3, padding=1),\n            BN(128),\n            nn.ReLU(inplace=True))\n        self.fusion4 = nn.Sequential(\n            nn.Conv2d(256 + 128, 128, kernel_size=3, padding=1),\n            BN(128),\n            nn.ReLU(inplace=True))\n\n        self.skipconv2 = nn.Sequential(\n            nn.Conv2d(64, 32, kernel_size=3, padding=1),\n            BN(32),\n            nn.ReLU(inplace=True))\n        self.fusion2 = nn.Sequential(\n            nn.Conv2d(128 + 32, 64, kernel_size=3, padding=1),\n            BN(64),\n            nn.ReLU(inplace=True))\n\n        self.head = nn.Conv2d(64, output_dim, kernel_size=1, padding=0)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                fan_in = m.out_channels * m.kernel_size[0] * m.kernel_size[1]\n                scale = math.sqrt(2. / fan_in)\n                m.weight.data.uniform_(-scale, scale)\n                if m.bias is not None:\n                    m.bias.data.zero_()\n            elif isinstance(m, nn.BatchNorm2d):\n                if not m.weight is None:\n                    m.weight.data.fill_(1)\n                if not m.bias is None:\n                    m.bias.data.zero_()\n\n    def forward(self, x, skip_feat):\n        layer1, layer2, layer4 = skip_feat\n\n        x1 = self.decoder1(x)\n        x2 = nn.functional.interpolate(\n            self.decoder2(x), size=(x1.size(2), x1.size(3)),\n            mode=\"bilinear\", align_corners=True)\n        x4 = nn.functional.interpolate(\n            self.decoder4(x), size=(x1.size(2), x1.size(3)),\n            mode=\"bilinear\", align_corners=True)\n        x8 = nn.functional.interpolate(\n            self.decoder8(x), size=(x1.size(2), x1.size(3)),\n            mode=\"bilinear\", align_corners=True)\n        cat = torch.cat([x1, x2, x4, x8], dim=1)\n        f8 = self.fusion8(cat)\n\n        f8_up = nn.functional.interpolate(\n            f8, size=(layer4.size(2), layer4.size(3)),\n            mode=\"bilinear\", align_corners=True)\n        f4 = self.fusion4(torch.cat([f8_up, self.skipconv4(layer4)], dim=1))\n\n        f4_up = nn.functional.interpolate(\n            f4, size=(layer2.size(2), layer2.size(3)),\n            mode=\"bilinear\", align_corners=True)\n        f2 = self.fusion2(torch.cat([f4_up, self.skipconv2(layer2)], dim=1))\n\n        flow = self.head(f2)\n        return flow\n\n\nclass MotionDecoderFlowNet(nn.Module):\n\n    def __init__(self, input_dim=512, output_dim=2, combo=[1,2,4,8]):\n        super(MotionDecoderFlowNet, self).__init__()\n        global BN\n\n        BN = nn.BatchNorm2d\n\n        self.decoder1 = nn.Sequential(\n            nn.Conv2d(input_dim, 128, kernel_size=3, padding=1),\n            BN(128),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(128, 128, kernel_size=3, padding=1),\n            BN(128),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(128, 128, kernel_size=3, padding=1),\n            BN(128),\n            nn.ReLU(inplace=True))\n\n        self.decoder2 = nn.Sequential(\n            nn.MaxPool2d(kernel_size=2, stride=2),\n            nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),\n            BN(128),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(128, 128, kernel_size=3, padding=1),\n            BN(128),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(128, 128, kernel_size=3, padding=1),\n            BN(128),\n            nn.ReLU(inplace=True))\n\n        self.decoder4 = nn.Sequential(\n            nn.MaxPool2d(kernel_size=4, stride=4),\n            nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),\n            BN(128),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(128, 128, kernel_size=3, padding=1),\n            BN(128),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(128, 128, kernel_size=3, padding=1),\n            BN(128),\n            nn.ReLU(inplace=True))\n\n        self.decoder8 = nn.Sequential(\n            nn.MaxPool2d(kernel_size=8, stride=8),\n            nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),\n            BN(128),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(128, 128, kernel_size=3, padding=1),\n            BN(128),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(128, 128, kernel_size=3, padding=1),\n            BN(128),\n            nn.ReLU(inplace=True))\n\n        self.fusion8 = nn.Sequential(\n            nn.Conv2d(512, 256, kernel_size=3, padding=1),\n            BN(256),\n            nn.ReLU(inplace=True))\n\n        # flownet head\n        self.predict_flow8 = predict_flow(256, output_dim)\n        self.predict_flow4 = predict_flow(384 + output_dim, output_dim)\n        self.predict_flow2 = predict_flow(192 + output_dim, output_dim)\n        self.predict_flow1 = predict_flow(67 + output_dim, output_dim)\n\n        self.upsampled_flow8_to_4 = nn.ConvTranspose2d(\n            output_dim, output_dim, 4, 2, 1, bias=False)\n        self.upsampled_flow4_to_2 = nn.ConvTranspose2d(\n            output_dim, output_dim, 4, 2, 1, bias=False)\n        self.upsampled_flow2_to_1 = nn.ConvTranspose2d(\n            output_dim, output_dim, 4, 2, 1, bias=False)\n\n        self.deconv8 = deconv(256, 128)\n        self.deconv4 = deconv(384 + output_dim, 128)\n        self.deconv2 = deconv(192 + output_dim, 64)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                fan_in = m.out_channels * m.kernel_size[0] * m.kernel_size[1]\n                scale = math.sqrt(2. / fan_in)\n                m.weight.data.uniform_(-scale, scale)\n                if m.bias is not None:\n                    m.bias.data.zero_()\n            elif isinstance(m, nn.BatchNorm2d):\n                if not m.weight is None:\n                    m.weight.data.fill_(1)\n                if not m.bias is None:\n                    m.bias.data.zero_()\n\n    def forward(self, x, skip_feat):\n        layer1, layer2, layer4 = skip_feat # 3, 64, 256\n\n        # propagation nets\n        x1 = self.decoder1(x)\n        x2 = nn.functional.interpolate(\n            self.decoder2(x), size=(x1.size(2), x1.size(3)),\n            mode=\"bilinear\", align_corners=True)\n        x4 = nn.functional.interpolate(\n            self.decoder4(x), size=(x1.size(2), x1.size(3)),\n            mode=\"bilinear\", align_corners=True)\n        x8 = nn.functional.interpolate(\n            self.decoder8(x), size=(x1.size(2), x1.size(3)),\n            mode=\"bilinear\", align_corners=True)\n        cat = torch.cat([x1, x2, x4, x8], dim=1)\n        feat8 = self.fusion8(cat) # 256\n\n        # flownet head\n        flow8 = self.predict_flow8(feat8)\n        flow8_up = self.upsampled_flow8_to_4(flow8)\n        out_deconv8 = self.deconv8(feat8) # 128\n\n        concat4 = torch.cat((layer4, out_deconv8, flow8_up), dim=1) # 394 + out\n        flow4 = self.predict_flow4(concat4)\n        flow4_up = self.upsampled_flow4_to_2(flow4)\n        out_deconv4 = self.deconv4(concat4) # 128\n\n        concat2 = torch.cat((layer2, out_deconv4, flow4_up), dim=1) # 192 + out\n        flow2 = self.predict_flow2(concat2)\n        flow2_up = self.upsampled_flow2_to_1(flow2)\n        out_deconv2 = self.deconv2(concat2) # 64\n\n        concat1 = torch.cat((layer1, out_deconv2, flow2_up), dim=1) # 67 + out\n        flow1 = self.predict_flow1(concat1)\n        \n        return [flow1, flow2, flow4, flow8]\n\n\ndef predict_flow(in_planes, out_planes):\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3,\n                     stride=1, padding=1, bias=True)\n\n\ndef deconv(in_planes, out_planes):\n    return nn.Sequential(\n        nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4,\n                           stride=2, padding=1, bias=True),\n        nn.LeakyReLU(0.1, inplace=True)\n    )\n\n\n"
  },
  {
    "path": "mimicmotion/modules/cmp/models/modules/others.py",
    "content": "import torch.nn as nn\n\nclass FixModule(nn.Module):\n\n    def __init__(self, m):\n        super(FixModule, self).__init__()\n        self.module = m\n\n    def forward(self, *args, **kwargs):\n        return self.module(*args, **kwargs)\n\n"
  },
  {
    "path": "mimicmotion/modules/cmp/models/modules/shallownet.py",
    "content": "import torch.nn as nn\nimport math\n\nclass ShallowNet(nn.Module):\n\n    def __init__(self, input_dim=4, output_dim=16, stride=[2, 2, 2]):\n        super(ShallowNet, self).__init__()\n        global BN\n\n        BN = nn.BatchNorm2d\n\n        self.features = nn.Sequential(\n            nn.Conv2d(input_dim, 16, kernel_size=5, stride=stride[0], padding=2),\n            nn.BatchNorm2d(16),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(kernel_size=stride[1], stride=stride[1]),\n            nn.Conv2d(16, output_dim, kernel_size=3, padding=1),\n            nn.BatchNorm2d(output_dim),\n            nn.ReLU(inplace=True),\n            nn.AvgPool2d(kernel_size=stride[2], stride=stride[2]),\n        )\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                fan_in = m.out_channels * m.kernel_size[0] * m.kernel_size[1]\n                scale = math.sqrt(2. / fan_in)\n                m.weight.data.uniform_(-scale, scale)\n                if m.bias is not None:\n                    m.bias.data.zero_()\n            elif isinstance(m, nn.BatchNorm2d):\n                if not m.weight is None:\n                    m.weight.data.fill_(1)\n                if not m.bias is None:\n                    m.bias.data.zero_()\n\n    def forward(self, x):\n        x = self.features(x)\n        return x\n\n\ndef shallownet8x(output_dim):\n    model = ShallowNet(output_dim=output_dim, stride=[2,2,2])\n    return model\n\ndef shallownet32x(output_dim, **kwargs):\n    model = ShallowNet(output_dim=output_dim, stride=[2,2,8])\n    return model\n\n\n\n"
  },
  {
    "path": "mimicmotion/modules/cmp/models/modules/warp.py",
    "content": "import torch\nimport torch.nn as nn\n\nclass WarpingLayerBWFlow(nn.Module):\n\n    def __init__(self):\n        super(WarpingLayerBWFlow, self).__init__()\n\n    def forward(self, image, flow):\n        flow_for_grip = torch.zeros_like(flow)\n        flow_for_grip[:,0,:,:] = flow[:,0,:,:] / ((flow.size(3) - 1.0) / 2.0)\n        flow_for_grip[:,1,:,:] = flow[:,1,:,:] / ((flow.size(2) - 1.0) / 2.0)\n\n        torchHorizontal = torch.linspace(\n            -1.0, 1.0, image.size(3)).view(\n            1, 1, 1, image.size(3)).expand(\n            image.size(0), 1, image.size(2), image.size(3))\n        torchVertical = torch.linspace(\n            -1.0, 1.0, image.size(2)).view(\n            1, 1, image.size(2), 1).expand(\n            image.size(0), 1, image.size(2), image.size(3))\n        grid = torch.cat([torchHorizontal, torchVertical], 1).cuda()\n\n        grid = (grid + flow_for_grip).permute(0, 2, 3, 1)\n        return torch.nn.functional.grid_sample(image, grid)\n\n\nclass WarpingLayerFWFlow(nn.Module):\n\n    def __init__(self):\n        super(WarpingLayerFWFlow, self).__init__()\n        self.initialized = False\n\n    def forward(self, image, flow, ret_mask = False):\n        n, h, w = image.size(0), image.size(2), image.size(3)\n\n        if not self.initialized or n != self.meshx.shape[0] or h * w != self.meshx.shape[1]:\n            self.meshx = torch.arange(w).view(1, 1, w).expand(\n                n, h, w).contiguous().view(n, -1).cuda()\n            self.meshy = torch.arange(h).view(1, h, 1).expand(\n                n, h, w).contiguous().view(n, -1).cuda()\n            self.warped_image = torch.zeros((n, 3, h, w), dtype=torch.float32).cuda()\n            if ret_mask:\n                self.hole_mask = torch.ones((n, 1, h, w), dtype=torch.float32).cuda()\n            self.initialized = True\n        \n        v = (flow[:,0,:,:] ** 2 + flow[:,1,:,:] ** 2).view(n, -1)\n        _, sortidx = torch.sort(v, dim=1)\n\n        warped_meshx = self.meshx + flow[:,0,:,:].long().view(n, -1)\n        warped_meshy = self.meshy + flow[:,1,:,:].long().view(n, -1)\n        \n        warped_meshx = torch.clamp(warped_meshx, 0, w - 1)\n        warped_meshy = torch.clamp(warped_meshy, 0, h - 1)\n        \n        self.warped_image.zero_()\n        if ret_mask:\n            self.hole_mask.fill_(1.)\n        for i in range(n):\n            for c in range(3):\n                ind = sortidx[i]\n                self.warped_image[i,c,warped_meshy[i][ind],warped_meshx[i][ind]] = image[i,c,self.meshy[i][ind],self.meshx[i][ind]]\n            if ret_mask:\n                self.hole_mask[i,0,warped_meshy[i],warped_meshx[i]] = 0.\n        if ret_mask:\n            return self.warped_image, self.hole_mask\n        else:\n            return self.warped_image\n"
  },
  {
    "path": "mimicmotion/modules/cmp/models/single_stage_model.py",
    "content": "import os\nimport torch\nimport torch.backends.cudnn as cudnn\nimport torch.distributed as dist\n\nimport mimicmotion.modules.cmp.models as models\nimport mimicmotion.modules.cmp.utils as utils\n\n\nclass SingleStageModel(object):\n\n    def __init__(self, params, dist_model=False):\n        model_params = params['module']\n        self.model = models.modules.__dict__[params['module']['arch']](model_params)\n        utils.init_weights(self.model, init_type='xavier')\n        self.model.cuda()\n        if dist_model:\n            self.model = utils.DistModule(self.model)\n            self.world_size = dist.get_world_size()\n        else:\n            self.model = models.modules.FixModule(self.model)\n            self.world_size = 1\n\n        if params['optim'] == 'SGD':\n            self.optim = torch.optim.SGD(\n                self.model.parameters(), lr=params['lr'],\n                momentum=0.9, weight_decay=0.0001)\n        elif params['optim'] == 'Adam':\n            self.optim = torch.optim.Adam(\n                self.model.parameters(), lr=params['lr'],\n                betas=(params['beta1'], 0.999))\n        else:   \n            raise Exception(\"No such optimizer: {}\".format(params['optim']))\n\n        cudnn.benchmark = True\n\n    def set_input(self, image_input, sparse_input, flow_target=None, rgb_target=None):\n        self.image_input = image_input\n        self.sparse_input = sparse_input\n        self.flow_target = flow_target\n        self.rgb_target = rgb_target\n\n    def eval(self, ret_loss=True):\n        pass\n\n    def step(self):\n        pass\n\n    def load_state(self, path, Iter, resume=False):\n        path = os.path.join(path, \"ckpt_iter_{}.pth.tar\".format(Iter))\n\n        if resume:\n            utils.load_state(path, self.model, self.optim)\n        else:\n            utils.load_state(path, self.model)\n\n    def load_pretrain(self, load_path):\n        utils.load_state(load_path, self.model)\n\n    def save_state(self, path, Iter):\n        path = os.path.join(path, \"ckpt_iter_{}.pth.tar\".format(Iter))\n\n        torch.save({\n            'step': Iter,\n            'state_dict': self.model.state_dict(),\n            'optimizer': self.optim.state_dict()}, path)\n\n    def switch_to(self, phase):\n        if phase == 'train':\n            self.model.train()\n        else:\n            self.model.eval()\n"
  },
  {
    "path": "mimicmotion/modules/cmp/utils/__init__.py",
    "content": "from .common_utils import *\nfrom .data_utils import *\nfrom .distributed_utils import *\nfrom .visualize_utils import *\nfrom .scheduler import *\nfrom . import flowlib\n"
  },
  {
    "path": "mimicmotion/modules/cmp/utils/common_utils.py",
    "content": "import os\nimport logging\nimport numpy as np\n\nimport torch\nfrom torch.nn import init\n\ndef init_weights(net, init_type='normal', init_gain=0.02):\n    \"\"\"Initialize network weights.\n    Parameters:\n        net (network)   -- network to be initialized\n        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal\n        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.\n    We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might\n    work better for some applications. Feel free to try yourself.\n    \"\"\"\n    def init_func(m):  # define the initialization function\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, init_gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=init_gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=init_gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.\n            init.normal_(m.weight.data, 1.0, init_gain)\n            init.constant_(m.bias.data, 0.0)\n\n    net.apply(init_func)  # apply the initialization function <init_func>\n\ndef create_logger(name, log_file, level=logging.INFO):\n    l = logging.getLogger(name)\n    formatter = logging.Formatter('[%(asctime)s] %(message)s')\n    fh = logging.FileHandler(log_file)\n    fh.setFormatter(formatter)\n    sh = logging.StreamHandler()\n    sh.setFormatter(formatter)\n    l.setLevel(level)\n    l.addHandler(fh)\n    l.addHandler(sh)\n    return l\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n    def __init__(self, length=0):\n        self.length = length\n        self.reset()\n\n    def reset(self):\n        if self.length > 0:\n            self.history = []\n        else:\n            self.count = 0\n            self.sum = 0.0\n        self.val = 0.0\n        self.avg = 0.0\n\n    def update(self, val):\n        if self.length > 0:\n            self.history.append(val)\n            if len(self.history) > self.length:\n                del self.history[0]\n\n            self.val = self.history[-1]\n            self.avg = np.mean(self.history)\n        else:\n            self.val = val\n            self.sum += val\n            self.count += 1\n            self.avg = self.sum / self.count\n            \ndef accuracy(output, target, topk=(1,)):\n    \"\"\"Computes the precision@k for the specified values of k\"\"\"\n    maxk = max(topk)\n    batch_size = target.size(0)\n\n    _, pred = output.topk(maxk, 1, True, True)\n    pred = pred.t()\n    correct = pred.eq(target.view(1, -1).expand_as(pred))\n\n    res = []\n    for k in topk:\n        correct_k = correct[:k].view(-1).float().sum(0, keepdims=True)\n        res.append(correct_k.mul_(100.0 / batch_size))\n    return res\n\ndef load_state(path, model, optimizer=None):\n    def map_func(storage, location):\n        return storage.cuda()\n    if os.path.isfile(path):\n        print(\"=> loading checkpoint '{}'\".format(path))\n        checkpoint = torch.load(path, map_location=map_func)\n        model.load_state_dict(checkpoint['state_dict'], strict=False)\n        ckpt_keys = set(checkpoint['state_dict'].keys())\n        own_keys = set(model.state_dict().keys())\n        missing_keys = own_keys - ckpt_keys\n        # print(ckpt_keys)\n        # print(own_keys)\n        # for k in missing_keys:\n        #     print('caution: missing keys from checkpoint {}: {}'.format(path, k))\n\n        last_iter = checkpoint['step']\n        if optimizer != None:\n            optimizer.load_state_dict(checkpoint['optimizer'])\n            print(\"=> also loaded optimizer from checkpoint '{}' (iter {})\"\n                  .format(path, last_iter))\n        return last_iter\n    else:\n        print(\"=> no checkpoint found at '{}'\".format(path))\n\n\n"
  },
  {
    "path": "mimicmotion/modules/cmp/utils/data_utils.py",
    "content": "from PIL import Image, ImageOps\nimport scipy.ndimage as ndimage\nimport cv2\nimport random\nimport numpy as np\nfrom scipy.ndimage.filters import maximum_filter\nfrom scipy import signal\ncv2.ocl.setUseOpenCL(False)\n\ndef get_edge(data, blur=False):\n    if blur:\n        data = cv2.GaussianBlur(data, (3, 3), 1.)\n    sobel = np.array([[1,0,-1],[2,0,-2],[1,0,-1]]).astype(np.float32)\n    ch_edges = []\n    for k in range(data.shape[2]):\n        edgex = signal.convolve2d(data[:,:,k], sobel, boundary='symm', mode='same')\n        edgey = signal.convolve2d(data[:,:,k], sobel.T, boundary='symm', mode='same')\n        ch_edges.append(np.sqrt(edgex**2 + edgey**2))\n    return sum(ch_edges)\n\ndef get_max(score, bbox):\n    u = max(0, bbox[0])\n    d = min(score.shape[0], bbox[1])\n    l = max(0, bbox[2])\n    r = min(score.shape[1], bbox[3])\n    return score[u:d,l:r].max()\n\ndef nms(score, ks):\n    assert ks % 2 == 1\n    ret_score = score.copy()\n    maxpool = maximum_filter(score, footprint=np.ones((ks, ks)))\n    ret_score[score < maxpool] = 0.\n    return ret_score\n\ndef image_flow_crop(img1, img2, flow, crop_size, phase):\n    assert len(crop_size) == 2\n    pad_h = max(crop_size[0] - img1.height, 0)\n    pad_w = max(crop_size[1] - img1.width, 0)\n    pad_h_half = int(pad_h / 2)\n    pad_w_half = int(pad_w / 2)\n    if pad_h > 0 or pad_w > 0:\n        flow_expand = np.zeros((img1.height + pad_h, img1.width + pad_w, 2), dtype=np.float32)\n        flow_expand[pad_h_half:pad_h_half+img1.height, pad_w_half:pad_w_half+img1.width, :] = flow\n        flow = flow_expand\n        border = (pad_w_half, pad_h_half, pad_w - pad_w_half, pad_h - pad_h_half)\n        img1 = ImageOps.expand(img1, border=border, fill=(0,0,0))\n        img2 = ImageOps.expand(img2, border=border, fill=(0,0,0))\n    if phase == 'train':\n        hoff = int(np.random.rand() * (img1.height - crop_size[0]))\n        woff = int(np.random.rand() * (img1.width - crop_size[1]))\n    else:\n        hoff = (img1.height - crop_size[0]) // 2\n        woff = (img1.width - crop_size[1]) // 2\n\n    img1 = img1.crop((woff, hoff, woff+crop_size[1], hoff+crop_size[0]))\n    img2 = img2.crop((woff, hoff, woff+crop_size[1], hoff+crop_size[0]))\n    flow = flow[hoff:hoff+crop_size[0], woff:woff+crop_size[1], :]\n    offset = (hoff, woff)\n    return img1, img2, flow, offset\n\ndef image_crop(img, crop_size):\n    pad_h = max(crop_size[0] - img.height, 0)\n    pad_w = max(crop_size[1] - img.width, 0)\n    pad_h_half = int(pad_h / 2)\n    pad_w_half = int(pad_w / 2)\n    if pad_h > 0 or pad_w > 0:\n        border = (pad_w_half, pad_h_half, pad_w - pad_w_half, pad_h - pad_h_half)\n        img = ImageOps.expand(img, border=border, fill=(0,0,0))\n    hoff = (img.height - crop_size[0]) // 2\n    woff = (img.width - crop_size[1]) // 2\n    return img.crop((woff, hoff, woff+crop_size[1], hoff+crop_size[0])), (pad_w_half, pad_h_half)\n\ndef image_flow_resize(img1, img2, flow, short_size=None, long_size=None):\n    assert (short_size is None) ^ (long_size is None)\n    w, h = img1.width, img1.height\n    if short_size is not None:\n        if w < h:\n            neww = short_size\n            newh = int(short_size / float(w) * h)\n        else:\n            neww = int(short_size / float(h) * w)\n            newh = short_size\n    else:\n        if w < h:\n            neww = int(long_size / float(h) * w)\n            newh = long_size\n        else:\n            neww = long_size\n            newh = int(long_size / float(w) * h)\n    img1 = img1.resize((neww, newh), Image.BICUBIC)\n    img2 = img2.resize((neww, newh), Image.BICUBIC)\n    ratio = float(newh) / h\n    flow = cv2.resize(flow.copy(), (neww, newh), interpolation=cv2.INTER_LINEAR) * ratio\n    return img1, img2, flow, ratio\n\ndef image_resize(img, short_size=None, long_size=None):\n    assert (short_size is None) ^ (long_size is None)\n    w, h = img.width, img.height\n    if short_size is not None:\n        if w < h:\n            neww = short_size\n            newh = int(short_size / float(w) * h)\n        else:\n            neww = int(short_size / float(h) * w)\n            newh = short_size\n    else:\n        if w < h:\n            neww = int(long_size / float(h) * w)\n            newh = long_size\n        else:\n            neww = long_size\n            newh = int(long_size / float(w) * h)\n    img = img.resize((neww, newh), Image.BICUBIC)\n    return img, [w, h]\n\n\ndef image_pose_crop(img, posemap, crop_size, scale):\n    assert len(crop_size) == 2\n    assert crop_size[0] <= img.height\n    assert crop_size[1] <= img.width\n    hoff = (img.height - crop_size[0]) // 2\n    woff = (img.width - crop_size[1]) // 2\n    img = img.crop((woff, hoff, woff+crop_size[1], hoff+crop_size[0]))\n    posemap = posemap[hoff//scale:hoff//scale+crop_size[0]//scale, woff//scale:woff//scale+crop_size[1]//scale,:]\n    return img, posemap\n\ndef neighbor_elim(ph, pw, d):\n    valid = np.ones((len(ph))).astype(np.int)\n    h_dist = np.fabs(np.tile(ph[:,np.newaxis], [1,len(ph)]) - np.tile(ph.T[np.newaxis,:], [len(ph),1]))\n    w_dist = np.fabs(np.tile(pw[:,np.newaxis], [1,len(pw)]) - np.tile(pw.T[np.newaxis,:], [len(pw),1]))\n    idx1, idx2 = np.where((h_dist < d) & (w_dist < d))\n    for i,j in zip(idx1, idx2):\n        if valid[i] and valid[j] and i != j:\n            if np.random.rand() > 0.5:\n                valid[i] = 0\n            else:\n                valid[j] = 0\n    valid_idx = np.where(valid==1)\n    return ph[valid_idx], pw[valid_idx]\n\ndef remove_border(mask):\n        mask[0,:] = 0\n        mask[:,0] = 0\n        mask[mask.shape[0]-1,:] = 0\n        mask[:,mask.shape[1]-1] = 0\n\ndef flow_sampler(flow, strategy=['grid'], bg_ratio=1./6400, nms_ks=15, max_num_guide=-1, guidepoint=None):\n    assert bg_ratio >= 0 and bg_ratio <= 1, \"sampling ratio must be in (0, 1]\"\n    for s in strategy:\n        assert s in ['grid', 'uniform', 'gradnms', 'watershed', 'single', 'full', 'specified'], \"No such strategy: {}\".format(s)\n    h = flow.shape[0]\n    w = flow.shape[1]\n    ds = max(1, max(h, w) // 400) # reduce computation\n\n    if 'full' in strategy:\n        sparse = flow.copy()\n        mask = np.ones(flow.shape, dtype=np.int)\n        return sparse, mask\n\n    pts_h = []\n    pts_w = []\n    if 'grid' in strategy:\n        stride = int(np.sqrt(1./bg_ratio))\n        mesh_start_h = int((h - h // stride * stride) / 2)\n        mesh_start_w = int((w - w // stride * stride) / 2)\n        mesh = np.meshgrid(np.arange(mesh_start_h, h, stride), np.arange(mesh_start_w, w, stride))\n        pts_h.append(mesh[0].flat)\n        pts_w.append(mesh[1].flat)\n    if 'uniform' in strategy:\n        pts_h.append(np.random.randint(0, h, int(bg_ratio * h * w)))\n        pts_w.append(np.random.randint(0, w, int(bg_ratio * h * w)))\n    if \"gradnms\" in strategy:\n        ks = w // ds // 20\n        edge = get_edge(flow[::ds,::ds,:])\n        kernel = np.ones((ks, ks), dtype=np.float32) / (ks * ks)\n        subkernel = np.ones((ks//2, ks//2), dtype=np.float32) / (ks//2 * ks//2)\n        score = signal.convolve2d(edge, kernel, boundary='symm', mode='same')\n        subscore = signal.convolve2d(edge, subkernel, boundary='symm', mode='same')\n        score = score / score.max() - subscore / subscore.max()\n        nms_res = nms(score, nms_ks)\n        pth, ptw = np.where(nms_res > 0.1)\n        pts_h.append(pth * ds)\n        pts_w.append(ptw * ds)\n    if \"watershed\" in strategy:\n        edge = get_edge(flow[::ds,::ds,:])\n        edge /= max(edge.max(), 0.01)\n        edge = (edge > 0.1).astype(np.float32)\n        watershed = ndimage.distance_transform_edt(1-edge)\n        nms_res = nms(watershed, nms_ks)\n        remove_border(nms_res)\n        pth, ptw = np.where(nms_res > 0)\n        pth, ptw = neighbor_elim(pth, ptw, (nms_ks-1)/2)\n        pts_h.append(pth * ds)\n        pts_w.append(ptw * ds)\n    if \"single\" in strategy:\n        pth, ptw = np.where((flow[:,:,0] != 0) | (flow[:,:,1] != 0))\n        randidx = np.random.randint(len(pth))\n        pts_h.append(pth[randidx:randidx+1])\n        pts_w.append(ptw[randidx:randidx+1])\n    if 'specified' in strategy:\n        assert guidepoint is not None, \"if using \\\"specified\\\", switch \\\"with_info\\\" on.\"\n        pts_h.append(guidepoint[:,1])\n        pts_w.append(guidepoint[:,0])\n\n    pts_h = np.concatenate(pts_h)\n    pts_w = np.concatenate(pts_w)\n\n    if max_num_guide == -1:\n        max_num_guide = np.inf\n\n    randsel = np.random.permutation(len(pts_h))[:len(pts_h)]\n    selidx = randsel[np.arange(min(max_num_guide, len(randsel)))]\n    pts_h = pts_h[selidx]\n    pts_w = pts_w[selidx]\n\n    sparse = np.zeros(flow.shape, dtype=flow.dtype)\n    mask = np.zeros(flow.shape, dtype=np.int)\n    \n    sparse[:, :, 0][(pts_h, pts_w)] = flow[:, :, 0][(pts_h, pts_w)]\n    sparse[:, :, 1][(pts_h, pts_w)] = flow[:, :, 1][(pts_h, pts_w)]\n    \n    mask[:,:,0][(pts_h, pts_w)] = 1\n    mask[:,:,1][(pts_h, pts_w)] = 1\n    return sparse, mask\n\ndef image_flow_aug(img1, img2, flow, flip_horizon=True):\n    if flip_horizon:\n        if random.random() < 0.5:\n            img1 = img1.transpose(Image.FLIP_LEFT_RIGHT)\n            img2 = img2.transpose(Image.FLIP_LEFT_RIGHT)\n            flow = flow[:,::-1,:].copy()\n            flow[:,:,0] = -flow[:,:,0]\n    return img1, img2, flow\n\ndef flow_aug(flow, reverse=True, scale=True, rotate=True):\n    if reverse:\n        if random.random() < 0.5:\n            flow = -flow\n    if scale:\n        rand_scale = random.uniform(0.5, 2.0)\n        flow = flow * rand_scale\n    if rotate and random.random() < 0.5:\n        lengh = np.sqrt(np.square(flow[:,:,0]) + np.square(flow[:,:,1]))\n        alpha = np.arctan(flow[:,:,1] / flow[:,:,0])\n        theta = random.uniform(0, np.pi*2)\n        flow[:,:,0] = lengh * np.cos(alpha + theta)\n        flow[:,:,1] = lengh * np.sin(alpha + theta)\n    return flow\n\ndef draw_gaussian(img, pt, sigma, type='Gaussian'):\n    # Check that any part of the gaussian is in-bounds\n    ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]\n    br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]\n    if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or\n            br[0] < 0 or br[1] < 0):\n        # If not, just return the image as is\n        return img\n\n    # Generate gaussian\n    size = 6 * sigma + 1\n    x = np.arange(0, size, 1, float)\n    y = x[:, np.newaxis]\n    x0 = y0 = size // 2\n    # The gaussian is not normalized, we want the center value to equal 1\n    if type == 'Gaussian':\n        g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))\n    elif type == 'Cauchy':\n        g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)\n\n    # Usable gaussian range\n    g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]\n    g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]\n    # Image range\n    img_x = max(0, ul[0]), min(br[0], img.shape[1])\n    img_y = max(0, ul[1]), min(br[1], img.shape[0])\n\n    img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]\n    return img\n\n\n"
  },
  {
    "path": "mimicmotion/modules/cmp/utils/distributed_utils.py",
    "content": "import os\nimport subprocess\nimport numpy as np\nimport multiprocessing as mp\nimport math\n\nimport torch\nimport torch.distributed as dist\nfrom torch.utils.data.sampler import Sampler\nfrom torch.nn import Module\n\nclass DistModule(Module):\n    def __init__(self, module):\n        super(DistModule, self).__init__()\n        self.module = module\n        broadcast_params(self.module)\n    def forward(self, *inputs, **kwargs):\n        return self.module(*inputs, **kwargs)\n    def train(self, mode=True):\n        super(DistModule, self).train(mode)\n        self.module.train(mode)\n\ndef average_gradients(model):\n    \"\"\" average gradients \"\"\"\n    for param in model.parameters():\n        if param.requires_grad:\n            dist.all_reduce(param.grad.data)\n\ndef broadcast_params(model):\n    \"\"\" broadcast model parameters \"\"\"\n    for p in model.state_dict().values():\n        dist.broadcast(p, 0)\n\ndef dist_init(launcher, backend='nccl', **kwargs):\n    if mp.get_start_method(allow_none=True) is None:\n        mp.set_start_method('spawn')\n    if launcher == 'pytorch':\n        _init_dist_pytorch(backend, **kwargs)\n    elif launcher == 'mpi':\n        _init_dist_mpi(backend, **kwargs)\n    elif launcher == 'slurm':\n        _init_dist_slurm(backend, **kwargs)\n    else:\n        raise ValueError('Invalid launcher type: {}'.format(launcher))\n\ndef _init_dist_pytorch(backend, **kwargs):\n    rank = int(os.environ['RANK'])\n    num_gpus = torch.cuda.device_count()\n    torch.cuda.set_device(rank % num_gpus)\n    dist.init_process_group(backend=backend, **kwargs)\n\ndef _init_dist_mpi(backend, **kwargs):\n    raise NotImplementedError\n\ndef _init_dist_slurm(backend, port=10086, **kwargs):\n    proc_id = int(os.environ['SLURM_PROCID'])\n    ntasks = int(os.environ['SLURM_NTASKS'])\n    node_list = os.environ['SLURM_NODELIST']\n    num_gpus = torch.cuda.device_count()\n    torch.cuda.set_device(proc_id % num_gpus)\n    addr = subprocess.getoutput(\n        'scontrol show hostname {} | head -n1'.format(node_list))\n    os.environ['MASTER_PORT'] = str(port)\n    os.environ['MASTER_ADDR'] = addr\n    os.environ['WORLD_SIZE'] = str(ntasks)\n    os.environ['RANK'] = str(proc_id)\n    dist.init_process_group(backend=backend)\n\ndef gather_tensors(input_array):\n    world_size = dist.get_world_size()\n    ## gather shapes first\n    myshape = input_array.shape\n    mycount = input_array.size\n    shape_tensor = torch.Tensor(np.array(myshape)).cuda()\n    all_shape = [torch.Tensor(np.array(myshape)).cuda() for i in range(world_size)]\n    dist.all_gather(all_shape, shape_tensor)\n    ## compute largest shapes\n    all_shape = [x.cpu().numpy() for x in all_shape]\n    all_count = [int(x.prod()) for x in all_shape]\n    all_shape = [list(map(int, x)) for x in all_shape]\n    max_count = max(all_count)\n    ## padding tensors and gather them\n    output_tensors = [torch.Tensor(max_count).cuda() for i in range(world_size)]\n    padded_input_array = np.zeros(max_count)\n    padded_input_array[:mycount] = input_array.reshape(-1)\n    input_tensor = torch.Tensor(padded_input_array).cuda()\n    dist.all_gather(output_tensors, input_tensor)\n    ## unpadding gathered tensors\n    padded_output = [x.cpu().numpy() for x in output_tensors]\n    output = [x[:all_count[i]].reshape(all_shape[i]) for i,x in enumerate(padded_output)]\n    return output\n\ndef gather_tensors_batch(input_array, part_size=10):\n    # gather\n    rank = dist.get_rank()\n    all_features = []\n    part_num = input_array.shape[0] // part_size + 1 if input_array.shape[0] % part_size != 0 else input_array.shape[0] // part_size\n    for i in range(part_num):\n        part_feat = input_array[i * part_size:min((i+1)*part_size, input_array.shape[0]),...]\n        assert part_feat.shape[0] > 0, \"rank: {}, length of part features should > 0\".format(rank)\n        print(\"rank: {}, gather part: {}/{}, length: {}\".format(rank, i, part_num, len(part_feat)))\n        gather_part_feat = gather_tensors(part_feat)\n        all_features.append(gather_part_feat)\n    print(\"rank: {}, gather done.\".format(rank))\n    all_features = np.concatenate([np.concatenate([all_features[i][j] for i in range(part_num)], axis=0) for j in range(len(all_features[0]))], axis=0)\n    return all_features\n\ndef reduce_tensors(tensor):\n    reduced_tensor = tensor.clone()\n    dist.all_reduce(reduced_tensor)\n    return reduced_tensor\n\nclass DistributedSequentialSampler(Sampler):\n    def __init__(self, dataset, world_size=None, rank=None):\n        if world_size == None:\n            world_size = dist.get_world_size()\n        if rank == None:\n            rank = dist.get_rank()\n        self.dataset = dataset\n        self.world_size = world_size\n        self.rank = rank\n        assert len(self.dataset) >= self.world_size, '{} vs {}'.format(len(self.dataset), self.world_size)\n        sub_num = int(math.ceil(len(self.dataset) * 1.0 / self.world_size))\n        self.beg = sub_num * self.rank\n        #self.end = min(self.beg+sub_num, len(self.dataset))\n        self.end = self.beg + sub_num\n        self.padded_ind = list(range(len(self.dataset))) + list(range(sub_num * self.world_size - len(self.dataset)))\n\n    def __iter__(self):\n        indices = [self.padded_ind[i] for i in range(self.beg, self.end)]\n        return iter(indices)\n\n    def __len__(self):\n        return self.end - self.beg\n\nclass GivenIterationSampler(Sampler):\n    def __init__(self, dataset, total_iter, batch_size, last_iter=-1):\n        self.dataset = dataset\n        self.total_iter = total_iter\n        self.batch_size = batch_size\n        self.last_iter = last_iter\n\n        self.total_size = self.total_iter * self.batch_size\n        self.indices = self.gen_new_list()\n        self.call = 0\n\n    def __iter__(self):\n        if self.call == 0:\n            self.call = 1\n            return iter(self.indices[(self.last_iter + 1) * self.batch_size:])\n        else:\n            raise RuntimeError(\"this sampler is not designed to be called more than once!!\")\n\n    def gen_new_list(self):\n\n        # each process shuffle all list with same seed, and pick one piece according to rank\n        np.random.seed(0)\n\n        all_size = self.total_size\n        indices = np.arange(len(self.dataset))\n        indices = indices[:all_size]\n        num_repeat = (all_size-1) // indices.shape[0] + 1\n        indices = np.tile(indices, num_repeat)\n        indices = indices[:all_size]\n\n        np.random.shuffle(indices)\n\n        assert len(indices) == self.total_size\n\n        return indices\n\n    def __len__(self):\n        return self.total_size\n\n\nclass DistributedGivenIterationSampler(Sampler):\n    def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=-1):\n        if world_size is None:\n            world_size = dist.get_world_size()\n        if rank is None:\n            rank = dist.get_rank()\n        assert rank < world_size\n        self.dataset = dataset\n        self.total_iter = total_iter\n        self.batch_size = batch_size\n        self.world_size = world_size\n        self.rank = rank\n        self.last_iter = last_iter\n\n        self.total_size = self.total_iter*self.batch_size\n\n        self.indices = self.gen_new_list()\n        self.call = 0\n\n    def __iter__(self):\n        if self.call == 0:\n            self.call = 1\n            return iter(self.indices[(self.last_iter+1)*self.batch_size:])\n        else:\n            raise RuntimeError(\"this sampler is not designed to be called more than once!!\")\n\n    def gen_new_list(self):\n\n        # each process shuffle all list with same seed, and pick one piece according to rank\n        np.random.seed(0)\n\n        all_size = self.total_size * self.world_size\n        indices = np.arange(len(self.dataset))\n        indices = indices[:all_size]\n        num_repeat = (all_size-1) // indices.shape[0] + 1\n        indices = np.tile(indices, num_repeat)\n        indices = indices[:all_size]\n\n        np.random.shuffle(indices)\n        beg = self.total_size * self.rank\n        indices = indices[beg:beg+self.total_size]\n\n        assert len(indices) == self.total_size\n\n        return indices\n\n    def __len__(self):\n        # note here we do not take last iter into consideration, since __len__\n        # should only be used for displaying, the correct remaining size is\n        # handled by dataloader\n        #return self.total_size - (self.last_iter+1)*self.batch_size\n        return self.total_size\n\n\n"
  },
  {
    "path": "mimicmotion/modules/cmp/utils/flowlib.py",
    "content": "#!/usr/bin/python\n\"\"\"\n# ==============================\n# flowlib.py\n# library for optical flow processing\n# Author: Ruoteng Li\n# Date: 6th Aug 2016\n# ==============================\n\"\"\"\n#import png\nimport numpy as np\nfrom PIL import Image\nimport io\n\nUNKNOWN_FLOW_THRESH = 1e7\nSMALLFLOW = 0.0\nLARGEFLOW = 1e8\n\n\"\"\"\n=============\nFlow Section\n=============\n\"\"\"\n\ndef write_flow(flow, filename):\n    \"\"\"\n    write optical flow in Middlebury .flo format\n    :param flow: optical flow map\n    :param filename: optical flow file path to be saved\n    :return: None\n    \"\"\"\n    f = open(filename, 'wb')\n    magic = np.array([202021.25], dtype=np.float32)\n    (height, width) = flow.shape[0:2]\n    w = np.array([width], dtype=np.int32)\n    h = np.array([height], dtype=np.int32)\n    magic.tofile(f)\n    w.tofile(f)\n    h.tofile(f)\n    flow.tofile(f)\n    f.close()\n\n\ndef save_flow_image(flow, image_file):\n    \"\"\"\n    save flow visualization into image file\n    :param flow: optical flow data\n    :param flow_fil\n    :return: None\n    \"\"\"\n    flow_img = flow_to_image(flow)\n    img_out = Image.fromarray(flow_img)\n    img_out.save(image_file)\n\ndef segment_flow(flow):\n    h = flow.shape[0]\n    w = flow.shape[1]\n    u = flow[:, :, 0]\n    v = flow[:, :, 1]\n\n    idx = ((abs(u) > LARGEFLOW) | (abs(v) > LARGEFLOW))\n    idx2 = (abs(u) == SMALLFLOW)\n    class0 = (v == 0) & (u == 0)\n    u[idx2] = 0.00001\n    tan_value = v / u\n\n    class1 = (tan_value < 1) & (tan_value >= 0) & (u > 0) & (v >= 0)\n    class2 = (tan_value >= 1) & (u >= 0) & (v >= 0)\n    class3 = (tan_value < -1) & (u <= 0) & (v >= 0)\n    class4 = (tan_value < 0) & (tan_value >= -1) & (u < 0) & (v >= 0)\n    class8 = (tan_value >= -1) & (tan_value < 0) & (u > 0) & (v <= 0)\n    class7 = (tan_value < -1) & (u >= 0) & (v <= 0)\n    class6 = (tan_value >= 1) & (u <= 0) & (v <= 0)\n    class5 = (tan_value >= 0) & (tan_value < 1) & (u < 0) & (v <= 0)\n\n    seg = np.zeros((h, w))\n\n    seg[class1] = 1\n    seg[class2] = 2\n    seg[class3] = 3\n    seg[class4] = 4\n    seg[class5] = 5\n    seg[class6] = 6\n    seg[class7] = 7\n    seg[class8] = 8\n    seg[class0] = 0\n    seg[idx] = 0\n\n    return seg\n\ndef flow_to_image(flow):\n    \"\"\"\n    Convert flow into middlebury color code image\n    :param flow: optical flow map\n    :return: optical flow image in middlebury color\n    \"\"\"\n    u = flow[:, :, 0]\n    v = flow[:, :, 1]\n\n    maxu = -999.\n    maxv = -999.\n    minu = 999.\n    minv = 999.\n\n    idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)\n    u[idxUnknow] = 0\n    v[idxUnknow] = 0\n\n    maxu = max(maxu, np.max(u))\n    minu = min(minu, np.min(u))\n\n    maxv = max(maxv, np.max(v))\n    minv = min(minv, np.min(v))\n\n    rad = np.sqrt(u ** 2 + v ** 2)\n    maxrad = max(5, np.max(rad))\n    #maxrad = max(-1, 99)\n\n    u = u/(maxrad + np.finfo(float).eps)\n    v = v/(maxrad + np.finfo(float).eps)\n\n    img = compute_color(u, v)\n\n    idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)\n    img[idx] = 0\n\n    return np.uint8(img)\n\ndef disp_to_flowfile(disp, filename):\n    \"\"\"\n    Read KITTI disparity file in png format\n    :param disp: disparity matrix\n    :param filename: the flow file name to save\n    :return: None\n    \"\"\"\n    f = open(filename, 'wb')\n    magic = np.array([202021.25], dtype=np.float32)\n    (height, width) = disp.shape[0:2]\n    w = np.array([width], dtype=np.int32)\n    h = np.array([height], dtype=np.int32)\n    empty_map = np.zeros((height, width), dtype=np.float32)\n    data = np.dstack((disp, empty_map))\n    magic.tofile(f)\n    w.tofile(f)\n    h.tofile(f)\n    data.tofile(f)\n    f.close()\n\ndef compute_color(u, v):\n    \"\"\"\n    compute optical flow color map\n    :param u: optical flow horizontal map\n    :param v: optical flow vertical map\n    :return: optical flow in color code\n    \"\"\"\n    [h, w] = u.shape\n    img = np.zeros([h, w, 3])\n    nanIdx = np.isnan(u) | np.isnan(v)\n    u[nanIdx] = 0\n    v[nanIdx] = 0\n\n    colorwheel = make_color_wheel()\n    ncols = np.size(colorwheel, 0)\n\n    rad = np.sqrt(u**2+v**2)\n\n    a = np.arctan2(-v, -u) / np.pi\n\n    fk = (a+1) / 2 * (ncols - 1) + 1\n\n    k0 = np.floor(fk).astype(int)\n\n    k1 = k0 + 1\n    k1[k1 == ncols+1] = 1\n    f = fk - k0\n\n    for i in range(0, np.size(colorwheel,1)):\n        tmp = colorwheel[:, i]\n        col0 = tmp[k0-1] / 255\n        col1 = tmp[k1-1] / 255\n        col = (1-f) * col0 + f * col1\n\n        idx = rad <= 1\n        col[idx] = 1-rad[idx]*(1-col[idx])\n        notidx = np.logical_not(idx)\n\n        col[notidx] *= 0.75\n        img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))\n\n    return img\n\n\ndef make_color_wheel():\n    \"\"\"\n    Generate color wheel according Middlebury color code\n    :return: Color wheel\n    \"\"\"\n    RY = 15\n    YG = 6\n    GC = 4\n    CB = 11\n    BM = 13\n    MR = 6\n\n    ncols = RY + YG + GC + CB + BM + MR\n\n    colorwheel = np.zeros([ncols, 3])\n\n    col = 0\n\n    # RY\n    colorwheel[0:RY, 0] = 255\n    colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))\n    col += RY\n\n    # YG\n    colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))\n    colorwheel[col:col+YG, 1] = 255\n    col += YG\n\n    # GC\n    colorwheel[col:col+GC, 1] = 255\n    colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))\n    col += GC\n\n    # CB\n    colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))\n    colorwheel[col:col+CB, 2] = 255\n    col += CB\n\n    # BM\n    colorwheel[col:col+BM, 2] = 255\n    colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))\n    col += + BM\n\n    # MR\n    colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))\n    colorwheel[col:col+MR, 0] = 255\n\n    return colorwheel\n\n\ndef read_flo_file(filename, memcached=False):\n    \"\"\"\n    Read from Middlebury .flo file\n    :param flow_file: name of the flow file\n    :return: optical flow data in matrix\n    \"\"\"\n    if memcached:\n        filename = io.BytesIO(filename)\n    f = open(filename, 'rb')\n    magic = np.fromfile(f, np.float32, count=1)[0]\n    data2d = None\n\n    if 202021.25 != magic:\n        print('Magic number incorrect. Invalid .flo file')\n    else:\n        w = np.fromfile(f, np.int32, count=1)[0]\n        h = np.fromfile(f, np.int32, count=1)[0]\n        data2d = np.fromfile(f, np.float32, count=2 * w * h)\n        # reshape data into 3D array (columns, rows, channels)\n        data2d = np.resize(data2d, (h, w, 2))\n    f.close()\n    return data2d\n\n\n# fast resample layer\ndef resample(img, sz):\n    \"\"\"\n    img: flow map to be resampled\n    sz: new flow map size. Must be [height,weight]\n    \"\"\"\n    original_image_size = img.shape\n    in_height = img.shape[0]\n    in_width = img.shape[1]\n    out_height = sz[0]\n    out_width = sz[1]\n    out_flow = np.zeros((out_height, out_width, 2))\n    # find scale\n    height_scale =  float(in_height) / float(out_height)\n    width_scale =  float(in_width) / float(out_width)\n\n    [x,y] = np.meshgrid(range(out_width), range(out_height))\n    xx = x * width_scale\n    yy = y * height_scale\n    x0 = np.floor(xx).astype(np.int32)\n    x1 = x0 + 1\n    y0 = np.floor(yy).astype(np.int32)\n    y1 = y0 + 1\n\n    x0 = np.clip(x0,0,in_width-1)\n    x1 = np.clip(x1,0,in_width-1)\n    y0 = np.clip(y0,0,in_height-1)\n    y1 = np.clip(y1,0,in_height-1)\n\n    Ia = img[y0,x0,:]\n    Ib = img[y1,x0,:]\n    Ic = img[y0,x1,:]\n    Id = img[y1,x1,:]\n\n    wa = (y1-yy) * (x1-xx)\n    wb = (yy-y0) * (x1-xx)\n    wc = (y1-yy) * (xx-x0)\n    wd = (yy-y0) * (xx-x0)\n    out_flow[:,:,0] = (Ia[:,:,0]*wa + Ib[:,:,0]*wb + Ic[:,:,0]*wc + Id[:,:,0]*wd) * out_width / in_width\n    out_flow[:,:,1] = (Ia[:,:,1]*wa + Ib[:,:,1]*wb + Ic[:,:,1]*wc + Id[:,:,1]*wd) * out_height / in_height\n\n    return out_flow\n"
  },
  {
    "path": "mimicmotion/modules/cmp/utils/scheduler.py",
    "content": "import torch\nfrom bisect import bisect_right\n\nclass _LRScheduler(object):\n    def __init__(self, optimizer, last_iter=-1):\n        if not isinstance(optimizer, torch.optim.Optimizer):\n            raise TypeError('{} is not an Optimizer'.format(\n                type(optimizer).__name__))\n        self.optimizer = optimizer\n        if last_iter == -1:\n            for group in optimizer.param_groups:\n                group.setdefault('initial_lr', group['lr'])\n        else:\n            for i, group in enumerate(optimizer.param_groups):\n                if 'initial_lr' not in group:\n                    raise KeyError(\"param 'initial_lr' is not specified \"\n                                   \"in param_groups[{}] when resuming an optimizer\".format(i))\n        self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))\n        self.last_iter = last_iter\n\n    def _get_new_lr(self):\n        raise NotImplementedError\n\n    def get_lr(self):\n        return list(map(lambda group: group['lr'], self.optimizer.param_groups))\n\n    def step(self, this_iter=None):\n        if this_iter is None:\n            this_iter = self.last_iter + 1\n        self.last_iter = this_iter\n        for param_group, lr in zip(self.optimizer.param_groups, self._get_new_lr()):\n            param_group['lr'] = lr\n\nclass _WarmUpLRSchedulerOld(_LRScheduler):\n\n    def __init__(self, optimizer, base_lr, warmup_lr, warmup_steps, last_iter=-1):\n        self.base_lr = base_lr\n        self.warmup_steps = warmup_steps\n        if warmup_steps == 0:\n            self.warmup_lr = base_lr\n        else:\n            self.warmup_lr = warmup_lr\n        super(_WarmUpLRSchedulerOld, self).__init__(optimizer, last_iter)\n    \n    def _get_warmup_lr(self):\n        if self.warmup_steps > 0 and self.last_iter < self.warmup_steps:\n            # first compute relative scale for self.base_lr, then multiply to base_lr\n            scale = ((self.last_iter/self.warmup_steps)*(self.warmup_lr - self.base_lr) + self.base_lr)/self.base_lr\n            #print('last_iter: {}, warmup_lr: {}, base_lr: {}, scale: {}'.format(self.last_iter, self.warmup_lr, self.base_lr, scale))\n            return [scale * base_lr for base_lr in self.base_lrs]\n        else:\n            return None\n\nclass _WarmUpLRScheduler(_LRScheduler):\n\n    def __init__(self, optimizer, base_lr, warmup_lr, warmup_steps, last_iter=-1):\n        self.base_lr = base_lr\n        self.warmup_lr = warmup_lr\n        self.warmup_steps = warmup_steps\n        assert isinstance(warmup_lr, list)\n        assert isinstance(warmup_steps, list)\n        assert len(warmup_lr) == len(warmup_steps)\n        super(_WarmUpLRScheduler, self).__init__(optimizer, last_iter)\n    \n    def _get_warmup_lr(self):\n        pos = bisect_right(self.warmup_steps, self.last_iter)\n        if pos >= len(self.warmup_steps):\n            return None\n        else:\n            if pos == 0:\n                curr_lr = self.base_lr + self.last_iter * (self.warmup_lr[pos] - self.base_lr) / self.warmup_steps[pos]\n            else:\n                curr_lr = self.warmup_lr[pos - 1] + (self.last_iter - self.warmup_steps[pos - 1]) * (self.warmup_lr[pos] - self.warmup_lr[pos - 1]) / (self.warmup_steps[pos] - self.warmup_steps[pos - 1])\n        scale = curr_lr / self.base_lr\n        return [scale * base_lr for base_lr in self.base_lrs]\n\nclass StepLRScheduler(_WarmUpLRScheduler):\n    def __init__(self, optimizer, milestones, lr_mults, base_lr, warmup_lr, warmup_steps, last_iter=-1):\n        super(StepLRScheduler, self).__init__(optimizer, base_lr, warmup_lr, warmup_steps, last_iter)\n\n        assert len(milestones) == len(lr_mults), \"{} vs {}\".format(milestones, lr_mults)\n        for x in milestones:\n            assert isinstance(x, int)\n        if not list(milestones) == sorted(milestones):\n            raise ValueError('Milestones should be a list of'\n                             ' increasing integers. Got {}', milestones)\n        self.milestones = milestones\n        self.lr_mults = [1.0]\n        for x in lr_mults:\n            self.lr_mults.append(self.lr_mults[-1]*x)\n    \n    def _get_new_lr(self):\n        warmup_lrs = self._get_warmup_lr()\n        if warmup_lrs is not None:\n            return warmup_lrs\n\n        pos = bisect_right(self.milestones, self.last_iter)\n        if len(self.warmup_lr) == 0:\n            scale = self.lr_mults[pos]\n        else:\n            scale = self.warmup_lr[-1] * self.lr_mults[pos] / self.base_lr\n        return [base_lr * scale for base_lr in self.base_lrs]\n"
  },
  {
    "path": "mimicmotion/modules/cmp/utils/visualize_utils.py",
    "content": "import numpy as np\n\nimport torch\nfrom . import flowlib\n\nclass Fuser(object):\n    def __init__(self, nbins, fmax):\n        self.nbins = nbins\n        self.fmax = fmax\n        self.step = 2 * fmax / float(nbins)\n        self.mesh = torch.arange(nbins).view(1,-1,1,1).float().cuda() * self.step - fmax + self.step / 2\n\n    def convert_flow(self, flow_prob):\n        flow_probx = torch.nn.functional.softmax(flow_prob[:, :self.nbins, :, :], dim=1)\n        flow_proby = torch.nn.functional.softmax(flow_prob[:, self.nbins:, :, :], dim=1)\n        flow_probx = flow_probx * self.mesh\n        flow_proby = flow_proby * self.mesh\n        flow = torch.cat([flow_probx.sum(dim=1, keepdim=True), flow_proby.sum(dim=1, keepdim=True)], dim=1)\n        return flow\n\ndef visualize_tensor_old(image, mask, flow_pred, flow_target, warped, rgb_gen, image_target, image_mean, image_div):\n    together = [\n        draw_cross(unormalize(image.cpu(), mean=image_mean, div=image_div), mask.cpu(), radius=int(image.size(3) / 50.)),\n        flow_to_image(flow_pred.detach().cpu()),\n        flow_to_image(flow_target.detach().cpu())]\n    if warped is not None:\n        together.append(torch.clamp(unormalize(warped.detach().cpu(), mean=image_mean, div=image_div), 0, 255))\n    if rgb_gen is not None:\n        together.append(torch.clamp(unormalize(rgb_gen.detach().cpu(), mean=image_mean, div=image_div), 0, 255))\n    if image_target is not None:\n        together.append(torch.clamp(unormalize(image_target.cpu(), mean=image_mean, div=image_div), 0, 255))\n    together = torch.cat(together, dim=3)\n    return together\n\ndef visualize_tensor(image, mask, flow_tensors, common_tensors, rgb_tensors, image_mean, image_div):\n    together = [\n        draw_cross(unormalize(image.cpu(), mean=image_mean, div=image_div), mask.cpu(), radius=int(image.size(3) / 50.))]\n    for ft in flow_tensors:\n        together.append(flow_to_image(ft.cpu()))\n    for ct in common_tensors:\n        together.append(torch.clamp(ct.cpu(), 0, 255))\n    for rt in rgb_tensors:\n        together.append(torch.clamp(unormalize(rt.cpu(), mean=image_mean, div=image_div), 0, 255))\n    together = torch.cat(together, dim=3)\n    return together\n\n\ndef unormalize(tensor, mean, div):\n    for c, (m, d) in enumerate(zip(mean, div)):\n        tensor[:,c,:,:].mul_(d).add_(m)\n    return tensor\n\n\ndef flow_to_image(flow):\n    flow = flow.numpy()\n    flow_img = np.array([flowlib.flow_to_image(fl.transpose((1,2,0))).transpose((2,0,1)) for fl in flow]).astype(np.float32)\n    return torch.from_numpy(flow_img)\n\ndef shift_tensor(input, offh, offw):\n    new = torch.zeros(input.size())\n    h = input.size(2)\n    w = input.size(3)\n    new[:,:,max(0,offh):min(h,h+offh),max(0,offw):min(w,w+offw)] = input[:,:,max(0,-offh):min(h,h-offh),max(0,-offw):min(w,w-offw)]\n    return new\n\ndef draw_block(mask, radius=5):\n    '''\n    input:  tensor (NxCxHxW)\n    output: block_mask (Nx1xHxW)\n    '''\n    all_mask = []\n    mask = mask[:,0:1,:,:]\n    for offh in range(-radius, radius+1):\n        for offw in range(-radius, radius+1):\n            all_mask.append(shift_tensor(mask, offh, offw))\n    block_mask = sum(all_mask)\n    block_mask[block_mask > 0] = 1\n    return block_mask\n\ndef expand_block(sparse, radius=5):\n    '''\n    input:  sparse (NxCxHxW)\n    output: block_sparse (NxCxHxW)\n    '''\n    all_sparse = []\n    for offh in range(-radius, radius+1):\n        for offw in range(-radius, radius+1):\n            all_sparse.append(shift_tensor(sparse, offh, offw))\n    block_sparse = sum(all_sparse)\n    return block_sparse\n\ndef draw_cross(tensor, mask, radius=5, thickness=2):\n    '''\n    input:  tensor (NxCxHxW)\n            mask (NxXxHxW)\n    output: new_tensor (NxCxHxW)\n    '''\n    all_mask = []\n    mask = mask[:,0:1,:,:]\n    for off in range(-radius, radius+1):\n        for t in range(-thickness, thickness+1):\n            all_mask.append(shift_tensor(mask, off, t))\n            all_mask.append(shift_tensor(mask, t, off))\n    cross_mask = sum(all_mask)\n    new_tensor = tensor.clone()\n    new_tensor[:,0:1,:,:][cross_mask > 0] = 255.0\n    new_tensor[:,1:2,:,:][cross_mask > 0] = 0.0\n    new_tensor[:,2:3,:,:][cross_mask > 0] = 0.0\n    return new_tensor\n"
  },
  {
    "path": "mimicmotion/modules/cmp_model.py",
    "content": "from typing import Any, Dict, List, Optional, Tuple, Union\nfrom dataclasses import dataclass\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom diffusers.configuration_utils import register_to_config\nfrom diffusers.utils import BaseOutput\n\nimport mimicmotion.modules.cmp.models as cmp_models\nimport mimicmotion.modules.cmp.utils as cmp_utils\n\nimport yaml\nimport os\nimport torchvision.transforms as transforms\n\n\nclass ArgObj(object):\n    def __init__(self):\n        pass\n\n\nclass CMP(nn.Module):\n    def __init__(self, configfn, load_iter):\n        super().__init__()\n        args = ArgObj()\n        with open(configfn) as f:\n            config = yaml.full_load(f)\n        for k, v in config.items():\n            setattr(args, k, v)\n        setattr(args, 'load_iter', load_iter)\n        setattr(args, 'exp_path', os.path.dirname(configfn))\n        \n        self.model = cmp_models.__dict__[args.model['arch']](args.model, dist_model=False)\n        self.model.load_state(\"{}/checkpoints\".format(args.exp_path), args.load_iter, False)        \n        self.model.switch_to('eval')\n        \n        self.data_mean = args.data['data_mean']\n        self.data_div = args.data['data_div']\n        \n        self.img_transform = transforms.Compose([\n            transforms.Normalize(self.data_mean, self.data_div)])\n        \n        self.args = args\n        self.fuser = cmp_utils.Fuser(args.model['module']['nbins'], args.model['module']['fmax'])\n        torch.cuda.synchronize()\n\n    def run(self, image, sparse, mask):\n        dtype = image.dtype\n        image = image * 2 - 1\n        self.model.set_input(image.float(), torch.cat([sparse, mask], dim=1).float(), None)\n        try:\n            cmp_output = self.model.model(self.model.image_input.to(torch.float16), self.model.sparse_input.to(torch.float16))\n        except:\n            cmp_output = self.model.model(self.model.image_input.to(torch.float32), self.model.sparse_input.to(torch.float32))\n        flow = self.fuser.convert_flow(cmp_output)\n        if flow.shape[2] != self.model.image_input.shape[2]:\n            flow = nn.functional.interpolate(\n                flow, size=self.model.image_input.shape[2:4],\n                mode=\"bilinear\", align_corners=True)\n\n        return flow.to(dtype)  # [b, 2, h, w]"
  },
  {
    "path": "mimicmotion/modules/controlnet.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import nn, einsum\nfrom torch.nn import functional as F\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.loaders import FromOriginalControlNetMixin\nfrom diffusers.utils import BaseOutput, logging\nfrom diffusers.models.attention_processor import (\n    ADDED_KV_ATTENTION_PROCESSORS,\n    CROSS_ATTENTION_PROCESSORS,\n    AttentionProcessor,\n    AttnAddedKVProcessor,\n    AttnProcessor,\n)\nfrom diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.models.unets.unet_3d_blocks  import (\n    get_down_block, get_up_block,UNetMidBlockSpatioTemporal,\n)\nfrom diffusers.models import UNetSpatioTemporalConditionModel\nfrom .point_adapter import PointAdapter\nfrom einops import rearrange, repeat\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\nimport pdb  \nfrom inspect import isfunction\n\ndef exists(val):\n    return val is not None\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n@dataclass\nclass ControlNetOutput(BaseOutput):\n    \"\"\"\n    The output of [`ControlNetModel`].\n\n    Args:\n        down_block_res_samples (`tuple[torch.Tensor]`):\n            A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should\n            be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be\n            used to condition the original UNet's downsampling activations.\n        mid_down_block_re_sample (`torch.Tensor`):\n            The activation of the midde block (the lowest sample resolution). Each tensor should be of shape\n            `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.\n            Output can be used to condition the original UNet's middle block activation.\n    \"\"\"\n\n    down_block_res_samples: Tuple[torch.Tensor]\n    mid_block_res_sample: torch.Tensor\n\n\nclass ControlNetConditioningEmbeddingSVD(nn.Module):\n    \"\"\"\n    Quoting from https://arxiv.org/abs/2302.05543: \"Stable Diffusion uses a pre-processing method similar to VQ-GAN\n    [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized\n    training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the\n    convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides\n    (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full\n    model) to encode image-space conditions ... into feature maps ...\"\n    \"\"\"\n\n    def __init__(\n        self,\n        conditioning_embedding_channels: int,\n        conditioning_channels: int = 3,\n        flow_channels: int = 2,\n        dift_channels: int = 640,\n        block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),\n        feature_out_channels: Tuple[int, ...] = (160, 160, 256, 256),\n    ):\n        super().__init__()\n\n        self.conv_in_flow = nn.Conv2d(flow_channels, block_out_channels[0], kernel_size=3, padding=1)\n\n        self.blocks_flow = nn.ModuleList([])\n\n        for i in range(len(block_out_channels) - 1):\n            channel_in = block_out_channels[i]\n            channel_out = block_out_channels[i + 1]\n            self.blocks_flow.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))\n            self.blocks_flow.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))\n\n        self.conv_out_flow = zero_module(\n            nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)\n        )\n\n        self.conv_in_traj = nn.Conv2d(flow_channels, block_out_channels[0], kernel_size=3, padding=1)\n\n        self.blocks_traj = nn.ModuleList([])\n\n        for i in range(len(block_out_channels) - 1):\n            channel_in = block_out_channels[i]\n            channel_out = block_out_channels[i + 1]\n            self.blocks_traj.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))\n            self.blocks_traj.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))\n\n        self.conv_out_traj = zero_module(\n            nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)\n        )\n        \n        self.conv_final = nn.Conv2d(conditioning_embedding_channels, conditioning_embedding_channels, kernel_size=3, padding=1)\n\n        \n    def forward(self, flow_conditioning, traj_conditioning):\n\n        # flow cond ###########\n        zero_flow = torch.zeros_like(flow_conditioning[:, 0, :, :, :]).to(flow_conditioning.device, flow_conditioning.dtype)\n        flow_conditioning = torch.cat([zero_flow.unsqueeze(1),flow_conditioning], dim=1)\n        flow_conditioning = rearrange(flow_conditioning, \"b f c h w -> (b f) c h w\")\n        \n        embedding_flow = self.conv_in_flow(flow_conditioning)\n        embedding_flow = F.silu(embedding_flow)\n\n        for block_flow in self.blocks_flow:\n            embedding_flow = block_flow(embedding_flow)\n            embedding_flow = F.silu(embedding_flow)\n\n        embedding_flow = self.conv_out_flow(embedding_flow)\n\n        # traj cond ###########\n        zero_traj = torch.zeros_like(traj_conditioning[:, 0, :, :, :]).to(traj_conditioning.device, traj_conditioning.dtype)\n        traj_conditioning = torch.cat([zero_traj.unsqueeze(1),traj_conditioning], dim=1)\n        traj_conditioning = rearrange(traj_conditioning, \"b f c h w -> (b f) c h w\")\n        \n        embedding_traj = self.conv_in_traj(traj_conditioning)\n        embedding_traj = F.silu(embedding_traj)\n\n        for block_traj in self.blocks_traj:\n            embedding_traj = block_traj(embedding_traj)\n            embedding_traj = F.silu(embedding_traj)\n\n        embedding_traj = self.conv_out_traj(embedding_traj)\n\n        embedding = self.conv_final(embedding_flow + embedding_traj)\n\n        return embedding\n\n\nclass ControlNetSVDModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):\n    r\"\"\"\n    A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample\n    shaped output.\n\n    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented\n    for all models (such as downloading or saving).\n\n    Parameters:\n        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):\n            Height and width of input/output sample.\n        in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.\n        out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.\n        down_block_types (`Tuple[str]`, *optional*, defaults to `(\"CrossAttnDownBlockSpatioTemporal\", \"CrossAttnDownBlockSpatioTemporal\", \"CrossAttnDownBlockSpatioTemporal\", \"DownBlockSpatioTemporal\")`):\n            The tuple of downsample blocks to use.\n        up_block_types (`Tuple[str]`, *optional*, defaults to `(\"UpBlockSpatioTemporal\", \"CrossAttnUpBlockSpatioTemporal\", \"CrossAttnUpBlockSpatioTemporal\", \"CrossAttnUpBlockSpatioTemporal\")`):\n            The tuple of upsample blocks to use.\n        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):\n            The tuple of output channels for each block.\n        addition_time_embed_dim: (`int`, defaults to 256):\n            Dimension to to encode the additional time ids.\n        projection_class_embeddings_input_dim (`int`, defaults to 768):\n            The dimension of the projection of encoded `added_time_ids`.\n        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.\n        cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):\n            The dimension of the cross attention features.\n        transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):\n            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for\n            [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],\n            [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].\n        num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):\n            The number of attention heads.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    @register_to_config\n    def __init__(\n        self,\n        sample_size: Optional[int] = None,\n        in_channels: int = 8,\n        out_channels: int = 4,\n        down_block_types: Tuple[str] = (\n            \"CrossAttnDownBlockSpatioTemporal\",\n            \"CrossAttnDownBlockSpatioTemporal\",\n            \"CrossAttnDownBlockSpatioTemporal\",\n            \"DownBlockSpatioTemporal\",\n        ),\n        up_block_types: Tuple[str] = (\n            \"UpBlockSpatioTemporal\",\n            \"CrossAttnUpBlockSpatioTemporal\",\n            \"CrossAttnUpBlockSpatioTemporal\",\n            \"CrossAttnUpBlockSpatioTemporal\",\n        ),\n        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),\n        addition_time_embed_dim: int = 256,\n        projection_class_embeddings_input_dim: int = 768,\n        layers_per_block: Union[int, Tuple[int]] = 2,\n        cross_attention_dim: Union[int, Tuple[int]] = 1024,\n        transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,\n        num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),\n        num_frames: int = 25,\n        conditioning_channels: int = 3,\n        conditioning_embedding_out_channels : Optional[Tuple[int, ...]] = (16, 32, 96, 256),\n    ):\n        super().__init__()\n        self.sample_size = sample_size\n\n        print(\"layers per block is\", layers_per_block)\n        \n        # Check inputs\n        if len(down_block_types) != len(up_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}.\"\n            )\n\n        if len(block_out_channels) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}.\"\n            )\n\n        # input\n        self.conv_in = nn.Conv2d(\n            in_channels,\n            block_out_channels[0],\n            kernel_size=3,\n            padding=1,\n        )\n\n        # time\n        time_embed_dim = block_out_channels[0] * 4\n\n        self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)\n        timestep_input_dim = block_out_channels[0]\n\n        self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)\n\n        self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)\n        self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n\n        self.down_blocks = nn.ModuleList([])\n        self.controlnet_down_blocks = nn.ModuleList([])\n\n        if isinstance(num_attention_heads, int):\n            num_attention_heads = (num_attention_heads,) * len(down_block_types)\n\n        if isinstance(cross_attention_dim, int):\n            cross_attention_dim = (cross_attention_dim,) * len(down_block_types)\n\n        if isinstance(layers_per_block, int):\n            layers_per_block = [layers_per_block] * len(down_block_types)\n\n        if isinstance(transformer_layers_per_block, int):\n            transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)\n\n        blocks_time_embed_dim = time_embed_dim\n        self.point_adapter = PointAdapter()\n\n        self.controlnet_cond_embedding = ControlNetConditioningEmbeddingSVD(\n            conditioning_embedding_channels=block_out_channels[0],\n            block_out_channels=conditioning_embedding_out_channels,\n            conditioning_channels=conditioning_channels,\n        )\n        \n        # down\n        output_channel = block_out_channels[0]\n        controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)\n        controlnet_block = zero_module(controlnet_block)\n        self.controlnet_down_blocks.append(controlnet_block)\n\n        \n        \n        for i, down_block_type in enumerate(down_block_types):\n            input_channel = output_channel\n            output_channel = block_out_channels[i]\n            is_final_block = i == len(block_out_channels) - 1\n\n            down_block = get_down_block(\n                down_block_type,\n                num_layers=layers_per_block[i],\n                transformer_layers_per_block=transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                temb_channels=blocks_time_embed_dim,\n                add_downsample=not is_final_block,\n                resnet_eps=1e-5,\n                cross_attention_dim=cross_attention_dim[i],\n                num_attention_heads=num_attention_heads[i],\n                resnet_act_fn=\"silu\",\n            )\n            self.down_blocks.append(down_block)\n            \n            for _ in range(layers_per_block[i]):\n                controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)\n                controlnet_block = zero_module(controlnet_block)\n                self.controlnet_down_blocks.append(controlnet_block)\n\n            if not is_final_block:\n                controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)\n                controlnet_block = zero_module(controlnet_block)\n                self.controlnet_down_blocks.append(controlnet_block)\n\n\n        # mid\n        mid_block_channel = block_out_channels[-1]\n        controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)\n        controlnet_block = zero_module(controlnet_block)\n        self.controlnet_mid_block = controlnet_block\n\n        \n        self.mid_block = UNetMidBlockSpatioTemporal(\n            block_out_channels[-1],\n            temb_channels=blocks_time_embed_dim,\n            transformer_layers_per_block=transformer_layers_per_block[-1],\n            cross_attention_dim=cross_attention_dim[-1],\n            num_attention_heads=num_attention_heads[-1],\n        )\n\n    @property\n    def attn_processors(self) -> Dict[str, AttentionProcessor]:\n        r\"\"\"\n        Returns:\n            `dict` of attention processors: A dictionary containing all attention processors used in the model with\n            indexed by its weight name.\n        \"\"\"\n        # set recursively\n        processors = {}\n\n        def fn_recursive_add_processors(\n            name: str,\n            module: torch.nn.Module,\n            processors: Dict[str, AttentionProcessor],\n        ):\n            if hasattr(module, \"get_processor\"):\n                processors[f\"{name}.processor\"] = module.get_processor(return_deprecated_lora=True)\n\n            for sub_name, child in module.named_children():\n                fn_recursive_add_processors(f\"{name}.{sub_name}\", child, processors)\n\n            return processors\n\n        for name, module in self.named_children():\n            fn_recursive_add_processors(name, module, processors)\n\n        return processors\n\n    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):\n        r\"\"\"\n        Sets the attention processor to use to compute attention.\n\n        Parameters:\n            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):\n                The instantiated processor class or a dictionary of processor classes that will be set as the processor\n                for **all** `Attention` layers.\n\n                If `processor` is a dict, the key needs to define the path to the corresponding cross attention\n                processor. This is strongly recommended when setting trainable attention processors.\n\n        \"\"\"\n        count = len(self.attn_processors.keys())\n\n        if isinstance(processor, dict) and len(processor) != count:\n            raise ValueError(\n                f\"A dict of processors was passed, but the number of processors {len(processor)} does not match the\"\n                f\" number of attention layers: {count}. Please make sure to pass {count} processor classes.\"\n            )\n\n        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):\n            if hasattr(module, \"set_processor\"):\n                if not isinstance(processor, dict):\n                    module.set_processor(processor)\n                else:\n                    module.set_processor(processor.pop(f\"{name}.processor\"))\n\n            for sub_name, child in module.named_children():\n                fn_recursive_attn_processor(f\"{name}.{sub_name}\", child, processor)\n\n        for name, module in self.named_children():\n            fn_recursive_attn_processor(name, module, processor)\n\n    def set_default_attn_processor(self):\n        \"\"\"\n        Disables custom attention processors and sets the default attention implementation.\n        \"\"\"\n        if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):\n            processor = AttnProcessor()\n        else:\n            raise ValueError(\n                f\"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}\"\n            )\n\n        self.set_attn_processor(processor)\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if hasattr(module, \"gradient_checkpointing\"):\n            module.gradient_checkpointing = value\n\n    # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking\n    def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:\n        \"\"\"\n        Sets the attention processor to use [feed forward\n        chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).\n\n        Parameters:\n            chunk_size (`int`, *optional*):\n                The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually\n                over each tensor of dim=`dim`.\n            dim (`int`, *optional*, defaults to `0`):\n                The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)\n                or dim=1 (sequence length).\n        \"\"\"\n        if dim not in [0, 1]:\n            raise ValueError(f\"Make sure to set `dim` to either 0 or 1, not {dim}\")\n\n        # By default chunk size is 1\n        chunk_size = chunk_size or 1\n\n        def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):\n            if hasattr(module, \"set_chunk_feed_forward\"):\n                module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)\n\n            for child in module.children():\n                fn_recursive_feed_forward(child, chunk_size, dim)\n\n        for module in self.children():\n            fn_recursive_feed_forward(module, chunk_size, dim)\n    \n    def forward(\n        self,\n        sample: torch.FloatTensor,\n        timestep: Union[torch.Tensor, float, int],\n        encoder_hidden_states: torch.Tensor,\n        added_time_ids: torch.Tensor,\n        controlnet_cond: torch.FloatTensor = None,\n        controlnet_flow: torch.FloatTensor = None,\n        traj_flow: torch.FloatTensor = None,\n        pose_image: torch.FloatTensor = None,\n        pose_latents: torch.FloatTensor = None,\n        dift_feat: torch.FloatTensor = None,\n        # ref_point = None,\n        point_list = None,\n        ref_point_emb = None,\n        image_only_indicator: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n        guess_mode: bool = False,\n        conditioning_scale: float = 1.0,\n\n\n    ) -> Union[ControlNetOutput, Tuple]:\n        r\"\"\"\n        The [`UNetSpatioTemporalConditionModel`] forward method.\n\n        Args:\n            sample (`torch.FloatTensor`):\n                The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.\n            timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.\n            encoder_hidden_states (`torch.FloatTensor`):\n                The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.\n            added_time_ids: (`torch.FloatTensor`):\n                The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal\n                embeddings and added to the time embeddings.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain\n                tuple.\n        Returns:\n            [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:\n                If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise\n                a `tuple` is returned where the first element is the sample tensor.\n        \"\"\"\n        # 1. time\n        timesteps = timestep\n        if not torch.is_tensor(timesteps):\n            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can\n            # This would be a good case for the `match` statement (Python 3.10+)\n            is_mps = sample.device.type == \"mps\"\n            if isinstance(timestep, float):\n                dtype = torch.float32 if is_mps else torch.float64\n            else:\n                dtype = torch.int32 if is_mps else torch.int64\n            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)\n        elif len(timesteps.shape) == 0:\n            timesteps = timesteps[None].to(sample.device)\n\n        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n        batch_size, num_frames = sample.shape[:2]\n        timesteps = timesteps.expand(batch_size)\n\n        t_emb = self.time_proj(timesteps)\n\n        # `Timesteps` does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=sample.dtype)\n\n        # print(t_emb.dtype)\n\n        emb = self.time_embedding(t_emb)\n\n        time_embeds = self.add_time_proj(added_time_ids.flatten())\n        time_embeds = time_embeds.reshape((batch_size, -1))\n        time_embeds = time_embeds.to(emb.dtype)\n        aug_emb = self.add_embedding(time_embeds)\n        emb = emb + aug_emb\n\n        # Flatten the batch and frames dimensions\n        # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]\n        sample = sample.flatten(0, 1)\n        # Repeat the embeddings num_video_frames times\n        # emb: [batch, channels] -> [batch * frames, channels]\n        emb = emb.repeat_interleave(num_frames, dim=0)\n        # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]\n        encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)\n\n        # 2. pre-process\n        sample = self.conv_in(sample)\n\n        bz, _, h, w = controlnet_cond.size()\n        #controlnet cond\n\n        if controlnet_flow is not None and traj_flow is not None:\n            cond_flow = self.controlnet_cond_embedding(controlnet_flow, traj_flow)\n            sample = sample + cond_flow\n\n        # get dift feat\n        adapter_state = []\n        loss_mask = None\n        if point_list is not None:\n            adapter_state, loss_mask = self.point_adapter(point_list, (w,h), ref_point_emb, pose_latents, loss_type=\"local\")\n            if not self.training:\n                for k, v in enumerate(adapter_state):\n                    adapter_state[k] = torch.cat([v] * 2, dim=0)\n\n        image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)\n\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    image_only_indicator=image_only_indicator,\n                )\n            else:\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    image_only_indicator=image_only_indicator,\n                )\n            if len(adapter_state) > 0:\n                additional_residuals = adapter_state.pop(0)\n                sample = sample + additional_residuals.flatten(0, 1)\n\n            down_block_res_samples += res_samples\n\n\n        # 4. mid\n        sample = self.mid_block(\n            hidden_states=sample,\n            temb=emb,\n            encoder_hidden_states=encoder_hidden_states,\n            image_only_indicator=image_only_indicator,\n        )\n\n        controlnet_down_block_res_samples = ()\n\n        for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):\n            down_block_res_sample = controlnet_block(down_block_res_sample)\n            controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)\n\n        down_block_res_samples = controlnet_down_block_res_samples\n\n        mid_block_res_sample = self.controlnet_mid_block(sample)\n\n        # 6. scaling\n\n        down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]\n        mid_block_res_sample = mid_block_res_sample * conditioning_scale\n\n        if not return_dict:\n            return (down_block_res_samples, mid_block_res_sample, loss_mask)\n\n        return ControlNetOutput(\n            down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample\n        )\n    \n\n    @classmethod\n    def from_unet(\n        cls,\n        unet: UNetSpatioTemporalConditionModel,\n        controlnet_conditioning_channel_order: str = \"rgb\",\n        conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),\n        load_weights_from_unet: bool = True,\n        conditioning_channels: int = 3,\n    ):\n        r\"\"\"\n        Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].\n\n        Parameters:\n            unet (`UNet2DConditionModel`):\n                The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied\n                where applicable.\n        \"\"\"\n\n        transformer_layers_per_block = (\n            unet.config.transformer_layers_per_block if \"transformer_layers_per_block\" in unet.config else 1\n        )\n        encoder_hid_dim = unet.config.encoder_hid_dim if \"encoder_hid_dim\" in unet.config else None\n        encoder_hid_dim_type = unet.config.encoder_hid_dim_type if \"encoder_hid_dim_type\" in unet.config else None\n        addition_embed_type = unet.config.addition_embed_type if \"addition_embed_type\" in unet.config else None\n        addition_time_embed_dim = (\n            unet.config.addition_time_embed_dim if \"addition_time_embed_dim\" in unet.config else None\n        )\n        print(unet.config)\n        controlnet = cls(\n            in_channels=unet.config.in_channels,\n            down_block_types=unet.config.down_block_types,\n            block_out_channels=unet.config.block_out_channels,\n            addition_time_embed_dim=unet.config.addition_time_embed_dim,\n            transformer_layers_per_block=unet.config.transformer_layers_per_block,\n            cross_attention_dim=unet.config.cross_attention_dim,\n            num_attention_heads=unet.config.num_attention_heads,\n            num_frames=unet.config.num_frames,\n            sample_size=unet.config.sample_size,  # Added based on the dict\n            layers_per_block=unet.config.layers_per_block,\n            projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,\n            conditioning_channels = conditioning_channels,\n            conditioning_embedding_out_channels = conditioning_embedding_out_channels,\n        )\n        #controlnet rgb channel order ignored, set to not makea  difference by default\n        \n        if load_weights_from_unet:\n            controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())\n            controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())\n            controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())\n\n           # if controlnet.class_embedding:\n           #     controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())\n\n            controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())\n            controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())\n\n        return controlnet\n\n    @property\n    # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors\n    def attn_processors(self) -> Dict[str, AttentionProcessor]:\n        r\"\"\"\n        Returns:\n            `dict` of attention processors: A dictionary containing all attention processors used in the model with\n            indexed by its weight name.\n        \"\"\"\n        # set recursively\n        processors = {}\n\n        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):\n            if hasattr(module, \"get_processor\"):\n                processors[f\"{name}.processor\"] = module.get_processor(return_deprecated_lora=True)\n\n            for sub_name, child in module.named_children():\n                fn_recursive_add_processors(f\"{name}.{sub_name}\", child, processors)\n\n            return processors\n\n        for name, module in self.named_children():\n            fn_recursive_add_processors(name, module, processors)\n\n        return processors\n\n    # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor\n    def set_attn_processor(\n        self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False\n    ):\n        r\"\"\"\n        Sets the attention processor to use to compute attention.\n\n        Parameters:\n            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):\n                The instantiated processor class or a dictionary of processor classes that will be set as the processor\n                for **all** `Attention` layers.\n\n                If `processor` is a dict, the key needs to define the path to the corresponding cross attention\n                processor. This is strongly recommended when setting trainable attention processors.\n\n        \"\"\"\n        count = len(self.attn_processors.keys())\n\n        if isinstance(processor, dict) and len(processor) != count:\n            raise ValueError(\n                f\"A dict of processors was passed, but the number of processors {len(processor)} does not match the\"\n                f\" number of attention layers: {count}. Please make sure to pass {count} processor classes.\"\n            )\n\n        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):\n            if hasattr(module, \"set_processor\"):\n                if not isinstance(processor, dict):\n                    module.set_processor(processor, _remove_lora=_remove_lora)\n                else:\n                    module.set_processor(processor.pop(f\"{name}.processor\"), _remove_lora=_remove_lora)\n\n            for sub_name, child in module.named_children():\n                fn_recursive_attn_processor(f\"{name}.{sub_name}\", child, processor)\n\n        for name, module in self.named_children():\n            fn_recursive_attn_processor(name, module, processor)\n\n    # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor\n    def set_default_attn_processor(self):\n        \"\"\"\n        Disables custom attention processors and sets the default attention implementation.\n        \"\"\"\n        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):\n            processor = AttnAddedKVProcessor()\n        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):\n            processor = AttnProcessor()\n        else:\n            raise ValueError(\n                f\"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}\"\n            )\n\n        self.set_attn_processor(processor, _remove_lora=True)\n\n    # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice\n    def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:\n        r\"\"\"\n        Enable sliced attention computation.\n\n        When this option is enabled, the attention module splits the input tensor in slices to compute attention in\n        several steps. This is useful for saving some memory in exchange for a small decrease in speed.\n\n        Args:\n            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `\"auto\"`):\n                When `\"auto\"`, input to the attention heads is halved, so attention is computed in two steps. If\n                `\"max\"`, maximum amount of memory is saved by running only one slice at a time. If a number is\n                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`\n                must be a multiple of `slice_size`.\n        \"\"\"\n        sliceable_head_dims = []\n\n        def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):\n            if hasattr(module, \"set_attention_slice\"):\n                sliceable_head_dims.append(module.sliceable_head_dim)\n\n            for child in module.children():\n                fn_recursive_retrieve_sliceable_dims(child)\n\n        # retrieve number of attention layers\n        for module in self.children():\n            fn_recursive_retrieve_sliceable_dims(module)\n\n        num_sliceable_layers = len(sliceable_head_dims)\n\n        if slice_size == \"auto\":\n            # half the attention head size is usually a good trade-off between\n            # speed and memory\n            slice_size = [dim // 2 for dim in sliceable_head_dims]\n        elif slice_size == \"max\":\n            # make smallest slice possible\n            slice_size = num_sliceable_layers * [1]\n\n        slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size\n\n        if len(slice_size) != len(sliceable_head_dims):\n            raise ValueError(\n                f\"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different\"\n                f\" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}.\"\n            )\n\n        for i in range(len(slice_size)):\n            size = slice_size[i]\n            dim = sliceable_head_dims[i]\n            if size is not None and size > dim:\n                raise ValueError(f\"size {size} has to be smaller or equal to {dim}.\")\n\n        # Recursively walk through all the children.\n        # Any children which exposes the set_attention_slice method\n        # gets the message\n        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):\n            if hasattr(module, \"set_attention_slice\"):\n                module.set_attention_slice(slice_size.pop())\n\n            for child in module.children():\n                fn_recursive_set_attention_slice(child, slice_size)\n\n        reversed_slice_size = list(reversed(slice_size))\n        for module in self.children():\n            fn_recursive_set_attention_slice(module, reversed_slice_size)\n\n #   def _set_gradient_checkpointing(self, module, value: bool = False) -> None:\n #       if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):\n  #          module.gradient_checkpointing = value\n\n    \ndef zero_module(module):\n    for p in module.parameters():\n        nn.init.zeros_(p)\n    return module\n"
  },
  {
    "path": "mimicmotion/modules/point_adapter.py",
    "content": "import random\nfrom typing import List\nfrom einops import rearrange, repeat\n\nimport torch\nimport torch.nn as nn\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.models.modeling_utils import ModelMixin\nimport pdb\nimport time\n\nclass MLP(nn.Module):\n    def __init__(self, in_dim, out_dim, mid_dim=128):\n        super().__init__()\n        self.mlp = nn.Sequential(\n            nn.Linear(in_dim, mid_dim, bias=True),\n            nn.SiLU(inplace=False),\n            nn.Linear(mid_dim, out_dim, bias=True)\n        )\n\n    def forward(self, x):\n        return self.mlp(x)\n\ndef vectorized_bilinear_interpolation(level_adapter_state, coords, frame_idx, interpolated_values):\n    x = coords[:, 0]\n    y = coords[:, 1]\n\n    x1 = x.floor().long()\n    y1 = y.floor().long()\n    x2 = x1 + 1\n    y2 = y1 + 1\n\n    x1 = torch.clamp(x1, 0, level_adapter_state.shape[3] - 1)\n    y1 = torch.clamp(y1, 0, level_adapter_state.shape[2] - 1)\n    x2 = torch.clamp(x2, 0, level_adapter_state.shape[3] - 1)\n    y2 = torch.clamp(y2, 0, level_adapter_state.shape[2] - 1)\n\n    x_frac = x - x1.float()\n    y_frac = y - y1.float()\n\n    w11 = (1 - x_frac) * (1 - y_frac)\n    w21 = x_frac * (1 - y_frac)\n    w12 = (1 - x_frac) * y_frac\n    w22 = x_frac * y_frac\n\n    for i, (x1_val, y1_val, x2_val, y2_val, w11_val, w21_val, w12_val, w22_val, interpolated_value) in enumerate(zip(x1, y1, x2, y2, w11, w21, w12, w22, interpolated_values)):\n        level_adapter_state[frame_idx, :, y1_val, x1_val] += interpolated_value * w11_val\n        level_adapter_state[frame_idx, :, y1_val, x2_val] += interpolated_value * w21_val\n        level_adapter_state[frame_idx, :, y2_val, x1_val] += interpolated_value * w12_val\n        level_adapter_state[frame_idx, :, y2_val, x2_val] += interpolated_value * w22_val\n\n    return level_adapter_state\n\ndef bilinear_interpolation(level_adapter_state, x, y, frame_idx, interpolated_value):\n    # note the boundary\n    x1 = int(x)\n    y1 = int(y)\n    x2 = x1 + 1\n    y2 = y1 + 1\n    x_frac = x - x1\n    y_frac = y - y1\n\n    x1, x2 = max(min(x1, level_adapter_state.shape[3] - 1), 0), max(min(x2, level_adapter_state.shape[3] - 1), 0)\n    y1, y2 = max(min(y1, level_adapter_state.shape[2] - 1), 0), max(min(y2, level_adapter_state.shape[2] - 1), 0)\n\n    w11 = (1 - x_frac) * (1 - y_frac)\n    w21 = x_frac * (1 - y_frac)\n    w12 = (1 - x_frac) * y_frac\n    w22 = x_frac * y_frac\n\n    level_adapter_state[frame_idx, :, y1, x1] += interpolated_value * w11\n    level_adapter_state[frame_idx, :, y1, x2] += interpolated_value * w21\n    level_adapter_state[frame_idx, :, y2, x1] += interpolated_value * w12\n    level_adapter_state[frame_idx, :, y2, x2] += interpolated_value * w22\n\n    return level_adapter_state\n\nclass PointAdapter(nn.Module):\n\n    def __init__(\n        self,\n        embedding_channels=1280,\n        channels=[320, 640, 1280, 1280],\n        downsample_rate=[16, 32, 64, 64],\n        mid_dim=128\n    ):\n        super().__init__()\n\n        self.model_list = nn.ModuleList()\n\n        for ch in channels:\n            self.model_list.append(MLP(embedding_channels, ch, mid_dim))\n\n        self.downsample_rate = downsample_rate\n        self.embedding_channels = embedding_channels\n        self.channels = channels\n        self.radius = 4\n\n    def generate_loss_mask(self, batch_size, point_tracker, num_frames, h, w, loss_type):\n        downsample_rate = self.downsample_rate[0]\n        level_w, level_h = w // downsample_rate, h // downsample_rate\n        if loss_type == 'global':\n            loss_mask = torch.ones((batch_size, num_frames, 4, level_h, level_w))\n        else:\n            loss_mask = torch.zeros((batch_size, num_frames, 4, level_h, level_w))\n            for batch_idx in range(batch_size):\n                for frame_idx in range(num_frames):\n                    if self.training:\n                        keypoints, subsets = point_tracker[frame_idx][\"candidate\"][batch_idx], point_tracker[frame_idx][\"subset\"][batch_idx][0]\n                    else:\n                        keypoints, subsets = point_tracker[frame_idx][\"candidate\"], point_tracker[frame_idx][\"subset\"][0]\n                        assert batch_size == 1\n                    for point_idx, (keypoint, subset) in enumerate(zip(keypoints, subsets)):\n                        if subset != -1:\n                            px, py = keypoint[0] * level_w, keypoint[1] * level_h\n\n                            x1 = int(px) - self.radius\n                            y1 = int(py) - self.radius\n                            x2 = int(px) + self.radius\n                            y2 = int(py) + self.radius\n\n                            x1, x2 = max(min(x1, level_w - 1), 0), max(min(x2, level_w - 1), 0)\n                            y1, y2 = max(min(y1, level_h - 1), 0), max(min(y2, level_h - 1), 0)\n                            loss_mask[batch_idx][frame_idx][:, y1:y2, x1:x2] = 1.0\n\n        return loss_mask\n\n    def forward(self, point_tracker, size, point_embedding, pose_latents, index_list=None, drop_rate=0.0, loss_type='global') -> List[torch.Tensor]:\n        w, h = size\n        num_frames = len(point_tracker)\n        batch_size, num_points, _ = point_embedding.shape\n\n        loss_mask = self.generate_loss_mask(batch_size, point_tracker, num_frames, h, w, loss_type)\n\n        downsample_rate = self.downsample_rate[0]\n        level_w, level_h = w // downsample_rate, h // downsample_rate\n        level_adapter_state = torch.zeros((batch_size, num_frames, self.embedding_channels, level_h, level_w)).to(point_embedding.device, dtype=point_embedding.dtype)\n        level_mask = torch.zeros((batch_size, num_frames, level_h, level_w)).to(point_embedding.device, dtype=point_embedding.dtype)\n        level_count = torch.ones((batch_size, num_frames, level_h, level_w)).to(point_embedding.device, dtype=point_embedding.dtype)\n        for batch_idx in range(batch_size):\n            for frame_idx in range(num_frames):\n                if self.training:\n                    keypoints, subsets = point_tracker[frame_idx][\"candidate\"][batch_idx], point_tracker[frame_idx][\"subset\"][batch_idx][0]\n                else:\n                    keypoints, subsets = point_tracker[frame_idx][\"candidate\"], point_tracker[frame_idx][\"subset\"][0]\n                    assert batch_size == 1\n                for point_idx, (keypoint, subset) in enumerate(zip(keypoints, subsets)):\n                    if keypoint.min() < 0:\n                        continue\n                    px, py = keypoint[0] * level_w, keypoint[1] * level_h\n                    px, py = max(min(int(px), level_w - 1), 0), max(min(int(py), level_h - 1), 0)\n                    if subset != -1:\n                        if point_embedding[batch_idx, point_idx].mean() != 0 or random.random() > drop_rate:\n                            if level_mask[batch_idx, frame_idx, py, px] !=0:\n                                level_count[batch_idx, frame_idx, py, px] +=1\n                            level_adapter_state[batch_idx, frame_idx, :, py, px] += point_embedding[batch_idx, point_idx]\n                            level_mask[batch_idx, frame_idx, py, px] = 1.0\n        \n        adapter_state = []\n        level_adapter_state = level_adapter_state/level_count.unsqueeze(2)\n        level_adapter_state = rearrange(level_adapter_state, \"b f c h w-> b f h w c\")\n        for level_idx, module in enumerate(self.model_list):\n            downsample_rate = self.downsample_rate[level_idx]\n            level_w, level_h = w // downsample_rate, h // downsample_rate\n\n            point_feat = module(level_adapter_state)\n            point_feat = point_feat * level_mask.unsqueeze(-1)\n\n            point_feat = rearrange(point_feat, \"b f h w c-> (b f) c h w\")\n            point_feat = nn.Upsample(size=(level_h, level_w), mode='bilinear')(point_feat)\n\n            temp_mask = rearrange(level_mask, \"b f h w-> (b f) h w\")\n            temp_mask = nn.Upsample(size=(level_h, level_w), mode='nearest')(temp_mask.unsqueeze(1))\n            point_feat = point_feat * temp_mask\n\n            point_feat = rearrange(point_feat, \"(b f) c h w-> b f c h w\", b=batch_size)\n            adapter_state.append(point_feat)\n        \n        return adapter_state, loss_mask\n"
  },
  {
    "path": "mimicmotion/modules/pose_net.py",
    "content": "from pathlib import Path\n\nimport einops\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.init as init\n\n\nclass PoseNet(nn.Module):\n    \"\"\"a tiny conv network for introducing pose sequence as the condition\n    \"\"\"\n    def __init__(self, noise_latent_channels=320, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        # multiple convolution layers\n        self.conv_layers = nn.Sequential(\n            nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1),\n            nn.SiLU(),\n            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1),\n            nn.SiLU(),\n\n            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1),\n            nn.SiLU(),\n            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1),\n            nn.SiLU(),\n\n            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),\n            nn.SiLU(),\n            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),\n            nn.SiLU(),\n\n            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),\n            nn.SiLU(),\n            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),\n            nn.SiLU()\n        )\n\n        # Final projection layer\n        self.final_proj = nn.Conv2d(in_channels=128, out_channels=noise_latent_channels, kernel_size=1)\n\n        # Initialize layers\n        self._initialize_weights()\n\n        self.scale = nn.Parameter(torch.ones(1) * 2)\n\n    def _initialize_weights(self):\n        \"\"\"Initialize weights with He. initialization and zero out the biases\n        \"\"\"\n        for m in self.conv_layers:\n            if isinstance(m, nn.Conv2d):\n                n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels\n                init.normal_(m.weight, mean=0.0, std=np.sqrt(2. / n))\n                if m.bias is not None:\n                    init.zeros_(m.bias)\n        init.zeros_(self.final_proj.weight)\n        if self.final_proj.bias is not None:\n            init.zeros_(self.final_proj.bias)\n\n    def forward(self, x):\n        if x.ndim == 5:\n            x = einops.rearrange(x, \"b f c h w -> (b f) c h w\")\n        x = self.conv_layers(x)\n        x = self.final_proj(x)\n\n        return x * self.scale\n\n    @classmethod\n    def from_pretrained(cls, pretrained_model_path):\n        \"\"\"load pretrained pose-net weights\n        \"\"\"\n        if not Path(pretrained_model_path).exists():\n            print(f\"There is no model file in {pretrained_model_path}\")\n        print(f\"loaded PoseNet's pretrained weights from {pretrained_model_path}.\")\n\n        state_dict = torch.load(pretrained_model_path, map_location=\"cpu\")\n        model = PoseNet(noise_latent_channels=320)\n\n        model.load_state_dict(state_dict, strict=True)\n\n        return model\n"
  },
  {
    "path": "mimicmotion/modules/unet.py",
    "content": "from dataclasses import dataclass\nfrom typing import Dict, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.loaders import UNet2DConditionLoadersMixin\nfrom diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor\nfrom diffusers.models.embeddings import TimestepEmbedding, Timesteps\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.utils import BaseOutput, logging\n\nfrom diffusers.models.unets.unet_3d_blocks import get_down_block, get_up_block, UNetMidBlockSpatioTemporal\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\n@dataclass\nclass UNetSpatioTemporalConditionOutput(BaseOutput):\n    \"\"\"\n    The output of [`UNetSpatioTemporalConditionModel`].\n\n    Args:\n        sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):\n            The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.\n    \"\"\"\n\n    sample: torch.FloatTensor = None\n\n\nclass UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):\n    r\"\"\"\n    A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state,\n    and a timestep and returns a sample shaped output.\n\n    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented\n    for all models (such as downloading or saving).\n\n    Parameters:\n        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):\n            Height and width of input/output sample.\n        in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.\n        out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.\n        down_block_types (`Tuple[str]`, *optional*, defaults to `(\"CrossAttnDownBlockSpatioTemporal\", \n            \"CrossAttnDownBlockSpatioTemporal\", \"CrossAttnDownBlockSpatioTemporal\", \"DownBlockSpatioTemporal\")`):\n            The tuple of downsample blocks to use.\n        up_block_types (`Tuple[str]`, *optional*, defaults to `(\"UpBlockSpatioTemporal\", \n            \"CrossAttnUpBlockSpatioTemporal\", \"CrossAttnUpBlockSpatioTemporal\", \"CrossAttnUpBlockSpatioTemporal\")`):\n            The tuple of upsample blocks to use.\n        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):\n            The tuple of output channels for each block.\n        addition_time_embed_dim: (`int`, defaults to 256):\n            Dimension to to encode the additional time ids.\n        projection_class_embeddings_input_dim (`int`, defaults to 768):\n            The dimension of the projection of encoded `added_time_ids`.\n        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.\n        cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):\n            The dimension of the cross attention features.\n        transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):\n            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for\n            [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], \n            [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],\n            [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].\n        num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):\n            The number of attention heads.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    @register_to_config\n    def __init__(\n            self,\n            sample_size: Optional[int] = None,\n            in_channels: int = 8,\n            out_channels: int = 4,\n            down_block_types: Tuple[str] = (\n                    \"CrossAttnDownBlockSpatioTemporal\",\n                    \"CrossAttnDownBlockSpatioTemporal\",\n                    \"CrossAttnDownBlockSpatioTemporal\",\n                    \"DownBlockSpatioTemporal\",\n            ),\n            up_block_types: Tuple[str] = (\n                    \"UpBlockSpatioTemporal\",\n                    \"CrossAttnUpBlockSpatioTemporal\",\n                    \"CrossAttnUpBlockSpatioTemporal\",\n                    \"CrossAttnUpBlockSpatioTemporal\",\n            ),\n            block_out_channels: Tuple[int] = (320, 640, 1280, 1280),\n            addition_time_embed_dim: int = 256,\n            projection_class_embeddings_input_dim: int = 768,\n            layers_per_block: Union[int, Tuple[int]] = 2,\n            cross_attention_dim: Union[int, Tuple[int]] = 1024,\n            transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,\n            num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),\n            num_frames: int = 25,\n    ):\n        super().__init__()\n\n        self.sample_size = sample_size\n\n        # Check inputs\n        if len(down_block_types) != len(up_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `down_block_types` as `up_block_types`. \" \\\n                f\"`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}.\"\n            )\n\n        if len(block_out_channels) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `block_out_channels` as `down_block_types`. \" \\\n                f\"`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `num_attention_heads` as `down_block_types`. \" \\\n                f\"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `cross_attention_dim` as `down_block_types`. \" \\\n                f\"`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `layers_per_block` as `down_block_types`. \" \\\n                f\"`layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}.\"\n            )\n\n        # input\n        self.conv_in = nn.Conv2d(\n            in_channels,\n            block_out_channels[0],\n            kernel_size=3,\n            padding=1,\n        )\n\n        # time\n        time_embed_dim = block_out_channels[0] * 4\n\n        self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)\n        timestep_input_dim = block_out_channels[0]\n\n        self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)\n\n        self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)\n        self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n\n        self.down_blocks = nn.ModuleList([])\n        self.up_blocks = nn.ModuleList([])\n\n        if isinstance(num_attention_heads, int):\n            num_attention_heads = (num_attention_heads,) * len(down_block_types)\n\n        if isinstance(cross_attention_dim, int):\n            cross_attention_dim = (cross_attention_dim,) * len(down_block_types)\n\n        if isinstance(layers_per_block, int):\n            layers_per_block = [layers_per_block] * len(down_block_types)\n\n        if isinstance(transformer_layers_per_block, int):\n            transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)\n\n        blocks_time_embed_dim = time_embed_dim\n\n        # down\n        output_channel = block_out_channels[0]\n        for i, down_block_type in enumerate(down_block_types):\n            input_channel = output_channel\n            output_channel = block_out_channels[i]\n            is_final_block = i == len(block_out_channels) - 1\n\n            down_block = get_down_block(\n                down_block_type,\n                num_layers=layers_per_block[i],\n                transformer_layers_per_block=transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                temb_channels=blocks_time_embed_dim,\n                add_downsample=not is_final_block,\n                resnet_eps=1e-5,\n                cross_attention_dim=cross_attention_dim[i],\n                num_attention_heads=num_attention_heads[i],\n                resnet_act_fn=\"silu\",\n            )\n            self.down_blocks.append(down_block)\n\n        # mid\n        self.mid_block = UNetMidBlockSpatioTemporal(\n            block_out_channels[-1],\n            temb_channels=blocks_time_embed_dim,\n            transformer_layers_per_block=transformer_layers_per_block[-1],\n            cross_attention_dim=cross_attention_dim[-1],\n            num_attention_heads=num_attention_heads[-1],\n        )\n\n        # count how many layers upsample the images\n        self.num_upsamplers = 0\n\n        # up\n        reversed_block_out_channels = list(reversed(block_out_channels))\n        reversed_num_attention_heads = list(reversed(num_attention_heads))\n        reversed_layers_per_block = list(reversed(layers_per_block))\n        reversed_cross_attention_dim = list(reversed(cross_attention_dim))\n        reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))\n\n        output_channel = reversed_block_out_channels[0]\n        for i, up_block_type in enumerate(up_block_types):\n            is_final_block = i == len(block_out_channels) - 1\n\n            prev_output_channel = output_channel\n            output_channel = reversed_block_out_channels[i]\n            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]\n\n            # add upsample block for all BUT final layer\n            if not is_final_block:\n                add_upsample = True\n                self.num_upsamplers += 1\n            else:\n                add_upsample = False\n\n            up_block = get_up_block(\n                up_block_type,\n                num_layers=reversed_layers_per_block[i] + 1,\n                transformer_layers_per_block=reversed_transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                prev_output_channel=prev_output_channel,\n                temb_channels=blocks_time_embed_dim,\n                add_upsample=add_upsample,\n                resnet_eps=1e-5,\n                resolution_idx=i,\n                cross_attention_dim=reversed_cross_attention_dim[i],\n                num_attention_heads=reversed_num_attention_heads[i],\n                resnet_act_fn=\"silu\",\n            )\n            self.up_blocks.append(up_block)\n            prev_output_channel = output_channel\n\n        # out\n        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)\n        self.conv_act = nn.SiLU()\n\n        self.conv_out = nn.Conv2d(\n            block_out_channels[0],\n            out_channels,\n            kernel_size=3,\n            padding=1,\n        )\n\n    @property\n    def attn_processors(self) -> Dict[str, AttentionProcessor]:\n        r\"\"\"\n        Returns:\n            `dict` of attention processors: A dictionary containing all attention processors used in the model with\n            indexed by its weight name.\n        \"\"\"\n        # set recursively\n        processors = {}\n\n        def fn_recursive_add_processors(\n                name: str,\n                module: torch.nn.Module,\n                processors: Dict[str, AttentionProcessor],\n        ):\n            if hasattr(module, \"get_processor\"):\n                processors[f\"{name}.processor\"] = module.get_processor(return_deprecated_lora=True)\n\n            for sub_name, child in module.named_children():\n                fn_recursive_add_processors(f\"{name}.{sub_name}\", child, processors)\n\n            return processors\n\n        for name, module in self.named_children():\n            fn_recursive_add_processors(name, module, processors)\n\n        return processors\n\n    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):\n        r\"\"\"\n        Sets the attention processor to use to compute attention.\n\n        Parameters:\n            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):\n                The instantiated processor class or a dictionary of processor classes that will be set as the processor\n                for **all** `Attention` layers.\n\n                If `processor` is a dict, the key needs to define the path to the corresponding cross attention\n                processor. This is strongly recommended when setting trainable attention processors.\n\n        \"\"\"\n        count = len(self.attn_processors.keys())\n\n        if isinstance(processor, dict) and len(processor) != count:\n            raise ValueError(\n                f\"A dict of processors was passed, but the number of processors {len(processor)} does not match the\"\n                f\" number of attention layers: {count}. Please make sure to pass {count} processor classes.\"\n            )\n\n        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):\n            if hasattr(module, \"set_processor\"):\n                if not isinstance(processor, dict):\n                    module.set_processor(processor)\n                else:\n                    module.set_processor(processor.pop(f\"{name}.processor\"))\n\n            for sub_name, child in module.named_children():\n                fn_recursive_attn_processor(f\"{name}.{sub_name}\", child, processor)\n\n        for name, module in self.named_children():\n            fn_recursive_attn_processor(name, module, processor)\n\n    def set_default_attn_processor(self):\n        \"\"\"\n        Disables custom attention processors and sets the default attention implementation.\n        \"\"\"\n        if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):\n            processor = AttnProcessor()\n        else:\n            raise ValueError(\n                f\"Cannot call `set_default_attn_processor` \" \\\n                f\"when attention processors are of type {next(iter(self.attn_processors.values()))}\"\n            )\n\n        self.set_attn_processor(processor)\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if hasattr(module, \"gradient_checkpointing\"):\n            module.gradient_checkpointing = value\n\n    # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking\n    def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:\n        \"\"\"\n        Sets the attention processor to use [feed forward\n        chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).\n\n        Parameters:\n            chunk_size (`int`, *optional*):\n                The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually\n                over each tensor of dim=`dim`.\n            dim (`int`, *optional*, defaults to `0`):\n                The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)\n                or dim=1 (sequence length).\n        \"\"\"\n        if dim not in [0, 1]:\n            raise ValueError(f\"Make sure to set `dim` to either 0 or 1, not {dim}\")\n\n        # By default chunk size is 1\n        chunk_size = chunk_size or 1\n\n        def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):\n            if hasattr(module, \"set_chunk_feed_forward\"):\n                module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)\n\n            for child in module.children():\n                fn_recursive_feed_forward(child, chunk_size, dim)\n\n        for module in self.children():\n            fn_recursive_feed_forward(module, chunk_size, dim)\n\n    def forward(\n            self,\n            sample: torch.FloatTensor,\n            timestep: Union[torch.Tensor, float, int],\n            encoder_hidden_states: torch.Tensor,\n            added_time_ids: torch.Tensor,\n            pose_latents: torch.Tensor = None,\n            down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,\n            mid_block_additional_residual: Optional[torch.Tensor] = None,\n            image_only_indicator: bool = False,\n            return_dict: bool = True,\n    ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:\n        r\"\"\"\n        The [`UNetSpatioTemporalConditionModel`] forward method.\n\n        Args:\n            sample (`torch.FloatTensor`):\n                The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.\n            timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.\n            encoder_hidden_states (`torch.FloatTensor`):\n                The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.\n            added_time_ids: (`torch.FloatTensor`):\n                The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal\n                embeddings and added to the time embeddings.\n            pose_latents: (`torch.FloatTensor`):\n                The additional latents for pose sequences.\n            image_only_indicator (`bool`, *optional*, defaults to `False`):\n                Whether or not training with all images.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] \n                instead of a plain tuple.\n        Returns:\n            [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:\n                If `return_dict` is True, \n                an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, \n                otherwise a `tuple` is returned where the first element is the sample tensor.\n        \"\"\"\n        # 1. time\n        timesteps = timestep\n        if not torch.is_tensor(timesteps):\n            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can\n            # This would be a good case for the `match` statement (Python 3.10+)\n            is_mps = sample.device.type == \"mps\"\n            if isinstance(timestep, float):\n                dtype = torch.float32 if is_mps else torch.float64\n            else:\n                dtype = torch.int32 if is_mps else torch.int64\n            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)\n        elif len(timesteps.shape) == 0:\n            timesteps = timesteps[None].to(sample.device)\n\n        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n        batch_size, num_frames = sample.shape[:2]\n        timesteps = timesteps.expand(batch_size)\n\n        t_emb = self.time_proj(timesteps)\n\n        # `Timesteps` does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=sample.dtype)\n\n        emb = self.time_embedding(t_emb)\n\n        time_embeds = self.add_time_proj(added_time_ids.flatten())\n        time_embeds = time_embeds.reshape((batch_size, -1))\n        time_embeds = time_embeds.to(emb.dtype)\n        aug_emb = self.add_embedding(time_embeds)\n        emb = emb + aug_emb\n\n        # Flatten the batch and frames dimensions\n        # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]\n        sample = sample.flatten(0, 1)\n        # Repeat the embeddings num_video_frames times\n        # emb: [batch, channels] -> [batch * frames, channels]\n        emb = emb.repeat_interleave(num_frames, dim=0)\n        # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]\n        encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)\n\n        # 2. pre-process\n        sample = self.conv_in(sample)\n        if pose_latents is not None:\n            sample = sample + pose_latents\n\n        image_only_indicator = torch.ones(batch_size, num_frames, dtype=sample.dtype, device=sample.device) \\\n            if image_only_indicator else torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)\n\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    image_only_indicator=image_only_indicator,\n                )\n            else:\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    image_only_indicator=image_only_indicator,\n                )\n\n            down_block_res_samples += res_samples\n\n            new_down_block_res_samples = ()\n            if down_block_additional_residuals is not None:\n                for down_block_res_sample, down_block_additional_residual in zip(\n                    down_block_res_samples, down_block_additional_residuals\n                ):\n                    down_block_res_sample = down_block_res_sample + down_block_additional_residual\n                    new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)\n\n                down_block_res_samples = new_down_block_res_samples\n\n        # 4. mid\n        sample = self.mid_block(\n            hidden_states=sample,\n            temb=emb,\n            encoder_hidden_states=encoder_hidden_states,\n            image_only_indicator=image_only_indicator,\n        )\n        if mid_block_additional_residual is not None:\n            sample = sample + mid_block_additional_residual\n\n        # 5. up\n        for i, upsample_block in enumerate(self.up_blocks):\n            res_samples = down_block_res_samples[-len(upsample_block.resnets):]\n            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]\n\n            if hasattr(upsample_block, \"has_cross_attention\") and upsample_block.has_cross_attention:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    encoder_hidden_states=encoder_hidden_states,\n                    image_only_indicator=image_only_indicator,\n                )\n            else:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    image_only_indicator=image_only_indicator,\n                )\n\n        # 6. post-process\n        sample = self.conv_norm_out(sample)\n        sample = self.conv_act(sample)\n        sample = self.conv_out(sample)\n\n        # 7. Reshape back to original shape\n        sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])\n\n        if not return_dict:\n            return (sample,)\n\n        return UNetSpatioTemporalConditionOutput(sample=sample)\n"
  },
  {
    "path": "mimicmotion/pipelines/pipeline_ctrl.py",
    "content": "import inspect\nfrom dataclasses import dataclass\nfrom typing import Callable, Dict, List, Optional, Union\n\nimport PIL.Image\nimport einops\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\n\nfrom diffusers.image_processor import VaeImageProcessor, PipelineImageInput\nfrom diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps\nfrom diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion \\\n    import _resize_with_antialiasing, _append_dims\nfrom diffusers.schedulers import EulerDiscreteScheduler\nfrom diffusers.utils import BaseOutput, logging\nfrom diffusers.utils.torch_utils import is_compiled_module, randn_tensor\nfrom transformers import CLIPImageProcessor, CLIPVisionModelWithProjection\nfrom mimicmotion.modules.controlnet import ControlNetSVDModel\n\n\nfrom ..modules.pose_net import PoseNet\nimport pdb\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\ndef _append_dims(x, target_dims):\n    \"\"\"Appends dimensions to the end of a tensor until it has target_dims dimensions.\"\"\"\n    dims_to_append = target_dims - x.ndim\n    if dims_to_append < 0:\n        raise ValueError(f\"input has {x.ndim} dims but target_dims is {target_dims}, which is less\")\n    return x[(...,) + (None,) * dims_to_append]\n\n\n# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid\ndef tensor2vid(video: torch.Tensor, processor: \"VaeImageProcessor\", output_type: str = \"np\"):\n    batch_size, channels, num_frames, height, width = video.shape\n    outputs = []\n    for batch_idx in range(batch_size):\n        batch_vid = video[batch_idx].permute(1, 0, 2, 3)\n        batch_output = processor.postprocess(batch_vid, output_type)\n\n        outputs.append(batch_output)\n\n    if output_type == \"np\":\n        outputs = np.stack(outputs)\n\n    elif output_type == \"pt\":\n        outputs = torch.stack(outputs)\n\n    elif not output_type == \"pil\":\n        raise ValueError(f\"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]\")\n\n    return outputs\n\n\n@dataclass\nclass MimicMotionPipelineOutput(BaseOutput):\n    r\"\"\"\n    Output class for mimicmotion pipeline.\n\n    Args:\n        frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.Tensor`]):\n            List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,\n            num_frames, height, width, num_channels)`.\n    \"\"\"\n\n    frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.Tensor]\n\n\nclass Ctrl_Pipeline(DiffusionPipeline):\n    r\"\"\"\n    Pipeline to generate video from an input image using Stable Video Diffusion.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    Args:\n        vae ([`AutoencoderKLTemporalDecoder`]):\n            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.\n        image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):\n            Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K]\n            (https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).\n        unet ([`UNetSpatioTemporalConditionModel`]):\n            A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.\n        scheduler ([`EulerDiscreteScheduler`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents.\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\n            A `CLIPImageProcessor` to extract features from generated images.\n        pose_net ([`PoseNet`]):\n            A `` to inject pose signals into unet.\n    \"\"\"\n\n    model_cpu_offload_seq = \"image_encoder->unet->vae\"\n    _callback_tensor_inputs = [\"latents\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKLTemporalDecoder,\n        image_encoder: CLIPVisionModelWithProjection,\n        unet: UNetSpatioTemporalConditionModel,\n        controlnet :ControlNetSVDModel,\n        scheduler: EulerDiscreteScheduler,\n        feature_extractor: CLIPImageProcessor,\n        pose_net: PoseNet,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            image_encoder=image_encoder,\n            unet=unet,\n            controlnet=controlnet,\n            scheduler=scheduler,\n            feature_extractor=feature_extractor,\n            pose_net=pose_net,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n\n    def _encode_image(\n        self, \n        image: PipelineImageInput, \n        device: Union[str, torch.device], \n        num_videos_per_prompt: int, \n        do_classifier_free_guidance: bool):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.image_processor.pil_to_numpy(image)\n            image = self.image_processor.numpy_to_pt(image)\n\n            # We normalize the image before resizing to match with the original implementation.\n            # Then we unnormalize it after resizing.\n            image = image * 2.0 - 1.0\n            image = _resize_with_antialiasing(image, (224, 224))\n            image = (image + 1.0) / 2.0\n\n            # Normalize the image with for CLIP input\n            image = self.feature_extractor(\n                images=image,\n                do_normalize=True,\n                do_center_crop=False,\n                do_resize=False,\n                do_rescale=False,\n                return_tensors=\"pt\",\n            ).pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        image_embeddings = self.image_encoder(image).image_embeds\n        image_embeddings = image_embeddings.unsqueeze(1)\n\n        # duplicate image embeddings for each generation per prompt, using mps friendly method\n        bs_embed, seq_len, _ = image_embeddings.shape\n        image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)\n        image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)\n\n        if do_classifier_free_guidance:\n            negative_image_embeddings = torch.zeros_like(image_embeddings)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])\n\n        return image_embeddings\n\n    def _encode_pose_image(\n        self, \n        pose_image: torch.Tensor, \n        do_classifier_free_guidance: bool,\n    ):\n        # Get latents_pose\n        pose_latents = self.pose_net(pose_image)\n\n        if do_classifier_free_guidance:\n            negative_pose_latents = torch.zeros_like(pose_latents)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            pose_latents = torch.cat([negative_pose_latents, pose_latents])\n\n        return pose_latents\n    \n    def _encode_vae_image(\n        self,\n        image: torch.Tensor,\n        device: Union[str, torch.device],\n        num_videos_per_prompt: int,\n        do_classifier_free_guidance: bool,\n    ):\n        image = image.to(device=device, dtype=self.vae.dtype)\n        image_latents = self.vae.encode(image).latent_dist.mode()\n\n        if do_classifier_free_guidance:\n            negative_image_latents = torch.zeros_like(image_latents)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            image_latents = torch.cat([negative_image_latents, image_latents])\n\n        # duplicate image_latents for each generation per prompt, using mps friendly method\n        image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)\n\n        return image_latents\n\n    def _get_add_time_ids(\n        self,\n        fps: int,\n        motion_bucket_id: int,\n        noise_aug_strength: float,\n        dtype: torch.dtype,\n        batch_size: int,\n        num_videos_per_prompt: int,\n        do_classifier_free_guidance: bool,\n    ):\n        add_time_ids = [fps, motion_bucket_id, noise_aug_strength]\n\n        passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)\n        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features\n\n        if expected_add_embed_dim != passed_add_embed_dim:\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, \" \\\n                f\"but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. \" \\\n                f\"Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`.\"\n            )\n\n        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)\n        add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)\n\n        if do_classifier_free_guidance:\n            add_time_ids = torch.cat([add_time_ids, add_time_ids])\n\n        return add_time_ids\n\n    def decode_latents(\n        self, \n        latents: torch.Tensor, \n        num_frames: int, \n        decode_chunk_size: int = 8):\n        # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]\n        latents = latents.flatten(0, 1)\n\n        latents = 1 / self.vae.config.scaling_factor * latents\n\n        forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward\n        accepts_num_frames = \"num_frames\" in set(inspect.signature(forward_vae_fn).parameters.keys())\n\n        # decode decode_chunk_size frames at a time to avoid OOM\n        frames = []\n        for i in range(0, latents.shape[0], decode_chunk_size):\n            num_frames_in = latents[i: i + decode_chunk_size].shape[0]\n            decode_kwargs = {}\n            if accepts_num_frames:\n                # we only pass num_frames_in if it's expected\n                decode_kwargs[\"num_frames\"] = num_frames_in\n\n            frame = self.vae.decode(latents[i: i + decode_chunk_size], **decode_kwargs).sample\n            frames.append(frame.cpu())\n        frames = torch.cat(frames, dim=0)\n\n        # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]\n        frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)\n\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        frames = frames.float()\n        return frames\n\n    def check_inputs(self, image, height, width):\n        if (\n                not isinstance(image, torch.Tensor)\n                and not isinstance(image, PIL.Image.Image)\n                and not isinstance(image, list)\n        ):\n            raise ValueError(\n                \"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is\"\n                f\" {type(image)}\"\n            )\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n    def prepare_latents(\n        self,\n        batch_size: int,\n        num_frames: int,\n        num_channels_latents: int,\n        height: int,\n        width: int,\n        dtype: torch.dtype,\n        device: Union[str, torch.device],\n        generator: torch.Generator,\n        latents: Optional[torch.Tensor] = None,\n    ):\n        shape = (\n            batch_size,\n            num_frames,\n            num_channels_latents // 2,\n            height // self.vae_scale_factor,\n            width // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        if isinstance(self.guidance_scale, (int, float)):\n            return self.guidance_scale > 1\n        return self.guidance_scale.max() > 1\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],\n        image_pose: Union[torch.FloatTensor],\n        controlnet_flow: Union[torch.FloatTensor],\n        controlnet_image: Union[torch.FloatTensor],\n        traj_flow: Union[torch.FloatTensor],\n        point_list,\n        dift_feats,\n        height: int = 576,\n        width: int = 1024,\n        num_frames: Optional[int] = None,\n        tile_size: Optional[int] = 16,\n        tile_overlap: Optional[int] = 4,\n        num_inference_steps: int = 25,\n        min_guidance_scale: float = 1.0,\n        max_guidance_scale: float = 3.0,\n        fps: int = 7,\n        controlnet_cond_scale: float = 1.0,\n        motion_bucket_id: int = 127,\n        noise_aug_strength: float = 0.02,\n        image_only_indicator: bool = False,\n        decode_chunk_size: Optional[int] = None,\n        num_videos_per_prompt: Optional[int] = 1,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        output_type: Optional[str] = \"pil\",\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        return_dict: bool = True,\n        device: Union[str, torch.device] =None,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):\n                Image or images to guide image generation. If you provide a tensor, it needs to be compatible with\n                [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/\n                feature_extractor/preprocessor_config.json).\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The width in pixels of the generated image.\n            num_frames (`int`, *optional*):\n                The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` \n                and to 25 for `stable-video-diffusion-img2vid-xt`\n            num_inference_steps (`int`, *optional*, defaults to 25):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference. This parameter is modulated by `strength`.\n            min_guidance_scale (`float`, *optional*, defaults to 1.0):\n                The minimum guidance scale. Used for the classifier free guidance with first frame.\n            max_guidance_scale (`float`, *optional*, defaults to 3.0):\n                The maximum guidance scale. Used for the classifier free guidance with last frame.\n            fps (`int`, *optional*, defaults to 7):\n                Frames per second.The rate at which the generated images shall be exported to a video after generation.\n                Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.\n            motion_bucket_id (`int`, *optional*, defaults to 127):\n                The motion bucket ID. Used as conditioning for the generation. \n                The higher the number the more motion will be in the video.\n            noise_aug_strength (`float`, *optional*, defaults to 0.02):\n                The amount of noise added to the init image, \n                the higher it is the less the video will look like the init image. Increase it for more motion.\n            image_only_indicator (`bool`, *optional*, defaults to False):\n                Whether to treat the inputs as batch of images instead of videos.\n            decode_chunk_size (`int`, *optional*):\n                The number of frames to decode at a time.The higher the chunk size, the higher the temporal consistency\n                between frames, but also the higher the memory consumption. \n                By default, the decoder will decode all frames at once for maximal quality. \n                Reduce `decode_chunk_size` to reduce memory usage.\n            num_videos_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            device:\n                On which device the pipeline runs on.\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, \n                [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list of list with the generated frames.\n\n        Examples:\n\n        ```py\n        from diffusers import StableVideoDiffusionPipeline\n        from diffusers.utils import load_image, export_to_video\n\n        pipe = StableVideoDiffusionPipeline.from_pretrained(\n            \"stabilityai/stable-video-diffusion-img2vid-xt\", torch_dtype=torch.float16, variant=\"fp16\")\n        pipe.to(\"cuda\")\n\n        image = load_image(\n        \"https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200\")\n        image = image.resize((1024, 576))\n\n        frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]\n        export_to_video(frames, \"generated.mp4\", fps=7)\n        ```\n        \"\"\"\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        num_frames = num_frames if num_frames is not None else self.unet.config.num_frames\n        decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(image, height, width)\n\n        # 2. Define call parameters\n        if isinstance(image, PIL.Image.Image):\n            batch_size = 1\n        elif isinstance(image, list):\n            batch_size = len(image)\n        else:\n            batch_size = image.shape[0]\n        device = device if device is not None else self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        self._guidance_scale = max_guidance_scale\n\n        # 3. Encode input image\n        self.image_encoder.to(device)\n        image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)\n        self.image_encoder.cpu()\n\n        # NOTE: Stable Diffusion Video was conditioned on fps - 1, which\n        # is why it is reduced here.\n        fps = fps - 1\n\n        # 4. Encode input image using VAE\n        image = self.image_processor.preprocess(image, height=height, width=width).to(device)\n        noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)\n        image = image + noise_aug_strength * noise\n\n        self.vae.to(device)\n        image_latents = self._encode_vae_image(\n            image,\n            device=device,\n            num_videos_per_prompt=num_videos_per_prompt,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n        )\n        image_latents = image_latents.to(image_embeddings.dtype)\n        self.vae.cpu()\n\n        # Repeat the image latents for each frame so we can concatenate them with the noise\n        # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]\n        image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)\n\n        pose_latents = self._encode_pose_image(\n            image_pose, do_classifier_free_guidance=self.do_classifier_free_guidance,\n        ).to(device)\n        pose_latents = einops.rearrange(pose_latents, '(b f) c h w -> b f c h w', f=num_frames)\n\n        # #### get point feature ##################################\n        bz, _, w, h = controlnet_image.size()\n        # #### get ref point feature ##############################\n        ref_point_emb = []\n        tgt_point_emb = []\n        ref_point = point_list[0]\n        assert bz == 1\n        rescale_ref_dift = nn.Upsample(size=(h, w), mode='bilinear')(dift_feats[0].squeeze(2))\n        for b in range(bz):\n            init_embedding = torch.zeros((18, 1280))\n            for point_idx, (keypoint, subset) in enumerate(zip(ref_point[\"candidate\"], ref_point[\"subset\"][0])):\n                px, py = keypoint[0] * w, keypoint[1] * h\n                point_x, point_y = max(min(int(px), w - 1), 0), max(min(int(py), h - 1), 0)\n                # point_x, point_y = max(min(w, w * int(keypoint[0]) - 1), 0), max(min(h, h * int(keypoint[1]) - 1), 0)\n                if subset!=-1:\n                    # point_x, point_y = int(torch.floor(x)), int(torch.floor(y))\n                    point_embedding = rescale_ref_dift[b, :, point_y, point_x]\n                    init_embedding[point_idx] = point_embedding\n            ref_point_emb.append(init_embedding)\n        ref_point_emb = torch.stack(ref_point_emb).to(device=self.controlnet.device)\n\n        # 5. Get Added Time IDs\n        added_time_ids = self._get_add_time_ids(\n            fps,\n            motion_bucket_id,\n            noise_aug_strength,\n            image_embeddings.dtype,\n            batch_size,\n            num_videos_per_prompt,\n            self.do_classifier_free_guidance,\n        )\n        added_time_ids = added_time_ids.to(device)\n\n        # 4. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None)\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_videos_per_prompt,\n            tile_size,\n            num_channels_latents,\n            height,\n            width,\n            image_embeddings.dtype,\n            device,\n            generator,\n            latents,\n        ) # [1, 72, 4, h//8, w//8]\n        latents = latents.repeat(1, num_frames // tile_size + 1, 1, 1, 1)[:, :num_frames] # [1, num_frames, 4, h//8, w//8]\n\n        ref_point_emb = ref_point_emb.to(device, latents.dtype)\n\n        controlnet_flow = torch.cat([controlnet_flow] * 2) if self.do_classifier_free_guidance else controlnet_flow\n        controlnet_flow = controlnet_flow.to(device, latents.dtype)\n\n        traj_flow = torch.cat([traj_flow] * 2) if self.do_classifier_free_guidance else traj_flow\n        traj_flow = traj_flow.to(device, latents.dtype)\n        # pdb.set_trace()\n        ctrl_image_pose = image_pose.unsqueeze(0)\n        neg_ctrl_image_pose = torch.zeros_like(ctrl_image_pose)\n        ctrl_image_pose = torch.cat([neg_ctrl_image_pose, ctrl_image_pose]) if self.do_classifier_free_guidance else ctrl_image_pose\n        ctrl_image_pose = ctrl_image_pose.to(device, latents.dtype)\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, 0.0)\n\n        # 7. Prepare guidance scale\n        guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)\n        guidance_scale = guidance_scale.to(device, latents.dtype)\n        guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)\n        guidance_scale = _append_dims(guidance_scale, latents.ndim)\n\n        self._guidance_scale = guidance_scale\n\n        # 8. Denoising loop\n        self._num_timesteps = len(timesteps)\n        indices = [[0, *range(i + 1, min(i + tile_size, num_frames))] for i in\n                   range(0, num_frames - tile_size + 1, tile_size - tile_overlap)]\n        if indices[-1][-1] < num_frames - 1:\n            indices.append([0, *range(num_frames - tile_size + 1, num_frames)])\n\n        self.pose_net.to(device)\n        self.unet.to(device)\n        self.controlnet.to(device)\n        self.controlnet.eval()\n        with torch.cuda.device(device):\n            torch.cuda.empty_cache()\n\n        with self.progress_bar(total=len(timesteps) * len(indices)) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # Concatenate image_latents over channels dimension\n                latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)\n\n                # predict the noise residual\n                noise_pred = torch.zeros_like(image_latents)\n                noise_pred_cnt = image_latents.new_zeros((num_frames,))\n                weight = (torch.arange(tile_size, device=device) + 0.5) * 2. / tile_size\n                weight = torch.minimum(weight, 2 - weight)\n                for idx in indices:\n                    flow_idx = [i-1 for i in idx[1:]]\n                    point_input = [point_list[i] for i in idx]\n                    down_block_res_samples, mid_block_res_sample, _ = self.controlnet(\n                        latent_model_input[:, idx],\n                        t,\n                        encoder_hidden_states=image_embeddings,\n                        controlnet_cond=controlnet_image,\n                        controlnet_flow=controlnet_flow[:, flow_idx],\n                        traj_flow=traj_flow[:, flow_idx],\n                        pose_latents=pose_latents[:, idx].flatten(0, 1),\n                        # pose_image=ctrl_image_pose[:, idx],\n                        point_list = point_input,\n                        dift_feat = dift_feats[1],\n                        ref_point_emb = ref_point_emb,\n                        added_time_ids=added_time_ids,\n                        conditioning_scale=controlnet_cond_scale,\n                        guess_mode=False,\n                        return_dict=False,\n                    )\n\n                    _noise_pred = self.unet(\n                        latent_model_input[:, idx],\n                        t,\n                        encoder_hidden_states=image_embeddings,\n                        added_time_ids=added_time_ids,\n                        pose_latents=pose_latents[:, idx].flatten(0, 1),\n                        down_block_additional_residuals=down_block_res_samples,\n                        mid_block_additional_residual=mid_block_res_sample,\n                        image_only_indicator=image_only_indicator,\n                        return_dict=False,\n                    )[0]\n                    noise_pred[:, idx] += _noise_pred * weight[:, None, None, None]\n\n                    noise_pred_cnt[idx] += weight\n                    progress_bar.update()\n                noise_pred.div_(noise_pred_cnt[:, None, None, None])\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n\n        self.pose_net.cpu()\n        self.unet.cpu()\n\n        if not output_type == \"latent\":\n            self.vae.decoder.to(device)\n            frames = self.decode_latents(latents, num_frames, decode_chunk_size)\n            frames = tensor2vid(frames, self.image_processor, output_type=output_type)\n        else:\n            frames = latents\n\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return frames\n\n        return MimicMotionPipelineOutput(frames=frames)\n"
  },
  {
    "path": "mimicmotion/pipelines/pipeline_mimicmotion.py",
    "content": "import inspect\nfrom dataclasses import dataclass\nfrom typing import Callable, Dict, List, Optional, Union\n\nimport PIL.Image\nimport einops\nimport numpy as np\nimport torch\nfrom diffusers.image_processor import VaeImageProcessor, PipelineImageInput\nfrom diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps\nfrom diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion \\\n    import _resize_with_antialiasing, _append_dims\nfrom diffusers.schedulers import EulerDiscreteScheduler\nfrom diffusers.utils import BaseOutput, logging\nfrom diffusers.utils.torch_utils import is_compiled_module, randn_tensor\nfrom transformers import CLIPImageProcessor, CLIPVisionModelWithProjection\n\nfrom ..modules.pose_net import PoseNet\nimport pdb\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\ndef _append_dims(x, target_dims):\n    \"\"\"Appends dimensions to the end of a tensor until it has target_dims dimensions.\"\"\"\n    dims_to_append = target_dims - x.ndim\n    if dims_to_append < 0:\n        raise ValueError(f\"input has {x.ndim} dims but target_dims is {target_dims}, which is less\")\n    return x[(...,) + (None,) * dims_to_append]\n\n\n# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid\ndef tensor2vid(video: torch.Tensor, processor: \"VaeImageProcessor\", output_type: str = \"np\"):\n    batch_size, channels, num_frames, height, width = video.shape\n    outputs = []\n    for batch_idx in range(batch_size):\n        batch_vid = video[batch_idx].permute(1, 0, 2, 3)\n        batch_output = processor.postprocess(batch_vid, output_type)\n\n        outputs.append(batch_output)\n\n    if output_type == \"np\":\n        outputs = np.stack(outputs)\n\n    elif output_type == \"pt\":\n        outputs = torch.stack(outputs)\n\n    elif not output_type == \"pil\":\n        raise ValueError(f\"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]\")\n\n    return outputs\n\n\n@dataclass\nclass MimicMotionPipelineOutput(BaseOutput):\n    r\"\"\"\n    Output class for mimicmotion pipeline.\n\n    Args:\n        frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.Tensor`]):\n            List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,\n            num_frames, height, width, num_channels)`.\n    \"\"\"\n\n    frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.Tensor]\n\n\nclass MimicMotionPipeline(DiffusionPipeline):\n    r\"\"\"\n    Pipeline to generate video from an input image using Stable Video Diffusion.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    Args:\n        vae ([`AutoencoderKLTemporalDecoder`]):\n            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.\n        image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):\n            Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K]\n            (https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).\n        unet ([`UNetSpatioTemporalConditionModel`]):\n            A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.\n        scheduler ([`EulerDiscreteScheduler`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents.\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\n            A `CLIPImageProcessor` to extract features from generated images.\n        pose_net ([`PoseNet`]):\n            A `` to inject pose signals into unet.\n    \"\"\"\n\n    model_cpu_offload_seq = \"image_encoder->unet->vae\"\n    _callback_tensor_inputs = [\"latents\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKLTemporalDecoder,\n        image_encoder: CLIPVisionModelWithProjection,\n        unet: UNetSpatioTemporalConditionModel,\n        scheduler: EulerDiscreteScheduler,\n        feature_extractor: CLIPImageProcessor,\n        pose_net: PoseNet,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            image_encoder=image_encoder,\n            unet=unet,\n            scheduler=scheduler,\n            feature_extractor=feature_extractor,\n            pose_net=pose_net,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n\n    def _encode_image(\n        self, \n        image: PipelineImageInput, \n        device: Union[str, torch.device], \n        num_videos_per_prompt: int, \n        do_classifier_free_guidance: bool):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.image_processor.pil_to_numpy(image)\n            image = self.image_processor.numpy_to_pt(image)\n\n            # We normalize the image before resizing to match with the original implementation.\n            # Then we unnormalize it after resizing.\n            image = image * 2.0 - 1.0\n            image = _resize_with_antialiasing(image, (224, 224))\n            image = (image + 1.0) / 2.0\n\n            # Normalize the image with for CLIP input\n            image = self.feature_extractor(\n                images=image,\n                do_normalize=True,\n                do_center_crop=False,\n                do_resize=False,\n                do_rescale=False,\n                return_tensors=\"pt\",\n            ).pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        image_embeddings = self.image_encoder(image).image_embeds\n        image_embeddings = image_embeddings.unsqueeze(1)\n\n        # duplicate image embeddings for each generation per prompt, using mps friendly method\n        bs_embed, seq_len, _ = image_embeddings.shape\n        image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)\n        image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)\n\n        if do_classifier_free_guidance:\n            negative_image_embeddings = torch.zeros_like(image_embeddings)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])\n\n        return image_embeddings\n\n    def _encode_pose_image(\n        self, \n        pose_image: torch.Tensor, \n        do_classifier_free_guidance: bool,\n    ):\n        # Get latents_pose\n        pose_latents = self.pose_net(pose_image)\n\n        if do_classifier_free_guidance:\n            negative_pose_latents = torch.zeros_like(pose_latents)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            pose_latents = torch.cat([negative_pose_latents, pose_latents])\n\n        return pose_latents\n    \n    def _encode_vae_image(\n        self,\n        image: torch.Tensor,\n        device: Union[str, torch.device],\n        num_videos_per_prompt: int,\n        do_classifier_free_guidance: bool,\n    ):\n        image = image.to(device=device, dtype=self.vae.dtype)\n        image_latents = self.vae.encode(image).latent_dist.mode()\n\n        if do_classifier_free_guidance:\n            negative_image_latents = torch.zeros_like(image_latents)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            image_latents = torch.cat([negative_image_latents, image_latents])\n\n        # duplicate image_latents for each generation per prompt, using mps friendly method\n        image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)\n\n        return image_latents\n\n    def _get_add_time_ids(\n        self,\n        fps: int,\n        motion_bucket_id: int,\n        noise_aug_strength: float,\n        dtype: torch.dtype,\n        batch_size: int,\n        num_videos_per_prompt: int,\n        do_classifier_free_guidance: bool,\n    ):\n        add_time_ids = [fps, motion_bucket_id, noise_aug_strength]\n\n        passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)\n        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features\n\n        if expected_add_embed_dim != passed_add_embed_dim:\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, \" \\\n                f\"but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. \" \\\n                f\"Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`.\"\n            )\n\n        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)\n        add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)\n\n        if do_classifier_free_guidance:\n            add_time_ids = torch.cat([add_time_ids, add_time_ids])\n\n        return add_time_ids\n\n    def decode_latents(\n        self, \n        latents: torch.Tensor, \n        num_frames: int, \n        decode_chunk_size: int = 8):\n        # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]\n        latents = latents.flatten(0, 1)\n\n        latents = 1 / self.vae.config.scaling_factor * latents\n\n        forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward\n        accepts_num_frames = \"num_frames\" in set(inspect.signature(forward_vae_fn).parameters.keys())\n\n        # decode decode_chunk_size frames at a time to avoid OOM\n        frames = []\n        for i in range(0, latents.shape[0], decode_chunk_size):\n            num_frames_in = latents[i: i + decode_chunk_size].shape[0]\n            decode_kwargs = {}\n            if accepts_num_frames:\n                # we only pass num_frames_in if it's expected\n                decode_kwargs[\"num_frames\"] = num_frames_in\n\n            frame = self.vae.decode(latents[i: i + decode_chunk_size], **decode_kwargs).sample\n            frames.append(frame.cpu())\n        frames = torch.cat(frames, dim=0)\n\n        # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]\n        frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)\n\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        frames = frames.float()\n        return frames\n\n    def check_inputs(self, image, height, width):\n        if (\n                not isinstance(image, torch.Tensor)\n                and not isinstance(image, PIL.Image.Image)\n                and not isinstance(image, list)\n        ):\n            raise ValueError(\n                \"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is\"\n                f\" {type(image)}\"\n            )\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n    def prepare_latents(\n        self,\n        batch_size: int,\n        num_frames: int,\n        num_channels_latents: int,\n        height: int,\n        width: int,\n        dtype: torch.dtype,\n        device: Union[str, torch.device],\n        generator: torch.Generator,\n        latents: Optional[torch.Tensor] = None,\n    ):\n        shape = (\n            batch_size,\n            num_frames,\n            num_channels_latents // 2,\n            height // self.vae_scale_factor,\n            width // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        if isinstance(self.guidance_scale, (int, float)):\n            return self.guidance_scale > 1\n        return self.guidance_scale.max() > 1\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],\n        image_pose: Union[torch.FloatTensor],\n        height: int = 576,\n        width: int = 1024,\n        num_frames: Optional[int] = None,\n        tile_size: Optional[int] = 16,\n        tile_overlap: Optional[int] = 4,\n        num_inference_steps: int = 25,\n        min_guidance_scale: float = 1.0,\n        max_guidance_scale: float = 3.0,\n        fps: int = 7,\n        motion_bucket_id: int = 127,\n        noise_aug_strength: float = 0.02,\n        image_only_indicator: bool = False,\n        decode_chunk_size: Optional[int] = None,\n        num_videos_per_prompt: Optional[int] = 1,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        output_type: Optional[str] = \"pil\",\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        return_dict: bool = True,\n        device: Union[str, torch.device] =None,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):\n                Image or images to guide image generation. If you provide a tensor, it needs to be compatible with\n                [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/\n                feature_extractor/preprocessor_config.json).\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The width in pixels of the generated image.\n            num_frames (`int`, *optional*):\n                The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` \n                and to 25 for `stable-video-diffusion-img2vid-xt`\n            num_inference_steps (`int`, *optional*, defaults to 25):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference. This parameter is modulated by `strength`.\n            min_guidance_scale (`float`, *optional*, defaults to 1.0):\n                The minimum guidance scale. Used for the classifier free guidance with first frame.\n            max_guidance_scale (`float`, *optional*, defaults to 3.0):\n                The maximum guidance scale. Used for the classifier free guidance with last frame.\n            fps (`int`, *optional*, defaults to 7):\n                Frames per second.The rate at which the generated images shall be exported to a video after generation.\n                Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.\n            motion_bucket_id (`int`, *optional*, defaults to 127):\n                The motion bucket ID. Used as conditioning for the generation. \n                The higher the number the more motion will be in the video.\n            noise_aug_strength (`float`, *optional*, defaults to 0.02):\n                The amount of noise added to the init image, \n                the higher it is the less the video will look like the init image. Increase it for more motion.\n            image_only_indicator (`bool`, *optional*, defaults to False):\n                Whether to treat the inputs as batch of images instead of videos.\n            decode_chunk_size (`int`, *optional*):\n                The number of frames to decode at a time.The higher the chunk size, the higher the temporal consistency\n                between frames, but also the higher the memory consumption. \n                By default, the decoder will decode all frames at once for maximal quality. \n                Reduce `decode_chunk_size` to reduce memory usage.\n            num_videos_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            device:\n                On which device the pipeline runs on.\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, \n                [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list of list with the generated frames.\n\n        Examples:\n\n        ```py\n        from diffusers import StableVideoDiffusionPipeline\n        from diffusers.utils import load_image, export_to_video\n\n        pipe = StableVideoDiffusionPipeline.from_pretrained(\n            \"stabilityai/stable-video-diffusion-img2vid-xt\", torch_dtype=torch.float16, variant=\"fp16\")\n        pipe.to(\"cuda\")\n\n        image = load_image(\n        \"https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200\")\n        image = image.resize((1024, 576))\n\n        frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]\n        export_to_video(frames, \"generated.mp4\", fps=7)\n        ```\n        \"\"\"\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        num_frames = num_frames if num_frames is not None else self.unet.config.num_frames\n        decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(image, height, width)\n\n        # 2. Define call parameters\n        if isinstance(image, PIL.Image.Image):\n            batch_size = 1\n        elif isinstance(image, list):\n            batch_size = len(image)\n        else:\n            batch_size = image.shape[0]\n        device = device if device is not None else self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        self._guidance_scale = max_guidance_scale\n\n        # 3. Encode input image\n        self.image_encoder.to(device)\n        image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)\n        self.image_encoder.cpu()\n\n        # NOTE: Stable Diffusion Video was conditioned on fps - 1, which\n        # is why it is reduced here.\n        fps = fps - 1\n\n        # 4. Encode input image using VAE\n        # pdb.set_trace()\n        image = self.image_processor.preprocess(image, height=height, width=width).to(device)\n        # PIL.Image.fromarray((image[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)).save(\"vis_img/test_flows.png\")\n        noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)\n        image = image + noise_aug_strength * noise\n\n        self.vae.to(device)\n        image_latents = self._encode_vae_image(\n            image,\n            device=device,\n            num_videos_per_prompt=num_videos_per_prompt,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n        )\n        image_latents = image_latents.to(image_embeddings.dtype)\n        self.vae.cpu()\n\n        # Repeat the image latents for each frame so we can concatenate them with the noise\n        # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]\n        image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)\n\n        pose_latents = self._encode_pose_image(\n            image_pose, do_classifier_free_guidance=self.do_classifier_free_guidance,\n        ).to(device)\n        pose_latents = einops.rearrange(pose_latents, '(b f) c h w -> b f c h w', f=num_frames)\n\n        # 5. Get Added Time IDs\n        added_time_ids = self._get_add_time_ids(\n            fps,\n            motion_bucket_id,\n            noise_aug_strength,\n            image_embeddings.dtype,\n            batch_size,\n            num_videos_per_prompt,\n            self.do_classifier_free_guidance,\n        )\n        added_time_ids = added_time_ids.to(device)\n\n        # 4. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None)\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_videos_per_prompt,\n            tile_size,\n            num_channels_latents,\n            height,\n            width,\n            image_embeddings.dtype,\n            device,\n            generator,\n            latents,\n        ) # [1, 72, 4, h//8, w//8]\n        latents = latents.repeat(1, num_frames // tile_size + 1, 1, 1, 1)[:, :num_frames] # [1, num_frames, 4, h//8, w//8]\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, 0.0)\n\n        # 7. Prepare guidance scale\n        guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)\n        guidance_scale = guidance_scale.to(device, latents.dtype)\n        guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)\n        guidance_scale = _append_dims(guidance_scale, latents.ndim)\n\n        self._guidance_scale = guidance_scale\n\n        # 8. Denoising loop\n        self._num_timesteps = len(timesteps)\n        indices = [[0, *range(i + 1, min(i + tile_size, num_frames))] for i in\n                   range(0, num_frames - tile_size + 1, tile_size - tile_overlap)]\n        if indices[-1][-1] < num_frames - 1:\n            indices.append([0, *range(num_frames - tile_size + 1, num_frames)])\n\n        self.pose_net.to(device)\n        self.unet.to(device)\n\n        with torch.cuda.device(device):\n            torch.cuda.empty_cache()\n\n        with self.progress_bar(total=len(timesteps) * len(indices)) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # Concatenate image_latents over channels dimension\n                latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)\n\n                # predict the noise residual\n                noise_pred = torch.zeros_like(image_latents)\n                noise_pred_cnt = image_latents.new_zeros((num_frames,))\n                weight = (torch.arange(tile_size, device=device) + 0.5) * 2. / tile_size\n                weight = torch.minimum(weight, 2 - weight)\n                for idx in indices:\n\n                    _noise_pred = self.unet(\n                        latent_model_input[:, idx],\n                        t,\n                        encoder_hidden_states=image_embeddings,\n                        added_time_ids=added_time_ids,\n                        pose_latents=pose_latents[:, idx].flatten(0, 1),\n                        image_only_indicator=image_only_indicator,\n                        return_dict=False,\n                    )[0]\n                    noise_pred[:, idx] += _noise_pred * weight[:, None, None, None]\n\n                    # # classification-free inference\n                    # pose_latents = self.pose_net(image_pose[idx].to(device))\n                    # _noise_pred = self.unet(\n                    #     latent_model_input[:1, idx],\n                    #     t,\n                    #     encoder_hidden_states=image_embeddings[:1],\n                    #     added_time_ids=added_time_ids[:1],\n                    #     pose_latents=None,\n                    #     image_only_indicator=image_only_indicator,\n                    #     return_dict=False,\n                    # )[0]\n                    # noise_pred[:1, idx] += _noise_pred * weight[:, None, None, None]\n\n                    # # normal inference\n                    # _noise_pred = self.unet(\n                    #     latent_model_input[1:, idx],\n                    #     t,\n                    #     encoder_hidden_states=image_embeddings[1:],\n                    #     added_time_ids=added_time_ids[1:],\n                    #     pose_latents=pose_latents,\n                    #     image_only_indicator=image_only_indicator,\n                    #     return_dict=False,\n                    # )[0]\n                    # noise_pred[1:, idx] += _noise_pred * weight[:, None, None, None]\n\n                    noise_pred_cnt[idx] += weight\n                    progress_bar.update()\n                noise_pred.div_(noise_pred_cnt[:, None, None, None])\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n\n        self.pose_net.cpu()\n        self.unet.cpu()\n\n        if not output_type == \"latent\":\n            self.vae.decoder.to(device)\n            frames = self.decode_latents(latents, num_frames, decode_chunk_size)\n            frames = tensor2vid(frames, self.image_processor, output_type=output_type)\n        else:\n            frames = latents\n\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return frames\n\n        return MimicMotionPipelineOutput(frames=frames)\n"
  },
  {
    "path": "mimicmotion/utils/__init__.py",
    "content": ""
  },
  {
    "path": "mimicmotion/utils/dift_utils.py",
    "content": "import gc\nfrom typing import Any, Dict, Optional, Union\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom diffusers import DDIMScheduler, StableDiffusionPipeline\nfrom diffusers.models.unet_2d_condition import UNet2DConditionModel\nfrom PIL import Image, ImageDraw\nimport pdb\n\nclass MyUNet2DConditionModel(UNet2DConditionModel):\n    def forward(\n        self,\n        sample: torch.FloatTensor,\n        timestep: Union[torch.Tensor, float, int],\n        up_ft_indices,\n        encoder_hidden_states: torch.Tensor,\n        class_labels: Optional[torch.Tensor] = None,\n        timestep_cond: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None\n    ):\n        r\"\"\"\n        Args:\n            sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor\n            timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps\n            encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under\n                `self.processor` in\n                [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).\n        \"\"\"\n        # By default samples have to be AT least a multiple of the overall upsampling factor.\n        # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).\n        # However, the upsampling interpolation output size can be forced to fit any upsampling size\n        # on the fly if necessary.\n        default_overall_up_factor = 2**self.num_upsamplers\n\n        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`\n        forward_upsample_size = False\n        upsample_size = None\n\n        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):\n            # logger.info(\"Forward upsample size to force interpolation output size.\")\n            forward_upsample_size = True\n\n        # prepare attention_mask\n        if attention_mask is not None:\n            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # 0. center input if necessary\n        if self.config.center_input_sample:\n            sample = 2 * sample - 1.0\n\n        # 1. time\n        timesteps = timestep\n        if not torch.is_tensor(timesteps):\n            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can\n            # This would be a good case for the `match` statement (Python 3.10+)\n            is_mps = sample.device.type == 'mps'\n            if isinstance(timestep, float):\n                dtype = torch.float32 if is_mps else torch.float64\n            else:\n                dtype = torch.int32 if is_mps else torch.int64\n            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)\n        elif len(timesteps.shape) == 0:\n            timesteps = timesteps[None].to(sample.device)\n\n        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n        timesteps = timesteps.expand(sample.shape[0])\n\n        t_emb = self.time_proj(timesteps)\n\n        # timesteps does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=self.dtype)\n\n        emb = self.time_embedding(t_emb, timestep_cond)\n\n        if self.class_embedding is not None:\n            if class_labels is None:\n                raise ValueError('class_labels should be provided when num_class_embeds > 0')\n\n            if self.config.class_embed_type == 'timestep':\n                class_labels = self.time_proj(class_labels)\n\n            class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)\n            emb = emb + class_emb\n\n        # 2. pre-process\n        sample = self.conv_in(sample)\n\n        # 3. down\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if hasattr(downsample_block, 'has_cross_attention') and downsample_block.has_cross_attention:\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                )\n            else:\n                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)\n\n            down_block_res_samples += res_samples\n\n        # 4. mid\n        if self.mid_block is not None:\n            sample = self.mid_block(\n                sample,\n                emb,\n                encoder_hidden_states=encoder_hidden_states,\n                attention_mask=attention_mask,\n                cross_attention_kwargs=cross_attention_kwargs,\n            )\n\n        # 5. up\n        up_ft = {}\n\n        for i, upsample_block in enumerate(self.up_blocks):\n\n            if i > np.max(up_ft_indices):\n                break\n\n            is_final_block = i == len(self.up_blocks) - 1\n\n            res_samples = down_block_res_samples[-len(upsample_block.resnets):]\n            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]\n\n            # if we have not reached the final block and need to forward the\n            # upsample size, we do it here\n            if not is_final_block and forward_upsample_size:\n                upsample_size = down_block_res_samples[-1].shape[2:]\n\n            if hasattr(upsample_block, 'has_cross_attention') and upsample_block.has_cross_attention:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    upsample_size=upsample_size,\n                    attention_mask=attention_mask,\n                )\n            else:\n                sample = upsample_block(\n                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size\n                )\n\n            if i in up_ft_indices:\n                up_ft[i] = sample.detach()\n\n        output = {}\n        output['up_ft'] = up_ft\n\n        return output\n\n\nclass OneStepSDPipeline(StableDiffusionPipeline):\n    @torch.no_grad()\n    def __call__(\n        self,\n        img_tensor,\n        t,\n        up_ft_indices,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None\n    ):\n\n        device = self._execution_device\n        latents = self.vae.encode(img_tensor).latent_dist.sample() * self.vae.config.scaling_factor\n        t = torch.tensor(t, dtype=torch.long, device=device)\n        noise = torch.randn_like(latents).to(device)\n        latents_noisy = self.scheduler.add_noise(latents, noise, t)\n        unet_output = self.unet(latents_noisy, t, up_ft_indices, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs)\n        return unet_output\n\n\nclass SDFeaturizer:\n    def __init__(self, sd_id='pretrained_models/stable-diffusion-v1-4', weight_dtype=torch.float32):\n        unet = MyUNet2DConditionModel.from_pretrained(sd_id, subfolder='unet', variant=\"fp16\").to(weight_dtype)\n        onestep_pipe = OneStepSDPipeline.from_pretrained(sd_id, unet=unet, safety_checker=None, variant=\"fp16\").to(weight_dtype)\n        onestep_pipe.vae.decoder = None\n        onestep_pipe.scheduler = DDIMScheduler.from_pretrained(sd_id, subfolder='scheduler')\n        gc.collect()\n        onestep_pipe = onestep_pipe.to('cuda')\n        onestep_pipe.enable_attention_slicing()\n        self.pipe = onestep_pipe\n\n        null_prompt = ''\n        self.null_prompt_embeds = self.pipe.encode_prompt(\n            prompt=null_prompt,\n            device='cuda',\n            num_images_per_prompt=1,\n            do_classifier_free_guidance=False)[0] # [1, 77, dim]\n        \n    @torch.no_grad()\n    def forward(self,\n                img_tensor,\n                prompt,\n                t=[261,0],\n                up_ft_index=[1,2],\n                ensemble_size=8):\n        '''\n        Args:\n            img_tensor: should be a single torch tensor in the shape of [1, C, H, W] or [C, H, W]\n            prompt: the prompt to use, a string\n            t: the time step to use, should be an int in the range of [0, 1000]\n            up_ft_index: which upsampling block of the U-Net to extract feature, you can choose [0, 1, 2, 3]\n            ensemble_size: the number of repeated images used in the batch to extract features\n        Return:\n            unet_ft: a torch tensor in the shape of [1, c, h, w]\n        '''\n        img_tensor = img_tensor.repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w\n        prompt_embeds = self.pipe.encode_prompt(\n            prompt=prompt,\n            device='cuda',\n            num_images_per_prompt=1,\n            do_classifier_free_guidance=False)[0] # [1, 77, dim]\n        prompt_embeds = prompt_embeds.repeat(ensemble_size, 1, 1)\n\n        unet_ft1 = self.pipe(\n            img_tensor=img_tensor,\n            t=t[0],\n            up_ft_indices=[up_ft_index[0]],\n            prompt_embeds=prompt_embeds)\n\n        unet_ft1 = unet_ft1['up_ft'][up_ft_index[0]] # ensem, c, h, w\n        unet_ft1 = unet_ft1.mean(0, keepdim=True) # 1,c,h,w\n\n        null_prompt_embeds = self.null_prompt_embeds.repeat(ensemble_size, 1, 1)\n        unet_ft2 = self.pipe(\n            img_tensor=img_tensor,\n            t=t[1],\n            up_ft_indices=[up_ft_index[1]],\n            prompt_embeds=null_prompt_embeds)\n        \n        unet_ft2 = unet_ft2['up_ft'][up_ft_index[1]] # ensem, c, h, w\n        unet_ft2 = unet_ft2.mean(0, keepdim=True) # 1,c,h,w\n\n        return unet_ft1, unet_ft2\n\n\nclass DIFT_Demo:\n    def __init__(self, source_img, source_dift, source_img_size):\n        self.source_dift = source_dift  # NCHW # torch.Size([1, 1280, 28, 48])\n        self.source_img = source_img\n        self.source_img_size = source_img_size\n\n    @torch.no_grad()\n    def query(self, target_img, target_dift, target_img_size, query_point, target_point, visualize=False):\n        num_channel = self.source_dift.size(1)\n        cos = nn.CosineSimilarity(dim=1)\n        source_x, source_y = int(np.round(query_point[1])), int(np.round(query_point[0]))\n\n        src_ft = self.source_dift\n        src_ft = nn.Upsample(size=self.source_img_size, mode='bilinear')(src_ft)\n        src_vec = src_ft[0, :, source_y, source_x].view(1, num_channel, 1, 1)  # 1, C, 1, 1\n\n        tgt_ft = nn.Upsample(size=target_img_size, mode='bilinear')(target_dift)\n        cos_map = cos(src_vec, tgt_ft).cpu().numpy()  # N, H, W  (1, 448, 768)\n\n        max_yx = np.unravel_index(cos_map[0].argmax(), cos_map[0].shape)\n        target_x, target_y = int(np.round(target_point[1])), int(np.round(target_point[0]))\n\n        if visualize:\n            heatmap = cos_map[0]\n            heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap))\n\n            cmap = plt.get_cmap('viridis')\n            heatmap_color = (cmap(heatmap) * 255)[..., :3].astype(np.uint8)\n\n            alpha, radius, color = 0.5, 3, (0, 255, 0)\n            blended_image = Image.blend(target_img, Image.fromarray(heatmap_color), alpha=alpha)\n            draw = ImageDraw.Draw(blended_image)\n            draw.ellipse((max_yx[1] - radius, max_yx[0] - radius, max_yx[1] + radius, max_yx[0] + radius), fill=color)\n            draw.ellipse((target_x - radius, target_y - radius, target_x + radius, target_y + radius), fill=color)\n        else:\n            blended_image = None\n        dift_feat, confidence = tgt_ft[0, :, target_y, target_x], cos_map[0, target_y, target_x]\n        return dift_feat, confidence, blended_image\n"
  },
  {
    "path": "mimicmotion/utils/flow_utils.py",
    "content": "from PIL import Image, ImageOps\nimport scipy.ndimage as ndimage\nimport cv2\nimport random\nimport numpy as np\nfrom scipy.ndimage.filters import maximum_filter\nfrom scipy import signal\ncv2.ocl.setUseOpenCL(False)\nimport torch\nimport torch.nn as nn\nclass ForwardWarp(nn.Module):\n    \"\"\"docstring for WarpLayer\"\"\"\n\n    def __init__(\n        self,\n    ):\n        super(ForwardWarp, self).__init__()\n\n    def forward(self, img, flo):\n        \"\"\"\n        -img: image (N, C, H, W)\n        -flo: optical flow (N, 2, H, W)\n        elements of flo is in [0, H] and [0, W] for dx, dy\n\n        \"\"\"\n\n        # (x1, y1)\t\t(x1, y2)\n        # +---------------+\n        # |\t\t\t\t  |\n        # |\to(x, y) \t  |\n        # |\t\t\t\t  |\n        # |\t\t\t\t  |\n        # |\t\t\t\t  |\n        # |\t\t\t\t  |\n        # +---------------+\n        # (x2, y1)\t\t(x2, y2)\n\n        N, C, _, _ = img.size()\n\n        # translate start-point optical flow to end-point optical flow\n        y = flo[:, 0:1:, :]\n        x = flo[:, 1:2, :, :]\n\n        x = x.repeat(1, C, 1, 1)\n        y = y.repeat(1, C, 1, 1)\n\n        # Four point of square (x1, y1), (x1, y2), (x2, y1), (y2, y2)\n        x1 = torch.floor(x)\n        x2 = x1 + 1\n        y1 = torch.floor(y)\n        y2 = y1 + 1\n\n        # firstly, get gaussian weights\n        w11, w12, w21, w22 = self.get_gaussian_weights(x, y, x1, x2, y1, y2)\n\n        # secondly, sample each weighted corner\n        img11, o11 = self.sample_one(img, x1, y1, w11)\n        img12, o12 = self.sample_one(img, x1, y2, w12)\n        img21, o21 = self.sample_one(img, x2, y1, w21)\n        img22, o22 = self.sample_one(img, x2, y2, w22)\n\n        imgw = img11 + img12 + img21 + img22\n        o = o11 + o12 + o21 + o22\n\n        return imgw, o\n\n    def get_gaussian_weights(self, x, y, x1, x2, y1, y2):\n        w11 = torch.exp(-((x - x1) ** 2 + (y - y1) ** 2))\n        w12 = torch.exp(-((x - x1) ** 2 + (y - y2) ** 2))\n        w21 = torch.exp(-((x - x2) ** 2 + (y - y1) ** 2))\n        w22 = torch.exp(-((x - x2) ** 2 + (y - y2) ** 2))\n\n        return w11, w12, w21, w22\n\n    def sample_one(self, img, shiftx, shifty, weight):\n        \"\"\"\n        Input:\n                -img (N, C, H, W)\n                -shiftx, shifty (N, c, H, W)\n        \"\"\"\n\n        N, C, H, W = img.size()\n\n        # flatten all (all restored as Tensors)\n        flat_shiftx = shiftx.view(-1)\n        flat_shifty = shifty.view(-1)\n        flat_basex = (\n            torch.arange(0, H, requires_grad=False)\n            .view(-1, 1)[None, None]\n            .cuda()\n            .long()\n            .repeat(N, C, 1, W)\n            .view(-1)\n        )\n        flat_basey = (\n            torch.arange(0, W, requires_grad=False)\n            .view(1, -1)[None, None]\n            .cuda()\n            .long()\n            .repeat(N, C, H, 1)\n            .view(-1)\n        )\n        flat_weight = weight.view(-1)\n        flat_img = img.view(-1)\n\n        # The corresponding positions in I1\n        idxn = (\n            torch.arange(0, N, requires_grad=False)\n            .view(N, 1, 1, 1)\n            .long()\n            .cuda()\n            .repeat(1, C, H, W)\n            .view(-1)\n        )\n        idxc = (\n            torch.arange(0, C, requires_grad=False)\n            .view(1, C, 1, 1)\n            .long()\n            .cuda()\n            .repeat(N, 1, H, W)\n            .view(-1)\n        )\n        # ttype = flat_basex.type()\n        idxx = flat_shiftx.long() + flat_basex\n        idxy = flat_shifty.long() + flat_basey\n\n        # recording the inside part the shifted\n        mask = idxx.ge(0) & idxx.lt(H) & idxy.ge(0) & idxy.lt(W)\n\n        # Mask off points out of boundaries\n        ids = idxn * C * H * W + idxc * H * W + idxx * W + idxy\n        ids_mask = torch.masked_select(ids, mask).clone().cuda()\n\n        # (zero part - gt) -> difference\n        # difference back propagate -> No influence! Whether we do need mask? mask?\n        # put (add) them together\n        # Note here! accmulate fla must be true for proper bp\n        img_warp = torch.zeros(\n            [\n                N * C * H * W,\n            ]\n        ).cuda()\n        img_warp.put_(\n            ids_mask, torch.masked_select(flat_img * flat_weight, mask), accumulate=True\n        )\n\n        one_warp = torch.zeros(\n            [\n                N * C * H * W,\n            ]\n        ).cuda()\n        one_warp.put_(ids_mask, torch.masked_select(flat_weight, mask), accumulate=True)\n\n        return img_warp.view(N, C, H, W), one_warp.view(N, C, H, W)\n\ndef get_edge(data, blur=False):\n    if blur:\n        data = cv2.GaussianBlur(data, (3, 3), 1.)\n    sobel = np.array([[1,0,-1],[2,0,-2],[1,0,-1]]).astype(np.float32)\n    ch_edges = []\n    for k in range(data.shape[2]):\n        edgex = signal.convolve2d(data[:,:,k], sobel, boundary='symm', mode='same')\n        edgey = signal.convolve2d(data[:,:,k], sobel.T, boundary='symm', mode='same')\n        ch_edges.append(np.sqrt(edgex**2 + edgey**2))\n    return sum(ch_edges)\n\ndef get_max(score, bbox):\n    u = max(0, bbox[0])\n    d = min(score.shape[0], bbox[1])\n    l = max(0, bbox[2])\n    r = min(score.shape[1], bbox[3])\n    return score[u:d,l:r].max()\n\ndef nms(score, ks):\n    assert ks % 2 == 1\n    ret_score = score.copy()\n    maxpool = maximum_filter(score, footprint=np.ones((ks, ks)))\n    ret_score[score < maxpool] = 0.\n    return ret_score\n\ndef image_flow_crop(img1, img2, flow, crop_size, phase):\n    assert len(crop_size) == 2\n    pad_h = max(crop_size[0] - img1.height, 0)\n    pad_w = max(crop_size[1] - img1.width, 0)\n    pad_h_half = int(pad_h / 2)\n    pad_w_half = int(pad_w / 2)\n    if pad_h > 0 or pad_w > 0:\n        flow_expand = np.zeros((img1.height + pad_h, img1.width + pad_w, 2), dtype=np.float32)\n        flow_expand[pad_h_half:pad_h_half+img1.height, pad_w_half:pad_w_half+img1.width, :] = flow\n        flow = flow_expand\n        border = (pad_w_half, pad_h_half, pad_w - pad_w_half, pad_h - pad_h_half)\n        img1 = ImageOps.expand(img1, border=border, fill=(0,0,0))\n        img2 = ImageOps.expand(img2, border=border, fill=(0,0,0))\n    if phase == 'train':\n        hoff = int(np.random.rand() * (img1.height - crop_size[0]))\n        woff = int(np.random.rand() * (img1.width - crop_size[1]))\n    else:\n        hoff = (img1.height - crop_size[0]) // 2\n        woff = (img1.width - crop_size[1]) // 2\n\n    img1 = img1.crop((woff, hoff, woff+crop_size[1], hoff+crop_size[0]))\n    img2 = img2.crop((woff, hoff, woff+crop_size[1], hoff+crop_size[0]))\n    flow = flow[hoff:hoff+crop_size[0], woff:woff+crop_size[1], :]\n    offset = (hoff, woff)\n    return img1, img2, flow, offset\n\ndef image_crop(img, crop_size):\n    pad_h = max(crop_size[0] - img.height, 0)\n    pad_w = max(crop_size[1] - img.width, 0)\n    pad_h_half = int(pad_h / 2)\n    pad_w_half = int(pad_w / 2)\n    if pad_h > 0 or pad_w > 0:\n        border = (pad_w_half, pad_h_half, pad_w - pad_w_half, pad_h - pad_h_half)\n        img = ImageOps.expand(img, border=border, fill=(0,0,0))\n    hoff = (img.height - crop_size[0]) // 2\n    woff = (img.width - crop_size[1]) // 2\n    return img.crop((woff, hoff, woff+crop_size[1], hoff+crop_size[0])), (pad_w_half, pad_h_half)\n\ndef image_flow_resize(img1, img2, flow, short_size=None, long_size=None):\n    assert (short_size is None) ^ (long_size is None)\n    w, h = img1.width, img1.height\n    if short_size is not None:\n        if w < h:\n            neww = short_size\n            newh = int(short_size / float(w) * h)\n        else:\n            neww = int(short_size / float(h) * w)\n            newh = short_size\n    else:\n        if w < h:\n            neww = int(long_size / float(h) * w)\n            newh = long_size\n        else:\n            neww = long_size\n            newh = int(long_size / float(w) * h)\n    img1 = img1.resize((neww, newh), Image.BICUBIC)\n    img2 = img2.resize((neww, newh), Image.BICUBIC)\n    ratio = float(newh) / h\n    flow = cv2.resize(flow.copy(), (neww, newh), interpolation=cv2.INTER_LINEAR) * ratio\n    return img1, img2, flow, ratio\n\ndef image_resize(img, short_size=None, long_size=None):\n    assert (short_size is None) ^ (long_size is None)\n    w, h = img.width, img.height\n    if short_size is not None:\n        if w < h:\n            neww = short_size\n            newh = int(short_size / float(w) * h)\n        else:\n            neww = int(short_size / float(h) * w)\n            newh = short_size\n    else:\n        if w < h:\n            neww = int(long_size / float(h) * w)\n            newh = long_size\n        else:\n            neww = long_size\n            newh = int(long_size / float(w) * h)\n    img = img.resize((neww, newh), Image.BICUBIC)\n    return img, [w, h]\n\n\ndef image_pose_crop(img, posemap, crop_size, scale):\n    assert len(crop_size) == 2\n    assert crop_size[0] <= img.height\n    assert crop_size[1] <= img.width\n    hoff = (img.height - crop_size[0]) // 2\n    woff = (img.width - crop_size[1]) // 2\n    img = img.crop((woff, hoff, woff+crop_size[1], hoff+crop_size[0]))\n    posemap = posemap[hoff//scale:hoff//scale+crop_size[0]//scale, woff//scale:woff//scale+crop_size[1]//scale,:]\n    return img, posemap\n\ndef neighbor_elim(ph, pw, d):\n    valid = np.ones((len(ph))).astype(np.int64)\n    h_dist = np.fabs(np.tile(ph[:,np.newaxis], [1,len(ph)]) - np.tile(ph.T[np.newaxis,:], [len(ph),1]))\n    w_dist = np.fabs(np.tile(pw[:,np.newaxis], [1,len(pw)]) - np.tile(pw.T[np.newaxis,:], [len(pw),1]))\n    idx1, idx2 = np.where((h_dist < d) & (w_dist < d))\n    for i,j in zip(idx1, idx2):\n        if valid[i] and valid[j] and i != j:\n            if np.random.rand() > 0.5:\n                valid[i] = 0\n            else:\n                valid[j] = 0\n    valid_idx = np.where(valid==1)\n    return ph[valid_idx], pw[valid_idx]\n\ndef remove_border(mask):\n        mask[0,:] = 0\n        mask[:,0] = 0\n        mask[mask.shape[0]-1,:] = 0\n        mask[:,mask.shape[1]-1] = 0\n\ndef flow_sampler(flow, strategy=['grid'], bg_ratio=1./6400, nms_ks=15, max_num_guide=-1, guidepoint=None):\n    assert bg_ratio >= 0 and bg_ratio <= 1, \"sampling ratio must be in (0, 1]\"\n    for s in strategy:\n        assert s in ['grid', 'uniform', 'gradnms', 'watershed', 'single', 'full', 'specified'], \"No such strategy: {}\".format(s)\n    h = flow.shape[0]\n    w = flow.shape[1]\n    ds = max(1, max(h, w) // 400) # reduce computation\n\n    if 'full' in strategy:\n        sparse = flow.copy()\n        mask = np.ones(flow.shape, dtype=np.int)\n        return sparse, mask\n\n    pts_h = []\n    pts_w = []\n    if 'grid' in strategy:\n        stride = int(np.sqrt(1./bg_ratio))\n        mesh_start_h = int((h - h // stride * stride) / 2)\n        mesh_start_w = int((w - w // stride * stride) / 2)\n        mesh = np.meshgrid(np.arange(mesh_start_h, h, stride), np.arange(mesh_start_w, w, stride))\n        pts_h.append(mesh[0].flat)\n        pts_w.append(mesh[1].flat)\n    if 'uniform' in strategy:\n        pts_h.append(np.random.randint(0, h, int(bg_ratio * h * w)))\n        pts_w.append(np.random.randint(0, w, int(bg_ratio * h * w)))\n    if \"gradnms\" in strategy:\n        ks = w // ds // 20\n        edge = get_edge(flow[::ds,::ds,:])\n        kernel = np.ones((ks, ks), dtype=np.float32) / (ks * ks)\n        subkernel = np.ones((ks//2, ks//2), dtype=np.float32) / (ks//2 * ks//2)\n        score = signal.convolve2d(edge, kernel, boundary='symm', mode='same')\n        subscore = signal.convolve2d(edge, subkernel, boundary='symm', mode='same')\n        score = score / score.max() - subscore / subscore.max()\n        nms_res = nms(score, nms_ks)\n        pth, ptw = np.where(nms_res > 0.1)\n        pts_h.append(pth * ds)\n        pts_w.append(ptw * ds)\n    if \"watershed\" in strategy:\n        edge = get_edge(flow[::ds,::ds,:])\n        edge /= max(edge.max(), 0.01)\n        edge = (edge > 0.1).astype(np.float32)\n        watershed = ndimage.distance_transform_edt(1-edge)\n        nms_res = nms(watershed, nms_ks)\n        remove_border(nms_res)\n        pth, ptw = np.where(nms_res > 0)\n        pth, ptw = neighbor_elim(pth, ptw, (nms_ks-1)/2)\n        pts_h.append(pth * ds)\n        pts_w.append(ptw * ds)\n    if \"single\" in strategy:\n        pth, ptw = np.where((flow[:,:,0] != 0) | (flow[:,:,1] != 0))\n        randidx = np.random.randint(len(pth))\n        pts_h.append(pth[randidx:randidx+1])\n        pts_w.append(ptw[randidx:randidx+1])\n    if 'specified' in strategy:\n        assert guidepoint is not None, \"if using \\\"specified\\\", switch \\\"with_info\\\" on.\"\n        pts_h.append(guidepoint[:,1])\n        pts_w.append(guidepoint[:,0])\n\n    pts_h = np.concatenate(pts_h)\n    pts_w = np.concatenate(pts_w)\n\n    if max_num_guide == -1:\n        max_num_guide = np.inf\n\n    randsel = np.random.permutation(len(pts_h))[:len(pts_h)]\n    selidx = randsel[np.arange(min(max_num_guide, len(randsel)))]\n    pts_h = pts_h[selidx]\n    pts_w = pts_w[selidx]\n\n    sparse = np.zeros(flow.shape, dtype=flow.dtype)\n    mask = np.zeros(flow.shape, dtype=np.int64)\n    \n    sparse[:, :, 0][(pts_h, pts_w)] = flow[:, :, 0][(pts_h, pts_w)]\n    sparse[:, :, 1][(pts_h, pts_w)] = flow[:, :, 1][(pts_h, pts_w)]\n    \n    mask[:,:,0][(pts_h, pts_w)] = 1\n    mask[:,:,1][(pts_h, pts_w)] = 1\n    return sparse, mask\n\ndef image_flow_aug(img1, img2, flow, flip_horizon=True):\n    if flip_horizon:\n        if random.random() < 0.5:\n            img1 = img1.transpose(Image.FLIP_LEFT_RIGHT)\n            img2 = img2.transpose(Image.FLIP_LEFT_RIGHT)\n            flow = flow[:,::-1,:].copy()\n            flow[:,:,0] = -flow[:,:,0]\n    return img1, img2, flow\n\ndef flow_aug(flow, reverse=True, scale=True, rotate=True):\n    if reverse:\n        if random.random() < 0.5:\n            flow = -flow\n    if scale:\n        rand_scale = random.uniform(0.5, 2.0)\n        flow = flow * rand_scale\n    if rotate and random.random() < 0.5:\n        lengh = np.sqrt(np.square(flow[:,:,0]) + np.square(flow[:,:,1]))\n        alpha = np.arctan(flow[:,:,1] / flow[:,:,0])\n        theta = random.uniform(0, np.pi*2)\n        flow[:,:,0] = lengh * np.cos(alpha + theta)\n        flow[:,:,1] = lengh * np.sin(alpha + theta)\n    return flow\n\ndef draw_gaussian(img, pt, sigma, type='Gaussian'):\n    # Check that any part of the gaussian is in-bounds\n    ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]\n    br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]\n    if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or\n            br[0] < 0 or br[1] < 0):\n        # If not, just return the image as is\n        return img\n\n    # Generate gaussian\n    size = 6 * sigma + 1\n    x = np.arange(0, size, 1, float)\n    y = x[:, np.newaxis]\n    x0 = y0 = size // 2\n    # The gaussian is not normalized, we want the center value to equal 1\n    if type == 'Gaussian':\n        g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))\n    elif type == 'Cauchy':\n        g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)\n\n    # Usable gaussian range\n    g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]\n    g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]\n    # Image range\n    img_x = max(0, ul[0]), min(br[0], img.shape[1])\n    img_y = max(0, ul[1]), min(br[1], img.shape[0])\n\n    img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]\n    return img\n\n\n"
  },
  {
    "path": "mimicmotion/utils/geglu_patch.py",
    "content": "import diffusers.models.activations\n\n\ndef patch_geglu_inplace():\n    \"\"\"Patch GEGLU with inplace multiplication to save GPU memory.\"\"\"\n    def forward(self, hidden_states):\n        hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)\n        return hidden_states.mul_(self.gelu(gate))\n    diffusers.models.activations.GEGLU.forward = forward\n"
  },
  {
    "path": "mimicmotion/utils/loader.py",
    "content": "import logging\n\nimport torch\nimport torch.utils.checkpoint\nfrom diffusers.models import AutoencoderKLTemporalDecoder\nfrom diffusers.schedulers import EulerDiscreteScheduler\nfrom transformers import CLIPImageProcessor, CLIPVisionModelWithProjection\n\nfrom ..modules.unet import UNetSpatioTemporalConditionModel\nfrom ..modules.pose_net import PoseNet\nfrom ..modules.controlnet import ControlNetSVDModel\nfrom ..pipelines.pipeline_mimicmotion import MimicMotionPipeline\nfrom ..pipelines.pipeline_ctrl import Ctrl_Pipeline\n\n\nlogger = logging.getLogger(__name__)\n\nclass MimicMotionModel(torch.nn.Module):\n    def __init__(self, base_model_path):\n        \"\"\"construnct base model components and load pretrained svd model except pose-net\n        Args:\n            base_model_path (str): pretrained svd model path\n        \"\"\"\n        super().__init__()\n        self.unet = UNetSpatioTemporalConditionModel.from_config(\n            UNetSpatioTemporalConditionModel.load_config(base_model_path, subfolder=\"unet\"))\n        self.vae = AutoencoderKLTemporalDecoder.from_pretrained(\n            base_model_path, subfolder=\"vae\", torch_dtype=torch.float16, variant=\"fp16\")\n        self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(\n            base_model_path, subfolder=\"image_encoder\", torch_dtype=torch.float16, variant=\"fp16\")\n        self.noise_scheduler = EulerDiscreteScheduler.from_pretrained(\n            base_model_path, subfolder=\"scheduler\")\n        self.feature_extractor = CLIPImageProcessor.from_pretrained(\n            base_model_path, subfolder=\"feature_extractor\")\n        # pose_net\n        self.pose_net = PoseNet(noise_latent_channels=self.unet.config.block_out_channels[0])\n\ndef create_ctrl_pipeline(infer_config, device):\n    \"\"\"create mimicmotion pipeline and load pretrained weight\n\n    Args:\n        infer_config (str): \n        device (str or torch.device): \"cpu\" or \"cuda:{device_id}\"\n    \"\"\"\n    mimicmotion_models = MimicMotionModel(infer_config.base_model_path)\n    mimicmotion_models.load_state_dict(torch.load(infer_config.ckpt_path, map_location=\"cpu\"), strict=False)\n    controlnet = ControlNetSVDModel.from_unet(mimicmotion_models.unet).to(device=mimicmotion_models.unet.device)\n    controlnet.load_state_dict(torch.load(infer_config.controlnet_path, map_location=\"cpu\"),strict=False)\n    pipeline = Ctrl_Pipeline(\n        vae=mimicmotion_models.vae, \n        image_encoder=mimicmotion_models.image_encoder, \n        unet=mimicmotion_models.unet, \n        controlnet=controlnet,\n        scheduler=mimicmotion_models.noise_scheduler,\n        feature_extractor=mimicmotion_models.feature_extractor, \n        pose_net=mimicmotion_models.pose_net\n    )\n    return pipeline\n\ndef create_pipeline(infer_config, device):\n    \"\"\"create mimicmotion pipeline and load pretrained weight\n\n    Args:\n        infer_config (str): \n        device (str or torch.device): \"cpu\" or \"cuda:{device_id}\"\n    \"\"\"\n    mimicmotion_models = MimicMotionModel(infer_config.base_model_path)\n    # .to(device=device).eval()\n    mimicmotion_models.load_state_dict(torch.load(infer_config.ckpt_path, map_location=\"cpu\"), strict=False)\n    pipeline = MimicMotionPipeline(\n        vae=mimicmotion_models.vae, \n        image_encoder=mimicmotion_models.image_encoder, \n        unet=mimicmotion_models.unet, \n        scheduler=mimicmotion_models.noise_scheduler,\n        feature_extractor=mimicmotion_models.feature_extractor, \n        pose_net=mimicmotion_models.pose_net\n    )\n    return pipeline\n\n"
  },
  {
    "path": "mimicmotion/utils/utils.py",
    "content": "import logging\nfrom pathlib import Path\nimport av\nfrom PIL import Image\nimport os\nfrom scipy.interpolate import PchipInterpolator\nimport numpy as np\nimport pdb\nimport torch\nimport torch.nn.functional as F\nfrom torchvision.io import write_video\n\nlogger = logging.getLogger(__name__)\n\n@torch.no_grad()\ndef get_cmp_flow(cmp, frames, sparse_optical_flow, mask):\n    '''\n        frames: [b, 13, 3, 384, 384] (0, 1) tensor\n        sparse_optical_flow: [b, 13, 2, 384, 384] (-384, 384) tensor\n        mask: [b, 13, 2, 384, 384] {0, 1} tensor\n    '''\n    # print(frames.shape)\n    dtype = frames.dtype\n    b, t, c, h, w = sparse_optical_flow.shape\n    assert h == 384 and w == 384\n    frames = frames.flatten(0, 1)  # [b*13, 3, 256, 256]\n    sparse_optical_flow = sparse_optical_flow.flatten(0, 1)  # [b*13, 2, 256, 256]\n    mask = mask.flatten(0, 1)  # [b*13, 2, 256, 256]\n\n    # print(frames.shape)\n    # print(sparse_optical_flow.shape)\n    # print(mask.shape)\n\n    # assert False\n\n    cmp_flow = []\n    for i in range(b*t):\n        tmp_flow = cmp.run(frames[i:i+1].float(), sparse_optical_flow[i:i+1].float(), mask[i:i+1].float())  # [b*13, 2, 256, 256]\n        cmp_flow.append(tmp_flow)\n    cmp_flow = torch.cat(cmp_flow, dim=0)\n    cmp_flow = cmp_flow.reshape(b, t, 2, h, w)\n\n    return cmp_flow.to(dtype=dtype)\n\n\n\ndef sample_optical_flow(A, B, h, w):\n    b, l, k, _ = A.shape\n\n    sparse_optical_flow = torch.zeros((b, l, h, w, 2), dtype=B.dtype, device=B.device)\n    mask = torch.zeros((b, l, h, w), dtype=torch.uint8, device=B.device)\n\n    x_coords = A[..., 0].long()\n    y_coords = A[..., 1].long()\n\n    x_coords = torch.clip(x_coords, 0, h - 1)\n    y_coords = torch.clip(y_coords, 0, w - 1)\n\n    b_idx = torch.arange(b)[:, None, None].repeat(1, l, k)\n    l_idx = torch.arange(l)[None, :, None].repeat(b, 1, k)\n\n    sparse_optical_flow[b_idx, l_idx, x_coords, y_coords] = B\n\n    mask[b_idx, l_idx, x_coords, y_coords] = 1\n\n    mask = mask.unsqueeze(-1).repeat(1, 1, 1, 1, 2)\n\n    return sparse_optical_flow, mask\n\n\n@torch.no_grad()\ndef get_sparse_flow(poses, h, w, t):\n\n    poses = torch.flip(poses, dims=[3])\n\n    pose_flow = (poses - poses[:, 0:1].repeat(1, t, 1, 1))[:, 1:]  # 前向光流\n    according_poses = poses[:, 0:1].repeat(1, t - 1, 1, 1)\n    \n    pose_flow = torch.flip(pose_flow, dims=[3])\n\n    b, t, K, _ = pose_flow.shape\n\n    sparse_optical_flow, mask = sample_optical_flow(according_poses, pose_flow, h, w)\n\n    return sparse_optical_flow.permute(0, 1, 4, 2, 3), mask.permute(0, 1, 4, 2, 3)\n\ndef sample_inputs_flow(first_frame, poses, poses_subset):\n\n    pb, pc, ph, pw = first_frame.shape\n    \n    # print(poses.shape)\n\n    pl = poses.shape[1]\n\n    sparse_optical_flow, mask = get_sparse_flow(poses, ph, pw, pl)\n\n    if ph != 384 or pw != 384:\n\n        first_frame_384 = F.interpolate(first_frame, (384, 384))  # [3, 384, 384]\n\n        poses_384 = torch.zeros_like(poses)\n        poses_384[:, :, :, 0] = poses[:, :, :, 0] / pw * 384\n        poses_384[:, :, :, 1] = poses[:, :, :, 1] / ph * 384\n\n        sparse_optical_flow_384, mask_384 = get_sparse_flow(poses_384, 384, 384, pl)\n    \n    else:\n        first_frame_384, poses_384 = first_frame, poses\n        sparse_optical_flow_384, mask_384 = sparse_optical_flow, mask\n    \n    controlnet_image = first_frame\n\n    return controlnet_image, sparse_optical_flow, mask, first_frame_384, sparse_optical_flow_384, mask_384\n\ndef pose2track(points_list, height, width):\n    track_points = np.zeros((18, len(points_list), 2)) # 18 x f x 2\n    track_points_subsets = np.zeros((18, len(points_list), 1)) # 18 x f x 2\n    for f in range(len(points_list)):\n        candidates, subsets, scores = points_list[f]['candidate'], points_list[f]['subset'][0], points_list[f]['score']\n        for i in range(18):\n            if subsets[i] == -1:\n                track_points_subsets[i][f] = -1\n            else:\n                # track_points[i][f][0] = candidates[i][0]\n                # track_points[i][f][1] = candidates[i][1]\n                track_points[i][f][0] = max(min(candidates[i][0] * width, width-1), 0)\n                track_points[i][f][1] = max(min(candidates[i][1] * height, height-1), 0)\n                track_points_subsets[i][f] = i\n    \n    return track_points, track_points_subsets\n\ndef pose2track_batch(points_list, height, width, batch_size):\n    track_points = np.zeros((batch_size, 18, len(points_list), 2)) # 18 x f x 2\n    track_points_subsets = np.zeros((batch_size, 18, len(points_list), 1)) # 18 x f x 2\n    for batch_idx in range(batch_size):\n        for f in range(len(points_list)):\n            candidates, subsets, scores = points_list[f]['candidate'][batch_idx], points_list[f]['subset'][batch_idx][0], points_list[f]['score'][batch_idx]\n            for i in range(18):\n                if subsets[i] == -1:\n                    track_points_subsets[batch_idx][i][f] = -1\n                else:\n                    # track_points[i][f][0] = candidates[i][0]\n                    # track_points[i][f][1] = candidates[i][1]\n                    track_points[batch_idx][i][f][0] = max(min(candidates[i][0] * width, width-1), 0)\n                    track_points[batch_idx][i][f][1] = max(min(candidates[i][1] * height, height-1), 0)\n                    track_points_subsets[batch_idx][i][f] = i\n    \n    return track_points, track_points_subsets\n\ndef points_to_flows_batch(points_list, model_length, height, width, batch_size):\n\n    track_points, track_points_subsets = pose2track_batch(points_list, height, width, batch_size)\n    # model_length = track_points.shape[1]\n    input_drag = np.zeros((batch_size, model_length - 1, height, width, 2))\n    for batch_idx in range(batch_size):\n        for splited_track, points_subset in zip(track_points[batch_idx], track_points_subsets[batch_idx]):\n            if len(splited_track) == 1: # stationary point\n                displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])\n                splited_track = tuple([splited_track[0], displacement_point])\n            # interpolate the track\n            # splited_track = interpolate_trajectory(splited_track, model_length)\n            # splited_track = splited_track[:model_length]\n            if len(splited_track) < model_length:\n                splited_track = splited_track + [splited_track[-1]] * (model_length -len(splited_track))\n            for i in range(model_length - 1):\n                if points_subset[i]!=-1:\n                    start_point = splited_track[i]\n                    end_point = splited_track[i+1]\n                    input_drag[batch_idx][i][int(start_point[1])][int(start_point[0])][0] = end_point[0] - start_point[0]\n                    input_drag[batch_idx][i][int(start_point[1])][int(start_point[0])][1] = end_point[1] - start_point[1]\n    return input_drag\n\ndef points_to_flows(points_list, model_length, height, width):\n    \n    track_points, track_points_subsets = pose2track(points_list, height, width)\n    # model_length = track_points.shape[1]\n    input_drag = np.zeros((model_length - 1, height, width, 2))\n\n    for splited_track, points_subset in zip(track_points, track_points_subsets):\n        if len(splited_track) == 1: # stationary point\n            displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])\n            splited_track = tuple([splited_track[0], displacement_point])\n        # interpolate the track\n        # splited_track = interpolate_trajectory(splited_track, model_length)\n        # splited_track = splited_track[:model_length]\n        if len(splited_track) < model_length:\n            splited_track = splited_track + [splited_track[-1]] * (model_length -len(splited_track))\n        for i in range(model_length - 1):\n            if points_subset[i]!=-1:\n                start_point = splited_track[i]\n                end_point = splited_track[i+1]\n                input_drag[i][int(start_point[1])][int(start_point[0])][0] = end_point[0] - start_point[0]\n                input_drag[i][int(start_point[1])][int(start_point[0])][1] = end_point[1] - start_point[1]\n    return input_drag\n\ndef interpolate_trajectory(points, n_points):\n    x = [point[0] for point in points]\n    y = [point[1] for point in points]\n\n    t = np.linspace(0, 1, len(points))\n\n    fx = PchipInterpolator(t, x)\n    fy = PchipInterpolator(t, y)\n\n    new_t = np.linspace(0, 1, n_points)\n\n    new_x = fx(new_t)\n    new_y = fy(new_t)\n    new_points = list(zip(new_x, new_y))\n\n    return new_points\n\n\ndef bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):\n    \"\"\"Generate a bivariate isotropic or anisotropic Gaussian kernel.\n    In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.\n    Args:\n        kernel_size (int):\n        sig_x (float):\n        sig_y (float):\n        theta (float): Radian measurement.\n        grid (ndarray, optional): generated by :func:`mesh_grid`,\n            with the shape (K, K, 2), K is the kernel size. Default: None\n        isotropic (bool):\n    Returns:\n        kernel (ndarray): normalized kernel.\n    \"\"\"\n    if grid is None:\n        grid, _, _ = mesh_grid(kernel_size)\n    if isotropic:\n        sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])\n    else:\n        sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)\n    kernel = pdf2(sigma_matrix, grid)\n    kernel = kernel / np.sum(kernel)\n    return kernel\n\ndef mesh_grid(kernel_size):\n    \"\"\"Generate the mesh grid, centering at zero.\n    Args:\n        kernel_size (int):\n    Returns:\n        xy (ndarray): with the shape (kernel_size, kernel_size, 2)\n        xx (ndarray): with the shape (kernel_size, kernel_size)\n        yy (ndarray): with the shape (kernel_size, kernel_size)\n    \"\"\"\n    ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)\n    xx, yy = np.meshgrid(ax, ax)\n    xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,\n                                                                           1))).reshape(kernel_size, kernel_size, 2)\n    return xy, xx, yy\n\n\ndef pdf2(sigma_matrix, grid):\n    \"\"\"Calculate PDF of the bivariate Gaussian distribution.\n    Args:\n        sigma_matrix (ndarray): with the shape (2, 2)\n        grid (ndarray): generated by :func:`mesh_grid`,\n            with the shape (K, K, 2), K is the kernel size.\n    Returns:\n        kernel (ndarrray): un-normalized kernel.\n    \"\"\"\n    inverse_sigma = np.linalg.inv(sigma_matrix)\n    kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))\n    return kernel\n\n\ndef sigma_matrix2(sig_x, sig_y, theta):\n    \"\"\"Calculate the rotated sigma matrix (two dimensional matrix).\n    Args:\n        sig_x (float):\n        sig_y (float):\n        theta (float): Radian measurement.\n    Returns:\n        ndarray: Rotated sigma matrix.\n    \"\"\"\n    d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])\n    u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])\n    return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))\n\n\ndef save_to_mp4(frames, save_path, fps=7):\n    frames = frames.permute((0, 2, 3, 1))  # (f, c, h, w) to (f, h, w, c)\n    Path(save_path).parent.mkdir(parents=True, exist_ok=True)\n    write_video(save_path, frames, fps=fps)\n\ndef read_frames(video_path):\n    container = av.open(video_path)\n\n    video_stream = next(s for s in container.streams if s.type == \"video\")\n    frames = []\n    for packet in container.demux(video_stream):\n        for frame in packet.decode():\n            image = Image.frombytes(\n                \"RGB\",\n                (frame.width, frame.height),\n                frame.to_rgb().to_ndarray(),\n            )\n            frames.append(image)\n\n    return frames\n\ndef get_fps(video_path):\n    container = av.open(video_path)\n    video_stream = next(s for s in container.streams if s.type == \"video\")\n    fps = video_stream.average_rate\n    container.close()\n    return fps\n\n\ndef save_videos_from_pil(pil_images, path, fps=8):\n    import av\n\n    save_fmt = Path(path).suffix\n    os.makedirs(os.path.dirname(path), exist_ok=True)\n    width, height = pil_images[0].size\n\n    if save_fmt == \".mp4\":\n        codec = \"libx264\"\n        container = av.open(path, \"w\")\n        stream = container.add_stream(codec, rate=fps)\n\n        stream.width = width\n        stream.height = height\n        stream.pix_fmt = 'yuv420p'\n        stream.bit_rate = 10000000   \n        stream.options[\"crf\"] = \"18\"\n\n        for pil_image in pil_images:\n            # pil_image = Image.fromarray(image_arr).convert(\"RGB\")\n            av_frame = av.VideoFrame.from_image(pil_image)\n            container.mux(stream.encode(av_frame))\n        container.mux(stream.encode())\n        container.close()\n\n    elif save_fmt == \".gif\":\n        pil_images[0].save(\n            fp=path,\n            format=\"GIF\",\n            append_images=pil_images[1:],\n            save_all=True,\n            duration=(1 / fps * 1000),\n            loop=0,\n        )\n    else:\n        raise ValueError(\"Unsupported file type. Use .mp4 or .gif.\")\n\n"
  },
  {
    "path": "mimicmotion/utils/visualizer.py",
    "content": "\n# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\nimport os\nimport numpy as np\nimport imageio\nimport torch\n\nfrom matplotlib import cm\nimport torch.nn.functional as F\nimport torchvision.transforms as transforms\nimport matplotlib.pyplot as plt\nfrom PIL import Image, ImageDraw\n\n\ndef read_video_from_path(path):\n    try:\n        reader = imageio.get_reader(path)\n    except Exception as e:\n        print(\"Error opening video file: \", e)\n        return None\n    frames = []\n    for i, im in enumerate(reader):\n        frames.append(np.array(im))\n    return np.stack(frames)\n\n\ndef draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True):\n    # Create a draw object\n    draw = ImageDraw.Draw(rgb)\n    # Calculate the bounding box of the circle\n    left_up_point = (coord[0] - radius, coord[1] - radius)\n    right_down_point = (coord[0] + radius, coord[1] + radius)\n    # Draw the circle\n    draw.ellipse(\n        [left_up_point, right_down_point],\n        fill=tuple(color) if visible else None,\n        outline=tuple(color),\n    )\n    return rgb\n\n\ndef draw_line(rgb, coord_y, coord_x, color, linewidth):\n    draw = ImageDraw.Draw(rgb)\n    draw.line(\n        (coord_y[0], coord_y[1], coord_x[0], coord_x[1]),\n        fill=tuple(color),\n        width=linewidth,\n    )\n    return rgb\n\n\ndef add_weighted(rgb, alpha, original, beta, gamma):\n    return (rgb * alpha + original * beta + gamma).astype(\"uint8\")\n\n\nclass Visualizer:\n    def __init__(\n        self,\n        save_dir: str = \"./results\",\n        grayscale: bool = False,\n        pad_value: int = 0,\n        fps: int = 8,\n        mode: str = \"rainbow\",  # 'cool', 'optical_flow'\n        linewidth: int = 2,\n        show_first_frame: int = 0,\n        tracks_leave_trace: int = 0,  # -1 for infinite\n    ):\n        self.mode = mode\n        self.save_dir = save_dir\n        if mode == \"rainbow\":\n            self.color_map = cm.get_cmap(\"gist_rainbow\")\n        elif mode == \"cool\":\n            self.color_map = cm.get_cmap(mode)\n        self.show_first_frame = show_first_frame\n        self.grayscale = grayscale\n        self.tracks_leave_trace = tracks_leave_trace\n        self.pad_value = pad_value\n        self.linewidth = linewidth\n        self.fps = fps\n\n    def visualize(\n        self,\n        video: torch.Tensor,  # (B,T,C,H,W)\n        tracks: torch.Tensor,  # (B,T,N,2)\n        visibility: torch.Tensor = None,  # (B, T, N, 1) bool\n        gt_tracks: torch.Tensor = None,  # (B,T,N,2)\n        segm_mask: torch.Tensor = None,  # (B,1,H,W)\n        filename: str = \"video\",\n        writer=None,  # tensorboard Summary Writer, used for visualization during training\n        step: int = 0,\n        query_frame: int = 0,\n        save_video: bool = True,\n        compensate_for_camera_motion: bool = False,\n    ):\n        if compensate_for_camera_motion:\n            assert segm_mask is not None\n        if segm_mask is not None:\n            coords = tracks[0, query_frame].round().long()\n            segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()\n\n        video = F.pad(\n            video,\n            (self.pad_value, self.pad_value, self.pad_value, self.pad_value),\n            \"constant\",\n            255,\n        )\n        tracks = tracks + self.pad_value\n\n        if self.grayscale:\n            transform = transforms.Grayscale()\n            video = transform(video)\n            video = video.repeat(1, 1, 3, 1, 1)\n\n        res_video = self.draw_tracks_on_video(\n            video=video,\n            tracks=tracks,\n            visibility=visibility,\n            segm_mask=segm_mask,\n            gt_tracks=gt_tracks,\n            query_frame=query_frame,\n            compensate_for_camera_motion=compensate_for_camera_motion,\n        )\n        if save_video:\n            self.save_video(res_video, filename=filename, writer=writer, step=step)\n        return res_video\n\n    def save_video(self, video, filename, writer=None, step=0):\n        if writer is not None:\n            writer.add_video(\n                filename,\n                video.to(torch.uint8),\n                global_step=step,\n                fps=self.fps,\n            )\n        else:\n            os.makedirs(self.save_dir, exist_ok=True)\n            wide_list = list(video.unbind(1))\n            wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]\n\n            # Prepare the video file path\n            save_path = os.path.join(self.save_dir, f\"{filename}.mp4\")\n\n            # Create a writer object\n            video_writer = imageio.get_writer(save_path, fps=self.fps)\n\n            # Write frames to the video file\n            for frame in wide_list[2:-1]:\n                video_writer.append_data(frame)\n\n            video_writer.close()\n\n            # print(f\"Video saved to {save_path}\")\n\n    def draw_tracks_on_video(\n        self,\n        video: torch.Tensor,\n        tracks: torch.Tensor,\n        visibility: torch.Tensor = None,\n        segm_mask: torch.Tensor = None,\n        gt_tracks=None,\n        query_frame: int = 0,\n        compensate_for_camera_motion=False,\n    ):\n        B, T, C, H, W = video.shape\n        _, _, N, D = tracks.shape\n\n        assert D == 2\n        assert C == 3\n        video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy()  # S, H, W, C\n        tracks = tracks[0].long().detach().cpu().numpy()  # S, N, 2\n        if gt_tracks is not None:\n            gt_tracks = gt_tracks[0].detach().cpu().numpy()\n\n        res_video = []\n\n        # process input video\n        for rgb in video:\n            res_video.append(rgb.copy())\n        vector_colors = np.zeros((T, N, 3))\n\n        if self.mode == \"optical_flow\":\n            import flow_vis\n\n            vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])\n        elif segm_mask is None:\n            if self.mode == \"rainbow\":\n                y_min, y_max = (\n                    tracks[query_frame, :, 1].min(),\n                    tracks[query_frame, :, 1].max(),\n                )\n                norm = plt.Normalize(y_min, y_max)\n                for n in range(N):\n                    color = self.color_map(norm(tracks[query_frame, n, 1]))\n                    color = np.array(color[:3])[None] * 255\n                    vector_colors[:, n] = np.repeat(color, T, axis=0)\n            else:\n                # color changes with time\n                for t in range(T):\n                    color = np.array(self.color_map(t / T)[:3])[None] * 255\n                    vector_colors[t] = np.repeat(color, N, axis=0)\n        else:\n            if self.mode == \"rainbow\":\n                vector_colors[:, segm_mask <= 0, :] = 255\n\n                y_min, y_max = (\n                    tracks[0, segm_mask > 0, 1].min(),\n                    tracks[0, segm_mask > 0, 1].max(),\n                )\n                norm = plt.Normalize(y_min, y_max)\n                for n in range(N):\n                    if segm_mask[n] > 0:\n                        color = self.color_map(norm(tracks[0, n, 1]))\n                        color = np.array(color[:3])[None] * 255\n                        vector_colors[:, n] = np.repeat(color, T, axis=0)\n\n            else:\n                # color changes with segm class\n                segm_mask = segm_mask.cpu()\n                color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)\n                color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0\n                color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0\n                vector_colors = np.repeat(color[None], T, axis=0)\n\n        #  draw tracks\n        if self.tracks_leave_trace != 0:\n            for t in range(query_frame + 1, T):\n                first_ind = (\n                    max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0\n                )\n                curr_tracks = tracks[first_ind : t + 1]\n                curr_colors = vector_colors[first_ind : t + 1]\n                if compensate_for_camera_motion:\n                    diff = (\n                        tracks[first_ind : t + 1, segm_mask <= 0]\n                        - tracks[t : t + 1, segm_mask <= 0]\n                    ).mean(1)[:, None]\n\n                    curr_tracks = curr_tracks - diff\n                    curr_tracks = curr_tracks[:, segm_mask > 0]\n                    curr_colors = curr_colors[:, segm_mask > 0]\n\n                res_video[t] = self._draw_pred_tracks(\n                    res_video[t],\n                    curr_tracks,\n                    curr_colors,\n                )\n                if gt_tracks is not None:\n                    res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1])\n\n        #  draw points\n        for t in range(query_frame, T):\n            img = Image.fromarray(np.uint8(res_video[t]))\n            for i in range(N):\n                coord = (tracks[t, i, 0], tracks[t, i, 1])\n                visibile = True\n                if visibility is not None:\n                    visibile = visibility[0, t, i]\n                if coord[0] != 0 and coord[1] != 0:\n                    if not compensate_for_camera_motion or (\n                        compensate_for_camera_motion and segm_mask[i] > 0\n                    ):\n                        img = draw_circle(\n                            img,\n                            coord=coord,\n                            radius=int(self.linewidth * 2),\n                            color=vector_colors[t, i].astype(int),\n                            visible=visibile,\n                        )\n            res_video[t] = np.array(img)\n\n        #  construct the final rgb sequence\n        if self.show_first_frame > 0:\n            res_video = [res_video[0]] * self.show_first_frame + res_video[1:]\n        return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()\n\n    def _draw_pred_tracks(\n        self,\n        rgb: np.ndarray,  # H x W x 3\n        tracks: np.ndarray,  # T x 2\n        vector_colors: np.ndarray,\n        alpha: float = 0.5,\n    ):\n        T, N, _ = tracks.shape\n        rgb = Image.fromarray(np.uint8(rgb))\n        for s in range(T - 1):\n            vector_color = vector_colors[s]\n            original = rgb.copy()\n            alpha = (s / T) ** 2\n            for i in range(N):\n                coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))\n                coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))\n                if coord_y[0] != 0 and coord_y[1] != 0:\n                    rgb = draw_line(\n                        rgb,\n                        coord_y,\n                        coord_x,\n                        vector_color[i].astype(int),\n                        self.linewidth,\n                    )\n            if self.tracks_leave_trace > 0:\n                rgb = Image.fromarray(\n                    np.uint8(add_weighted(np.array(rgb), alpha, np.array(original), 1 - alpha, 0))\n                )\n        rgb = np.array(rgb)\n        return rgb\n\n    def _draw_gt_tracks(\n        self,\n        rgb: np.ndarray,  # H x W x 3,\n        gt_tracks: np.ndarray,  # T x 2\n    ):\n        T, N, _ = gt_tracks.shape\n        color = np.array((211, 0, 0))\n        rgb = Image.fromarray(np.uint8(rgb))\n        for t in range(T):\n            for i in range(N):\n                gt_tracks = gt_tracks[t][i]\n                #  draw a red cross\n                if gt_tracks[0] > 0 and gt_tracks[1] > 0:\n                    length = self.linewidth * 3\n                    coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)\n                    coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)\n                    rgb = draw_line(\n                        rgb,\n                        coord_y,\n                        coord_x,\n                        color,\n                        self.linewidth,\n                    )\n                    coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)\n                    coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)\n                    rgb = draw_line(\n                        rgb,\n                        coord_y,\n                        coord_x,\n                        color,\n                        self.linewidth,\n                    )\n        rgb = np.array(rgb)\n        return rgb\n\n\n\n########## optical flow visualization ########## \n\nUNKNOWN_FLOW_THRESH = 1e7\nSMALLFLOW = 0.0\nLARGEFLOW = 1e8\n\n\ndef vis_flow_to_video(optical_flow, num_frames):\n    '''\n    optical_flow: T-1 x H x W x C\n    '''\n    video = []\n    for i in range(1, num_frames):\n        flow_img = flow_to_image(optical_flow[i])\n        flow_img = torch.Tensor(flow_img) # H x W x 3\n        video.append(flow_img)\n    video = torch.stack(video, dim=0) # T-1 x H x W x 3\n    return video\n\n\n# from https://github.com/gengshan-y/VCN\ndef flow_to_image(flow):\n    \"\"\"\n    Convert flow into middlebury color code image\n    :param flow: optical flow map\n    :return: optical flow image in middlebury color\n    \"\"\"\n    u = flow[:, :, 0]\n    v = flow[:, :, 1]\n\n    maxu = -999.\n    maxv = -999.\n    minu = 999.\n    minv = 999.\n\n    idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)\n    u[idxUnknow] = 0\n    v[idxUnknow] = 0\n\n    maxu = max(maxu, np.max(u))\n    minu = min(minu, np.min(u))\n\n    maxv = max(maxv, np.max(v))\n    minv = min(minv, np.min(v))\n\n    rad = np.sqrt(u ** 2 + v ** 2)\n    maxrad = max(-1, np.max(rad))\n\n    u = u / (maxrad + np.finfo(float).eps)\n    v = v / (maxrad + np.finfo(float).eps)\n\n    img = compute_color(u, v)\n\n    idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)\n    img[idx] = 0\n\n    return np.uint8(img)\n\n\ndef compute_color(u, v):\n    \"\"\"\n    compute optical flow color map\n    :param u: optical flow horizontal map\n    :param v: optical flow vertical map\n    :return: optical flow in color code\n    \"\"\"\n    [h, w] = u.shape\n    img = np.zeros([h, w, 3])\n    nanIdx = np.isnan(u) | np.isnan(v)\n    u[nanIdx] = 0\n    v[nanIdx] = 0\n\n    colorwheel = make_color_wheel()\n    ncols = np.size(colorwheel, 0)\n\n    rad = np.sqrt(u ** 2 + v ** 2)\n\n    a = np.arctan2(-v, -u) / np.pi\n\n    fk = (a + 1) / 2 * (ncols - 1) + 1\n\n    k0 = np.floor(fk).astype(int)\n\n    k1 = k0 + 1\n    k1[k1 == ncols + 1] = 1\n    f = fk - k0\n\n    for i in range(0, np.size(colorwheel, 1)):\n        tmp = colorwheel[:, i]\n        col0 = tmp[k0 - 1] / 255\n        col1 = tmp[k1 - 1] / 255\n        col = (1 - f) * col0 + f * col1\n\n        idx = rad <= 1\n        col[idx] = 1 - rad[idx] * (1 - col[idx]) # 光流越小，颜色越亮。这样可以使得静止或者运动较慢的区域在可视化结果中更加明显\n        notidx = np.logical_not(idx) \n\n        col[notidx] *= 0.75 # 光流越大，颜色越暗\n        img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx)))\n\n    return img\n\n\ndef make_color_wheel():\n    \"\"\"\n    Generate color wheel according Middlebury color code\n    :return: Color wheel\n    \"\"\"\n    RY = 15\n    YG = 6\n    GC = 4\n    CB = 11\n    BM = 13\n    MR = 6\n\n    ncols = RY + YG + GC + CB + BM + MR\n\n    colorwheel = np.zeros([ncols, 3])\n\n    col = 0\n\n    # RY\n    colorwheel[0:RY, 0] = 255\n    colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))\n    col += RY\n\n    # YG\n    colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG))\n    colorwheel[col:col + YG, 1] = 255\n    col += YG\n\n    # GC\n    colorwheel[col:col + GC, 1] = 255\n    colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))\n    col += GC\n\n    # CB\n    colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB))\n    colorwheel[col:col + CB, 2] = 255\n    col += CB\n\n    # BM\n    colorwheel[col:col + BM, 2] = 255\n    colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))\n    col += + BM\n\n    # MR\n    colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))\n    colorwheel[col:col + MR, 0] = 255\n\n    return colorwheel\n"
  },
  {
    "path": "requirements.txt",
    "content": "accelerate\ntorch\ntorchvision\nPillow\nnumpy\nomegaconf\ndecord\neinops\nmatplotlib\ndiffusers==0.27.0\nscipy\nav==12.0.0\nimageio\nopencv_contrib_python\ntransformers\nhuggingface_hub==0.25.2\nonnxruntime\n"
  },
  {
    "path": "scripts/test.sh",
    "content": "CUDA_VISIBLE_DEVICES=0 python inference_ctrl.py \\\n    --inference_config configs/test.yaml \\\n    --name test \\\n    # --no_use_float16 for V100\n"
  }
]