Repository: KlingAIResearch/LivePortrait Branch: main Commit: 49784e879821 Files: 145 Total size: 752.1 KB Directory structure: gitextract_l1jm0s2t/ ├── .gitignore ├── .vscode/ │ └── settings.json ├── LICENSE ├── app.py ├── app_animals.py ├── assets/ │ ├── .gitignore │ ├── docs/ │ │ ├── changelog/ │ │ │ ├── 2024-07-10.md │ │ │ ├── 2024-07-19.md │ │ │ ├── 2024-07-24.md │ │ │ ├── 2024-08-02.md │ │ │ ├── 2024-08-05.md │ │ │ ├── 2024-08-06.md │ │ │ ├── 2024-08-19.md │ │ │ └── 2025-01-01.md │ │ ├── directory-structure.md │ │ ├── how-to-install-ffmpeg.md │ │ └── speed.md │ ├── examples/ │ │ └── driving/ │ │ ├── aggrieved.pkl │ │ ├── d1.pkl │ │ ├── d2.pkl │ │ ├── d5.pkl │ │ ├── d7.pkl │ │ ├── d8.pkl │ │ ├── laugh.pkl │ │ ├── open_lip.pkl │ │ ├── shake_face.pkl │ │ ├── shy.pkl │ │ ├── talking.pkl │ │ └── wink.pkl │ └── gradio/ │ ├── gradio_description_animate_clear.md │ ├── gradio_description_animation.md │ ├── gradio_description_retargeting.md │ ├── gradio_description_retargeting_video.md │ ├── gradio_description_upload.md │ ├── gradio_description_upload_animal.md │ └── gradio_title.md ├── inference.py ├── inference_animals.py ├── pretrained_weights/ │ └── .gitkeep ├── readme.md ├── readme_zh_cn.md ├── requirements.txt ├── requirements_base.txt ├── requirements_macOS.txt ├── speed.py └── src/ ├── config/ │ ├── __init__.py │ ├── argument_config.py │ ├── base_config.py │ ├── crop_config.py │ ├── inference_config.py │ └── models.yaml ├── gradio_pipeline.py ├── live_portrait_pipeline.py ├── live_portrait_pipeline_animal.py ├── live_portrait_wrapper.py ├── modules/ │ ├── __init__.py │ ├── appearance_feature_extractor.py │ ├── convnextv2.py │ ├── dense_motion.py │ ├── motion_extractor.py │ ├── spade_generator.py │ ├── stitching_retargeting_network.py │ ├── util.py │ └── warping_network.py └── utils/ ├── __init__.py ├── animal_landmark_runner.py ├── camera.py ├── check_windows_port.py ├── crop.py ├── cropper.py ├── dependencies/ │ ├── XPose/ │ │ ├── config_model/ │ │ │ ├── UniPose_SwinT.py │ │ │ └── coco_transformer.py │ │ ├── models/ │ │ │ ├── UniPose/ │ │ │ │ ├── __init__.py │ │ │ │ ├── attention.py │ │ │ │ ├── backbone.py │ │ │ │ ├── deformable_transformer.py │ │ │ │ ├── fuse_modules.py │ │ │ │ ├── mask_generate.py │ │ │ │ ├── ops/ │ │ │ │ │ ├── functions/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── ms_deform_attn_func.py │ │ │ │ │ ├── modules/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── ms_deform_attn.py │ │ │ │ │ │ └── ms_deform_attn_key_aware.py │ │ │ │ │ ├── setup.py │ │ │ │ │ ├── src/ │ │ │ │ │ │ ├── cpu/ │ │ │ │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ │ │ │ └── ms_deform_attn_cpu.h │ │ │ │ │ │ ├── cuda/ │ │ │ │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ │ │ │ ├── ms_deform_attn.h │ │ │ │ │ │ └── vision.cpp │ │ │ │ │ └── test.py │ │ │ │ ├── position_encoding.py │ │ │ │ ├── swin_transformer.py │ │ │ │ ├── transformer_deformable.py │ │ │ │ ├── transformer_vanilla.py │ │ │ │ ├── unipose.py │ │ │ │ └── utils.py │ │ │ ├── __init__.py │ │ │ └── registry.py │ │ ├── predefined_keypoints.py │ │ ├── transforms.py │ │ └── util/ │ │ ├── addict.py │ │ ├── box_ops.py │ │ ├── config.py │ │ ├── keypoint_ops.py │ │ └── misc.py │ └── insightface/ │ ├── __init__.py │ ├── app/ │ │ ├── __init__.py │ │ ├── common.py │ │ └── face_analysis.py │ ├── data/ │ │ ├── __init__.py │ │ ├── image.py │ │ ├── objects/ │ │ │ └── meanshape_68.pkl │ │ ├── pickle_object.py │ │ └── rec_builder.py │ ├── model_zoo/ │ │ ├── __init__.py │ │ ├── arcface_onnx.py │ │ ├── attribute.py │ │ ├── inswapper.py │ │ ├── landmark.py │ │ ├── model_store.py │ │ ├── model_zoo.py │ │ ├── retinaface.py │ │ └── scrfd.py │ └── utils/ │ ├── __init__.py │ ├── constant.py │ ├── download.py │ ├── face_align.py │ ├── filesystem.py │ ├── storage.py │ └── transform.py ├── face_analysis_diy.py ├── filter.py ├── helper.py ├── human_landmark_runner.py ├── io.py ├── resources/ │ ├── clip_embedding_68.pkl │ ├── clip_embedding_9.pkl │ └── lip_array.pkl ├── retargeting_utils.py ├── rprint.py ├── timer.py ├── video.py └── viz.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ **/__pycache__/ *.py[cod] **/*.py[cod] *$py.class # Model weights **/*.pth **/*.onnx pretrained_weights/*.md pretrained_weights/docs pretrained_weights/liveportrait pretrained_weights/liveportrait_animals # Ipython notebook *.ipynb # Temporary files or benchmark resources animations/* tmp/* .vscode/launch.json **/*.DS_Store gradio_temp/** # Windows dependencies ffmpeg/ LivePortrait_env/ # XPose build files src/utils/dependencies/XPose/models/UniPose/ops/build src/utils/dependencies/XPose/models/UniPose/ops/dist src/utils/dependencies/XPose/models/UniPose/ops/MultiScaleDeformableAttention.egg-info ================================================ FILE: .vscode/settings.json ================================================ { "[python]": { "editor.tabSize": 4 }, "files.eol": "\n", "files.insertFinalNewline": true, "files.trimFinalNewlines": true, "files.trimTrailingWhitespace": true, "files.exclude": { "**/.git": true, "**/.svn": true, "**/.hg": true, "**/CVS": true, "**/.DS_Store": true, "**/Thumbs.db": true, "**/*.crswap": true, "**/__pycache__": true } } ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2024 Kuaishou Visual Generation and Interaction Center Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. --- The code of InsightFace is released under the MIT License. The models of InsightFace are for non-commercial research purposes only. If you want to use the LivePortrait project for commercial purposes, you should remove and replace InsightFace’s detection models to fully comply with the MIT license. ================================================ FILE: app.py ================================================ # coding: utf-8 """ The entrance of the gradio for human """ import os import tyro import subprocess import gradio as gr import os.path as osp from src.utils.helper import load_description from src.gradio_pipeline import GradioPipeline from src.config.crop_config import CropConfig from src.config.argument_config import ArgumentConfig from src.config.inference_config import InferenceConfig def partial_fields(target_class, kwargs): return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) def fast_check_ffmpeg(): try: subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) return True except: return False # set tyro theme tyro.extras.set_accent_color("bright_cyan") args = tyro.cli(ArgumentConfig) ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg") if osp.exists(ffmpeg_dir): os.environ["PATH"] += (os.pathsep + ffmpeg_dir) if not fast_check_ffmpeg(): raise ImportError( "FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html" ) # specify configs for inference inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig # global_tab_selection = None gradio_pipeline = GradioPipeline( inference_cfg=inference_cfg, crop_cfg=crop_cfg, args=args ) if args.gradio_temp_dir not in (None, ''): os.environ["GRADIO_TEMP_DIR"] = args.gradio_temp_dir os.makedirs(args.gradio_temp_dir, exist_ok=True) def gpu_wrapped_execute_video(*args, **kwargs): return gradio_pipeline.execute_video(*args, **kwargs) def gpu_wrapped_execute_image_retargeting(*args, **kwargs): return gradio_pipeline.execute_image_retargeting(*args, **kwargs) def gpu_wrapped_execute_video_retargeting(*args, **kwargs): return gradio_pipeline.execute_video_retargeting(*args, **kwargs) def reset_sliders(*args, **kwargs): return 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.5, True, True # assets title_md = "assets/gradio/gradio_title.md" example_portrait_dir = "assets/examples/source" example_video_dir = "assets/examples/driving" data_examples_i2v = [ [osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False], [osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False], [osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False], [osp.join(example_portrait_dir, "s5.jpg"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False], [osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False], [osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True], ] data_examples_v2v = [ [osp.join(example_portrait_dir, "s13.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, 3e-7], # [osp.join(example_portrait_dir, "s14.mp4"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False, False, 3e-7], # [osp.join(example_portrait_dir, "s15.mp4"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False, False, 3e-7], [osp.join(example_portrait_dir, "s18.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, 3e-7], # [osp.join(example_portrait_dir, "s19.mp4"), osp.join(example_video_dir, "d6.mp4"), True, True, True, False, False, 3e-7], [osp.join(example_portrait_dir, "s20.mp4"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False, 3e-7], ] #################### interface logic #################### # Define components first retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.5, step=0.05, label="crop scale") video_retargeting_source_scale = gr.Number(minimum=1.8, maximum=3.2, value=2.3, step=0.05, label="crop scale") driving_smooth_observation_variance_retargeting = gr.Number(value=3e-6, label="motion smooth strength", minimum=1e-11, maximum=1e-2, step=1e-8) video_retargeting_silence = gr.Checkbox(value=False, label="keeping the lip silent") eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio") lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio") video_lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio") head_pitch_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative pitch") head_yaw_slider = gr.Slider(minimum=-25, maximum=25, value=0, step=1, label="relative yaw") head_roll_slider = gr.Slider(minimum=-15.0, maximum=15.0, value=0, step=1, label="relative roll") mov_x = gr.Slider(minimum=-0.19, maximum=0.19, value=0.0, step=0.01, label="x-axis movement") mov_y = gr.Slider(minimum=-0.19, maximum=0.19, value=0.0, step=0.01, label="y-axis movement") mov_z = gr.Slider(minimum=0.9, maximum=1.2, value=1.0, step=0.01, label="z-axis movement") lip_variation_zero = gr.Slider(minimum=-0.09, maximum=0.09, value=0, step=0.01, label="pouting") lip_variation_one = gr.Slider(minimum=-20.0, maximum=15.0, value=0, step=0.01, label="pursing 😐") lip_variation_two = gr.Slider(minimum=0.0, maximum=15.0, value=0, step=0.01, label="grin 😁") lip_variation_three = gr.Slider(minimum=-90.0, maximum=120.0, value=0, step=1.0, label="lip close <-> open") smile = gr.Slider(minimum=-0.3, maximum=1.3, value=0, step=0.01, label="smile 😄") wink = gr.Slider(minimum=0, maximum=39, value=0, step=0.01, label="wink 😉") eyebrow = gr.Slider(minimum=-30, maximum=30, value=0, step=0.01, label="eyebrow 🤨") eyeball_direction_x = gr.Slider(minimum=-30.0, maximum=30.0, value=0, step=0.01, label="eye gaze (horizontal) 👀") eyeball_direction_y = gr.Slider(minimum=-63.0, maximum=63.0, value=0, step=0.01, label="eye gaze (vertical) 🙄") retargeting_input_image = gr.Image(type="filepath") retargeting_input_video = gr.Video() output_image = gr.Image(type="numpy") output_image_paste_back = gr.Image(type="numpy") retargeting_output_image = gr.Image(type="numpy") retargeting_output_image_paste_back = gr.Image(type="numpy") output_video = gr.Video(autoplay=False) output_video_paste_back = gr.Video(autoplay=False) with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo: gr.HTML(load_description(title_md)) gr.Markdown(load_description("assets/gradio/gradio_description_upload.md")) with gr.Row(): with gr.Column(): with gr.Tabs(): with gr.TabItem("🖼️ Source Image") as tab_image: with gr.Accordion(open=True, label="Source Image"): source_image_input = gr.Image(type="filepath") gr.Examples( examples=[ [osp.join(example_portrait_dir, "s9.jpg")], [osp.join(example_portrait_dir, "s6.jpg")], [osp.join(example_portrait_dir, "s10.jpg")], [osp.join(example_portrait_dir, "s5.jpg")], [osp.join(example_portrait_dir, "s7.jpg")], [osp.join(example_portrait_dir, "s12.jpg")], [osp.join(example_portrait_dir, "s22.jpg")], [osp.join(example_portrait_dir, "s23.jpg")], ], inputs=[source_image_input], cache_examples=False, ) with gr.TabItem("🎞️ Source Video") as tab_video: with gr.Accordion(open=True, label="Source Video"): source_video_input = gr.Video() gr.Examples( examples=[ [osp.join(example_portrait_dir, "s13.mp4")], # [osp.join(example_portrait_dir, "s14.mp4")], # [osp.join(example_portrait_dir, "s15.mp4")], [osp.join(example_portrait_dir, "s18.mp4")], # [osp.join(example_portrait_dir, "s19.mp4")], [osp.join(example_portrait_dir, "s20.mp4")], ], inputs=[source_video_input], cache_examples=False, ) tab_selection = gr.Textbox(visible=False) tab_image.select(lambda: "Image", None, tab_selection) tab_video.select(lambda: "Video", None, tab_selection) with gr.Accordion(open=True, label="Cropping Options for Source Image or Video"): with gr.Row(): flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)") scale = gr.Number(value=2.3, label="source crop scale", minimum=1.8, maximum=3.2, step=0.05) vx_ratio = gr.Number(value=0.0, label="source crop x", minimum=-0.5, maximum=0.5, step=0.01) vy_ratio = gr.Number(value=-0.125, label="source crop y", minimum=-0.5, maximum=0.5, step=0.01) with gr.Column(): with gr.Tabs(): with gr.TabItem("🎞️ Driving Video") as v_tab_video: with gr.Accordion(open=True, label="Driving Video"): driving_video_input = gr.Video() gr.Examples( examples=[ [osp.join(example_video_dir, "d0.mp4")], [osp.join(example_video_dir, "d18.mp4")], [osp.join(example_video_dir, "d19.mp4")], [osp.join(example_video_dir, "d14.mp4")], [osp.join(example_video_dir, "d6.mp4")], [osp.join(example_video_dir, "d20.mp4")], ], inputs=[driving_video_input], cache_examples=False, ) with gr.TabItem("🖼️ Driving Image") as v_tab_image: with gr.Accordion(open=True, label="Driving Image"): driving_image_input = gr.Image(type="filepath") gr.Examples( examples=[ [osp.join(example_video_dir, "d30.jpg")], [osp.join(example_video_dir, "d9.jpg")], [osp.join(example_video_dir, "d19.jpg")], [osp.join(example_video_dir, "d8.jpg")], [osp.join(example_video_dir, "d12.jpg")], [osp.join(example_video_dir, "d38.jpg")], ], inputs=[driving_image_input], cache_examples=False, ) with gr.TabItem("📁 Driving Pickle") as v_tab_pickle: with gr.Accordion(open=True, label="Driving Pickle"): driving_video_pickle_input = gr.File(type="filepath", file_types=[".pkl"]) gr.Examples( examples=[ [osp.join(example_video_dir, "d1.pkl")], [osp.join(example_video_dir, "d2.pkl")], [osp.join(example_video_dir, "d5.pkl")], [osp.join(example_video_dir, "d7.pkl")], [osp.join(example_video_dir, "d8.pkl")], ], inputs=[driving_video_pickle_input], cache_examples=False, ) v_tab_selection = gr.Textbox(visible=False) v_tab_video.select(lambda: "Video", None, v_tab_selection) v_tab_image.select(lambda: "Image", None, v_tab_selection) v_tab_pickle.select(lambda: "Pickle", None, v_tab_selection) # with gr.Accordion(open=False, label="Animation Instructions"): # gr.Markdown(load_description("assets/gradio/gradio_description_animation.md")) with gr.Accordion(open=True, label="Cropping Options for Driving Video"): with gr.Row(): flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving)") scale_crop_driving_video = gr.Number(value=2.2, label="driving crop scale", minimum=1.8, maximum=3.2, step=0.05) vx_ratio_crop_driving_video = gr.Number(value=0.0, label="driving crop x", minimum=-0.5, maximum=0.5, step=0.01) vy_ratio_crop_driving_video = gr.Number(value=-0.1, label="driving crop y", minimum=-0.5, maximum=0.5, step=0.01) with gr.Row(): with gr.Accordion(open=True, label="Animation Options"): with gr.Row(): flag_normalize_lip = gr.Checkbox(value=False, label="normalize lip") flag_relative_input = gr.Checkbox(value=True, label="relative motion") flag_remap_input = gr.Checkbox(value=True, label="paste-back") flag_stitching_input = gr.Checkbox(value=True, label="stitching") animation_region = gr.Radio(["exp", "pose", "lip", "eyes", "all"], value="all", label="animation region") driving_option_input = gr.Radio(['expression-friendly', 'pose-friendly'], value="expression-friendly", label="driving option (i2v)") driving_multiplier = gr.Number(value=1.0, label="driving multiplier (i2v)", minimum=0.0, maximum=2.0, step=0.02) driving_smooth_observation_variance = gr.Number(value=3e-7, label="motion smooth strength (v2v)", minimum=1e-11, maximum=1e-2, step=1e-8) gr.Markdown(load_description("assets/gradio/gradio_description_animate_clear.md")) with gr.Row(): process_button_animation = gr.Button("🚀 Animate", variant="primary") with gr.Row(): with gr.Column(): output_video_i2v = gr.Video(autoplay=False, label="The animated video in the original image space") with gr.Column(): output_video_concat_i2v = gr.Video(autoplay=False, label="The animated video") with gr.Row(): with gr.Column(): output_image_i2i = gr.Image(type="numpy", label="The animated image in the original image space", visible=False) with gr.Column(): output_image_concat_i2i = gr.Image(type="numpy", label="The animated image", visible=False) with gr.Row(): process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_pickle_input, driving_video_input, driving_image_input, output_video_i2v, output_video_concat_i2v, output_image_i2i, output_image_concat_i2i], value="🧹 Clear") with gr.Row(): # Examples gr.Markdown("## You could also choose the examples below by one click ⬇️") with gr.Row(): with gr.Tabs(): with gr.TabItem("🖼️ Portrait Animation"): gr.Examples( examples=data_examples_i2v, fn=gpu_wrapped_execute_video, inputs=[ source_image_input, driving_video_input, flag_relative_input, flag_do_crop_input, flag_remap_input, flag_crop_driving_video_input, ], outputs=[output_image, output_image_paste_back], examples_per_page=len(data_examples_i2v), cache_examples=False, ) with gr.TabItem("🎞️ Portrait Video Editing"): gr.Examples( examples=data_examples_v2v, fn=gpu_wrapped_execute_video, inputs=[ source_video_input, driving_video_input, flag_relative_input, flag_do_crop_input, flag_remap_input, flag_crop_driving_video_input, driving_smooth_observation_variance, ], outputs=[output_image, output_image_paste_back], examples_per_page=len(data_examples_v2v), cache_examples=False, ) # Retargeting Image gr.Markdown(load_description("assets/gradio/gradio_description_retargeting.md"), visible=True) with gr.Row(visible=True): flag_do_crop_input_retargeting_image = gr.Checkbox(value=True, label="do crop (source)") flag_stitching_retargeting_input = gr.Checkbox(value=True, label="stitching") retargeting_source_scale.render() eye_retargeting_slider.render() lip_retargeting_slider.render() with gr.Row(visible=True): with gr.Column(): with gr.Accordion(open=True, label="Facial movement sliders"): with gr.Row(visible=True): head_pitch_slider.render() head_yaw_slider.render() head_roll_slider.render() with gr.Row(visible=True): mov_x.render() mov_y.render() mov_z.render() with gr.Column(): with gr.Accordion(open=True, label="Facial expression sliders"): with gr.Row(visible=True): lip_variation_zero.render() lip_variation_one.render() lip_variation_two.render() with gr.Row(visible=True): lip_variation_three.render() smile.render() wink.render() with gr.Row(visible=True): eyebrow.render() eyeball_direction_x.render() eyeball_direction_y.render() with gr.Row(visible=True): reset_button = gr.Button("🔄 Reset") reset_button.click( fn=reset_sliders, inputs=None, outputs=[ head_pitch_slider, head_yaw_slider, head_roll_slider, mov_x, mov_y, mov_z, lip_variation_zero, lip_variation_one, lip_variation_two, lip_variation_three, smile, wink, eyebrow, eyeball_direction_x, eyeball_direction_y, retargeting_source_scale, flag_stitching_retargeting_input, flag_do_crop_input_retargeting_image ] ) with gr.Row(visible=True): with gr.Column(): with gr.Accordion(open=True, label="Retargeting Image Input"): retargeting_input_image.render() gr.Examples( examples=[ [osp.join(example_portrait_dir, "s9.jpg")], [osp.join(example_portrait_dir, "s6.jpg")], [osp.join(example_portrait_dir, "s10.jpg")], [osp.join(example_portrait_dir, "s5.jpg")], [osp.join(example_portrait_dir, "s7.jpg")], [osp.join(example_portrait_dir, "s12.jpg")], [osp.join(example_portrait_dir, "s22.jpg")], # [osp.join(example_portrait_dir, "s23.jpg")], [osp.join(example_portrait_dir, "s42.jpg")], ], inputs=[retargeting_input_image], cache_examples=False, ) with gr.Column(): with gr.Accordion(open=True, label="Retargeting Result"): retargeting_output_image.render() with gr.Column(): with gr.Accordion(open=True, label="Paste-back Result"): retargeting_output_image_paste_back.render() with gr.Row(visible=True): process_button_reset_retargeting = gr.ClearButton( [ retargeting_input_image, retargeting_output_image, retargeting_output_image_paste_back, ], value="🧹 Clear" ) # Retargeting Video gr.Markdown(load_description("assets/gradio/gradio_description_retargeting_video.md"), visible=True) with gr.Row(visible=True): flag_do_crop_input_retargeting_video = gr.Checkbox(value=True, label="do crop (source)") video_retargeting_source_scale.render() video_lip_retargeting_slider.render() driving_smooth_observation_variance_retargeting.render() video_retargeting_silence.render() with gr.Row(visible=True): process_button_retargeting_video = gr.Button("🚗 Retargeting Video", variant="primary") with gr.Row(visible=True): with gr.Column(): with gr.Accordion(open=True, label="Retargeting Video Input"): retargeting_input_video.render() gr.Examples( examples=[ [osp.join(example_portrait_dir, "s13.mp4")], # [osp.join(example_portrait_dir, "s18.mp4")], # [osp.join(example_portrait_dir, "s20.mp4")], [osp.join(example_portrait_dir, "s29.mp4")], [osp.join(example_portrait_dir, "s32.mp4")], [osp.join(example_video_dir, "d3.mp4")], ], inputs=[retargeting_input_video], cache_examples=False, ) with gr.Column(): with gr.Accordion(open=True, label="Retargeting Result"): output_video.render() with gr.Column(): with gr.Accordion(open=True, label="Paste-back Result"): output_video_paste_back.render() with gr.Row(visible=True): process_button_reset_retargeting = gr.ClearButton( [ video_lip_retargeting_slider, retargeting_input_video, output_video, output_video_paste_back ], value="🧹 Clear" ) # binding functions for buttons process_button_animation.click( fn=gpu_wrapped_execute_video, inputs=[ source_image_input, source_video_input, driving_video_input, driving_image_input, driving_video_pickle_input, flag_normalize_lip, flag_relative_input, flag_do_crop_input, flag_remap_input, flag_stitching_input, animation_region, driving_option_input, driving_multiplier, flag_crop_driving_video_input, scale, vx_ratio, vy_ratio, scale_crop_driving_video, vx_ratio_crop_driving_video, vy_ratio_crop_driving_video, driving_smooth_observation_variance, tab_selection, v_tab_selection, ], outputs=[output_video_i2v, output_video_i2v, output_video_concat_i2v, output_video_concat_i2v, output_image_i2i, output_image_i2i, output_image_concat_i2i, output_image_concat_i2i], show_progress=True ) retargeting_input_image.change( fn=gradio_pipeline.init_retargeting_image, inputs=[retargeting_source_scale, eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image], outputs=[eye_retargeting_slider, lip_retargeting_slider] ) sliders = [eye_retargeting_slider, lip_retargeting_slider, head_pitch_slider, head_yaw_slider, head_roll_slider, mov_x, mov_y, mov_z, lip_variation_zero, lip_variation_one, lip_variation_two, lip_variation_three, smile, wink, eyebrow, eyeball_direction_x, eyeball_direction_y] for slider in sliders: # NOTE: gradio >= 4.0.0 may cause slow response slider.change( fn=gpu_wrapped_execute_image_retargeting, inputs=[ eye_retargeting_slider, lip_retargeting_slider, head_pitch_slider, head_yaw_slider, head_roll_slider, mov_x, mov_y, mov_z, lip_variation_zero, lip_variation_one, lip_variation_two, lip_variation_three, smile, wink, eyebrow, eyeball_direction_x, eyeball_direction_y, retargeting_input_image, retargeting_source_scale, flag_stitching_retargeting_input, flag_do_crop_input_retargeting_image ], outputs=[retargeting_output_image, retargeting_output_image_paste_back], ) process_button_retargeting_video.click( fn=gpu_wrapped_execute_video_retargeting, inputs=[video_lip_retargeting_slider, retargeting_input_video, video_retargeting_source_scale, driving_smooth_observation_variance_retargeting, video_retargeting_silence, flag_do_crop_input_retargeting_video], outputs=[output_video, output_video_paste_back], show_progress=True ) demo.launch( server_port=args.server_port, share=args.share, server_name=args.server_name ) ================================================ FILE: app_animals.py ================================================ # coding: utf-8 """ The entrance of the gradio for animal """ import os import tyro import subprocess import gradio as gr import os.path as osp from src.utils.helper import load_description from src.gradio_pipeline import GradioPipelineAnimal from src.config.crop_config import CropConfig from src.config.argument_config import ArgumentConfig from src.config.inference_config import InferenceConfig def partial_fields(target_class, kwargs): return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) def fast_check_ffmpeg(): try: subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) return True except: return False # set tyro theme tyro.extras.set_accent_color("bright_cyan") args = tyro.cli(ArgumentConfig) ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg") if osp.exists(ffmpeg_dir): os.environ["PATH"] += (os.pathsep + ffmpeg_dir) if not fast_check_ffmpeg(): raise ImportError( "FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html" ) # specify configs for inference inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig gradio_pipeline_animal: GradioPipelineAnimal = GradioPipelineAnimal( inference_cfg=inference_cfg, crop_cfg=crop_cfg, args=args ) if args.gradio_temp_dir not in (None, ''): os.environ["GRADIO_TEMP_DIR"] = args.gradio_temp_dir os.makedirs(args.gradio_temp_dir, exist_ok=True) def gpu_wrapped_execute_video(*args, **kwargs): return gradio_pipeline_animal.execute_video(*args, **kwargs) # assets title_md = "assets/gradio/gradio_title.md" example_portrait_dir = "assets/examples/source" example_video_dir = "assets/examples/driving" data_examples_i2v = [ [osp.join(example_portrait_dir, "s41.jpg"), osp.join(example_video_dir, "d3.mp4"), True, False, False, False], [osp.join(example_portrait_dir, "s40.jpg"), osp.join(example_video_dir, "d6.mp4"), True, False, False, False], [osp.join(example_portrait_dir, "s25.jpg"), osp.join(example_video_dir, "d19.mp4"), True, False, False, False], ] data_examples_i2v_pickle = [ [osp.join(example_portrait_dir, "s25.jpg"), osp.join(example_video_dir, "wink.pkl"), True, False, False, False], [osp.join(example_portrait_dir, "s40.jpg"), osp.join(example_video_dir, "talking.pkl"), True, False, False, False], [osp.join(example_portrait_dir, "s41.jpg"), osp.join(example_video_dir, "aggrieved.pkl"), True, False, False, False], ] #################### interface logic #################### # Define components first output_image = gr.Image(type="numpy") output_image_paste_back = gr.Image(type="numpy") output_video_i2v = gr.Video(autoplay=False) output_video_concat_i2v = gr.Video(autoplay=False) output_video_i2v_gif = gr.Image(type="numpy") with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo: gr.HTML(load_description(title_md)) gr.Markdown(load_description("assets/gradio/gradio_description_upload_animal.md")) with gr.Row(): with gr.Column(): with gr.Accordion(open=True, label="🐱 Source Animal Image"): source_image_input = gr.Image(type="filepath") gr.Examples( examples=[ [osp.join(example_portrait_dir, "s25.jpg")], [osp.join(example_portrait_dir, "s30.jpg")], [osp.join(example_portrait_dir, "s31.jpg")], [osp.join(example_portrait_dir, "s32.jpg")], [osp.join(example_portrait_dir, "s33.jpg")], [osp.join(example_portrait_dir, "s39.jpg")], [osp.join(example_portrait_dir, "s40.jpg")], [osp.join(example_portrait_dir, "s41.jpg")], [osp.join(example_portrait_dir, "s38.jpg")], [osp.join(example_portrait_dir, "s36.jpg")], ], inputs=[source_image_input], cache_examples=False, ) with gr.Accordion(open=True, label="Cropping Options for Source Image"): with gr.Row(): flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)") scale = gr.Number(value=2.3, label="source crop scale", minimum=1.8, maximum=3.2, step=0.05) vx_ratio = gr.Number(value=0.0, label="source crop x", minimum=-0.5, maximum=0.5, step=0.01) vy_ratio = gr.Number(value=-0.125, label="source crop y", minimum=-0.5, maximum=0.5, step=0.01) with gr.Column(): with gr.Tabs(): with gr.TabItem("📁 Driving Pickle") as tab_pickle: with gr.Accordion(open=True, label="Driving Pickle"): driving_video_pickle_input = gr.File() gr.Examples( examples=[ [osp.join(example_video_dir, "wink.pkl")], [osp.join(example_video_dir, "shy.pkl")], [osp.join(example_video_dir, "aggrieved.pkl")], [osp.join(example_video_dir, "open_lip.pkl")], [osp.join(example_video_dir, "laugh.pkl")], [osp.join(example_video_dir, "talking.pkl")], [osp.join(example_video_dir, "shake_face.pkl")], ], inputs=[driving_video_pickle_input], cache_examples=False, ) with gr.TabItem("🎞️ Driving Video") as tab_video: with gr.Accordion(open=True, label="Driving Video"): driving_video_input = gr.Video() gr.Examples( examples=[ # [osp.join(example_video_dir, "d0.mp4")], # [osp.join(example_video_dir, "d18.mp4")], [osp.join(example_video_dir, "d19.mp4")], [osp.join(example_video_dir, "d14.mp4")], [osp.join(example_video_dir, "d6.mp4")], [osp.join(example_video_dir, "d3.mp4")], ], inputs=[driving_video_input], cache_examples=False, ) tab_selection = gr.Textbox(visible=False) tab_pickle.select(lambda: "Pickle", None, tab_selection) tab_video.select(lambda: "Video", None, tab_selection) with gr.Accordion(open=True, label="Cropping Options for Driving Video"): with gr.Row(): flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving)") scale_crop_driving_video = gr.Number(value=2.2, label="driving crop scale", minimum=1.8, maximum=3.2, step=0.05) vx_ratio_crop_driving_video = gr.Number(value=0.0, label="driving crop x", minimum=-0.5, maximum=0.5, step=0.01) vy_ratio_crop_driving_video = gr.Number(value=-0.1, label="driving crop y", minimum=-0.5, maximum=0.5, step=0.01) with gr.Row(): with gr.Accordion(open=False, label="Animation Options"): with gr.Row(): flag_stitching = gr.Checkbox(value=False, label="stitching (not recommended)") flag_remap_input = gr.Checkbox(value=False, label="paste-back (not recommended)") driving_multiplier = gr.Number(value=1.0, label="driving multiplier", minimum=0.0, maximum=2.0, step=0.02) gr.Markdown(load_description("assets/gradio/gradio_description_animate_clear.md")) with gr.Row(): process_button_animation = gr.Button("🚀 Animate", variant="primary") with gr.Row(): with gr.Column(): with gr.Accordion(open=True, label="The animated video in the cropped image space"): output_video_i2v.render() with gr.Column(): with gr.Accordion(open=True, label="The animated gif in the cropped image space"): output_video_i2v_gif.render() with gr.Column(): with gr.Accordion(open=True, label="The animated video"): output_video_concat_i2v.render() with gr.Row(): process_button_reset = gr.ClearButton([source_image_input, driving_video_input, output_video_i2v, output_video_concat_i2v, output_video_i2v_gif], value="🧹 Clear") with gr.Row(): # Examples gr.Markdown("## You could also choose the examples below by one click ⬇️") with gr.Row(): with gr.Tabs(): with gr.TabItem("📁 Driving Pickle") as tab_video: gr.Examples( examples=data_examples_i2v_pickle, fn=gpu_wrapped_execute_video, inputs=[ source_image_input, driving_video_pickle_input, flag_do_crop_input, flag_stitching, flag_remap_input, flag_crop_driving_video_input, ], outputs=[output_image, output_image_paste_back, output_video_i2v_gif], examples_per_page=len(data_examples_i2v_pickle), cache_examples=False, ) with gr.TabItem("🎞️ Driving Video") as tab_video: gr.Examples( examples=data_examples_i2v, fn=gpu_wrapped_execute_video, inputs=[ source_image_input, driving_video_input, flag_do_crop_input, flag_stitching, flag_remap_input, flag_crop_driving_video_input, ], outputs=[output_image, output_image_paste_back, output_video_i2v_gif], examples_per_page=len(data_examples_i2v), cache_examples=False, ) process_button_animation.click( fn=gpu_wrapped_execute_video, inputs=[ source_image_input, driving_video_input, driving_video_pickle_input, flag_do_crop_input, flag_remap_input, driving_multiplier, flag_stitching, flag_crop_driving_video_input, scale, vx_ratio, vy_ratio, scale_crop_driving_video, vx_ratio_crop_driving_video, vy_ratio_crop_driving_video, tab_selection, ], outputs=[output_video_i2v, output_video_concat_i2v, output_video_i2v_gif], show_progress=True ) demo.launch( server_port=args.server_port, share=args.share, server_name=args.server_name ) ================================================ FILE: assets/.gitignore ================================================ examples/driving/*.pkl examples/driving/*_crop.mp4 ================================================ FILE: assets/docs/changelog/2024-07-10.md ================================================ ## 2024/07/10 **First, thank you all for your attention, support, sharing, and contributions to LivePortrait!** ❤️ The popularity of LivePortrait has exceeded our expectations. If you encounter any issues or other problems and we do not respond promptly, please accept our apologies. We are still actively updating and improving this repository. ### Updates - Audio and video concatenating: If the driving video contains audio, it will automatically be included in the generated video. Additionally, the generated video will maintain the same FPS as the driving video. If you run LivePortrait on Windows, you need to install `ffprobe` and `ffmpeg` exe, see issue [#94](https://github.com/KlingTeam/LivePortrait/issues/94). - Driving video auto-cropping: Implemented automatic cropping for driving videos by tracking facial landmarks and calculating a global cropping box with a 1:1 aspect ratio. Alternatively, you can crop using video editing software or other tools to achieve a 1:1 ratio. Auto-cropping is not enbaled by default, you can specify it by `--flag_crop_driving_video`. - Motion template making: Added the ability to create motion templates to protect privacy. The motion template is a `.pkl` file that only contains the motions of the driving video. Theoretically, it is impossible to reconstruct the original face from the template. These motion templates can be used to generate videos without needing the original driving video. By default, the motion template will be generated and saved as a `.pkl` file with the same name as the driving video, e.g., `d0.mp4` -> `d0.pkl`. Once generated, you can specify it using the `-d` or `--driving` option. ### About driving video - For a guide on using your own driving video, see the [driving video auto-cropping](https://github.com/KlingTeam/LivePortrait/tree/main?tab=readme-ov-file#driving-video-auto-cropping) section. ### Others - If you encounter a black box problem, disable half-precision inference by using `--no_flag_use_half_precision`, reported by issue [#40](https://github.com/KlingTeam/LivePortrait/issues/40), [#48](https://github.com/KlingTeam/LivePortrait/issues/48), [#62](https://github.com/KlingTeam/LivePortrait/issues/62). ================================================ FILE: assets/docs/changelog/2024-07-19.md ================================================ ## 2024/07/19 **Once again, we would like to express our heartfelt gratitude for your love, attention, and support for LivePortrait! 🎉** We are excited to announce the release of an implementation of Portrait Video Editing (aka v2v) today! Special thanks to the hard work of the LivePortrait team: [Dingyun Zhang](https://github.com/Mystery099), [Zhizhou Zhong](https://github.com/zzzweakman), and [Jianzhu Guo](https://github.com/cleardusk). ### Updates - Portrait video editing (v2v): Implemented a version of Portrait Video Editing (aka v2v). Ensure you have `pykalman` package installed, which has been added in [`requirements_base.txt`](../../../requirements_base.txt). You can specify the source video using the `-s` or `--source` option, adjust the temporal smoothness of motion with `--driving_smooth_observation_variance`, enable head pose motion transfer with `--flag_video_editing_head_rotation`, and ensure the eye-open scalar of each source frame matches the first source frame before animation with `--flag_source_video_eye_retargeting`. - More options in Gradio: We have upgraded the Gradio interface and added more options. These include `Cropping Options for Source Image or Video` and `Cropping Options for Driving Video`, providing greater flexibility and control.

LivePortrait
The Gradio Interface for LivePortrait

### Community Contributions - **ONNX/TensorRT Versions of LivePortrait:** Explore optimized versions of LivePortrait for faster performance: - [FasterLivePortrait](https://github.com/warmshao/FasterLivePortrait) by [warmshao](https://github.com/warmshao) ([#150](https://github.com/KlingTeam/LivePortrait/issues/150)) - [Efficient-Live-Portrait](https://github.com/aihacker111/Efficient-Live-Portrait) by [aihacker111](https://github.com/aihacker111/Efficient-Live-Portrait) ([#126](https://github.com/KlingTeam/LivePortrait/issues/126), [#142](https://github.com/KlingTeam/LivePortrait/issues/142)) - **LivePortrait with [X-Pose](https://github.com/IDEA-Research/X-Pose) Detection:** Check out [LivePortrait](https://github.com/ShiJiaying/LivePortrait) by [ShiJiaying](https://github.com/ShiJiaying) for enhanced detection capabilities using X-pose, see [#119](https://github.com/KlingTeam/LivePortrait/issues/119). ================================================ FILE: assets/docs/changelog/2024-07-24.md ================================================ ## 2024/07/24 ### Updates - **Portrait pose editing:** You can change the `relative pitch`, `relative yaw`, and `relative roll` in the Gradio interface to adjust the pose of the source portrait. - **Detection threshold:** We have added a `--det_thresh` argument with a default value of 0.15 to increase recall, meaning more types of faces (e.g., monkeys, human-like) will be detected. You can set it to other values, e.g., 0.5, by using `python app.py --det_thresh 0.5`.

LivePortrait
Pose Editing in the Gradio Interface

================================================ FILE: assets/docs/changelog/2024-08-02.md ================================================ ## 2024/08/02
Animals Singing Dance Monkey 🎤
🎉 We are excited to announce the release of a new version featuring animals mode, along with several other updates. Special thanks to the dedicated efforts of the LivePortrait team. 💪 We also provided an one-click installer for Windows users, checkout the details [here](./2024-08-05.md). ### Updates on Animals mode We are pleased to announce the release of the animals mode, which is fine-tuned on approximately 230K frames of various animals (mostly cats and dogs). The trained weights have been updated in the `liveportrait_animals` subdirectory, available on [HuggingFace](https://huggingface.co/KlingTeam/LivePortrait/tree/main/) or [Google Drive](https://drive.google.com/drive/u/0/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib). You should [download the weights](https://github.com/KlingTeam/LivePortrait?tab=readme-ov-file#2-download-pretrained-weights) before running. There are two ways to run this mode. > Please note that we have not trained the stitching and retargeting modules for the animals model due to several technical issues. _This may be addressed in future updates._ Therefore, we recommend **disabling stitching by setting the `--no_flag_stitching`** option when running the model. Additionally, `paste-back` is also not recommended. #### Install X-Pose We have chosen [X-Pose](https://github.com/IDEA-Research/X-Pose) as the keypoints detector for animals. This relies on `transformers==4.22.0` and `pillow>=10.2.0` (which are already updated in `requirements.txt`) and requires building an OP named `MultiScaleDeformableAttention`. Refer to the [PyTorch installation](https://github.com/KlingTeam/LivePortrait?tab=readme-ov-file#for-linux-or-windows-users) for Linux and Windows users. Next, build the OP `MultiScaleDeformableAttention` by running: ```bash cd src/utils/dependencies/XPose/models/UniPose/ops python setup.py build install cd - # this returns to the previous directory ``` To run the model, use the `inference_animals.py` script: ```bash python inference_animals.py -s assets/examples/source/s39.jpg -d assets/examples/driving/wink.pkl --no_flag_stitching --driving_multiplier 1.75 ``` Alternatively, you can use Gradio for a more user-friendly interface. Launch it with: ```bash python app_animals.py # --server_port 8889 --server_name "0.0.0.0" --share ``` > [!WARNING] > [X-Pose](https://github.com/IDEA-Research/X-Pose) is only for Non-commercial Scientific Research Purposes, you should remove and replace it with other detectors if you use it for commercial purposes. ### Updates on Humans mode - **Driving Options**: We have introduced an `expression-friendly` driving option to **reduce head wobbling**, now set as the default. While it may be less effective with large head poses, you can also select the `pose-friendly` option, which is the same as the previous version. This can be set using `--driving_option` or selected in the Gradio interface. Additionally, we added a `--driving_multiplier` option to adjust driving intensity, with a default value of 1, which can also be set in the Gradio interface. - **Retargeting Video in Gradio**: We have implemented a video retargeting feature. You can specify a `target lip-open ratio` to adjust the mouth movement in the source video. For instance, setting it to 0 will close the mouth in the source video 🤐. ### Others - [**Poe supports LivePortrait**](https://poe.com/LivePortrait). Check out the news on [X](https://x.com/poe_platform/status/1816136105781256260). - [ComfyUI-LivePortraitKJ](https://github.com/kijai/ComfyUI-LivePortraitKJ) (1.1K 🌟) now includes MediaPipe as an alternative to InsightFace, ensuring the license remains under MIT and Apache 2.0. - [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait) features real-time portrait pose/expression editing and animation, and is registered with ComfyUI-Manager. **Below are some screenshots of the new features and improvements:** | ![The Gradio Interface of Animals Mode](../animals-mode-gradio-2024-08-02.jpg) | |:---:| | **The Gradio Interface of Animals Mode** | | ![Driving Options and Multiplier](../driving-option-multiplier-2024-08-02.jpg) | |:---:| | **Driving Options and Multiplier** | | ![The Feature of Retargeting Video](../retargeting-video-2024-08-02.jpg) | |:---:| | **The Feature of Retargeting Video** | ================================================ FILE: assets/docs/changelog/2024-08-05.md ================================================ ## One-click Windows Installer ### Download the installer from HuggingFace ```bash # !pip install -U "huggingface_hub[cli]" huggingface-cli download cleardusk/LivePortrait-Windows LivePortrait-Windows-v20240806.zip --local-dir ./ ``` If you cannot access to Huggingface, you can use [hf-mirror](https://hf-mirror.com/) to download: ```bash # !pip install -U "huggingface_hub[cli]" export HF_ENDPOINT=https://hf-mirror.com huggingface-cli download cleardusk/LivePortrait-Windows LivePortrait-Windows-v20240806.zip --local-dir ./ ``` Alternatively, you can manually download it from the [HuggingFace](https://huggingface.co/cleardusk/LivePortrait-Windows/blob/main/LivePortrait-Windows-v20240806.zip) page. Then, simply unzip the package `LivePortrait-Windows-v20240806.zip` and double-click `run_windows_human.bat` for the Humans mode, or `run_windows_animal.bat` for the **Animals mode**. ================================================ FILE: assets/docs/changelog/2024-08-06.md ================================================ ## Precise Portrait Editing Inspired by [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait) ([@PowerHouseMan](https://github.com/PowerHouseMan)), we have implemented a version of Precise Portrait Editing in the Gradio interface. With each adjustment of the slider, the edited image updates in real-time. You can click the `🔄 Reset` button to reset all slider parameters. However, the performance may not be as fast as the ComfyUI plugin.

LivePortrait
Preciese Portrait Editing in the Gradio Interface

================================================ FILE: assets/docs/changelog/2024-08-19.md ================================================ ## Image Driven and Regional Control

LivePortrait
Image Drives an Image

You can now **use an image as a driving signal** to drive the source image or video! Additionally, we **have refined the driving options to support expressions, pose, lips, eyes, or all** (all is consistent with the previous default method), which we name it regional control. The control is becoming more and more precise! 🎯 > Please note that image-based driving or regional control may not perform well in certain cases. Feel free to try different options, and be patient. 😊 > [!Note] > We recognize that the project now offers more options, which have become increasingly complex, but due to our limited team capacity and resources, we haven’t fully documented them yet. We ask for your understanding and will work to improve the documentation over time. Contributions via PRs are welcome! If anyone is considering donating or sponsoring, feel free to leave a message in the GitHub Issues or Discussions. We will set up a payment account to reward the team members or support additional efforts in maintaining the project. 💖 ### CLI Usage It's very simple to use an image as a driving reference. Just set the `-d` argument to the driving image: ```bash python inference.py -s assets/examples/source/s5.jpg -d assets/examples/driving/d30.jpg ``` To change the `animation_region` option, you can use the `--animation_region` argument to `exp`, `pose`, `lip`, `eyes`, or `all`. For example, to only drive the lip region, you can run by: ```bash # only driving the lip region python inference.py -s assets/examples/source/s5.jpg -d assets/examples/driving/d0.mp4 --animation_region lip ``` ### Gradio Interface

LivePortrait
Image-driven Portrait Animation and Regional Control

### More Detailed Explanation **flag_relative_motion**: When using an image as the driving input, setting `--flag_relative_motion` to true will apply the motion deformation between the driving image and its canonical form. If set to false, the absolute motion of the driving image is used, which may amplify expression driving strength but could also cause identity leakage. This option corresponds to the `relative motion` toggle in the Gradio interface. Additionally, if both source and driving inputs are images, the output will be an image. If the source is a video and the driving input is an image, the output will be a video, with each frame driven by the image's motion. The Gradio interface automatically saves and displays the output in the appropriate format. **animation_region**: This argument offers five options: - `exp`: Only the expression of the driving input influences the source. - `pose`: Only the head pose drives the source. - `lip`: Only lip movement drives the source. - `eyes`: Only eye movement drives the source. - `all`: All motions from the driving input are applied. You can also select these options directly in the Gradio interface. **Editing the Lip Region of the Source Video to a Neutral Expression**: In response to requests for a more neutral lip region in the `Retargeting Video` of the Gradio interface, we've added a `keeping the lip silent` option. When selected, the animated video's lip region will adopt a neutral expression. However, this may cause inter-frame jitter or identity leakage, as it uses a mode similar to absolute driving. Note that the neutral expression may sometimes feature a slightly open mouth. **Others**: When both source and driving inputs are videos, the output motion may be a blend of both, due to the default setting of `--flag_relative_motion`. This option uses relative driving, where the motion offset of the current driving frame relative to the first driving frame is added to the source frame's motion. In contrast, `--no_flag_relative_motion` applies the driving frame's motion directly as the final driving motion. For CLI usage, to retain only the driving video's motion in the output, use: ```bash python inference.py --no_flag_relative_motion ``` In the Gradio interface, simply uncheck the relative motion option. Note that absolute driving may cause jitter or identity leakage in the animated video. ================================================ FILE: assets/docs/changelog/2025-01-01.md ================================================ ## 2025/01/01 **We’re thrilled that cats 🐱 are now speaking and singing across the internet!** 🎶 In this update, we’ve improved the [Animals model](https://huggingface.co/KlingTeam/LivePortrait/tree/main/liveportrait_animals/base_models_v1.1) with more data. While you might notice only a slight improvement for cats (if at all 😼), dogs have gotten a slightly better upgrade. For example, the model is now better at recognizing their mouths instead of mistaking them for noses. 🐶
Before vs. After (v1.1)
The new version (v1.1) Animals Model has been updated on [HuggingFace](https://huggingface.co/KlingTeam/LivePortrait/tree/main/liveportrait_animals/base_models_v1.1). The new version is enabled by default. > [!IMPORTANT] > Note: Make sure to update your weights to use the new version. If you prefer to use the original version, simply modify the configuration in [inference_config.py](../../../src/config/inference_config.py#L29) ```python version_animals = "" # old version # version_animals = "_v1.1" # new (v1.1) version ``` ================================================ FILE: assets/docs/directory-structure.md ================================================ ## The directory structure of `pretrained_weights` ```text pretrained_weights ├── insightface │ └── models │ └── buffalo_l │ ├── 2d106det.onnx │ └── det_10g.onnx ├── liveportrait │ ├── base_models │ │ ├── appearance_feature_extractor.pth │ │ ├── motion_extractor.pth │ │ ├── spade_generator.pth │ │ └── warping_module.pth │ ├── landmark.onnx │ └── retargeting_models │ └── stitching_retargeting_module.pth └── liveportrait_animals ├── base_models │ ├── appearance_feature_extractor.pth │ ├── motion_extractor.pth │ ├── spade_generator.pth │ └── warping_module.pth ├── retargeting_models │ └── stitching_retargeting_module.pth └── xpose.pth ``` ================================================ FILE: assets/docs/how-to-install-ffmpeg.md ================================================ ## Install FFmpeg Make sure you have `ffmpeg` and `ffprobe` installed on your system. If you don't have them installed, follow the instructions below. > [!Note] > The installation is copied from [SoVITS](https://github.com/RVC-Boss/GPT-SoVITS) 🤗 ### Conda Users ```bash conda install ffmpeg ``` ### Ubuntu/Debian Users ```bash sudo apt install ffmpeg sudo apt install libsox-dev conda install -c conda-forge 'ffmpeg<7' ``` ### Windows Users Download and place [ffmpeg.exe](https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/ffmpeg.exe) and [ffprobe.exe](https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/ffprobe.exe) in the GPT-SoVITS root. ### MacOS Users ```bash brew install ffmpeg ``` ================================================ FILE: assets/docs/speed.md ================================================ ### Speed Below are the results of inferring one frame on an RTX 4090 GPU using the native PyTorch framework with `torch.compile`: | Model | Parameters(M) | Model Size(MB) | Inference(ms) | |-----------------------------------|:-------------:|:--------------:|:-------------:| | Appearance Feature Extractor | 0.84 | 3.3 | 0.82 | | Motion Extractor | 28.12 | 108 | 0.84 | | Spade Generator | 55.37 | 212 | 7.59 | | Warping Module | 45.53 | 174 | 5.21 | | Stitching and Retargeting Modules | 0.23 | 2.3 | 0.31 | *Note: The values for the Stitching and Retargeting Modules represent the combined parameter counts and total inference time of three sequential MLP networks.* ================================================ FILE: assets/gradio/gradio_description_animate_clear.md ================================================
Step 3: Click the 🚀 Animate button below to generate, or click 🧹 Clear to erase the results
================================================ FILE: assets/gradio/gradio_description_animation.md ================================================ 🔥 To animate the source image or video with the driving video, please follow these steps:
1. In the Animation Options for Source Image or Video section, we recommend enabling the do crop (source) option if faces occupy a small portion of your source image or video.
2. In the Animation Options for Driving Video section, the relative head rotation and smooth strength options only take effect if the source input is a video.
3. Press the 🚀 Animate button and wait for a moment. Your animated video will appear in the result block. This may take a few moments. If the input is a source video, the length of the animated video is the minimum of the length of the source video and the driving video.
4. If you want to upload your own driving video, the best practice: - Crop it to a 1:1 aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-driving by checking `do crop (driving video)`. - Focus on the head area, similar to the example videos. - Minimize shoulder movement. - Make sure the first frame of driving video is a frontal face with **neutral expression**.
================================================ FILE: assets/gradio/gradio_description_retargeting.md ================================================

Retargeting and Editing Portraits

Upload a source portrait, and the eyes-open ratio and lip-open ratio will be auto-calculated. Adjust the sliders to see instant edits. Feel free to experiment! 🎨

😊 Set both target eyes-open and lip-open ratios to 0.8 to see what's going on!

================================================ FILE: assets/gradio/gradio_description_retargeting_video.md ================================================

Retargeting Video

Upload a Source Video as Retargeting Input, then drag the sliders and click the 🚗 Retargeting Video button. You can try running it multiple times.
🤐 Set target lip-open ratio to 0 to see what's going on!

================================================ FILE: assets/gradio/gradio_description_upload.md ================================================
Step 1: Upload a Source Image or Video (any aspect ratio) ⬇️
Note: Better if Source Video has the same FPS as the Driving Video.
Step 2: Upload a Driving Video (any aspect ratio) ⬇️
Tips: Focus on the head, minimize shoulder movement, neutral expression in first frame.
================================================ FILE: assets/gradio/gradio_description_upload_animal.md ================================================
Step 1: Upload a Source Animal Image (any aspect ratio) ⬇️
Step 2: Upload a Driving Pickle or Driving Video (any aspect ratio) ⬇️
Tips: Focus on the head, minimize shoulder movement, neutral expression in first frame.
================================================ FILE: assets/gradio/gradio_title.md ================================================

LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control

  Project Page      
================================================ FILE: inference.py ================================================ # coding: utf-8 """ The entrance of humans """ import os import os.path as osp import tyro import subprocess from src.config.argument_config import ArgumentConfig from src.config.inference_config import InferenceConfig from src.config.crop_config import CropConfig from src.live_portrait_pipeline import LivePortraitPipeline def partial_fields(target_class, kwargs): return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) def fast_check_ffmpeg(): try: subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) return True except: return False def fast_check_args(args: ArgumentConfig): if not osp.exists(args.source): raise FileNotFoundError(f"source info not found: {args.source}") if not osp.exists(args.driving): raise FileNotFoundError(f"driving info not found: {args.driving}") def main(): # set tyro theme tyro.extras.set_accent_color("bright_cyan") args = tyro.cli(ArgumentConfig) ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg") if osp.exists(ffmpeg_dir): os.environ["PATH"] += (os.pathsep + ffmpeg_dir) if not fast_check_ffmpeg(): raise ImportError( "FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html" ) fast_check_args(args) # specify configs for inference inference_cfg = partial_fields(InferenceConfig, args.__dict__) crop_cfg = partial_fields(CropConfig, args.__dict__) live_portrait_pipeline = LivePortraitPipeline( inference_cfg=inference_cfg, crop_cfg=crop_cfg ) # run live_portrait_pipeline.execute(args) if __name__ == "__main__": main() ================================================ FILE: inference_animals.py ================================================ # coding: utf-8 """ The entrance of animal """ import os import os.path as osp import tyro import subprocess from src.config.argument_config import ArgumentConfig from src.config.inference_config import InferenceConfig from src.config.crop_config import CropConfig from src.live_portrait_pipeline_animal import LivePortraitPipelineAnimal def partial_fields(target_class, kwargs): return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) def fast_check_ffmpeg(): try: subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) return True except: return False def fast_check_args(args: ArgumentConfig): if not osp.exists(args.source): raise FileNotFoundError(f"source info not found: {args.source}") if not osp.exists(args.driving): raise FileNotFoundError(f"driving info not found: {args.driving}") def main(): # set tyro theme tyro.extras.set_accent_color("bright_cyan") args = tyro.cli(ArgumentConfig) ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg") if osp.exists(ffmpeg_dir): os.environ["PATH"] += (os.pathsep + ffmpeg_dir) if not fast_check_ffmpeg(): raise ImportError( "FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html" ) fast_check_args(args) # specify configs for inference inference_cfg = partial_fields(InferenceConfig, args.__dict__) crop_cfg = partial_fields(CropConfig, args.__dict__) live_portrait_pipeline_animal = LivePortraitPipelineAnimal( inference_cfg=inference_cfg, crop_cfg=crop_cfg ) # run live_portrait_pipeline_animal.execute(args) if __name__ == "__main__": main() ================================================ FILE: pretrained_weights/.gitkeep ================================================ ================================================ FILE: readme.md ================================================

LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control

Jianzhu Guo 1*†Dingyun Zhang 1,2*Xiaoqiang Liu 1Zhizhou Zhong 1,3Yuan Zhang 1
Pengfei Wan 1Di Zhang 1
1 Kuaishou Technology  2 University of Science and Technology of China  3 Fudan University 
* Equal contributions Project lead

Windows one-click installer  HuggingFace online demo

arXiv link  project homepage  HF space  Featured by HelloGitHub  GitHub stars

English | 简体中文

LivePortrait showcase GIF

🔥 For more results, visit our homepage 🔥

## 🔥 Updates - **`2025/06/01`**: 🌍 Over the past year, **LivePortrait** has 🚀 become an efficient portrait-animation (humans, cats and dogs) solution adopted by major video platforms—Kuaishou, Douyin, Jianying, WeChat Channels—as well as numerous startups and creators. 🎉 - **`2025/01/01`**: 🐶 We updated a new version of the Animals model with more data, see [**here**](./assets/docs/changelog/2025-01-01.md). - **`2024/10/18`**: ❗ We have updated the versions of the `transformers` and `gradio` libraries to avoid security vulnerabilities. Details [here](https://github.com/KlingTeam/LivePortrait/pull/421/files). - **`2024/08/29`**: 📦 We update the Windows [one-click installer](https://huggingface.co/cleardusk/LivePortrait-Windows/blob/main/LivePortrait-Windows-v20240829.zip) and support auto-updates, see [changelog](https://huggingface.co/cleardusk/LivePortrait-Windows#20240829). - **`2024/08/19`**: 🖼️ We support **image driven mode** and **regional control**. For details, see [**here**](./assets/docs/changelog/2024-08-19.md). - **`2024/08/06`**: 🎨 We support **precise portrait editing** in the Gradio interface, inspired by [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait). See [**here**](./assets/docs/changelog/2024-08-06.md). - **`2024/08/05`**: 📦 Windows users can now download the [one-click installer](https://huggingface.co/cleardusk/LivePortrait-Windows/blob/main/LivePortrait-Windows-v20240806.zip) for Humans mode and **Animals mode** now! For details, see [**here**](./assets/docs/changelog/2024-08-05.md). - **`2024/08/02`**: 😸 We released a version of the **Animals model**, along with several other updates and improvements. Check out the details [**here**](./assets/docs/changelog/2024-08-02.md)! - **`2024/07/25`**: 📦 Windows users can now download the package from [HuggingFace](https://huggingface.co/cleardusk/LivePortrait-Windows/tree/main). Simply unzip and double-click `run_windows.bat` to enjoy! - **`2024/07/24`**: 🎨 We support pose editing for source portraits in the Gradio interface. We’ve also lowered the default detection threshold to increase recall. [Have fun](assets/docs/changelog/2024-07-24.md)! - **`2024/07/19`**: ✨ We support 🎞️ **portrait video editing (aka v2v)**! More to see [here](assets/docs/changelog/2024-07-19.md). - **`2024/07/17`**: 🍎 We support macOS with Apple Silicon, modified from [jeethu](https://github.com/jeethu)'s PR [#143](https://github.com/KlingTeam/LivePortrait/pull/143). - **`2024/07/10`**: 💪 We support audio and video concatenating, driving video auto-cropping, and template making to protect privacy. More to see [here](assets/docs/changelog/2024-07-10.md). - **`2024/07/09`**: 🤗 We released the [HuggingFace Space](https://huggingface.co/spaces/KlingTeam/LivePortrait), thanks to the HF team and [Gradio](https://github.com/gradio-app/gradio)! - **`2024/07/04`**: 😊 We released the initial version of the inference code and models. Continuous updates, stay tuned! - **`2024/07/04`**: 🔥 We released the [homepage](https://liveportrait.github.io) and technical report on [arXiv](https://arxiv.org/pdf/2407.03168). ## Introduction 📖 This repo, named **LivePortrait**, contains the official PyTorch implementation of our paper [LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control](https://arxiv.org/pdf/2407.03168). We are actively updating and improving this repository. If you find any bugs or have suggestions, welcome to raise issues or submit pull requests (PR) 💖. ## Getting Started 🏁 ### 1. Clone the code and prepare the environment 🛠️ > [!Note] > Make sure your system has [`git`](https://git-scm.com/), [`conda`](https://anaconda.org/anaconda/conda), and [`FFmpeg`](https://ffmpeg.org/download.html) installed. For details on FFmpeg installation, see [**how to install FFmpeg**](assets/docs/how-to-install-ffmpeg.md). ```bash git clone https://github.com/KlingTeam/LivePortrait cd LivePortrait # create env using conda conda create -n LivePortrait python=3.10 conda activate LivePortrait ``` #### For Linux 🐧 or Windows 🪟 Users [X-Pose](https://github.com/IDEA-Research/X-Pose), required by Animals mode, is a dependency that needs to be installed. The step of `Check your CUDA versions` is **optional** if you only want to run Humans mode.
Check your CUDA versions Firstly, check your current CUDA version by: ```bash nvcc -V # example versions: 11.1, 11.8, 12.1, etc. ``` Then, install the corresponding torch version. Here are examples for different CUDA versions. Visit the [PyTorch Official Website](https://pytorch.org/get-started/previous-versions) for installation commands if your CUDA version is not listed: ```bash # for CUDA 11.1 pip install torch==1.10.1+cu111 torchvision==0.11.2 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html # for CUDA 11.8 pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu118 # for CUDA 12.1 pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121 # ... ``` **Note**: On Windows systems, some higher versions of CUDA (such as 12.4, 12.6, etc.) may lead to unknown issues. You may consider downgrading CUDA to version 11.8 for stability. See the [downgrade guide](https://github.com/dimitribarbot/sd-webui-live-portrait/blob/main/assets/docs/how-to-install-xpose.md#cuda-toolkit-118) by [@dimitribarbot](https://github.com/dimitribarbot).
Finally, install the remaining dependencies: ```bash pip install -r requirements.txt ``` #### For macOS  with Apple Silicon Users The [X-Pose](https://github.com/IDEA-Research/X-Pose) dependency does not support macOS, so you can skip its installation. While Humans mode works as usual, Animals mode is not supported. Use the provided requirements file for macOS with Apple Silicon: ```bash # for macOS with Apple Silicon users pip install -r requirements_macOS.txt ``` ### 2. Download pretrained weights 📥 The easiest way to download the pretrained weights is from HuggingFace: ```bash # !pip install -U "huggingface_hub[cli]" huggingface-cli download KlingTeam/LivePortrait --local-dir pretrained_weights --exclude "*.git*" "README.md" "docs" ``` If you cannot access to Huggingface, you can use [hf-mirror](https://hf-mirror.com/) to download: ```bash # !pip install -U "huggingface_hub[cli]" export HF_ENDPOINT=https://hf-mirror.com huggingface-cli download KlingTeam/LivePortrait --local-dir pretrained_weights --exclude "*.git*" "README.md" "docs" ``` Alternatively, you can download all pretrained weights from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn). Unzip and place them in `./pretrained_weights`. Ensuring the directory structure is as or contains [**this**](assets/docs/directory-structure.md). ### 3. Inference 🚀 #### Fast hands-on (humans) 👤 ```bash # For Linux and Windows users python inference.py # For macOS users with Apple Silicon (Intel is not tested). NOTE: this maybe 20x slower than RTX 4090 PYTORCH_ENABLE_MPS_FALLBACK=1 python inference.py ``` If the script runs successfully, you will get an output mp4 file named `animations/s6--d0_concat.mp4`. This file includes the following results: driving video, input image or video, and generated result.

image

Or, you can change the input by specifying the `-s` and `-d` arguments: ```bash # source input is an image python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4 # source input is a video ✨ python inference.py -s assets/examples/source/s13.mp4 -d assets/examples/driving/d0.mp4 # more options to see python inference.py -h ``` #### Fast hands-on (animals) 🐱🐶 Animals mode is ONLY tested on Linux and Windows with NVIDIA GPU. You need to build an OP named `MultiScaleDeformableAttention` first (refer to the Check your CUDA versions if needed), which is used by [X-Pose](https://github.com/IDEA-Research/X-Pose), a general keypoint detection framework. ```bash cd src/utils/dependencies/XPose/models/UniPose/ops python setup.py build install cd - # equal to cd ../../../../../../../ ``` Then ```bash python inference_animals.py -s assets/examples/source/s39.jpg -d assets/examples/driving/wink.pkl --driving_multiplier 1.75 --no_flag_stitching ``` If the script runs successfully, you will get an output mp4 file named `animations/s39--wink_concat.mp4`.

image

#### Driving video auto-cropping 📢📢📢 > [!IMPORTANT] > To use your own driving video, we **recommend**: ⬇️ > - Crop it to a **1:1** aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-cropping by `--flag_crop_driving_video`. > - Focus on the head area, similar to the example videos. > - Minimize shoulder movement. > - Make sure the first frame of driving video is a frontal face with **neutral expression**. Below is an auto-cropping case by `--flag_crop_driving_video`: ```bash python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d13.mp4 --flag_crop_driving_video ``` If you find the results of auto-cropping is not well, you can modify the `--scale_crop_driving_video`, `--vy_ratio_crop_driving_video` options to adjust the scale and offset, or do it manually. #### Motion template making You can also use the auto-generated motion template files ending with `.pkl` to speed up inference, and **protect privacy**, such as: ```bash python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d5.pkl # portrait animation python inference.py -s assets/examples/source/s13.mp4 -d assets/examples/driving/d5.pkl # portrait video editing ``` ### 4. Gradio interface 🤗 We also provide a Gradio interface for a better experience, just run by: ```bash # For Linux and Windows users (and macOS with Intel??) python app.py # humans mode # For macOS with Apple Silicon users, Intel not supported, this maybe 20x slower than RTX 4090 PYTORCH_ENABLE_MPS_FALLBACK=1 python app.py # humans mode ``` We also provide a Gradio interface of animals mode, which is only tested on Linux with NVIDIA GPU: ```bash python app_animals.py # animals mode 🐱🐶 ``` You can specify the `--server_port`, `--share`, `--server_name` arguments to satisfy your needs! 🚀 We also provide an acceleration option `--flag_do_torch_compile`. The first-time inference triggers an optimization process (about one minute), making subsequent inferences 20-30% faster. Performance gains may vary with different CUDA versions. ```bash # enable torch.compile for faster inference python app.py --flag_do_torch_compile ``` **Note**: This method is not supported on Windows and macOS. **Or, try it out effortlessly on [HuggingFace](https://huggingface.co/spaces/KlingTeam/LivePortrait) 🤗** ### 5. Inference speed evaluation 🚀🚀🚀 We have also provided a script to evaluate the inference speed of each module: ```bash # For NVIDIA GPU python speed.py ``` The results are [**here**](./assets/docs/speed.md). ## Community Resources 🤗 Discover the invaluable resources contributed by our community to enhance your LivePortrait experience. ### Community-developed Projects | Repo | Description | Author / Links | |------|------|--------| | [**ditto-talkinghead**](https://github.com/antgroup/ditto-talkinghead) | Real-time audio-driven talking head. | [ArXiv](https://arxiv.org/abs/2411.19509), [Homepage](https://digital-avatar.github.io/ai/Ditto/) | | [**FasterLivePortrait**](https://github.com/warmshao/FasterLivePortrait) | Faster real-time version using TensorRT. | [@warmshao](https://github.com/warmshao) | | [**AdvancedLivePortrait-WebUI**](https://github.com/jhj0517/AdvancedLivePortrait-WebUI) | Dedicated gradio based WebUI started from [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait). | [@jhj0517](https://github.com/jhj0517) | | [**FacePoke**](https://github.com/jbilcke-hf/FacePoke) | A real-time head transformation app, controlled by your mouse! | [@jbilcke-hf](https://github.com/jbilcke-hf) | | [**FaceFusion**](https://github.com/facefusion/facefusion) | FaceFusion 3.0 integregates LivePortrait as `expression_restorer` and `face_editor` processors. | [@henryruhs](https://github.com/henryruhs) | | [**sd-webui-live-portrait**](https://github.com/dimitribarbot/sd-webui-live-portrait) | WebUI extension of LivePortrait, adding atab to the original Stable Diffusion WebUI to benefit from LivePortrait features. | [@dimitribarbot](https://github.com/dimitribarbot) | | [**ComfyUI-LivePortraitKJ**](https://github.com/kijai/ComfyUI-LivePortraitKJ) | A ComfyUI node to use LivePortrait, with MediaPipe as as an alternative to Insightface. | [@kijai](https://github.com/kijai) | | [**ComfyUI-AdvancedLivePortrait**](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait) | A faster ComfyUI node with real-time preview that has inspired many other community-developed tools and projects. | [@PowerHouseMan](https://github.com/PowerHouseMan) | | [**comfyui-liveportrait**](https://github.com/shadowcz007/comfyui-liveportrait) | A ComfyUI node to use LivePortrait, supporting multi-faces, expression interpolation etc, with a [tutorial](https://www.bilibili.com/video/BV1JW421R7sP). | [@shadowcz007](https://github.com/shadowcz007) | ### Playgrounds, 🤗 HuggingFace Spaces and Others - [FacePoke Space](https://huggingface.co/spaces/jbilcke-hf/FacePoke) - [Expression Editor Space](https://huggingface.co/spaces/fffiloni/expression-editor) - [Expression Editor Replicate](https://replicate.com/fofr/expression-editor) - [Face Control Realtime Demo](https://fal.ai/demos/face-control) on FAL - [Replicate Playground](https://replicate.com/fofr/live-portrait) - Nuke can use LivePortrait through CompyUI node, details [here](https://x.com/bilawalsidhu/status/1837349806475276338) - LivePortrait lives on [Poe](https://poe.com/LivePortrait) ### Video Tutorials - [Workflow of LivePortrait Video to Video](https://youtu.be/xfzK_6cTs58?si=aYjgypeJBkhc46VL) by [@curiousrefuge](https://www.youtube.com/@curiousrefuge) - [Google Colab tutorial](https://youtu.be/59Y9ePAXTp0?si=KzEWhklBlporW7D8) by [@Planet Ai](https://www.youtube.com/@planetai217) - [Paper reading](https://youtu.be/fD0P6UWSu8I?si=Vn5wxUa8qSu1jv4l) by [@TwoMinutePapers](https://www.youtube.com/@TwoMinutePapers) - [ComfyUI Advanced LivePortrait](https://youtu.be/q0Vf-ZZsbzI?si=nbs3npleH-dVCt28) by [TutoView](https://www.youtube.com/@TutoView) - [LivePortarit exploration](https://www.youtube.com/watch?v=vsvlbTEqgXQ) and [A deep dive into LivePortrait](https://youtu.be/cucaEEDYmsw?si=AtPaDWc5G-a4E8dD) by [TheoreticallyMedia](https://www.youtube.com/@TheoreticallyMedia) - [LivePortrait hands-on tutorial](https://www.youtube.com/watch?v=uyjSTAOY7yI) by [@AI Search](https://www.youtube.com/@theAIsearch) - [ComfyUI tutorial](https://www.youtube.com/watch?v=8-IcDDmiUMM) by [@Sebastian Kamph](https://www.youtube.com/@sebastiankamph) - A [tutorial](https://www.bilibili.com/video/BV1cf421i7Ly) on BiliBili And so MANY amazing contributions from our community, too many to list them all 💖 ## Acknowledgements 💐 We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) and [X-Pose](https://github.com/IDEA-Research/X-Pose) repositories, for their open research and contributions. ## Ethics Considerations 🛡️ Portrait animation technologies come with social risks, particularly the potential for misuse in creating deepfakes. To mitigate these risks, it’s crucial to follow ethical guidelines and adopt responsible usage practices. At present, the synthesized results contain visual artifacts that may help in detecting deepfakes. Please note that we do not assume any legal responsibility for the use of the results generated by this project. ## Citation 💖 If you find LivePortrait useful for your project or research, welcome to 🌟 this repo and cite our work using the following BibTeX: ```bibtex @article{guo2024liveportrait, title = {LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control}, author = {Guo, Jianzhu and Zhang, Dingyun and Liu, Xiaoqiang and Zhong, Zhizhou and Zhang, Yuan and Wan, Pengfei and Zhang, Di}, journal = {arXiv preprint arXiv:2407.03168}, year = {2024} } ``` *Long live in arXiv.* ## Contact 📧 [**Jianzhu Guo (郭建珠)**](https://guojianzhu.com); **guojianzhu1994@gmail.com** ## Star History 🌟
Click to view Star chart

Star History Chart

================================================ FILE: readme_zh_cn.md ================================================

LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control

郭建珠 1*†张丁芸 1,2*刘晓强 1钟智舟 1,3张渊 1万鹏飞 1张迪 1
1 快手科技  2 中国科学技术大学  3 复旦大学 
* Equal contributions Project lead

Windows 一键安装包  HuggingFace 在线体验

arXiv link  project homepage  HF space  Featured by HelloGitHub  GitHub stars

English | 简体中文

LivePortrait 效果展示 GIF

🔥 更多效果,请访问我们的 主页 🔥

## 🔥 更新日志 - **`2025/06/01`**:🌍 过去一年里,LivePortrait 🚀 已成为高效的人像与宠物(猫狗)动画解决方案,被快手、抖音、剪映、视频号等主流视频平台,以及众多初创公司和创作者所采用。🎉 - **`2025/01/01`**:🐶 我们更新了一版动物模型(使用了更多动物数据),具体查看[**这里**](./assets/docs/changelog/2025-01-01.md). - **`2024/10/18`**:❗ 我们更新了`transformers`,`gradio`库的版本避免安全漏洞,具体查看[这里](https://github.com/KlingTeam/LivePortrait/pull/421/files). - **`2024/08/29`**:📦 我们更新了Windows[一键安装程序](https://huggingface.co/cleardusk/LivePortrait-Windows/blob/main/LivePortrait-Windows-v20240829.zip)并支持自动更新, 详情建[这里](https://huggingface.co/cleardusk/LivePortrait-Windows#20240829)。 - **`2024/08/19`**:🖼️ 我们支持了**图像驱动模式**和**区域控制**。详情请见[**这里**](./assets/docs/changelog/2024-08-19.md)。 - **`2024/08/06`**:🎨 我们在Gradio界面支持**精确的人像编辑**, 受到[ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait)启发。详见[**这里**](./assets/docs/changelog/2024-08-06.md)。 - **`2024/08/05`**:📦 Windows用户现在可以下载[一键安装程序](https://huggingface.co/cleardusk/LivePortrait-Windows/blob/main/LivePortrait-Windows-v20240806.zip),支持**人类模式**和**动物模式**!详情见[**这里**](./assets/docs/changelog/2024-08-05.md)。 - **`2024/08/02`**:😸 我们发布了**动物模型**版本,以及其他一些更新和改进。查看详情[**这里**](./assets/docs/changelog/2024-08-02.md)! - **`2024/07/25`**:📦 Windows用户现在可以从 [HuggingFace](https://huggingface.co/cleardusk/LivePortrait-Windows/tree/main) 或 [百度云](https://pan.baidu.com/s/1FWsWqKe0eNfXrwjEhhCqlw?pwd=86q2) 下载软件包。解压并双击`run_windows.bat`即可享受! - **`2024/07/24`**:🎨 我们在Gradio界面支持源人像的姿势编辑。我们还降低了默认检测阈值以增加召回率。[玩得开心](assets/docs/changelog/2024-07-24.md)! - **`2024/07/19`**:✨ 我们支持🎞️ **人像视频编辑(aka v2v)**!更多信息见[**这里**](assets/docs/changelog/2024-07-19.md)。 - **`2024/07/17`**:🍎 我们支持macOS搭载Apple Silicon,修改来自 [jeethu](https://github.com/jeethu) 的PR [#143](https://github.com/KlingTeam/LivePortrait/pull/143) 。 - **`2024/07/10`**:💪我们支持音频和视频拼接、驱动视频自动裁剪以及制作模板以保护隐私。更多信息见[这里](assets/docs/changelog/2024-07-10.md)。 - **`2024/07/09`**:🤗 我们发布了[HuggingFace Space](https://huggingface.co/spaces/KlingTeam/LivePortrait),感谢HF团队和[Gradio](https://github.com/gradio-app/gradio)! - **`2024/07/04`**:😊 我们发布了初始版本的推理代码和模型。持续更新,敬请关注! - **`2024/07/04`**:🔥 我们发布了[主页](https://liveportrait.github.io)和在[arXiv](https://arxiv.org/pdf/2407.03168)上的技术报告。 ## 介绍 📖 此仓库名为**LivePortrait**,包含我们论文([LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control](https://arxiv.org/pdf/2407.03168))的官方PyTorch实现。 我们正在积极更新和改进此仓库。如果您发现任何错误或有建议,欢迎提出问题或提交合并请求💖。 ## 上手指南 🏁 ### 1. 克隆代码和安装运行环境 🛠️ > [!Note] > 确保您的系统已安装[`git`](https://git-scm.com/)、[`conda`](https://anaconda.org/anaconda/conda)和[`FFmpeg`](https://ffmpeg.org/download.html)。有关FFmpeg安装的详细信息,见[**如何安装FFmpeg**](assets/docs/how-to-install-ffmpeg.md)。 ```bash git clone https://github.com/KlingTeam/LivePortrait cd LivePortrait # 使用conda创建环境 conda create -n LivePortrait python=3.10 conda activate LivePortrait ``` #### 对于Linux或Windows用户 [X-Pose](https://github.com/IDEA-Research/X-Pose)需要您的`torch`版本与CUDA版本兼容。 首先,通过以下命令检查您当前的CUDA版本: ```bash nvcc -V # example versions: 11.1, 11.8, 12.1, etc. ``` 然后,安装相应版本的torch。以下是不同CUDA版本的示例。如果您的CUDA版本未列出,请访问[PyTorch官方网站](https://pytorch.org/get-started/previous-versions)获取安装命令: ```bash # for CUDA 11.1 pip install torch==1.10.1+cu111 torchvision==0.11.2 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html # for CUDA 11.8 pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu118 # for CUDA 12.1 pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121 # ... ``` **注意**:在Windows系统上,一些过高版本的CUDA(12.4、12.6等)可能会导致未知的问题,您可以考虑降低您的CUDA版本到11.8,这是我们测试的一个较为稳定的版本。降级方法可以参考 [@dimitribarbot](https://github.com/dimitribarbot) 提供的[文档](https://github.com/dimitribarbot/sd-webui-live-portrait/blob/main/assets/docs/how-to-install-xpose.md#cuda-toolkit-118). 最后,安装其余依赖项: ```bash pip install -r requirements.txt ``` #### 对于搭载Apple Silicon的macOS用户 [X-Pose](https://github.com/IDEA-Research/X-Pose)依赖项不支持macOS,因此您可以跳过其安装。人类模式照常工作,但不支持动物模式。使用为搭载Apple Silicon的macOS提供的requirements文件: ```bash # 对于搭载Apple Silicon的macOS用户 pip install -r requirements_macOS.txt ``` ### 2. 下载预训练权重(Pretrained weights) 📥 从HuggingFace下载预训练权重的最简单方法是: ```bash # !pip install -U "huggingface_hub[cli]" huggingface-cli download KlingTeam/LivePortrait --local-dir pretrained_weights --exclude "*.git*" "README.md" "docs" ``` 若您不能访问HuggingFace平台,你可以访问其镜像网站[hf-mirror](https://hf-mirror.com/)进行下载操作: ```bash # !pip install -U "huggingface_hub[cli]" export HF_ENDPOINT=https://hf-mirror.com huggingface-cli download KlingTeam/LivePortrait --local-dir pretrained_weights --exclude "*.git*" "README.md" "docs" ``` 或者,您可以从[Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib)或[百度云](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn)(进行中)下载所有预训练权重。解压并将它们放置在`./pretrained_weights`目录下。 确保目录结构如所示包含[**本仓库该路径**](assets/docs/directory-structure.md)其中展示的内容。 ### 3. 推理 🚀 #### 快速上手(人类模型)👤 ```bash # 对于Linux和Windows用户 python inference.py # 对于搭载Apple Silicon的macOS用户(Intel未测试)。注意:这可能比RTX 4090慢20倍 PYTORCH_ENABLE_MPS_FALLBACK=1 python inference.py ``` 如果脚本成功运行,您将得到一个名为`animations/s6--d0_concat.mp4`的输出mp4文件。此文件包含以下结果:驱动视频、输入图像或视频以及生成结果。

image

或者,您可以通过指定`-s`和`-d`参数来更改输入: ```bash # 源输入是图像 python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4 # 源输入是视频 ✨ python inference.py -s assets/examples/source/s13.mp4 -d assets/examples/driving/d0.mp4 # 更多选项请见 python inference.py -h ``` #### 快速上手(动物模型) 🐱🐶 动物模式仅在Linux和Windows上经过测试,并且需要NVIDIA GPU。 您需要首先构建一个名为`MultiScaleDeformableAttention`的OP,该OP由[X-Pose](https://github.com/IDEA-Research/X-Pose)使用,这是一个通用的关键点检测框架。 ```bash cd src/utils/dependencies/XPose/models/UniPose/ops python setup.py build install cd - # 等同于 cd ../../../../../../../ ``` 然后执行 ```bash python inference_animals.py -s assets/examples/source/s39.jpg -d assets/examples/driving/wink.pkl --driving_multiplier 1.75 --no_flag_stitching ``` 如果脚本成功运行,您将得到一个名为`animations/s39--wink_concat.mp4`的输出mp4文件。

image

#### 驱动视频自动裁剪 📢📢📢 > [!IMPORTANT] > 使用您自己的驱动视频时,我们**推荐**: ⬇️ > > - 将其裁剪为**1:1**的宽高比(例如,512x512或256x256像素),或通过`--flag_crop_driving_video`启用自动裁剪。 > - 专注于头部区域,类似于示例视频。 > - 最小化肩部运动。 > - 确保驱动视频的第一帧是具有**中性表情**的正面面部。 以下是通过`--flag_crop_driving_video`自动裁剪的示例: ```bash python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d13.mp4 --flag_crop_driving_video ``` 如果自动裁剪的结果不理想,您可以修改`--scale_crop_driving_video`、`--vy_ratio_crop_driving_video`选项来调整比例和偏移,或者手动进行调整。 #### 动作模板制作 您也可以使用以`.pkl`结尾的自动生成的动作模板文件来加快推理速度,并**保护隐私**,例如: ```bash python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d5.pkl # 人像动画 python inference.py -s assets/examples/source/s13.mp4 -d assets/examples/driving/d5.pkl # 人像视频编辑 ``` ### 4. Gradio 界面 🤗 我们还提供了Gradio界面 ,以获得更好的体验,只需运行: ```bash # 对于Linux和Windows用户(以及搭载Intel的macOS??) python app.py # 人类模型模式 # 对于搭载Apple Silicon的macOS用户,不支持Intel,这可能比RTX 4090慢20倍 PYTORCH_ENABLE_MPS_FALLBACK=1 python app.py # 人类模型模式 ``` 我们还为动物模式提供了Gradio界面,这仅在Linux上经过NVIDIA GPU测试: ```bash python app_animals.py # animals mode 🐱🐶 ``` 您可以指定`--server_port`、`--share`、`--server_name`参数以满足您的需求! 🚀我们还提供了一个加速选项`--flag_do_torch_compile`。第一次推理触发优化过程(约一分钟),使后续推理速度提高20-30%。不同CUDA版本的性能提升可能有所不同。 ```bash # 启用torch.compile以进行更快的推理 python app.py --flag_do_torch_compile ``` **注意**:此方法在Windows和macOS上不受支持。 **或者,在[HuggingFace](https://huggingface.co/spaces/KlingTeam/LivePortrait)上轻松尝试**🤗。 ### 5. 推理速度预估 🚀🚀🚀 我们还提供了一个脚本来评估每个模块的推理速度: ```bash # 对于NVIDIA GPU python speed.py ``` 结果在[**本仓库该文件展示**](./assets/docs/speed.md). ## 社区资源 🤗 ### 社区项目 | 仓库 | 描述 | 作者 / 链接 | |------|------|--------| | [**ditto-talkinghead**](https://github.com/antgroup/ditto-talkinghead) | 实时音频驱动。 | [论文](https://arxiv.org/abs/2411.19509), [主页](https://digital-avatar.github.io/ai/Ditto/) | | [**FasterLivePortrait**](https://github.com/warmshao/FasterLivePortrait) | 基于TensorRT加速更快的实时版本。 | [@warmshao](https://github.com/warmshao) | | [**AdvancedLivePortrait-WebUI**](https://github.com/jhj0517/AdvancedLivePortrait-WebUI) | Dedicated gradio based WebUI started from [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait) | [@jhj0517](https://github.com/jhj0517) | | [**FacePoke**](https://github.com/jbilcke-hf/FacePoke) | 一个实时的头部姿态表情控制应用,通过鼠标控制! | [@jbilcke-hf](https://github.com/jbilcke-hf) | | [**FaceFusion**](https://github.com/facefusion/facefusion) | FaceFusion 3.0 集成了 LivePortrait 作为 `expression_restorer` 和 `face_editor` 处理器。 | [@henryruhs](https://github.com/henryruhs) | | [**sd-webui-live-portrait**](https://github.com/dimitribarbot/sd-webui-live-portrait) | LivePortrait 的 WebUI 扩展,在原版 Stable Diffusion WebUI 中添加了一个标签以使用 LivePortrait 的功能。 | [@dimitribarbot](https://github.com/dimitribarbot) | | [**ComfyUI-LivePortraitKJ**](https://github.com/kijai/ComfyUI-LivePortraitKJ) | 一个用于 LivePortrait 的 ComfyUI 节点,使用 MediaPipe 作为 Insightface 的替代方案。 | [@kijai](https://github.com/kijai) | | [**ComfyUI-AdvancedLivePortrait**](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait) | 一个更快的 ComfyUI 节点,具有实时预览功能,启发了许多社区开发的工具和项目。 | [@PowerHouseMan](https://github.com/PowerHouseMan) | | [**comfyui-liveportrait**](https://github.com/shadowcz007/comfyui-liveportrait) | 一个用于 LivePortrait 的 ComfyUI 节点,支持多面部、表情插值等功能,并有[教程](https://www.bilibili.com/video/BV1JW421R7sP)。 | [@shadowcz007](https://github.com/shadowcz007) | ### Playgrounds, 🤗 HuggingFace Spaces 以及其它 - [FacePoke Space](https://huggingface.co/spaces/jbilcke-hf/FacePoke) - [Expression Editor Space](https://huggingface.co/spaces/fffiloni/expression-editor) - [Expression Editor Replicate](https://replicate.com/fofr/expression-editor) - [Face Control Realtime Demo](https://fal.ai/demos/face-control) on FAL - [Replicate Playground](https://replicate.com/fofr/live-portrait) - Nuke 可以通过 CompyUI 节点使用 LivePortrait,详情见[这里](https://x.com/bilawalsidhu/status/1837349806475276338) - LivePortrait 在 [Poe](https://poe.com/LivePortrait) 上运行 ### 视频教程 - [LivePortrait 视频转视频的工作流程](https://youtu.be/xfzK_6cTs58?si=aYjgypeJBkhc46VL) 由 [@curiousrefuge](https://www.youtube.com/@curiousrefuge) 制作 - [Google Colab 教程](https://youtu.be/59Y9ePAXTp0?si=KzEWhklBlporW7D8) 由 [@Planet Ai](https://www.youtube.com/@planetai217) 制作 - [论文解读](https://youtu.be/fD0P6UWSu8I?si=Vn5wxUa8qSu1jv4l) 由 [@TwoMinutePapers](https://www.youtube.com/@TwoMinutePapers) 制作 - [ComfyUI 高级 LivePortrait 教程](https://youtu.be/q0Vf-ZZsbzI?si=nbs3npleH-dVCt28) 由 [TutoView](https://www.youtube.com/@TutoView) 制作 - [LivePortrait 探索](https://www.youtube.com/watch?v=vsvlbTEqgXQ) 和 [LivePortrait 深入探讨](https://youtu.be/cucaEEDYmsw?si=AtPaDWc5G-a4E8dD) 由 [TheoreticallyMedia](https://www.youtube.com/@TheoreticallyMedia) 制作 - [LivePortrait 实战教程](https://www.youtube.com/watch?v=uyjSTAOY7yI) 由 [@AI Search](https://www.youtube.com/@theAIsearch) 制作 - [ComfyUI 教程](https://www.youtube.com/watch?v=8-IcDDmiUMM) 由 [@Sebastian Kamph](https://www.youtube.com/@sebastiankamph) 制作 - B 站上的[教程](https://www.bilibili.com/video/BV1cf421i7Ly) 还有来自社区的无数令人惊叹的贡献,未能一一列举 💖 ## 致谢 💐 我们要感谢[FOMM](https://github.com/AliaksandrSiarohin/first-order-model)、[Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis)、[SPADE](https://github.com/NVlabs/SPADE)、[InsightFace](https://github.com/deepinsight/insightface)和[X-Pose](https://github.com/IDEA-Research/X-Pose)仓库的的贡献者,感谢他们的开放研究和贡献。 ## 道德考量 🛡️ 肖像动画技术伴随着社会风险,特别是在创建深度伪造(deepfakes)时可能被滥用。为了减轻这些风险,遵循道德指南并采取负责任的使用实践至关重要。目前,生成的结果包含一些视觉伪影,这些伪影可能有助于检测深度伪造。请注意,我们不对本项目生成的结果的使用承担任何法律责任。 ## 引用 💖 如果您发现LivePortrait对您的研究有用,欢迎引用我们的工作,使用以下BibTeX: ```bibtex @article{guo2024liveportrait, title = {LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control}, author = {Guo, Jianzhu and Zhang, Dingyun and Liu, Xiaoqiang and Zhong, Zhizhou and Zhang, Yuan and Wan, Pengfei and Zhang, Di}, journal = {arXiv preprint arXiv:2407.03168}, year = {2024} } ``` ## 联系方式 📧 [**Jianzhu Guo (郭建珠)**](https://guojianzhu.com); **guojianzhu1994@gmail.com**; ## Star History 🌟
点击展开查看项目 Star 曲线

Star History Chart

================================================ FILE: requirements.txt ================================================ -r requirements_base.txt onnxruntime-gpu==1.18.0 transformers==4.38.0 ================================================ FILE: requirements_base.txt ================================================ numpy==1.26.4 pyyaml==6.0.1 opencv-python==4.10.0.84 scipy==1.13.1 imageio==2.34.2 lmdb==1.4.1 tqdm==4.66.4 rich==13.7.1 ffmpeg-python==0.2.0 onnx==1.16.1 scikit-image==0.24.0 albumentations==1.4.10 matplotlib==3.9.0 imageio-ffmpeg==0.5.1 tyro==0.8.5 gradio==5.1.0 pykalman==0.9.7 pillow>=10.2.0 ================================================ FILE: requirements_macOS.txt ================================================ -r requirements_base.txt --extra-index-url https://download.pytorch.org/whl/cpu torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 onnxruntime-silicon==1.16.3 ================================================ FILE: speed.py ================================================ # coding: utf-8 """ Benchmark the inference speed of each module in LivePortrait. TODO: heavy GPT style, need to refactor """ import torch torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution import yaml import time import numpy as np from src.utils.helper import load_model, concat_feat from src.config.inference_config import InferenceConfig def initialize_inputs(batch_size=1, device_id=0): """ Generate random input tensors and move them to GPU """ feature_3d = torch.randn(batch_size, 32, 16, 64, 64).to(device_id).half() kp_source = torch.randn(batch_size, 21, 3).to(device_id).half() kp_driving = torch.randn(batch_size, 21, 3).to(device_id).half() source_image = torch.randn(batch_size, 3, 256, 256).to(device_id).half() generator_input = torch.randn(batch_size, 256, 64, 64).to(device_id).half() eye_close_ratio = torch.randn(batch_size, 3).to(device_id).half() lip_close_ratio = torch.randn(batch_size, 2).to(device_id).half() feat_stitching = concat_feat(kp_source, kp_driving).half() feat_eye = concat_feat(kp_source, eye_close_ratio).half() feat_lip = concat_feat(kp_source, lip_close_ratio).half() inputs = { 'feature_3d': feature_3d, 'kp_source': kp_source, 'kp_driving': kp_driving, 'source_image': source_image, 'generator_input': generator_input, 'feat_stitching': feat_stitching, 'feat_eye': feat_eye, 'feat_lip': feat_lip } return inputs def load_and_compile_models(cfg, model_config): """ Load and compile models for inference """ appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor') motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor') warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module') spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator') stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module') models_with_params = [ ('Appearance Feature Extractor', appearance_feature_extractor), ('Motion Extractor', motion_extractor), ('Warping Network', warping_module), ('SPADE Decoder', spade_generator) ] compiled_models = {} for name, model in models_with_params: model = model.half() model = torch.compile(model, mode='max-autotune') # Optimize for inference model.eval() # Switch to evaluation mode compiled_models[name] = model retargeting_models = ['stitching', 'eye', 'lip'] for retarget in retargeting_models: module = stitching_retargeting_module[retarget].half() module = torch.compile(module, mode='max-autotune') # Optimize for inference module.eval() # Switch to evaluation mode stitching_retargeting_module[retarget] = module return compiled_models, stitching_retargeting_module def warm_up_models(compiled_models, stitching_retargeting_module, inputs): """ Warm up models to prepare them for benchmarking """ print("Warm up start!") with torch.no_grad(): for _ in range(10): compiled_models['Appearance Feature Extractor'](inputs['source_image']) compiled_models['Motion Extractor'](inputs['source_image']) compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source']) compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required stitching_retargeting_module['stitching'](inputs['feat_stitching']) stitching_retargeting_module['eye'](inputs['feat_eye']) stitching_retargeting_module['lip'](inputs['feat_lip']) print("Warm up end!") def measure_inference_times(compiled_models, stitching_retargeting_module, inputs): """ Measure inference times for each model """ times = {name: [] for name in compiled_models.keys()} times['Stitching and Retargeting Modules'] = [] overall_times = [] with torch.no_grad(): for _ in range(100): torch.cuda.synchronize() overall_start = time.time() start = time.time() compiled_models['Appearance Feature Extractor'](inputs['source_image']) torch.cuda.synchronize() times['Appearance Feature Extractor'].append(time.time() - start) start = time.time() compiled_models['Motion Extractor'](inputs['source_image']) torch.cuda.synchronize() times['Motion Extractor'].append(time.time() - start) start = time.time() compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source']) torch.cuda.synchronize() times['Warping Network'].append(time.time() - start) start = time.time() compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required torch.cuda.synchronize() times['SPADE Decoder'].append(time.time() - start) start = time.time() stitching_retargeting_module['stitching'](inputs['feat_stitching']) stitching_retargeting_module['eye'](inputs['feat_eye']) stitching_retargeting_module['lip'](inputs['feat_lip']) torch.cuda.synchronize() times['Stitching and Retargeting Modules'].append(time.time() - start) overall_times.append(time.time() - overall_start) return times, overall_times def print_benchmark_results(compiled_models, stitching_retargeting_module, retargeting_models, times, overall_times): """ Print benchmark results with average and standard deviation of inference times """ average_times = {name: np.mean(times[name]) * 1000 for name in times.keys()} std_times = {name: np.std(times[name]) * 1000 for name in times.keys()} for name, model in compiled_models.items(): num_params = sum(p.numel() for p in model.parameters()) num_params_in_millions = num_params / 1e6 print(f"Number of parameters for {name}: {num_params_in_millions:.2f} M") for index, retarget in enumerate(retargeting_models): num_params = sum(p.numel() for p in stitching_retargeting_module[retarget].parameters()) num_params_in_millions = num_params / 1e6 print(f"Number of parameters for part_{index} in Stitching and Retargeting Modules: {num_params_in_millions:.2f} M") for name, avg_time in average_times.items(): std_time = std_times[name] print(f"Average inference time for {name} over 100 runs: {avg_time:.2f} ms (std: {std_time:.2f} ms)") def main(): """ Main function to benchmark speed and model parameters """ # Load configuration cfg = InferenceConfig() model_config_path = cfg.models_config with open(model_config_path, 'r') as file: model_config = yaml.safe_load(file) # Sample input tensors inputs = initialize_inputs(device_id = cfg.device_id) # Load and compile models compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config) # Warm up models warm_up_models(compiled_models, stitching_retargeting_module, inputs) # Measure inference times times, overall_times = measure_inference_times(compiled_models, stitching_retargeting_module, inputs) # Print benchmark results print_benchmark_results(compiled_models, stitching_retargeting_module, ['stitching', 'eye', 'lip'], times, overall_times) if __name__ == "__main__": main() ================================================ FILE: src/config/__init__.py ================================================ ================================================ FILE: src/config/argument_config.py ================================================ # coding: utf-8 """ All configs for user """ from dataclasses import dataclass import tyro from typing_extensions import Annotated from typing import Optional, Literal from .base_config import PrintableConfig, make_abs_path @dataclass(repr=False) # use repr from PrintableConfig class ArgumentConfig(PrintableConfig): ########## input arguments ########## source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s0.jpg') # path to the source portrait (human/animal) or video (human) driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format) output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video ########## inference arguments ########## flag_use_half_precision: bool = True # whether to use half precision (FP16). If black boxes appear, it might be due to GPU incompatibility; set to False. flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video device_id: int = 0 # gpu device id flag_force_cpu: bool = False # force cpu inference, WIP! flag_normalize_lip: bool = False # whether to let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False flag_source_video_eye_retargeting: bool = False # when the input is a source video, whether to let the eye-open scalar of each frame to be the same as the first source frame before the animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False, may cause the inter-frame jittering flag_eye_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the eyes-open ratio of each driving frame to the source image or the corresponding source frame flag_lip_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the lip-open ratio of each driving frame to the source image or the corresponding source frame flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large or the source image is an animal flag_relative_motion: bool = True # whether to use relative motion flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space flag_do_crop: bool = True # whether to crop the source portrait or video to the face-cropping space driving_option: Literal["expression-friendly", "pose-friendly"] = "expression-friendly" # "expression-friendly" or "pose-friendly"; "expression-friendly" would adapt the driving motion with the global multiplier, and could be used when the source is a human image driving_multiplier: float = 1.0 # be used only when driving_option is "expression-friendly" driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy audio_priority: Literal['source', 'driving'] = 'driving' # whether to use the audio from source or driving video animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "all" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose, "all" means all regions ########## source crop arguments ########## det_thresh: float = 0.15 # detection threshold scale: float = 2.3 # the ratio of face area is smaller if scale is larger vx_ratio: float = 0 # the ratio to move the face to left or right in cropping space vy_ratio: float = -0.125 # the ratio to move the face to up or down in cropping space flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True source_max_dim: int = 1280 # the max dim of height and width of source image or video, you can change it to a larger number, e.g., 1920 source_division: int = 2 # make sure the height and width of source image or video can be divided by this number ########## driving crop arguments ########## scale_crop_driving_video: float = 2.2 # scale factor for cropping driving video vx_ratio_crop_driving_video: float = 0. # adjust y offset vy_ratio_crop_driving_video: float = -0.1 # adjust x offset ########## gradio arguments ########## server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890 # port for gradio server share: bool = False # whether to share the server to public server_name: Optional[str] = "127.0.0.1" # set the local server name, "0.0.0.0" to broadcast all flag_do_torch_compile: bool = False # whether to use torch.compile to accelerate generation gradio_temp_dir: Optional[str] = None # directory to save gradio temp files ================================================ FILE: src/config/base_config.py ================================================ # coding: utf-8 """ pretty printing class """ from __future__ import annotations import os.path as osp from typing import Tuple def make_abs_path(fn): return osp.join(osp.dirname(osp.realpath(__file__)), fn) class PrintableConfig: # pylint: disable=too-few-public-methods """Printable Config defining str function""" def __repr__(self): lines = [self.__class__.__name__ + ":"] for key, val in vars(self).items(): if isinstance(val, Tuple): flattened_val = "[" for item in val: flattened_val += str(item) + "\n" flattened_val = flattened_val.rstrip("\n") val = flattened_val + "]" lines += f"{key}: {str(val)}".split("\n") return "\n ".join(lines) ================================================ FILE: src/config/crop_config.py ================================================ # coding: utf-8 """ parameters used for crop faces """ from dataclasses import dataclass from .base_config import PrintableConfig, make_abs_path @dataclass(repr=False) # use repr from PrintableConfig class CropConfig(PrintableConfig): insightface_root: str = make_abs_path("../../pretrained_weights/insightface") landmark_ckpt_path: str = make_abs_path("../../pretrained_weights/liveportrait/landmark.onnx") xpose_config_file_path: str = make_abs_path("../utils/dependencies/XPose/config_model/UniPose_SwinT.py") xpose_embedding_cache_path: str = make_abs_path('../utils/resources/clip_embedding') xpose_ckpt_path: str = make_abs_path("../../pretrained_weights/liveportrait_animals/xpose.pth") device_id: int = 0 # gpu device id flag_force_cpu: bool = False # force cpu inference, WIP det_thresh: float = 0.1 # detection threshold ########## source image or video cropping option ########## dsize: int = 512 # crop size scale: float = 2.3 # scale factor vx_ratio: float = 0 # vx ratio vy_ratio: float = -0.125 # vy ratio +up, -down max_face_num: int = 0 # max face number, 0 mean no limit flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True animal_face_type: str = "animal_face_9" # animal_face_68 -> 68 landmark points, animal_face_9 -> 9 landmarks ########## driving video auto cropping option ########## scale_crop_driving_video: float = 2.2 # 2.0 # scale factor for cropping driving video vx_ratio_crop_driving_video: float = 0.0 # adjust y offset vy_ratio_crop_driving_video: float = -0.1 # adjust x offset direction: str = "large-small" # direction of cropping ================================================ FILE: src/config/inference_config.py ================================================ # coding: utf-8 """ config dataclass used for inference """ import cv2 from numpy import ndarray import pickle as pkl from dataclasses import dataclass, field from typing import Literal, Tuple from .base_config import PrintableConfig, make_abs_path def load_lip_array(): with open(make_abs_path('../utils/resources/lip_array.pkl'), 'rb') as f: return pkl.load(f) @dataclass(repr=False) # use repr from PrintableConfig class InferenceConfig(PrintableConfig): # HUMAN MODEL CONFIG, NOT EXPORTED PARAMS models_config: str = make_abs_path('./models.yaml') # portrait animation config checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint of F checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint pf M checkpoint_G: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/spade_generator.pth') # path to checkpoint of G checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint of W checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip # ANIMAL MODEL CONFIG, NOT EXPORTED PARAMS # version_animals = "" # old version version_animals = "_v1.1" # new (v1.1) version checkpoint_F_animal: str = make_abs_path(f'../../pretrained_weights/liveportrait_animals/base_models{version_animals}/appearance_feature_extractor.pth') # path to checkpoint of F checkpoint_M_animal: str = make_abs_path(f'../../pretrained_weights/liveportrait_animals/base_models{version_animals}/motion_extractor.pth') # path to checkpoint pf M checkpoint_G_animal: str = make_abs_path(f'../../pretrained_weights/liveportrait_animals/base_models{version_animals}/spade_generator.pth') # path to checkpoint of G checkpoint_W_animal: str = make_abs_path(f'../../pretrained_weights/liveportrait_animals/base_models{version_animals}/warping_module.pth') # path to checkpoint of W checkpoint_S_animal: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip, NOTE: use human temporarily! # EXPORTED PARAMS flag_use_half_precision: bool = True flag_crop_driving_video: bool = False device_id: int = 0 flag_normalize_lip: bool = True flag_source_video_eye_retargeting: bool = False flag_eye_retargeting: bool = False flag_lip_retargeting: bool = False flag_stitching: bool = True flag_relative_motion: bool = True flag_pasteback: bool = True flag_do_crop: bool = True flag_do_rot: bool = True flag_force_cpu: bool = False flag_do_torch_compile: bool = False driving_option: str = "pose-friendly" # "expression-friendly" or "pose-friendly" driving_multiplier: float = 1.0 driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy source_max_dim: int = 1280 # the max dim of height and width of source image or video source_division: int = 2 # make sure the height and width of source image or video can be divided by this number animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "all" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose # NOT EXPORTED PARAMS lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip source_video_eye_retargeting_threshold: float = 0.18 # threshold for eyes retargeting if the input is a source video anchor_frame: int = 0 # TO IMPLEMENT input_shape: Tuple[int, int] = (256, 256) # input shape output_format: Literal['mp4', 'gif'] = 'mp4' # output video format crf: int = 15 # crf for output video output_fps: int = 25 # default output fps mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR)) lip_array: ndarray = field(default_factory=load_lip_array) size_gif: int = 256 # default gif size, TO IMPLEMENT ================================================ FILE: src/config/models.yaml ================================================ model_params: appearance_feature_extractor_params: # the F in the paper image_channel: 3 block_expansion: 64 num_down_blocks: 2 max_features: 512 reshape_channel: 32 reshape_depth: 16 num_resblocks: 6 motion_extractor_params: # the M in the paper num_kp: 21 backbone: convnextv2_tiny warping_module_params: # the W in the paper num_kp: 21 block_expansion: 64 max_features: 512 num_down_blocks: 2 reshape_channel: 32 estimate_occlusion_map: True dense_motion_params: block_expansion: 32 max_features: 1024 num_blocks: 5 reshape_depth: 16 compress: 4 spade_generator_params: # the G in the paper upscale: 2 # represents upsample factor 256x256 -> 512x512 block_expansion: 64 max_features: 512 num_down_blocks: 2 stitching_retargeting_module_params: # the S in the paper stitching: input_size: 126 # (21*3)*2 hidden_sizes: [128, 128, 64] output_size: 65 # (21*3)+2(tx,ty) lip: input_size: 65 # (21*3)+2 hidden_sizes: [128, 128, 64] output_size: 63 # (21*3) eye: input_size: 66 # (21*3)+3 hidden_sizes: [256, 256, 128, 128, 64] output_size: 63 # (21*3) ================================================ FILE: src/gradio_pipeline.py ================================================ # coding: utf-8 """ Pipeline for gradio """ import os.path as osp import os import cv2 from rich.progress import track import gradio as gr import numpy as np import torch from .config.argument_config import ArgumentConfig from .live_portrait_pipeline import LivePortraitPipeline from .live_portrait_pipeline_animal import LivePortraitPipelineAnimal from .utils.io import load_img_online, load_video, resize_to_limit from .utils.filter import smooth from .utils.rprint import rlog as log from .utils.crop import prepare_paste_back, paste_back from .utils.camera import get_rotation_matrix from .utils.video import get_fps, has_audio_stream, concat_frames, images2video, add_audio_to_video from .utils.helper import is_square_video, mkdir, dct2device, basename from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio def update_args(args, user_args): """update the args according to user inputs """ for k, v in user_args.items(): if hasattr(args, k): setattr(args, k, v) return args class GradioPipeline(LivePortraitPipeline): """gradio for human """ def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig): super().__init__(inference_cfg, crop_cfg) # self.live_portrait_wrapper = self.live_portrait_wrapper self.args = args @torch.no_grad() def update_delta_new_eyeball_direction(self, eyeball_direction_x, eyeball_direction_y, delta_new, **kwargs): if eyeball_direction_x > 0: delta_new[0, 11, 0] += eyeball_direction_x * 0.0007 delta_new[0, 15, 0] += eyeball_direction_x * 0.001 else: delta_new[0, 11, 0] += eyeball_direction_x * 0.001 delta_new[0, 15, 0] += eyeball_direction_x * 0.0007 delta_new[0, 11, 1] += eyeball_direction_y * -0.001 delta_new[0, 15, 1] += eyeball_direction_y * -0.001 blink = -eyeball_direction_y / 2. delta_new[0, 11, 1] += blink * -0.001 delta_new[0, 13, 1] += blink * 0.0003 delta_new[0, 15, 1] += blink * -0.001 delta_new[0, 16, 1] += blink * 0.0003 return delta_new @torch.no_grad() def update_delta_new_smile(self, smile, delta_new, **kwargs): delta_new[0, 20, 1] += smile * -0.01 delta_new[0, 14, 1] += smile * -0.02 delta_new[0, 17, 1] += smile * 0.0065 delta_new[0, 17, 2] += smile * 0.003 delta_new[0, 13, 1] += smile * -0.00275 delta_new[0, 16, 1] += smile * -0.00275 delta_new[0, 3, 1] += smile * -0.0035 delta_new[0, 7, 1] += smile * -0.0035 return delta_new @torch.no_grad() def update_delta_new_wink(self, wink, delta_new, **kwargs): delta_new[0, 11, 1] += wink * 0.001 delta_new[0, 13, 1] += wink * -0.0003 delta_new[0, 17, 0] += wink * 0.0003 delta_new[0, 17, 1] += wink * 0.0003 delta_new[0, 3, 1] += wink * -0.0003 return delta_new @torch.no_grad() def update_delta_new_eyebrow(self, eyebrow, delta_new, **kwargs): if eyebrow > 0: delta_new[0, 1, 1] += eyebrow * 0.001 delta_new[0, 2, 1] += eyebrow * -0.001 else: delta_new[0, 1, 0] += eyebrow * -0.001 delta_new[0, 2, 0] += eyebrow * 0.001 delta_new[0, 1, 1] += eyebrow * 0.0003 delta_new[0, 2, 1] += eyebrow * -0.0003 return delta_new @torch.no_grad() def update_delta_new_lip_variation_zero(self, lip_variation_zero, delta_new, **kwargs): delta_new[0, 19, 0] += lip_variation_zero return delta_new @torch.no_grad() def update_delta_new_lip_variation_one(self, lip_variation_one, delta_new, **kwargs): delta_new[0, 14, 1] += lip_variation_one * 0.001 delta_new[0, 3, 1] += lip_variation_one * -0.0005 delta_new[0, 7, 1] += lip_variation_one * -0.0005 delta_new[0, 17, 2] += lip_variation_one * -0.0005 return delta_new @torch.no_grad() def update_delta_new_lip_variation_two(self, lip_variation_two, delta_new, **kwargs): delta_new[0, 20, 2] += lip_variation_two * -0.001 delta_new[0, 20, 1] += lip_variation_two * -0.001 delta_new[0, 14, 1] += lip_variation_two * -0.001 return delta_new @torch.no_grad() def update_delta_new_lip_variation_three(self, lip_variation_three, delta_new, **kwargs): delta_new[0, 19, 1] += lip_variation_three * 0.001 delta_new[0, 19, 2] += lip_variation_three * 0.0001 delta_new[0, 17, 1] += lip_variation_three * -0.0001 return delta_new @torch.no_grad() def update_delta_new_mov_x(self, mov_x, delta_new, **kwargs): delta_new[0, 5, 0] += mov_x return delta_new @torch.no_grad() def update_delta_new_mov_y(self, mov_y, delta_new, **kwargs): delta_new[0, 5, 1] += mov_y return delta_new @torch.no_grad() def execute_video( self, input_source_image_path=None, input_source_video_path=None, input_driving_video_path=None, input_driving_image_path=None, input_driving_video_pickle_path=None, flag_normalize_lip=False, flag_relative_input=True, flag_do_crop_input=True, flag_remap_input=True, flag_stitching_input=True, animation_region="all", driving_option_input="pose-friendly", driving_multiplier=1.0, flag_crop_driving_video_input=True, # flag_video_editing_head_rotation=False, scale=2.3, vx_ratio=0.0, vy_ratio=-0.125, scale_crop_driving_video=2.2, vx_ratio_crop_driving_video=0.0, vy_ratio_crop_driving_video=-0.1, driving_smooth_observation_variance=3e-7, tab_selection=None, v_tab_selection=None ): """ for video-driven portrait animation or video editing """ if tab_selection == 'Image': input_source_path = input_source_image_path elif tab_selection == 'Video': input_source_path = input_source_video_path else: input_source_path = input_source_image_path if v_tab_selection == 'Video': input_driving_path = input_driving_video_path elif v_tab_selection == 'Image': input_driving_path = input_driving_image_path elif v_tab_selection == 'Pickle': input_driving_path = input_driving_video_pickle_path else: input_driving_path = input_driving_video_path if input_source_path is not None and input_driving_path is not None: if osp.exists(input_driving_path) and v_tab_selection == 'Video' and not flag_crop_driving_video_input and is_square_video(input_driving_path) is False: flag_crop_driving_video_input = True log("The driving video is not square, it will be cropped to square automatically.") gr.Info("The driving video is not square, it will be cropped to square automatically.", duration=2) args_user = { 'source': input_source_path, 'driving': input_driving_path, 'flag_normalize_lip' : flag_normalize_lip, 'flag_relative_motion': flag_relative_input, 'flag_do_crop': flag_do_crop_input, 'flag_pasteback': flag_remap_input, 'flag_stitching': flag_stitching_input, 'animation_region': animation_region, 'driving_option': driving_option_input, 'driving_multiplier': driving_multiplier, 'flag_crop_driving_video': flag_crop_driving_video_input, 'scale': scale, 'vx_ratio': vx_ratio, 'vy_ratio': vy_ratio, 'scale_crop_driving_video': scale_crop_driving_video, 'vx_ratio_crop_driving_video': vx_ratio_crop_driving_video, 'vy_ratio_crop_driving_video': vy_ratio_crop_driving_video, 'driving_smooth_observation_variance': driving_smooth_observation_variance, } # update config from user input self.args = update_args(self.args, args_user) self.live_portrait_wrapper.update_config(self.args.__dict__) self.cropper.update_config(self.args.__dict__) output_path, output_path_concat = self.execute(self.args) gr.Info("Run successfully!", duration=2) if output_path.endswith(".jpg"): return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), output_path, gr.update(visible=True), output_path_concat, gr.update(visible=True) else: return output_path, gr.update(visible=True), output_path_concat, gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) else: raise gr.Error("Please upload the source portrait or source video, and driving video 🤗🤗🤗", duration=5) @torch.no_grad() def execute_image_retargeting( self, input_eye_ratio: float, input_lip_ratio: float, input_head_pitch_variation: float, input_head_yaw_variation: float, input_head_roll_variation: float, mov_x: float, mov_y: float, mov_z: float, lip_variation_zero: float, lip_variation_one: float, lip_variation_two: float, lip_variation_three: float, smile: float, wink: float, eyebrow: float, eyeball_direction_x: float, eyeball_direction_y: float, input_image, retargeting_source_scale: float, flag_stitching_retargeting_input=True, flag_do_crop_input_retargeting_image=True): """ for single image retargeting """ if input_head_pitch_variation is None or input_head_yaw_variation is None or input_head_roll_variation is None: raise gr.Error("Invalid relative pose input 💥!", duration=5) # disposable feature f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \ self.prepare_retargeting_image( input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop=flag_do_crop_input_retargeting_image) if input_eye_ratio is None or input_lip_ratio is None: raise gr.Error("Invalid ratio input 💥!", duration=5) else: device = self.live_portrait_wrapper.device # inference_cfg = self.live_portrait_wrapper.inference_cfg x_s_user = x_s_user.to(device) f_s_user = f_s_user.to(device) R_s_user = R_s_user.to(device) R_d_user = R_d_user.to(device) mov_x = torch.tensor(mov_x).to(device) mov_y = torch.tensor(mov_y).to(device) mov_z = torch.tensor(mov_z).to(device) eyeball_direction_x = torch.tensor(eyeball_direction_x).to(device) eyeball_direction_y = torch.tensor(eyeball_direction_y).to(device) smile = torch.tensor(smile).to(device) wink = torch.tensor(wink).to(device) eyebrow = torch.tensor(eyebrow).to(device) lip_variation_zero = torch.tensor(lip_variation_zero).to(device) lip_variation_one = torch.tensor(lip_variation_one).to(device) lip_variation_two = torch.tensor(lip_variation_two).to(device) lip_variation_three = torch.tensor(lip_variation_three).to(device) x_c_s = x_s_info['kp'].to(device) delta_new = x_s_info['exp'].to(device) scale_new = x_s_info['scale'].to(device) t_new = x_s_info['t'].to(device) R_d_new = (R_d_user @ R_s_user.permute(0, 2, 1)) @ R_s_user if eyeball_direction_x != 0 or eyeball_direction_y != 0: delta_new = self.update_delta_new_eyeball_direction(eyeball_direction_x, eyeball_direction_y, delta_new) if smile != 0: delta_new = self.update_delta_new_smile(smile, delta_new) if wink != 0: delta_new = self.update_delta_new_wink(wink, delta_new) if eyebrow != 0: delta_new = self.update_delta_new_eyebrow(eyebrow, delta_new) if lip_variation_zero != 0: delta_new = self.update_delta_new_lip_variation_zero(lip_variation_zero, delta_new) if lip_variation_one != 0: delta_new = self.update_delta_new_lip_variation_one(lip_variation_one, delta_new) if lip_variation_two != 0: delta_new = self.update_delta_new_lip_variation_two(lip_variation_two, delta_new) if lip_variation_three != 0: delta_new = self.update_delta_new_lip_variation_three(lip_variation_three, delta_new) if mov_x != 0: delta_new = self.update_delta_new_mov_x(-mov_x, delta_new) if mov_y !=0 : delta_new = self.update_delta_new_mov_y(mov_y, delta_new) x_d_new = mov_z * scale_new * (x_c_s @ R_d_new + delta_new) + t_new eyes_delta, lip_delta = None, None if input_eye_ratio != self.source_eye_ratio: combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[float(input_eye_ratio)]], source_lmk_user) eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor) if input_lip_ratio != self.source_lip_ratio: combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[float(input_lip_ratio)]], source_lmk_user) lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor) print(lip_delta) x_d_new = x_d_new + \ (eyes_delta if eyes_delta is not None else 0) + \ (lip_delta if lip_delta is not None else 0) if flag_stitching_retargeting_input: x_d_new = self.live_portrait_wrapper.stitching(x_s_user, x_d_new) out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new) out = self.live_portrait_wrapper.parse_output(out['out'])[0] if flag_do_crop_input_retargeting_image: out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori) else: out_to_ori_blend = out return out, out_to_ori_blend @torch.no_grad() def prepare_retargeting_image( self, input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop=True): """ for single image retargeting """ if input_image is not None: # gr.Info("Upload successfully!", duration=2) args_user = {'scale': retargeting_source_scale} self.args = update_args(self.args, args_user) self.cropper.update_config(self.args.__dict__) inference_cfg = self.live_portrait_wrapper.inference_cfg ######## process source portrait ######## img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=2) if flag_do_crop: crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg) I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256']) source_lmk_user = crop_info['lmk_crop'] crop_M_c2o = crop_info['M_c2o'] mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0])) else: I_s = self.live_portrait_wrapper.prepare_source(img_rgb) source_lmk_user = self.cropper.calc_lmk_from_cropped_image(img_rgb) crop_M_c2o = None mask_ori = None x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) x_d_info_user_pitch = x_s_info['pitch'] + input_head_pitch_variation x_d_info_user_yaw = x_s_info['yaw'] + input_head_yaw_variation x_d_info_user_roll = x_s_info['roll'] + input_head_roll_variation R_s_user = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) R_d_user = get_rotation_matrix(x_d_info_user_pitch, x_d_info_user_yaw, x_d_info_user_roll) ############################################ f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s) x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info) return f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb else: raise gr.Error("Please upload a source portrait as the retargeting input 🤗🤗🤗", duration=5) @torch.no_grad() def init_retargeting_image(self, retargeting_source_scale: float, source_eye_ratio: float, source_lip_ratio:float, input_image = None): """ initialize the retargeting slider """ if input_image != None: args_user = {'scale': retargeting_source_scale} self.args = update_args(self.args, args_user) self.cropper.update_config(self.args.__dict__) # inference_cfg = self.live_portrait_wrapper.inference_cfg ######## process source portrait ######## img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16) log(f"Load source image from {input_image}.") crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg) if crop_info is None: raise gr.Error("Source portrait NO face detected", duration=2) source_eye_ratio = calc_eye_close_ratio(crop_info['lmk_crop'][None]) source_lip_ratio = calc_lip_close_ratio(crop_info['lmk_crop'][None]) self.source_eye_ratio = round(float(source_eye_ratio.mean()), 2) self.source_lip_ratio = round(float(source_lip_ratio[0][0]), 2) log("Calculating eyes-open and lip-open ratios successfully!") return self.source_eye_ratio, self.source_lip_ratio else: return source_eye_ratio, source_lip_ratio @torch.no_grad() def execute_video_retargeting(self, input_lip_ratio: float, input_video, retargeting_source_scale: float, driving_smooth_observation_variance_retargeting: float, video_retargeting_silence=False, flag_do_crop_input_retargeting_video=True): """ retargeting the lip-open ratio of each source frame """ # disposable feature device = self.live_portrait_wrapper.device if not video_retargeting_silence: f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames = \ self.prepare_retargeting_video(input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=flag_do_crop_input_retargeting_video) if input_lip_ratio is None: raise gr.Error("Invalid ratio input 💥!", duration=5) else: inference_cfg = self.live_portrait_wrapper.inference_cfg I_p_pstbk_lst = None if flag_do_crop_input_retargeting_video: I_p_pstbk_lst = [] I_p_lst = [] for i in track(range(n_frames), description='Retargeting video...', total=n_frames): x_s_user_i = x_s_user_lst[i].to(device) f_s_user_i = f_s_user_lst[i].to(device) lip_delta_retargeting = lip_delta_retargeting_lst_smooth[i] x_d_i_new = x_s_user_i + lip_delta_retargeting x_d_i_new = self.live_portrait_wrapper.stitching(x_s_user_i, x_d_i_new) out = self.live_portrait_wrapper.warp_decode(f_s_user_i, x_s_user_i, x_d_i_new) I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0] I_p_lst.append(I_p_i) if flag_do_crop_input_retargeting_video: I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i]) I_p_pstbk_lst.append(I_p_pstbk) else: inference_cfg = self.live_portrait_wrapper.inference_cfg f_s_user_lst, x_s_user_lst, x_d_i_new_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, source_fps, n_frames = \ self.prepare_video_lip_silence(input_video, device, flag_do_crop=flag_do_crop_input_retargeting_video) I_p_pstbk_lst = None if flag_do_crop_input_retargeting_video: I_p_pstbk_lst = [] I_p_lst = [] for i in track(range(n_frames), description='Silencing lip...', total=n_frames): x_s_user_i = x_s_user_lst[i].to(device) f_s_user_i = f_s_user_lst[i].to(device) x_d_i_new = x_d_i_new_lst[i] x_d_i_new = self.live_portrait_wrapper.stitching(x_s_user_i, x_d_i_new) out = self.live_portrait_wrapper.warp_decode(f_s_user_i, x_s_user_i, x_d_i_new) I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0] I_p_lst.append(I_p_i) if flag_do_crop_input_retargeting_video: I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i]) I_p_pstbk_lst.append(I_p_pstbk) mkdir(self.args.output_dir) flag_source_has_audio = has_audio_stream(input_video) ######### build the final concatenation result ######### # source frame | generation frames_concatenated = concat_frames(driving_image_lst=None, source_image_lst=img_crop_256x256_lst, I_p_lst=I_p_lst) wfp_concat = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat.mp4') images2video(frames_concatenated, wfp=wfp_concat, fps=source_fps) if flag_source_has_audio: # final result with concatenation wfp_concat_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat_with_audio.mp4') add_audio_to_video(wfp_concat, input_video, wfp_concat_with_audio) os.replace(wfp_concat_with_audio, wfp_concat) log(f"Replace {wfp_concat_with_audio} with {wfp_concat}") # save the animated result wfp = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting.mp4') if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0: images2video(I_p_pstbk_lst, wfp=wfp, fps=source_fps) else: images2video(I_p_lst, wfp=wfp, fps=source_fps) ######### build the final result ######### if flag_source_has_audio: wfp_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_with_audio.mp4') add_audio_to_video(wfp, input_video, wfp_with_audio) os.replace(wfp_with_audio, wfp) log(f"Replace {wfp_with_audio} with {wfp}") gr.Info("Run successfully!", duration=2) return wfp_concat, wfp @torch.no_grad() def prepare_retargeting_video(self, input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=True): """ for video retargeting """ if input_video is not None: # gr.Info("Upload successfully!", duration=2) args_user = {'scale': retargeting_source_scale} self.args = update_args(self.args, args_user) self.cropper.update_config(self.args.__dict__) inference_cfg = self.live_portrait_wrapper.inference_cfg ######## process source video ######## source_rgb_lst = load_video(input_video) source_rgb_lst = [resize_to_limit(img, inference_cfg.source_max_dim, inference_cfg.source_division) for img in source_rgb_lst] source_fps = int(get_fps(input_video)) n_frames = len(source_rgb_lst) log(f"Load source video from {input_video}. FPS is {source_fps}") if flag_do_crop: ret_s = self.cropper.crop_source_video(source_rgb_lst, self.cropper.crop_cfg) log(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.') if len(ret_s["frame_crop_lst"]) != n_frames: n_frames = min(len(source_rgb_lst), len(ret_s["frame_crop_lst"])) img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst'] mask_ori_lst = [prepare_paste_back(inference_cfg.mask_crop, source_M_c2o, dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0])) for source_M_c2o in source_M_c2o_lst] else: source_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(source_rgb_lst) img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in source_rgb_lst] # force to resize to 256x256 source_M_c2o_lst, mask_ori_lst = None, None c_s_eyes_lst, c_s_lip_lst = self.live_portrait_wrapper.calc_ratio(source_lmk_crop_lst) # save the motion template I_s_lst = self.live_portrait_wrapper.prepare_videos(img_crop_256x256_lst) source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=source_fps) c_d_lip_retargeting = [input_lip_ratio] f_s_user_lst, x_s_user_lst, lip_delta_retargeting_lst = [], [], [] for i in track(range(n_frames), description='Preparing retargeting video...', total=n_frames): x_s_info = source_template_dct['motion'][i] x_s_info = dct2device(x_s_info, device) x_s_user = x_s_info['x_s'] source_lmk = source_lmk_crop_lst[i] img_crop_256x256 = img_crop_256x256_lst[i] I_s = I_s_lst[i] f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s) combined_lip_ratio_tensor_retargeting = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_retargeting, source_lmk) lip_delta_retargeting = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor_retargeting) f_s_user_lst.append(f_s_user); x_s_user_lst.append(x_s_user); lip_delta_retargeting_lst.append(lip_delta_retargeting.cpu().numpy().astype(np.float32)) lip_delta_retargeting_lst_smooth = smooth(lip_delta_retargeting_lst, lip_delta_retargeting_lst[0].shape, device, driving_smooth_observation_variance_retargeting) return f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames else: # when press the clear button, go here raise gr.Error("Please upload a source video as the retargeting input 🤗🤗🤗", duration=5) @torch.no_grad() def prepare_video_lip_silence(self, input_video, device, flag_do_crop=True): """ for keeping lips in the source video silent """ if input_video is not None: inference_cfg = self.live_portrait_wrapper.inference_cfg ######## process source video ######## source_rgb_lst = load_video(input_video) source_rgb_lst = [resize_to_limit(img, inference_cfg.source_max_dim, inference_cfg.source_division) for img in source_rgb_lst] source_fps = int(get_fps(input_video)) n_frames = len(source_rgb_lst) log(f"Load source video from {input_video}. FPS is {source_fps}") if flag_do_crop: ret_s = self.cropper.crop_source_video(source_rgb_lst, self.cropper.crop_cfg) log(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.') if len(ret_s["frame_crop_lst"]) != n_frames: n_frames = min(len(source_rgb_lst), len(ret_s["frame_crop_lst"])) img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst'] mask_ori_lst = [prepare_paste_back(inference_cfg.mask_crop, source_M_c2o, dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0])) for source_M_c2o in source_M_c2o_lst] else: source_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(source_rgb_lst) img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in source_rgb_lst] # force to resize to 256x256 source_M_c2o_lst, mask_ori_lst = None, None c_s_eyes_lst, c_s_lip_lst = self.live_portrait_wrapper.calc_ratio(source_lmk_crop_lst) # save the motion template I_s_lst = self.live_portrait_wrapper.prepare_videos(img_crop_256x256_lst) source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=source_fps) f_s_user_lst, x_s_user_lst, x_d_i_new_lst = [], [], [] for i in track(range(n_frames), description='Preparing silencing lip...', total=n_frames): x_s_info = source_template_dct['motion'][i] x_s_info = dct2device(x_s_info, device) scale_s = x_s_info['scale'] x_s_user = x_s_info['x_s'] x_c_s = x_s_info['kp'] R_s = x_s_info['R'] t_s = x_s_info['t'] delta_new = torch.zeros_like(x_s_info['exp']) + torch.from_numpy(inference_cfg.lip_array).to(dtype=torch.float32, device=device) for eyes_idx in [11, 13, 15, 16, 18]: delta_new[:, eyes_idx, :] = x_s_info['exp'][:, eyes_idx, :] source_lmk = source_lmk_crop_lst[i] img_crop_256x256 = img_crop_256x256_lst[i] I_s = I_s_lst[i] f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s) x_d_i_new = scale_s * (x_c_s @ R_s + delta_new) + t_s f_s_user_lst.append(f_s_user); x_s_user_lst.append(x_s_user); x_d_i_new_lst.append(x_d_i_new) return f_s_user_lst, x_s_user_lst, x_d_i_new_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, source_fps, n_frames else: # when press the clear button, go here raise gr.Error("Please upload a source video as the input 🤗🤗🤗", duration=5) class GradioPipelineAnimal(LivePortraitPipelineAnimal): """gradio for animal """ def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig): inference_cfg.flag_crop_driving_video = True # ensure the face_analysis_wrapper is enabled super().__init__(inference_cfg, crop_cfg) # self.live_portrait_wrapper_animal = self.live_portrait_wrapper_animal self.args = args @torch.no_grad() def execute_video( self, input_source_image_path=None, input_driving_video_path=None, input_driving_video_pickle_path=None, flag_do_crop_input=False, flag_remap_input=False, driving_multiplier=1.0, flag_stitching=False, flag_crop_driving_video_input=False, scale=2.3, vx_ratio=0.0, vy_ratio=-0.125, scale_crop_driving_video=2.2, vx_ratio_crop_driving_video=0.0, vy_ratio_crop_driving_video=-0.1, tab_selection=None, ): """ for video-driven potrait animation """ input_source_path = input_source_image_path if tab_selection == 'Video': input_driving_path = input_driving_video_path elif tab_selection == 'Pickle': input_driving_path = input_driving_video_pickle_path else: input_driving_path = input_driving_video_pickle_path if input_source_path is not None and input_driving_path is not None: if osp.exists(input_driving_path) and tab_selection == 'Video' and is_square_video(input_driving_path) is False: flag_crop_driving_video_input = True log("The driving video is not square, it will be cropped to square automatically.") gr.Info("The driving video is not square, it will be cropped to square automatically.", duration=2) args_user = { 'source': input_source_path, 'driving': input_driving_path, 'flag_do_crop': flag_do_crop_input, 'flag_pasteback': flag_remap_input, 'driving_multiplier': driving_multiplier, 'flag_stitching': flag_stitching, 'flag_crop_driving_video': flag_crop_driving_video_input, 'scale': scale, 'vx_ratio': vx_ratio, 'vy_ratio': vy_ratio, 'scale_crop_driving_video': scale_crop_driving_video, 'vx_ratio_crop_driving_video': vx_ratio_crop_driving_video, 'vy_ratio_crop_driving_video': vy_ratio_crop_driving_video, } # update config from user input self.args = update_args(self.args, args_user) self.live_portrait_wrapper_animal.update_config(self.args.__dict__) self.cropper.update_config(self.args.__dict__) # video driven animation video_path, video_path_concat, video_gif_path = self.execute(self.args) gr.Info("Run successfully!", duration=2) return video_path, video_path_concat, video_gif_path else: raise gr.Error("Please upload the source animal image, and driving video 🤗🤗🤗", duration=5) ================================================ FILE: src/live_portrait_pipeline.py ================================================ # coding: utf-8 """ Pipeline of LivePortrait (Human) """ import torch torch.backends.cudnn.benchmark = True # disable CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR warning import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) import numpy as np import os import os.path as osp from rich.progress import track from .config.argument_config import ArgumentConfig from .config.inference_config import InferenceConfig from .config.crop_config import CropConfig from .utils.cropper import Cropper from .utils.camera import get_rotation_matrix from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream from .utils.crop import prepare_paste_back, paste_back from .utils.io import load_image_rgb, load_video, resize_to_limit, dump, load from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image, is_square_video, calc_motion_multiplier from .utils.filter import smooth from .utils.rprint import rlog as log # from .utils.viz import viz_lmk from .live_portrait_wrapper import LivePortraitWrapper def make_abs_path(fn): return osp.join(osp.dirname(osp.realpath(__file__)), fn) class LivePortraitPipeline(object): def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig): self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(inference_cfg=inference_cfg) self.cropper: Cropper = Cropper(crop_cfg=crop_cfg) def make_motion_template(self, I_lst, c_eyes_lst, c_lip_lst, **kwargs): n_frames = I_lst.shape[0] template_dct = { 'n_frames': n_frames, 'output_fps': kwargs.get('output_fps', 25), 'motion': [], 'c_eyes_lst': [], 'c_lip_lst': [], } for i in track(range(n_frames), description='Making motion templates...', total=n_frames): # collect s, R, δ and t for inference I_i = I_lst[i] x_i_info = self.live_portrait_wrapper.get_kp_info(I_i) x_s = self.live_portrait_wrapper.transform_keypoint(x_i_info) R_i = get_rotation_matrix(x_i_info['pitch'], x_i_info['yaw'], x_i_info['roll']) item_dct = { 'scale': x_i_info['scale'].cpu().numpy().astype(np.float32), 'R': R_i.cpu().numpy().astype(np.float32), 'exp': x_i_info['exp'].cpu().numpy().astype(np.float32), 't': x_i_info['t'].cpu().numpy().astype(np.float32), 'kp': x_i_info['kp'].cpu().numpy().astype(np.float32), 'x_s': x_s.cpu().numpy().astype(np.float32), } template_dct['motion'].append(item_dct) c_eyes = c_eyes_lst[i].astype(np.float32) template_dct['c_eyes_lst'].append(c_eyes) c_lip = c_lip_lst[i].astype(np.float32) template_dct['c_lip_lst'].append(c_lip) return template_dct def execute(self, args: ArgumentConfig): # for convenience inf_cfg = self.live_portrait_wrapper.inference_cfg device = self.live_portrait_wrapper.device crop_cfg = self.cropper.crop_cfg ######## load source input ######## flag_is_source_video = False source_fps = None if is_image(args.source): flag_is_source_video = False img_rgb = load_image_rgb(args.source) img_rgb = resize_to_limit(img_rgb, inf_cfg.source_max_dim, inf_cfg.source_division) log(f"Load source image from {args.source}") source_rgb_lst = [img_rgb] elif is_video(args.source): flag_is_source_video = True source_rgb_lst = load_video(args.source) source_rgb_lst = [resize_to_limit(img, inf_cfg.source_max_dim, inf_cfg.source_division) for img in source_rgb_lst] source_fps = int(get_fps(args.source)) log(f"Load source video from {args.source}, FPS is {source_fps}") else: # source input is an unknown format raise Exception(f"Unknown source format: {args.source}") ######## process driving info ######## flag_load_from_template = is_template(args.driving) driving_rgb_crop_256x256_lst = None wfp_template = None if flag_load_from_template: # NOTE: load from template, it is fast, but the cropping video is None log(f"Load from template: {args.driving}, NOT the video, so the cropping video and audio are both NULL.", style='bold green') driving_template_dct = load(args.driving) c_d_eyes_lst = driving_template_dct['c_eyes_lst'] if 'c_eyes_lst' in driving_template_dct.keys() else driving_template_dct['c_d_eyes_lst'] # compatible with previous keys c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst'] driving_n_frames = driving_template_dct['n_frames'] flag_is_driving_video = True if driving_n_frames > 1 else False if flag_is_source_video and flag_is_driving_video: n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames elif flag_is_source_video and not flag_is_driving_video: n_frames = len(source_rgb_lst) else: n_frames = driving_n_frames # set output_fps output_fps = driving_template_dct.get('output_fps', inf_cfg.output_fps) log(f'The FPS of template: {output_fps}') if args.flag_crop_driving_video: log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.") elif osp.exists(args.driving): if is_video(args.driving): flag_is_driving_video = True # load from video file, AND make motion template output_fps = int(get_fps(args.driving)) log(f"Load driving video from: {args.driving}, FPS is {output_fps}") driving_rgb_lst = load_video(args.driving) elif is_image(args.driving): flag_is_driving_video = False driving_img_rgb = load_image_rgb(args.driving) output_fps = 25 log(f"Load driving image from {args.driving}") driving_rgb_lst = [driving_img_rgb] else: raise Exception(f"{args.driving} is not a supported type!") ######## make motion template ######## log("Start making driving motion template...") driving_n_frames = len(driving_rgb_lst) if flag_is_source_video and flag_is_driving_video: n_frames = min(len(source_rgb_lst), driving_n_frames) # minimum number as the number of the animated frames driving_rgb_lst = driving_rgb_lst[:n_frames] elif flag_is_source_video and not flag_is_driving_video: n_frames = len(source_rgb_lst) else: n_frames = driving_n_frames if inf_cfg.flag_crop_driving_video or (not is_square_video(args.driving)): ret_d = self.cropper.crop_driving_video(driving_rgb_lst) log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.') if len(ret_d["frame_crop_lst"]) is not n_frames and flag_is_driving_video: n_frames = min(n_frames, len(ret_d["frame_crop_lst"])) driving_rgb_crop_lst, driving_lmk_crop_lst = ret_d['frame_crop_lst'], ret_d['lmk_crop_lst'] driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst] else: driving_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(driving_rgb_lst) driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] # force to resize to 256x256 ####################################### c_d_eyes_lst, c_d_lip_lst = self.live_portrait_wrapper.calc_ratio(driving_lmk_crop_lst) # save the motion template I_d_lst = self.live_portrait_wrapper.prepare_videos(driving_rgb_crop_256x256_lst) driving_template_dct = self.make_motion_template(I_d_lst, c_d_eyes_lst, c_d_lip_lst, output_fps=output_fps) wfp_template = remove_suffix(args.driving) + '.pkl' dump(wfp_template, driving_template_dct) log(f"Dump motion template to {wfp_template}") else: raise Exception(f"{args.driving} does not exist!") if not flag_is_driving_video: c_d_eyes_lst = c_d_eyes_lst*n_frames c_d_lip_lst = c_d_lip_lst*n_frames ######## prepare for pasteback ######## I_p_pstbk_lst = None if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: I_p_pstbk_lst = [] log("Prepared pasteback mask done.") I_p_lst = [] R_d_0, x_d_0_info = None, None flag_normalize_lip = inf_cfg.flag_normalize_lip # not overwrite flag_source_video_eye_retargeting = inf_cfg.flag_source_video_eye_retargeting # not overwrite lip_delta_before_animation, eye_delta_before_animation = None, None ######## process source info ######## if flag_is_source_video: log(f"Start making source motion template...") source_rgb_lst = source_rgb_lst[:n_frames] if inf_cfg.flag_do_crop: ret_s = self.cropper.crop_source_video(source_rgb_lst, crop_cfg) log(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.') if len(ret_s["frame_crop_lst"]) is not n_frames: n_frames = min(n_frames, len(ret_s["frame_crop_lst"])) img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst'] else: source_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(source_rgb_lst) img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in source_rgb_lst] # force to resize to 256x256 c_s_eyes_lst, c_s_lip_lst = self.live_portrait_wrapper.calc_ratio(source_lmk_crop_lst) # save the motion template I_s_lst = self.live_portrait_wrapper.prepare_videos(img_crop_256x256_lst) source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=source_fps) key_r = 'R' if 'R' in driving_template_dct['motion'][0].keys() else 'R_d' # compatible with previous keys if inf_cfg.flag_relative_motion: if flag_is_driving_video: x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)] x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance) else: x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + (driving_template_dct['motion'][0]['exp'] - inf_cfg.lip_array) for i in range(n_frames)] x_d_exp_lst_smooth = [torch.tensor(x_d_exp[0], dtype=torch.float32, device=device) for x_d_exp in x_d_exp_lst] if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose": if flag_is_driving_video: x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)] x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance) else: x_d_r_lst = [source_template_dct['motion'][i]['R'] for i in range(n_frames)] x_d_r_lst_smooth = [torch.tensor(x_d_r[0], dtype=torch.float32, device=device) for x_d_r in x_d_r_lst] else: if flag_is_driving_video: x_d_exp_lst = [driving_template_dct['motion'][i]['exp'] for i in range(n_frames)] x_d_exp_lst_smooth = smooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, device, inf_cfg.driving_smooth_observation_variance) else: x_d_exp_lst = [driving_template_dct['motion'][0]['exp']] x_d_exp_lst_smooth = [torch.tensor(x_d_exp[0], dtype=torch.float32, device=device) for x_d_exp in x_d_exp_lst]*n_frames if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose": if flag_is_driving_video: x_d_r_lst = [driving_template_dct['motion'][i][key_r] for i in range(n_frames)] x_d_r_lst_smooth = smooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, device, inf_cfg.driving_smooth_observation_variance) else: x_d_r_lst = [driving_template_dct['motion'][0][key_r]] x_d_r_lst_smooth = [torch.tensor(x_d_r[0], dtype=torch.float32, device=device) for x_d_r in x_d_r_lst]*n_frames else: # if the input is a source image, process it only once if inf_cfg.flag_do_crop: crop_info = self.cropper.crop_source_image(source_rgb_lst[0], crop_cfg) if crop_info is None: raise Exception("No face detected in the source image!") source_lmk = crop_info['lmk_crop'] img_crop_256x256 = crop_info['img_crop_256x256'] else: source_lmk = self.cropper.calc_lmk_from_cropped_image(source_rgb_lst[0]) img_crop_256x256 = cv2.resize(source_rgb_lst[0], (256, 256)) # force to resize to 256x256 I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256) x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) x_c_s = x_s_info['kp'] R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) f_s = self.live_portrait_wrapper.extract_feature_3d(I_s) x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info) # let lip-open scalar to be 0 at first if flag_normalize_lip and inf_cfg.flag_relative_motion and source_lmk is not None: c_d_lip_before_animation = [0.] combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk) if combined_lip_ratio_tensor_before_animation[0][0] >= inf_cfg.lip_normalize_threshold: lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation) if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0])) ######## animate ######## if flag_is_driving_video or (flag_is_source_video and not flag_is_driving_video): log(f"The animated video consists of {n_frames} frames.") else: log(f"The output of image-driven portrait animation is an image.") for i in track(range(n_frames), description='🚀Animating...', total=n_frames): if flag_is_source_video: # source video x_s_info = source_template_dct['motion'][i] x_s_info = dct2device(x_s_info, device) source_lmk = source_lmk_crop_lst[i] img_crop_256x256 = img_crop_256x256_lst[i] I_s = I_s_lst[i] f_s = self.live_portrait_wrapper.extract_feature_3d(I_s) x_c_s = x_s_info['kp'] R_s = x_s_info['R'] x_s =x_s_info['x_s'] # let lip-open scalar to be 0 at first if the input is a video if flag_normalize_lip and inf_cfg.flag_relative_motion and source_lmk is not None: c_d_lip_before_animation = [0.] combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk) if combined_lip_ratio_tensor_before_animation[0][0] >= inf_cfg.lip_normalize_threshold: lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation) else: lip_delta_before_animation = None # let eye-open scalar to be the same as the first frame if the latter is eye-open state if flag_source_video_eye_retargeting and source_lmk is not None: if i == 0: combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0] c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]] if c_d_eye_before_animation_frame_zero[0][0] < inf_cfg.source_video_eye_retargeting_threshold: c_d_eye_before_animation_frame_zero = [[0.39]] combined_eye_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eye_before_animation_frame_zero, source_lmk) eye_delta_before_animation = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor_before_animation) if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: # prepare for paste back mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, source_M_c2o_lst[i], dsize=(source_rgb_lst[i].shape[1], source_rgb_lst[i].shape[0])) if flag_is_source_video and not flag_is_driving_video: x_d_i_info = driving_template_dct['motion'][0] else: x_d_i_info = driving_template_dct['motion'][i] x_d_i_info = dct2device(x_d_i_info, device) R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys if i == 0: # cache the first frame R_d_0 = R_d_i x_d_0_info = x_d_i_info.copy() delta_new = x_s_info['exp'].clone() if inf_cfg.flag_relative_motion: if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose": R_new = x_d_r_lst_smooth[i] if flag_is_source_video else (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s else: R_new = R_s if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp": if flag_is_source_video: for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]: delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :] delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1] delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2] delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2] delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:] else: if flag_is_driving_video: delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) else: delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device)) elif inf_cfg.animation_region == "lip": for lip_idx in [6, 12, 14, 17, 19, 20]: if flag_is_source_video: delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :] elif flag_is_driving_video: delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :] else: delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(inf_cfg.lip_array).to(dtype=torch.float32, device=device)))[:, lip_idx, :] elif inf_cfg.animation_region == "eyes": for eyes_idx in [11, 13, 15, 16, 18]: if flag_is_source_video: delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :] elif flag_is_driving_video: delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, eyes_idx, :] else: delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - 0))[:, eyes_idx, :] if inf_cfg.animation_region == "all": scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) else: scale_new = x_s_info['scale'] if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose": t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) else: t_new = x_s_info['t'] else: if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose": R_new = x_d_r_lst_smooth[i] if flag_is_source_video else R_d_i else: R_new = R_s if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "exp": for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]: delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :] if flag_is_source_video else x_d_i_info['exp'][:, idx, :] delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1] if flag_is_source_video else x_d_i_info['exp'][:, 3:5, 1] delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2] if flag_is_source_video else x_d_i_info['exp'][:, 5, 2] delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2] if flag_is_source_video else x_d_i_info['exp'][:, 8, 2] delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:] if flag_is_source_video else x_d_i_info['exp'][:, 9, 1:] elif inf_cfg.animation_region == "lip": for lip_idx in [6, 12, 14, 17, 19, 20]: delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :] if flag_is_source_video else x_d_i_info['exp'][:, lip_idx, :] elif inf_cfg.animation_region == "eyes": for eyes_idx in [11, 13, 15, 16, 18]: delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :] if flag_is_source_video else x_d_i_info['exp'][:, eyes_idx, :] scale_new = x_s_info['scale'] if inf_cfg.animation_region == "all" or inf_cfg.animation_region == "pose": t_new = x_d_i_info['t'] else: t_new = x_s_info['t'] t_new[..., 2].fill_(0) # zero tz x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new if inf_cfg.flag_relative_motion and inf_cfg.driving_option == "expression-friendly" and not flag_is_source_video and flag_is_driving_video: if i == 0: x_d_0_new = x_d_i_new motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new) # motion_multiplier *= inf_cfg.driving_multiplier x_d_diff = (x_d_i_new - x_d_0_new) * motion_multiplier x_d_i_new = x_d_diff + x_s # Algorithm 1: if not inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting: # without stitching or retargeting if flag_normalize_lip and lip_delta_before_animation is not None: x_d_i_new += lip_delta_before_animation if flag_source_video_eye_retargeting and eye_delta_before_animation is not None: x_d_i_new += eye_delta_before_animation else: pass elif inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting: # with stitching and without retargeting if flag_normalize_lip and lip_delta_before_animation is not None: x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation else: x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) if flag_source_video_eye_retargeting and eye_delta_before_animation is not None: x_d_i_new += eye_delta_before_animation else: eyes_delta, lip_delta = None, None if inf_cfg.flag_eye_retargeting and source_lmk is not None: c_d_eyes_i = c_d_eyes_lst[i] combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk) # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i) eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor) if inf_cfg.flag_lip_retargeting and source_lmk is not None: c_d_lip_i = c_d_lip_lst[i] combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk) # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i) lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor) if inf_cfg.flag_relative_motion: # use x_s x_d_i_new = x_s + \ (eyes_delta if eyes_delta is not None else 0) + \ (lip_delta if lip_delta is not None else 0) else: # use x_d,i x_d_i_new = x_d_i_new + \ (eyes_delta if eyes_delta is not None else 0) + \ (lip_delta if lip_delta is not None else 0) if inf_cfg.flag_stitching: x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) x_d_i_new = x_s + (x_d_i_new - x_s) * inf_cfg.driving_multiplier out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new) I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0] I_p_lst.append(I_p_i) if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: # TODO: the paste back procedure is slow, considering optimize it using multi-threading or GPU if flag_is_source_video: I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_float) else: I_p_pstbk = paste_back(I_p_i, crop_info['M_c2o'], source_rgb_lst[0], mask_ori_float) I_p_pstbk_lst.append(I_p_pstbk) mkdir(args.output_dir) wfp_concat = None ######### build the final concatenation result ######### # driving frame | source frame | generation if flag_is_source_video and flag_is_driving_video: frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst) elif flag_is_source_video and not flag_is_driving_video: if flag_load_from_template: frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256_lst, I_p_lst) else: frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst*n_frames, img_crop_256x256_lst, I_p_lst) else: frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst) if flag_is_driving_video or (flag_is_source_video and not flag_is_driving_video): flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source) flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving) wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4') # NOTE: update output fps output_fps = source_fps if flag_is_source_video else output_fps images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps) if flag_source_has_audio or flag_driving_has_audio: # final result with concatenation wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4') audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source log(f"Audio is selected from {audio_from_which_video}, concat mode") add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio) os.replace(wfp_concat_with_audio, wfp_concat) log(f"Replace {wfp_concat_with_audio} with {wfp_concat}") # save the animated result wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4') if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0: images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps) else: images2video(I_p_lst, wfp=wfp, fps=output_fps) ######### build the final result ######### if flag_source_has_audio or flag_driving_has_audio: wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4') audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source log(f"Audio is selected from {audio_from_which_video}") add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio) os.replace(wfp_with_audio, wfp) log(f"Replace {wfp_with_audio} with {wfp}") # final log if wfp_template not in (None, ''): log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green') log(f'Animated video: {wfp}') log(f'Animated video with concat: {wfp_concat}') else: wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.jpg') cv2.imwrite(wfp_concat, frames_concatenated[0][..., ::-1]) wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.jpg') if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0: cv2.imwrite(wfp, I_p_pstbk_lst[0][..., ::-1]) else: cv2.imwrite(wfp, frames_concatenated[0][..., ::-1]) # final log log(f'Animated image: {wfp}') log(f'Animated image with concat: {wfp_concat}') return wfp, wfp_concat ================================================ FILE: src/live_portrait_pipeline_animal.py ================================================ # coding: utf-8 """ Pipeline of LivePortrait (Animal) """ import warnings warnings.filterwarnings("ignore", message="torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument.") warnings.filterwarnings("ignore", message="torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly.") warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None") import torch torch.backends.cudnn.benchmark = True # disable CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR warning import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) import numpy as np import os import os.path as osp from rich.progress import track from .config.argument_config import ArgumentConfig from .config.inference_config import InferenceConfig from .config.crop_config import CropConfig from .utils.cropper import Cropper from .utils.camera import get_rotation_matrix from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream, video2gif from .utils.crop import _transform_img, prepare_paste_back, paste_back from .utils.io import load_image_rgb, load_video, resize_to_limit, dump, load from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix, is_image, calc_motion_multiplier from .utils.rprint import rlog as log # from .utils.viz import viz_lmk from .live_portrait_wrapper import LivePortraitWrapperAnimal def make_abs_path(fn): return osp.join(osp.dirname(osp.realpath(__file__)), fn) class LivePortraitPipelineAnimal(object): def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig): self.live_portrait_wrapper_animal: LivePortraitWrapperAnimal = LivePortraitWrapperAnimal(inference_cfg=inference_cfg) self.cropper: Cropper = Cropper(crop_cfg=crop_cfg, image_type='animal_face', flag_use_half_precision=inference_cfg.flag_use_half_precision) def make_motion_template(self, I_lst, **kwargs): n_frames = I_lst.shape[0] template_dct = { 'n_frames': n_frames, 'output_fps': kwargs.get('output_fps', 25), 'motion': [], } for i in track(range(n_frames), description='Making driving motion templates...', total=n_frames): # collect s, R, δ and t for inference I_i = I_lst[i] x_i_info = self.live_portrait_wrapper_animal.get_kp_info(I_i) R_i = get_rotation_matrix(x_i_info['pitch'], x_i_info['yaw'], x_i_info['roll']) item_dct = { 'scale': x_i_info['scale'].cpu().numpy().astype(np.float32), 'R': R_i.cpu().numpy().astype(np.float32), 'exp': x_i_info['exp'].cpu().numpy().astype(np.float32), 't': x_i_info['t'].cpu().numpy().astype(np.float32), } template_dct['motion'].append(item_dct) return template_dct def execute(self, args: ArgumentConfig): # for convenience inf_cfg = self.live_portrait_wrapper_animal.inference_cfg device = self.live_portrait_wrapper_animal.device crop_cfg = self.cropper.crop_cfg ######## load source input ######## if is_image(args.source): img_rgb = load_image_rgb(args.source) img_rgb = resize_to_limit(img_rgb, inf_cfg.source_max_dim, inf_cfg.source_division) log(f"Load source image from {args.source}") else: # source input is an unknown format raise Exception(f"Unknown source format: {args.source}") ######## process driving info ######## flag_load_from_template = is_template(args.driving) driving_rgb_crop_256x256_lst = None wfp_template = None if flag_load_from_template: # NOTE: load from template, it is fast, but the cropping video is None log(f"Load from template: {args.driving}, NOT the video, so the cropping video and audio are both NULL.", style='bold green') driving_template_dct = load(args.driving) n_frames = driving_template_dct['n_frames'] # set output_fps output_fps = driving_template_dct.get('output_fps', inf_cfg.output_fps) log(f'The FPS of template: {output_fps}') if args.flag_crop_driving_video: log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.") elif osp.exists(args.driving) and is_video(args.driving): # load from video file, AND make motion template output_fps = int(get_fps(args.driving)) log(f"Load driving video from: {args.driving}, FPS is {output_fps}") driving_rgb_lst = load_video(args.driving) n_frames = len(driving_rgb_lst) ######## make motion template ######## log("Start making driving motion template...") if inf_cfg.flag_crop_driving_video: ret_d = self.cropper.crop_driving_video(driving_rgb_lst) log(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.') if len(ret_d["frame_crop_lst"]) is not n_frames: n_frames = min(n_frames, len(ret_d["frame_crop_lst"])) driving_rgb_crop_lst = ret_d['frame_crop_lst'] driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst] else: driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] # force to resize to 256x256 ####################################### # save the motion template I_d_lst = self.live_portrait_wrapper_animal.prepare_videos(driving_rgb_crop_256x256_lst) driving_template_dct = self.make_motion_template(I_d_lst, output_fps=output_fps) wfp_template = remove_suffix(args.driving) + '.pkl' dump(wfp_template, driving_template_dct) log(f"Dump motion template to {wfp_template}") else: raise Exception(f"{args.driving} not exists or unsupported driving info types!") ######## prepare for pasteback ######## I_p_pstbk_lst = None if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: I_p_pstbk_lst = [] log("Prepared pasteback mask done.") ######## process source info ######## if inf_cfg.flag_do_crop: crop_info = self.cropper.crop_source_image(img_rgb, crop_cfg) if crop_info is None: raise Exception("No animal face detected in the source image!") img_crop_256x256 = crop_info['img_crop_256x256'] else: img_crop_256x256 = cv2.resize(img_rgb, (256, 256)) # force to resize to 256x256 I_s = self.live_portrait_wrapper_animal.prepare_source(img_crop_256x256) x_s_info = self.live_portrait_wrapper_animal.get_kp_info(I_s) x_c_s = x_s_info['kp'] R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) f_s = self.live_portrait_wrapper_animal.extract_feature_3d(I_s) x_s = self.live_portrait_wrapper_animal.transform_keypoint(x_s_info) if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0])) ######## animate ######## I_p_lst = [] for i in track(range(n_frames), description='🚀Animating...', total=n_frames): x_d_i_info = driving_template_dct['motion'][i] x_d_i_info = dct2device(x_d_i_info, device) R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys delta_new = x_d_i_info['exp'] t_new = x_d_i_info['t'] t_new[..., 2].fill_(0) # zero tz scale_new = x_s_info['scale'] x_d_i = scale_new * (x_c_s @ R_d_i + delta_new) + t_new if i == 0: x_d_0 = x_d_i motion_multiplier = calc_motion_multiplier(x_s, x_d_0) x_d_diff = (x_d_i - x_d_0) * motion_multiplier x_d_i = x_d_diff + x_s if not inf_cfg.flag_stitching: pass else: x_d_i = self.live_portrait_wrapper_animal.stitching(x_s, x_d_i) x_d_i = x_s + (x_d_i - x_s) * inf_cfg.driving_multiplier out = self.live_portrait_wrapper_animal.warp_decode(f_s, x_s, x_d_i) I_p_i = self.live_portrait_wrapper_animal.parse_output(out['out'])[0] I_p_lst.append(I_p_i) if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: I_p_pstbk = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori_float) I_p_pstbk_lst.append(I_p_pstbk) mkdir(args.output_dir) wfp_concat = None flag_driving_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving) ######### build the final concatenation result ######### # driving frame | source image | generation frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst) wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4') images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps) if flag_driving_has_audio: # final result with concatenation wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4') audio_from_which_video = args.driving add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio) os.replace(wfp_concat_with_audio, wfp_concat) log(f"Replace {wfp_concat_with_audio} with {wfp_concat}") # save the animated result wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4') if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0: images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps) else: images2video(I_p_lst, wfp=wfp, fps=output_fps) ######### build the final result ######### if flag_driving_has_audio: wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4') audio_from_which_video = args.driving add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio) os.replace(wfp_with_audio, wfp) log(f"Replace {wfp_with_audio} with {wfp}") # final log if wfp_template not in (None, ''): log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green') log(f'Animated video: {wfp}') log(f'Animated video with concat: {wfp_concat}') # build the gif wfp_gif = video2gif(wfp) log(f'Animated gif: {wfp_gif}') return wfp, wfp_concat, wfp_gif ================================================ FILE: src/live_portrait_wrapper.py ================================================ # coding: utf-8 """ Wrappers for LivePortrait core functions """ import contextlib import os.path as osp import numpy as np import cv2 import torch import yaml from .utils.timer import Timer from .utils.helper import load_model, concat_feat from .utils.camera import headpose_pred_to_degree, get_rotation_matrix from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio from .config.inference_config import InferenceConfig from .utils.rprint import rlog as log class LivePortraitWrapper(object): """ Wrapper for Human """ def __init__(self, inference_cfg: InferenceConfig): self.inference_cfg = inference_cfg self.device_id = inference_cfg.device_id self.compile = inference_cfg.flag_do_torch_compile if inference_cfg.flag_force_cpu: self.device = 'cpu' else: try: if torch.backends.mps.is_available(): self.device = 'mps' else: self.device = 'cuda:' + str(self.device_id) except: self.device = 'cuda:' + str(self.device_id) model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader) # init F self.appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, self.device, 'appearance_feature_extractor') log(f'Load appearance_feature_extractor from {osp.realpath(inference_cfg.checkpoint_F)} done.') # init M self.motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, self.device, 'motion_extractor') log(f'Load motion_extractor from {osp.realpath(inference_cfg.checkpoint_M)} done.') # init W self.warping_module = load_model(inference_cfg.checkpoint_W, model_config, self.device, 'warping_module') log(f'Load warping_module from {osp.realpath(inference_cfg.checkpoint_W)} done.') # init G self.spade_generator = load_model(inference_cfg.checkpoint_G, model_config, self.device, 'spade_generator') log(f'Load spade_generator from {osp.realpath(inference_cfg.checkpoint_G)} done.') # init S and R if inference_cfg.checkpoint_S is not None and osp.exists(inference_cfg.checkpoint_S): self.stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, self.device, 'stitching_retargeting_module') log(f'Load stitching_retargeting_module from {osp.realpath(inference_cfg.checkpoint_S)} done.') else: self.stitching_retargeting_module = None # Optimize for inference if self.compile: torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution self.warping_module = torch.compile(self.warping_module, mode='max-autotune') self.spade_generator = torch.compile(self.spade_generator, mode='max-autotune') self.timer = Timer() def inference_ctx(self): if self.device == "mps": ctx = contextlib.nullcontext() else: ctx = torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision) return ctx def update_config(self, user_args): for k, v in user_args.items(): if hasattr(self.inference_cfg, k): setattr(self.inference_cfg, k, v) def prepare_source(self, img: np.ndarray) -> torch.Tensor: """ construct the input as standard img: HxWx3, uint8, 256x256 """ h, w = img.shape[:2] if h != self.inference_cfg.input_shape[0] or w != self.inference_cfg.input_shape[1]: x = cv2.resize(img, (self.inference_cfg.input_shape[0], self.inference_cfg.input_shape[1])) else: x = img.copy() if x.ndim == 3: x = x[np.newaxis].astype(np.float32) / 255. # HxWx3 -> 1xHxWx3, normalized to 0~1 elif x.ndim == 4: x = x.astype(np.float32) / 255. # BxHxWx3, normalized to 0~1 else: raise ValueError(f'img ndim should be 3 or 4: {x.ndim}') x = np.clip(x, 0, 1) # clip to 0~1 x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW x = x.to(self.device) return x def prepare_videos(self, imgs) -> torch.Tensor: """ construct the input as standard imgs: NxBxHxWx3, uint8 """ if isinstance(imgs, list): _imgs = np.array(imgs)[..., np.newaxis] # TxHxWx3x1 elif isinstance(imgs, np.ndarray): _imgs = imgs else: raise ValueError(f'imgs type error: {type(imgs)}') y = _imgs.astype(np.float32) / 255. y = np.clip(y, 0, 1) # clip to 0~1 y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW y = y.to(self.device) return y def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor: """ get the appearance feature of the image by F x: Bx3xHxW, normalized to 0~1 """ with torch.no_grad(), self.inference_ctx(): feature_3d = self.appearance_feature_extractor(x) return feature_3d.float() def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict: """ get the implicit keypoint information x: Bx3xHxW, normalized to 0~1 flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp' """ with torch.no_grad(), self.inference_ctx(): kp_info = self.motion_extractor(x) if self.inference_cfg.flag_use_half_precision: # float the dict for k, v in kp_info.items(): if isinstance(v, torch.Tensor): kp_info[k] = v.float() flag_refine_info: bool = kwargs.get('flag_refine_info', True) if flag_refine_info: bs = kp_info['kp'].shape[0] kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None] # Bx1 kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None] # Bx1 kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None] # Bx1 kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3) # BxNx3 kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3) # BxNx3 return kp_info def get_pose_dct(self, kp_info: dict) -> dict: pose_dct = dict( pitch=headpose_pred_to_degree(kp_info['pitch']).item(), yaw=headpose_pred_to_degree(kp_info['yaw']).item(), roll=headpose_pred_to_degree(kp_info['roll']).item(), ) return pose_dct def get_fs_and_kp_info(self, source_prepared, driving_first_frame): # get the canonical keypoints of source image by M source_kp_info = self.get_kp_info(source_prepared, flag_refine_info=True) source_rotation = get_rotation_matrix(source_kp_info['pitch'], source_kp_info['yaw'], source_kp_info['roll']) # get the canonical keypoints of first driving frame by M driving_first_frame_kp_info = self.get_kp_info(driving_first_frame, flag_refine_info=True) driving_first_frame_rotation = get_rotation_matrix( driving_first_frame_kp_info['pitch'], driving_first_frame_kp_info['yaw'], driving_first_frame_kp_info['roll'] ) # get feature volume by F source_feature_3d = self.extract_feature_3d(source_prepared) return source_kp_info, source_rotation, source_feature_3d, driving_first_frame_kp_info, driving_first_frame_rotation def transform_keypoint(self, kp_info: dict): """ transform the implicit keypoints with the pose, shift, and expression deformation kp: BxNx3 """ kp = kp_info['kp'] # (bs, k, 3) pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll'] t, exp = kp_info['t'], kp_info['exp'] scale = kp_info['scale'] pitch = headpose_pred_to_degree(pitch) yaw = headpose_pred_to_degree(yaw) roll = headpose_pred_to_degree(roll) bs = kp.shape[0] if kp.ndim == 2: num_kp = kp.shape[1] // 3 # Bx(num_kpx3) else: num_kp = kp.shape[1] # Bxnum_kpx3 rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3) # Eqn.2: s * (R * x_c,s + exp) + t kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3) kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3) kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty return kp_transformed def retarget_eye(self, kp_source: torch.Tensor, eye_close_ratio: torch.Tensor) -> torch.Tensor: """ kp_source: BxNx3 eye_close_ratio: Bx3 Return: Bx(3*num_kp) """ feat_eye = concat_feat(kp_source, eye_close_ratio) with torch.no_grad(): delta = self.stitching_retargeting_module['eye'](feat_eye) return delta.reshape(-1, kp_source.shape[1], 3) def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor: """ kp_source: BxNx3 lip_close_ratio: Bx2 Return: Bx(3*num_kp) """ feat_lip = concat_feat(kp_source, lip_close_ratio) with torch.no_grad(): delta = self.stitching_retargeting_module['lip'](feat_lip) return delta.reshape(-1, kp_source.shape[1], 3) def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: """ kp_source: BxNx3 kp_driving: BxNx3 Return: Bx(3*num_kp+2) """ feat_stiching = concat_feat(kp_source, kp_driving) with torch.no_grad(): delta = self.stitching_retargeting_module['stitching'](feat_stiching) return delta def stitching(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: """ conduct the stitching kp_source: Bxnum_kpx3 kp_driving: Bxnum_kpx3 """ if self.stitching_retargeting_module is not None: bs, num_kp = kp_source.shape[:2] kp_driving_new = kp_driving.clone() delta = self.stitch(kp_source, kp_driving_new) delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) # 1x20x3 delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) # 1x1x2 kp_driving_new += delta_exp kp_driving_new[..., :2] += delta_tx_ty return kp_driving_new return kp_driving def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: """ get the image after the warping of the implicit keypoints feature_3d: Bx32x16x64x64, feature volume kp_source: BxNx3 kp_driving: BxNx3 """ # The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i)) with torch.no_grad(), self.inference_ctx(): if self.compile: # Mark the beginning of a new CUDA Graph step torch.compiler.cudagraph_mark_step_begin() # get decoder input ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving) # decode ret_dct['out'] = self.spade_generator(feature=ret_dct['out']) # float the dict if self.inference_cfg.flag_use_half_precision: for k, v in ret_dct.items(): if isinstance(v, torch.Tensor): ret_dct[k] = v.float() return ret_dct def parse_output(self, out: torch.Tensor) -> np.ndarray: """ construct the output as standard return: 1xHxWx3, uint8 """ out = np.transpose(out.data.cpu().numpy(), [0, 2, 3, 1]) # 1x3xHxW -> 1xHxWx3 out = np.clip(out, 0, 1) # clip to 0~1 out = np.clip(out * 255, 0, 255).astype(np.uint8) # 0~1 -> 0~255 return out def calc_ratio(self, lmk_lst): input_eye_ratio_lst = [] input_lip_ratio_lst = [] for lmk in lmk_lst: # for eyes retargeting input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None])) # for lip retargeting input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None])) return input_eye_ratio_lst, input_lip_ratio_lst def calc_combined_eye_ratio(self, c_d_eyes_i, source_lmk): c_s_eyes = calc_eye_close_ratio(source_lmk[None]) c_s_eyes_tensor = torch.from_numpy(c_s_eyes).float().to(self.device) c_d_eyes_i_tensor = torch.Tensor([c_d_eyes_i[0][0]]).reshape(1, 1).to(self.device) # [c_s,eyes, c_d,eyes,i] combined_eye_ratio_tensor = torch.cat([c_s_eyes_tensor, c_d_eyes_i_tensor], dim=1) return combined_eye_ratio_tensor def calc_combined_lip_ratio(self, c_d_lip_i, source_lmk): c_s_lip = calc_lip_close_ratio(source_lmk[None]) c_s_lip_tensor = torch.from_numpy(c_s_lip).float().to(self.device) c_d_lip_i_tensor = torch.Tensor([c_d_lip_i[0]]).to(self.device).reshape(1, 1) # 1x1 # [c_s,lip, c_d,lip,i] combined_lip_ratio_tensor = torch.cat([c_s_lip_tensor, c_d_lip_i_tensor], dim=1) # 1x2 return combined_lip_ratio_tensor class LivePortraitWrapperAnimal(LivePortraitWrapper): """ Wrapper for Animal """ def __init__(self, inference_cfg: InferenceConfig): # super().__init__(inference_cfg) # 调用父类的初始化方法 self.inference_cfg = inference_cfg self.device_id = inference_cfg.device_id self.compile = inference_cfg.flag_do_torch_compile if inference_cfg.flag_force_cpu: self.device = 'cpu' else: try: if torch.backends.mps.is_available(): self.device = 'mps' else: self.device = 'cuda:' + str(self.device_id) except: self.device = 'cuda:' + str(self.device_id) model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader) # init F self.appearance_feature_extractor = load_model(inference_cfg.checkpoint_F_animal, model_config, self.device, 'appearance_feature_extractor') log(f'Load appearance_feature_extractor from {osp.realpath(inference_cfg.checkpoint_F_animal)} done.') # init M self.motion_extractor = load_model(inference_cfg.checkpoint_M_animal, model_config, self.device, 'motion_extractor') log(f'Load motion_extractor from {osp.realpath(inference_cfg.checkpoint_M_animal)} done.') # init W self.warping_module = load_model(inference_cfg.checkpoint_W_animal, model_config, self.device, 'warping_module') log(f'Load warping_module from {osp.realpath(inference_cfg.checkpoint_W_animal)} done.') # init G self.spade_generator = load_model(inference_cfg.checkpoint_G_animal, model_config, self.device, 'spade_generator') log(f'Load spade_generator from {osp.realpath(inference_cfg.checkpoint_G_animal)} done.') # init S and R if inference_cfg.checkpoint_S_animal is not None and osp.exists(inference_cfg.checkpoint_S_animal): self.stitching_retargeting_module = load_model(inference_cfg.checkpoint_S_animal, model_config, self.device, 'stitching_retargeting_module') log(f'Load stitching_retargeting_module from {osp.realpath(inference_cfg.checkpoint_S_animal)} done.') else: self.stitching_retargeting_module = None # Optimize for inference if self.compile: torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution self.warping_module = torch.compile(self.warping_module, mode='max-autotune') self.spade_generator = torch.compile(self.spade_generator, mode='max-autotune') self.timer = Timer() ================================================ FILE: src/modules/__init__.py ================================================ ================================================ FILE: src/modules/appearance_feature_extractor.py ================================================ # coding: utf-8 """ Appearance extractor(F) defined in paper, which maps the source image s to a 3D appearance feature volume. """ import torch from torch import nn from .util import SameBlock2d, DownBlock2d, ResBlock3d class AppearanceFeatureExtractor(nn.Module): def __init__(self, image_channel, block_expansion, num_down_blocks, max_features, reshape_channel, reshape_depth, num_resblocks): super(AppearanceFeatureExtractor, self).__init__() self.image_channel = image_channel self.block_expansion = block_expansion self.num_down_blocks = num_down_blocks self.max_features = max_features self.reshape_channel = reshape_channel self.reshape_depth = reshape_depth self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1)) down_blocks = [] for i in range(num_down_blocks): in_features = min(max_features, block_expansion * (2 ** i)) out_features = min(max_features, block_expansion * (2 ** (i + 1))) down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) self.down_blocks = nn.ModuleList(down_blocks) self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) self.resblocks_3d = torch.nn.Sequential() for i in range(num_resblocks): self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) def forward(self, source_image): out = self.first(source_image) # Bx3x256x256 -> Bx64x256x256 for i in range(len(self.down_blocks)): out = self.down_blocks[i](out) out = self.second(out) bs, c, h, w = out.shape # ->Bx512x64x64 f_s = out.view(bs, self.reshape_channel, self.reshape_depth, h, w) # ->Bx32x16x64x64 f_s = self.resblocks_3d(f_s) # ->Bx32x16x64x64 return f_s ================================================ FILE: src/modules/convnextv2.py ================================================ # coding: utf-8 """ This moudle is adapted to the ConvNeXtV2 version for the extraction of implicit keypoints, poses, and expression deformation. """ import torch import torch.nn as nn # from timm.models.layers import trunc_normal_, DropPath from .util import LayerNorm, DropPath, trunc_normal_, GRN __all__ = ['convnextv2_tiny'] class Block(nn.Module): """ ConvNeXtV2 Block. Args: dim (int): Number of input channels. drop_path (float): Stochastic depth rate. Default: 0.0 """ def __init__(self, dim, drop_path=0.): super().__init__() self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv self.norm = LayerNorm(dim, eps=1e-6) self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.grn = GRN(4 * dim) self.pwconv2 = nn.Linear(4 * dim, dim) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): input = x x = self.dwconv(x) x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.grn(x) x = self.pwconv2(x) x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) x = input + self.drop_path(x) return x class ConvNeXtV2(nn.Module): """ ConvNeXt V2 Args: in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] drop_path_rate (float): Stochastic depth rate. Default: 0. head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. """ def __init__( self, in_chans=3, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., **kwargs ): super().__init__() self.depths = depths self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers stem = nn.Sequential( nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm(dims[0], eps=1e-6, data_format="channels_first") ) self.downsample_layers.append(stem) for i in range(3): downsample_layer = nn.Sequential( LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), ) self.downsample_layers.append(downsample_layer) self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] cur = 0 for i in range(4): stage = nn.Sequential( *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])] ) self.stages.append(stage) cur += depths[i] self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer # NOTE: the output semantic items num_bins = kwargs.get('num_bins', 66) num_kp = kwargs.get('num_kp', 24) # the number of implicit keypoints self.fc_kp = nn.Linear(dims[-1], 3 * num_kp) # implicit keypoints # print('dims[-1]: ', dims[-1]) self.fc_scale = nn.Linear(dims[-1], 1) # scale self.fc_pitch = nn.Linear(dims[-1], num_bins) # pitch bins self.fc_yaw = nn.Linear(dims[-1], num_bins) # yaw bins self.fc_roll = nn.Linear(dims[-1], num_bins) # roll bins self.fc_t = nn.Linear(dims[-1], 3) # translation self.fc_exp = nn.Linear(dims[-1], 3 * num_kp) # expression / delta def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=.02) nn.init.constant_(m.bias, 0) def forward_features(self, x): for i in range(4): x = self.downsample_layers[i](x) x = self.stages[i](x) return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) def forward(self, x): x = self.forward_features(x) # implicit keypoints kp = self.fc_kp(x) # pose and expression deformation pitch = self.fc_pitch(x) yaw = self.fc_yaw(x) roll = self.fc_roll(x) t = self.fc_t(x) exp = self.fc_exp(x) scale = self.fc_scale(x) ret_dct = { 'pitch': pitch, 'yaw': yaw, 'roll': roll, 't': t, 'exp': exp, 'scale': scale, 'kp': kp, # canonical keypoint } return ret_dct def convnextv2_tiny(**kwargs): model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) return model ================================================ FILE: src/modules/dense_motion.py ================================================ # coding: utf-8 """ The module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving """ from torch import nn import torch.nn.functional as F import torch from .util import Hourglass, make_coordinate_grid, kp2gaussian class DenseMotionNetwork(nn.Module): def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, estimate_occlusion_map=True): super(DenseMotionNetwork, self).__init__() self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) # ~60+G self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) # 65G! NOTE: computation cost is large self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) # 0.8G self.norm = nn.BatchNorm3d(compress, affine=True) self.num_kp = num_kp self.flag_estimate_occlusion_map = estimate_occlusion_map if self.flag_estimate_occlusion_map: self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3) else: self.occlusion = None def create_sparse_motions(self, feature, kp_driving, kp_source): bs, _, d, h, w = feature.shape # (bs, 4, 16, 64, 64) identity_grid = make_coordinate_grid((d, h, w), ref=kp_source) # (16, 64, 64, 3) identity_grid = identity_grid.view(1, 1, d, h, w, 3) # (1, 1, d=16, h=64, w=64, 3) coordinate_grid = identity_grid - kp_driving.view(bs, self.num_kp, 1, 1, 1, 3) k = coordinate_grid.shape[1] # NOTE: there lacks an one-order flow driving_to_source = coordinate_grid + kp_source.view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3) # adding background feature identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1) sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) # (bs, 1+num_kp, d, h, w, 3) return sparse_motions def create_deformed_feature(self, feature, sparse_motions): bs, _, d, h, w = feature.shape feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w) feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w) sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3) sparse_deformed = F.grid_sample(feature_repeat, sparse_motions, align_corners=False) sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w) return sparse_deformed def create_heatmap_representations(self, feature, kp_driving, kp_source): spatial_size = feature.shape[3:] # (d=16, h=64, w=64) gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w) gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w) heatmap = gaussian_driving - gaussian_source # (bs, num_kp, d, h, w) # adding background feature zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.dtype).to(heatmap.device) heatmap = torch.cat([zeros, heatmap], dim=1) heatmap = heatmap.unsqueeze(2) # (bs, 1+num_kp, 1, d, h, w) return heatmap def forward(self, feature, kp_driving, kp_source): bs, _, d, h, w = feature.shape # (bs, 32, 16, 64, 64) feature = self.compress(feature) # (bs, 4, 16, 64, 64) feature = self.norm(feature) # (bs, 4, 16, 64, 64) feature = F.relu(feature) # (bs, 4, 16, 64, 64) out_dict = dict() # 1. deform 3d feature sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) # (bs, 1+num_kp, d, h, w, 3) deformed_feature = self.create_deformed_feature(feature, sparse_motion) # (bs, 1+num_kp, c=4, d=16, h=64, w=64) # 2. (bs, 1+num_kp, d, h, w) heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) # (bs, 1+num_kp, 1, d, h, w) input = torch.cat([heatmap, deformed_feature], dim=2) # (bs, 1+num_kp, c=5, d=16, h=64, w=64) input = input.view(bs, -1, d, h, w) # (bs, (1+num_kp)*c=105, d=16, h=64, w=64) prediction = self.hourglass(input) mask = self.mask(prediction) mask = F.softmax(mask, dim=1) # (bs, 1+num_kp, d=16, h=64, w=64) out_dict['mask'] = mask mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w) sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w) deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) mask take effect in this place deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3) out_dict['deformation'] = deformation if self.flag_estimate_occlusion_map: bs, _, d, h, w = prediction.shape prediction_reshape = prediction.view(bs, -1, h, w) occlusion_map = torch.sigmoid(self.occlusion(prediction_reshape)) # Bx1x64x64 out_dict['occlusion_map'] = occlusion_map return out_dict ================================================ FILE: src/modules/motion_extractor.py ================================================ # coding: utf-8 """ Motion extractor(M), which directly predicts the canonical keypoints, head pose and expression deformation of the input image """ from torch import nn import torch from .convnextv2 import convnextv2_tiny from .util import filter_state_dict model_dict = { 'convnextv2_tiny': convnextv2_tiny, } class MotionExtractor(nn.Module): def __init__(self, **kwargs): super(MotionExtractor, self).__init__() # default is convnextv2_base backbone = kwargs.get('backbone', 'convnextv2_tiny') self.detector = model_dict.get(backbone)(**kwargs) def load_pretrained(self, init_path: str): if init_path not in (None, ''): state_dict = torch.load(init_path, map_location=lambda storage, loc: storage)['model'] state_dict = filter_state_dict(state_dict, remove_name='head') ret = self.detector.load_state_dict(state_dict, strict=False) print(f'Load pretrained model from {init_path}, ret: {ret}') def forward(self, x): out = self.detector(x) return out ================================================ FILE: src/modules/spade_generator.py ================================================ # coding: utf-8 """ Spade decoder(G) defined in the paper, which input the warped feature to generate the animated image. """ import torch from torch import nn import torch.nn.functional as F from .util import SPADEResnetBlock class SPADEDecoder(nn.Module): def __init__(self, upscale=1, max_features=256, block_expansion=64, out_channels=64, num_down_blocks=2): for i in range(num_down_blocks): input_channels = min(max_features, block_expansion * (2 ** (i + 1))) self.upscale = upscale super().__init__() norm_G = 'spadespectralinstance' label_num_channels = input_channels # 256 self.fc = nn.Conv2d(input_channels, 2 * input_channels, 3, padding=1) self.G_middle_0 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) self.G_middle_1 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) self.G_middle_2 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) self.G_middle_3 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) self.G_middle_4 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) self.G_middle_5 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) self.up_0 = SPADEResnetBlock(2 * input_channels, input_channels, norm_G, label_num_channels) self.up_1 = SPADEResnetBlock(input_channels, out_channels, norm_G, label_num_channels) self.up = nn.Upsample(scale_factor=2) if self.upscale is None or self.upscale <= 1: self.conv_img = nn.Conv2d(out_channels, 3, 3, padding=1) else: self.conv_img = nn.Sequential( nn.Conv2d(out_channels, 3 * (2 * 2), kernel_size=3, padding=1), nn.PixelShuffle(upscale_factor=2) ) def forward(self, feature): seg = feature # Bx256x64x64 x = self.fc(feature) # Bx512x64x64 x = self.G_middle_0(x, seg) x = self.G_middle_1(x, seg) x = self.G_middle_2(x, seg) x = self.G_middle_3(x, seg) x = self.G_middle_4(x, seg) x = self.G_middle_5(x, seg) x = self.up(x) # Bx512x64x64 -> Bx512x128x128 x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128 x = self.up(x) # Bx256x128x128 -> Bx256x256x256 x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256 x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW x = torch.sigmoid(x) # Bx3xHxW return x ================================================ FILE: src/modules/stitching_retargeting_network.py ================================================ # coding: utf-8 """ Stitching module(S) and two retargeting modules(R) defined in the paper. - The stitching module pastes the animated portrait back into the original image space without pixel misalignment, such as in the stitching region. - The eyes retargeting module is designed to address the issue of incomplete eye closure during cross-id reenactment, especially when a person with small eyes drives a person with larger eyes. - The lip retargeting module is designed similarly to the eye retargeting module, and can also normalize the input by ensuring that the lips are in a closed state, which facilitates better animation driving. """ from torch import nn class StitchingRetargetingNetwork(nn.Module): def __init__(self, input_size, hidden_sizes, output_size): super(StitchingRetargetingNetwork, self).__init__() layers = [] for i in range(len(hidden_sizes)): if i == 0: layers.append(nn.Linear(input_size, hidden_sizes[i])) else: layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i])) layers.append(nn.ReLU(inplace=True)) layers.append(nn.Linear(hidden_sizes[-1], output_size)) self.mlp = nn.Sequential(*layers) def initialize_weights_to_zero(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.zeros_(m.weight) nn.init.zeros_(m.bias) def forward(self, x): return self.mlp(x) ================================================ FILE: src/modules/util.py ================================================ # coding: utf-8 """ This file defines various neural network modules and utility functions, including convolutional and residual blocks, normalizations, and functions for spatial transformation and tensor manipulation. """ from torch import nn import torch.nn.functional as F import torch import torch.nn.utils.spectral_norm as spectral_norm import math import warnings import collections.abc from itertools import repeat def kp2gaussian(kp, spatial_size, kp_variance): """ Transform a keypoint into gaussian like representation """ mean = kp coordinate_grid = make_coordinate_grid(spatial_size, mean) number_of_leading_dimensions = len(mean.shape) - 1 shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape coordinate_grid = coordinate_grid.view(*shape) repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1) coordinate_grid = coordinate_grid.repeat(*repeats) # Preprocess kp shape shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3) mean = mean.view(*shape) mean_sub = (coordinate_grid - mean) out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance) return out def make_coordinate_grid(spatial_size, ref, **kwargs): d, h, w = spatial_size x = torch.arange(w).type(ref.dtype).to(ref.device) y = torch.arange(h).type(ref.dtype).to(ref.device) z = torch.arange(d).type(ref.dtype).to(ref.device) # NOTE: must be right-down-in x = (2 * (x / (w - 1)) - 1) # the x axis faces to the right y = (2 * (y / (h - 1)) - 1) # the y axis faces to the bottom z = (2 * (z / (d - 1)) - 1) # the z axis faces to the inner yy = y.view(1, -1, 1).repeat(d, 1, w) xx = x.view(1, 1, -1).repeat(d, h, 1) zz = z.view(-1, 1, 1).repeat(1, h, w) meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3) return meshed class ConvT2d(nn.Module): """ Upsampling block for use in decoder. """ def __init__(self, in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1): super(ConvT2d, self).__init__() self.convT = nn.ConvTranspose2d(in_features, out_features, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding) self.norm = nn.InstanceNorm2d(out_features) def forward(self, x): out = self.convT(x) out = self.norm(out) out = F.leaky_relu(out) return out class ResBlock3d(nn.Module): """ Res block, preserve spatial resolution. """ def __init__(self, in_features, kernel_size, padding): super(ResBlock3d, self).__init__() self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) self.norm1 = nn.BatchNorm3d(in_features, affine=True) self.norm2 = nn.BatchNorm3d(in_features, affine=True) def forward(self, x): out = self.norm1(x) out = F.relu(out) out = self.conv1(out) out = self.norm2(out) out = F.relu(out) out = self.conv2(out) out += x return out class UpBlock3d(nn.Module): """ Upsampling block for use in decoder. """ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): super(UpBlock3d, self).__init__() self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) self.norm = nn.BatchNorm3d(out_features, affine=True) def forward(self, x): out = F.interpolate(x, scale_factor=(1, 2, 2)) out = self.conv(out) out = self.norm(out) out = F.relu(out) return out class DownBlock2d(nn.Module): """ Downsampling block for use in encoder. """ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): super(DownBlock2d, self).__init__() self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) self.norm = nn.BatchNorm2d(out_features, affine=True) self.pool = nn.AvgPool2d(kernel_size=(2, 2)) def forward(self, x): out = self.conv(x) out = self.norm(out) out = F.relu(out) out = self.pool(out) return out class DownBlock3d(nn.Module): """ Downsampling block for use in encoder. """ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): super(DownBlock3d, self).__init__() ''' self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups, stride=(1, 2, 2)) ''' self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) self.norm = nn.BatchNorm3d(out_features, affine=True) self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2)) def forward(self, x): out = self.conv(x) out = self.norm(out) out = F.relu(out) out = self.pool(out) return out class SameBlock2d(nn.Module): """ Simple block, preserve spatial resolution. """ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False): super(SameBlock2d, self).__init__() self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) self.norm = nn.BatchNorm2d(out_features, affine=True) if lrelu: self.ac = nn.LeakyReLU() else: self.ac = nn.ReLU() def forward(self, x): out = self.conv(x) out = self.norm(out) out = self.ac(out) return out class Encoder(nn.Module): """ Hourglass Encoder """ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): super(Encoder, self).__init__() down_blocks = [] for i in range(num_blocks): down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), min(max_features, block_expansion * (2 ** (i + 1))), kernel_size=3, padding=1)) self.down_blocks = nn.ModuleList(down_blocks) def forward(self, x): outs = [x] for down_block in self.down_blocks: outs.append(down_block(outs[-1])) return outs class Decoder(nn.Module): """ Hourglass Decoder """ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): super(Decoder, self).__init__() up_blocks = [] for i in range(num_blocks)[::-1]: in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) out_filters = min(max_features, block_expansion * (2 ** i)) up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1)) self.up_blocks = nn.ModuleList(up_blocks) self.out_filters = block_expansion + in_features self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1) self.norm = nn.BatchNorm3d(self.out_filters, affine=True) def forward(self, x): out = x.pop() for up_block in self.up_blocks: out = up_block(out) skip = x.pop() out = torch.cat([out, skip], dim=1) out = self.conv(out) out = self.norm(out) out = F.relu(out) return out class Hourglass(nn.Module): """ Hourglass architecture. """ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): super(Hourglass, self).__init__() self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) self.out_filters = self.decoder.out_filters def forward(self, x): return self.decoder(self.encoder(x)) class SPADE(nn.Module): def __init__(self, norm_nc, label_nc): super().__init__() self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) nhidden = 128 self.mlp_shared = nn.Sequential( nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1), nn.ReLU()) self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) def forward(self, x, segmap): normalized = self.param_free_norm(x) segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') actv = self.mlp_shared(segmap) gamma = self.mlp_gamma(actv) beta = self.mlp_beta(actv) out = normalized * (1 + gamma) + beta return out class SPADEResnetBlock(nn.Module): def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1): super().__init__() # Attributes self.learned_shortcut = (fin != fout) fmiddle = min(fin, fout) self.use_se = use_se # create conv layers self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation) self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation) if self.learned_shortcut: self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) # apply spectral norm if specified if 'spectral' in norm_G: self.conv_0 = spectral_norm(self.conv_0) self.conv_1 = spectral_norm(self.conv_1) if self.learned_shortcut: self.conv_s = spectral_norm(self.conv_s) # define normalization layers self.norm_0 = SPADE(fin, label_nc) self.norm_1 = SPADE(fmiddle, label_nc) if self.learned_shortcut: self.norm_s = SPADE(fin, label_nc) def forward(self, x, seg1): x_s = self.shortcut(x, seg1) dx = self.conv_0(self.actvn(self.norm_0(x, seg1))) dx = self.conv_1(self.actvn(self.norm_1(dx, seg1))) out = x_s + dx return out def shortcut(self, x, seg1): if self.learned_shortcut: x_s = self.conv_s(self.norm_s(x, seg1)) else: x_s = x return x_s def actvn(self, x): return F.leaky_relu(x, 2e-1) def filter_state_dict(state_dict, remove_name='fc'): new_state_dict = {} for key in state_dict: if remove_name in key: continue new_state_dict[key] = state_dict[key] return new_state_dict class GRN(nn.Module): """ GRN (Global Response Normalization) layer """ def __init__(self, dim): super().__init__() self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) def forward(self, x): Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) return self.gamma * (x * Nx) + self.beta + x class LayerNorm(nn.Module): r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). """ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): super().__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.eps = eps self.data_format = data_format if self.data_format not in ["channels_last", "channels_first"]: raise NotImplementedError self.normalized_shape = (normalized_shape, ) def forward(self, x): if self.data_format == "channels_last": return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) elif self.data_format == "channels_first": u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1. + math.erf(x / math.sqrt(2.))) / 2. if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2) with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def drop_path(x, drop_prob=0., training=False, scale_by_keep=True): """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0 and scale_by_keep: random_tensor.div_(keep_prob) return x * random_tensor class DropPath(nn.Module): """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None, scale_by_keep=True): super(DropPath, self).__init__() self.drop_prob = drop_prob self.scale_by_keep = scale_by_keep def forward(self, x): return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): return _no_grad_trunc_normal_(tensor, mean, std, a, b) # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return tuple(x) return tuple(repeat(x, n)) return parse to_2tuple = _ntuple(2) ================================================ FILE: src/modules/warping_network.py ================================================ # coding: utf-8 """ Warping field estimator(W) defined in the paper, which generates a warping field using the implicit keypoint representations x_s and x_d, and employs this flow field to warp the source feature volume f_s. """ from torch import nn import torch.nn.functional as F from .util import SameBlock2d from .dense_motion import DenseMotionNetwork class WarpingNetwork(nn.Module): def __init__( self, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, estimate_occlusion_map=False, dense_motion_params=None, **kwargs ): super(WarpingNetwork, self).__init__() self.upscale = kwargs.get('upscale', 1) self.flag_use_occlusion_map = kwargs.get('flag_use_occlusion_map', True) if dense_motion_params is not None: self.dense_motion_network = DenseMotionNetwork( num_kp=num_kp, feature_channel=reshape_channel, estimate_occlusion_map=estimate_occlusion_map, **dense_motion_params ) else: self.dense_motion_network = None self.third = SameBlock2d(max_features, block_expansion * (2 ** num_down_blocks), kernel_size=(3, 3), padding=(1, 1), lrelu=True) self.fourth = nn.Conv2d(in_channels=block_expansion * (2 ** num_down_blocks), out_channels=block_expansion * (2 ** num_down_blocks), kernel_size=1, stride=1) self.estimate_occlusion_map = estimate_occlusion_map def deform_input(self, inp, deformation): return F.grid_sample(inp, deformation, align_corners=False) def forward(self, feature_3d, kp_driving, kp_source): if self.dense_motion_network is not None: # Feature warper, Transforming feature representation according to deformation and occlusion dense_motion = self.dense_motion_network( feature=feature_3d, kp_driving=kp_driving, kp_source=kp_source ) if 'occlusion_map' in dense_motion: occlusion_map = dense_motion['occlusion_map'] # Bx1x64x64 else: occlusion_map = None deformation = dense_motion['deformation'] # Bx16x64x64x3 out = self.deform_input(feature_3d, deformation) # Bx32x16x64x64 bs, c, d, h, w = out.shape # Bx32x16x64x64 out = out.view(bs, c * d, h, w) # -> Bx512x64x64 out = self.third(out) # -> Bx256x64x64 out = self.fourth(out) # -> Bx256x64x64 if self.flag_use_occlusion_map and (occlusion_map is not None): out = out * occlusion_map ret_dct = { 'occlusion_map': occlusion_map, 'deformation': deformation, 'out': out, } return ret_dct ================================================ FILE: src/utils/__init__.py ================================================ ================================================ FILE: src/utils/animal_landmark_runner.py ================================================ # coding: utf-8 """ face detectoin and alignment using XPose """ import os import pickle import torch import numpy as np from PIL import Image from torchvision.ops import nms from .timer import Timer from .rprint import rlog as log from .helper import clean_state_dict from .dependencies.XPose import transforms as T from .dependencies.XPose.models import build_model from .dependencies.XPose.predefined_keypoints import * from .dependencies.XPose.util import box_ops from .dependencies.XPose.util.config import Config class XPoseRunner(object): def __init__(self, model_config_path, model_checkpoint_path, embeddings_cache_path=None, cpu_only=False, **kwargs): self.device_id = kwargs.get("device_id", 0) self.flag_use_half_precision = kwargs.get("flag_use_half_precision", True) self.device = f"cuda:{self.device_id}" if not cpu_only else "cpu" self.model = self.load_animal_model(model_config_path, model_checkpoint_path, self.device) self.timer = Timer() # Load cached embeddings if available try: with open(f'{embeddings_cache_path}_9.pkl', 'rb') as f: self.ins_text_embeddings_9, self.kpt_text_embeddings_9 = pickle.load(f) with open(f'{embeddings_cache_path}_68.pkl', 'rb') as f: self.ins_text_embeddings_68, self.kpt_text_embeddings_68 = pickle.load(f) print("Loaded cached embeddings from file.") except Exception: raise ValueError("Could not load clip embeddings from file, please check your file path.") def load_animal_model(self, model_config_path, model_checkpoint_path, device): args = Config.fromfile(model_config_path) args.device = device model = build_model(args) checkpoint = torch.load(model_checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False) load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) model.eval() return model def load_image(self, input_image): image_pil = input_image.convert("RGB") transform = T.Compose([ T.RandomResize([800], max_size=1333), # NOTE: fixed size to 800 T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) image, _ = transform(image_pil, None) return image_pil, image def get_unipose_output(self, image, instance_text_prompt, keypoint_text_prompt, box_threshold, IoU_threshold): instance_list = instance_text_prompt.split(',') if len(keypoint_text_prompt) == 9: # torch.Size([1, 512]) torch.Size([9, 512]) ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_9, self.kpt_text_embeddings_9 elif len(keypoint_text_prompt) ==68: # torch.Size([1, 512]) torch.Size([68, 512]) ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_68, self.kpt_text_embeddings_68 else: raise ValueError("Invalid number of keypoint embeddings.") target = { "instance_text_prompt": instance_list, "keypoint_text_prompt": keypoint_text_prompt, "object_embeddings_text": ins_text_embeddings.float(), "kpts_embeddings_text": torch.cat((kpt_text_embeddings.float(), torch.zeros(100 - kpt_text_embeddings.shape[0], 512, device=self.device)), dim=0), "kpt_vis_text": torch.cat((torch.ones(kpt_text_embeddings.shape[0], device=self.device), torch.zeros(100 - kpt_text_embeddings.shape[0], device=self.device)), dim=0) } self.model = self.model.to(self.device) image = image.to(self.device) with torch.no_grad(): with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.flag_use_half_precision): outputs = self.model(image[None], [target]) logits = outputs["pred_logits"].sigmoid()[0] boxes = outputs["pred_boxes"][0] keypoints = outputs["pred_keypoints"][0][:, :2 * len(keypoint_text_prompt)] logits_filt = logits.cpu().clone() boxes_filt = boxes.cpu().clone() keypoints_filt = keypoints.cpu().clone() filt_mask = logits_filt.max(dim=1)[0] > box_threshold logits_filt = logits_filt[filt_mask] boxes_filt = boxes_filt[filt_mask] keypoints_filt = keypoints_filt[filt_mask] keep_indices = nms(box_ops.box_cxcywh_to_xyxy(boxes_filt), logits_filt.max(dim=1)[0], iou_threshold=IoU_threshold) filtered_boxes = boxes_filt[keep_indices] filtered_keypoints = keypoints_filt[keep_indices] return filtered_boxes, filtered_keypoints def run(self, input_image, instance_text_prompt, keypoint_text_example, box_threshold, IoU_threshold): if keypoint_text_example in globals(): keypoint_dict = globals()[keypoint_text_example] elif instance_text_prompt in globals(): keypoint_dict = globals()[instance_text_prompt] else: keypoint_dict = globals()["animal"] keypoint_text_prompt = keypoint_dict.get("keypoints") keypoint_skeleton = keypoint_dict.get("skeleton") image_pil, image = self.load_image(input_image) boxes_filt, keypoints_filt = self.get_unipose_output(image, instance_text_prompt, keypoint_text_prompt, box_threshold, IoU_threshold) size = image_pil.size H, W = size[1], size[0] keypoints_filt = keypoints_filt[0].squeeze(0) kp = np.array(keypoints_filt.cpu()) num_kpts = len(keypoint_text_prompt) Z = kp[:num_kpts * 2] * np.array([W, H] * num_kpts) Z = Z.reshape(num_kpts * 2) x = Z[0::2] y = Z[1::2] return np.stack((x, y), axis=1) def warmup(self): self.timer.tic() img_rgb = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)) self.run(img_rgb, 'face', 'face', box_threshold=0.0, IoU_threshold=0.0) elapse = self.timer.toc() log(f'XPoseRunner warmup time: {elapse:.3f}s') ================================================ FILE: src/utils/camera.py ================================================ # coding: utf-8 """ functions for processing and transforming 3D facial keypoints """ import numpy as np import torch import torch.nn.functional as F PI = np.pi def headpose_pred_to_degree(pred): """ pred: (bs, 66) or (bs, 1) or others """ if pred.ndim > 1 and pred.shape[1] == 66: # NOTE: note that the average is modified to 97.5 device = pred.device idx_tensor = [idx for idx in range(0, 66)] idx_tensor = torch.FloatTensor(idx_tensor).to(device) pred = F.softmax(pred, dim=1) degree = torch.sum(pred*idx_tensor, axis=1) * 3 - 97.5 return degree return pred def get_rotation_matrix(pitch_, yaw_, roll_): """ the input is in degree """ # transform to radian pitch = pitch_ / 180 * PI yaw = yaw_ / 180 * PI roll = roll_ / 180 * PI device = pitch.device if pitch.ndim == 1: pitch = pitch.unsqueeze(1) if yaw.ndim == 1: yaw = yaw.unsqueeze(1) if roll.ndim == 1: roll = roll.unsqueeze(1) # calculate the euler matrix bs = pitch.shape[0] ones = torch.ones([bs, 1]).to(device) zeros = torch.zeros([bs, 1]).to(device) x, y, z = pitch, yaw, roll rot_x = torch.cat([ ones, zeros, zeros, zeros, torch.cos(x), -torch.sin(x), zeros, torch.sin(x), torch.cos(x) ], dim=1).reshape([bs, 3, 3]) rot_y = torch.cat([ torch.cos(y), zeros, torch.sin(y), zeros, ones, zeros, -torch.sin(y), zeros, torch.cos(y) ], dim=1).reshape([bs, 3, 3]) rot_z = torch.cat([ torch.cos(z), -torch.sin(z), zeros, torch.sin(z), torch.cos(z), zeros, zeros, zeros, ones ], dim=1).reshape([bs, 3, 3]) rot = rot_z @ rot_y @ rot_x return rot.permute(0, 2, 1) # transpose ================================================ FILE: src/utils/check_windows_port.py ================================================ import socket import sys if len(sys.argv) != 2: print("Usage: python check_port.py ") sys.exit(1) port = int(sys.argv[1]) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(1) result = sock.connect_ex(('127.0.0.1', port)) if result == 0: print("LISTENING") else: print("NOT LISTENING") sock.close ================================================ FILE: src/utils/crop.py ================================================ # coding: utf-8 """ cropping function and the related preprocess functions for cropping """ import numpy as np import os.path as osp from math import sin, cos, acos, degrees import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) # NOTE: enforce single thread from .rprint import rprint as print DTYPE = np.float32 CV2_INTERP = cv2.INTER_LINEAR def make_abs_path(fn): return osp.join(osp.dirname(osp.realpath(__file__)), fn) def _transform_img(img, M, dsize, flags=CV2_INTERP, borderMode=None): """ conduct similarity or affine transformation to the image, do not do border operation! img: M: 2x3 matrix or 3x3 matrix dsize: target shape (width, height) """ if isinstance(dsize, tuple) or isinstance(dsize, list): _dsize = tuple(dsize) else: _dsize = (dsize, dsize) if borderMode is not None: return cv2.warpAffine(img, M[:2, :], dsize=_dsize, flags=flags, borderMode=borderMode, borderValue=(0, 0, 0)) else: return cv2.warpAffine(img, M[:2, :], dsize=_dsize, flags=flags) def _transform_pts(pts, M): """ conduct similarity or affine transformation to the pts pts: Nx2 ndarray M: 2x3 matrix or 3x3 matrix return: Nx2 """ return pts @ M[:2, :2].T + M[:2, 2] def parse_pt2_from_pt101(pt101, use_lip=True): """ parsing the 2 points according to the 101 points, which cancels the roll """ # the former version use the eye center, but it is not robust, now use interpolation pt_left_eye = np.mean(pt101[[39, 42, 45, 48]], axis=0) # left eye center pt_right_eye = np.mean(pt101[[51, 54, 57, 60]], axis=0) # right eye center if use_lip: # use lip pt_center_eye = (pt_left_eye + pt_right_eye) / 2 pt_center_lip = (pt101[75] + pt101[81]) / 2 pt2 = np.stack([pt_center_eye, pt_center_lip], axis=0) else: pt2 = np.stack([pt_left_eye, pt_right_eye], axis=0) return pt2 def parse_pt2_from_pt106(pt106, use_lip=True): """ parsing the 2 points according to the 106 points, which cancels the roll """ pt_left_eye = np.mean(pt106[[33, 35, 40, 39]], axis=0) # left eye center pt_right_eye = np.mean(pt106[[87, 89, 94, 93]], axis=0) # right eye center if use_lip: # use lip pt_center_eye = (pt_left_eye + pt_right_eye) / 2 pt_center_lip = (pt106[52] + pt106[61]) / 2 pt2 = np.stack([pt_center_eye, pt_center_lip], axis=0) else: pt2 = np.stack([pt_left_eye, pt_right_eye], axis=0) return pt2 def parse_pt2_from_pt203(pt203, use_lip=True): """ parsing the 2 points according to the 203 points, which cancels the roll """ pt_left_eye = np.mean(pt203[[0, 6, 12, 18]], axis=0) # left eye center pt_right_eye = np.mean(pt203[[24, 30, 36, 42]], axis=0) # right eye center if use_lip: # use lip pt_center_eye = (pt_left_eye + pt_right_eye) / 2 pt_center_lip = (pt203[48] + pt203[66]) / 2 pt2 = np.stack([pt_center_eye, pt_center_lip], axis=0) else: pt2 = np.stack([pt_left_eye, pt_right_eye], axis=0) return pt2 def parse_pt2_from_pt68(pt68, use_lip=True): """ parsing the 2 points according to the 68 points, which cancels the roll """ lm_idx = np.array([31, 37, 40, 43, 46, 49, 55], dtype=np.int32) - 1 if use_lip: pt5 = np.stack([ np.mean(pt68[lm_idx[[1, 2]], :], 0), # left eye np.mean(pt68[lm_idx[[3, 4]], :], 0), # right eye pt68[lm_idx[0], :], # nose pt68[lm_idx[5], :], # lip pt68[lm_idx[6], :] # lip ], axis=0) pt2 = np.stack([ (pt5[0] + pt5[1]) / 2, (pt5[3] + pt5[4]) / 2 ], axis=0) else: pt2 = np.stack([ np.mean(pt68[lm_idx[[1, 2]], :], 0), # left eye np.mean(pt68[lm_idx[[3, 4]], :], 0), # right eye ], axis=0) return pt2 def parse_pt2_from_pt5(pt5, use_lip=True): """ parsing the 2 points according to the 5 points, which cancels the roll """ if use_lip: pt2 = np.stack([ (pt5[0] + pt5[1]) / 2, (pt5[3] + pt5[4]) / 2 ], axis=0) else: pt2 = np.stack([ pt5[0], pt5[1] ], axis=0) return pt2 def parse_pt2_from_pt9(pt9, use_lip=True): ''' parsing the 2 points according to the 9 points, which cancels the roll ['right eye right', 'right eye left', 'left eye right', 'left eye left', 'nose tip', 'lip right', 'lip left', 'upper lip', 'lower lip'] ''' if use_lip: pt9 = np.stack([ (pt9[2] + pt9[3]) / 2, # left eye (pt9[0] + pt9[1]) / 2, # right eye pt9[4], (pt9[5] + pt9[6] ) / 2 # lip ], axis=0) pt2 = np.stack([ (pt9[0] + pt9[1]) / 2, # eye pt9[3] # lip ], axis=0) else: pt2 = np.stack([ (pt9[2] + pt9[3]) / 2, (pt9[0] + pt9[1]) / 2, ], axis=0) return pt2 def parse_pt2_from_pt_x(pts, use_lip=True): if pts.shape[0] == 101: pt2 = parse_pt2_from_pt101(pts, use_lip=use_lip) elif pts.shape[0] == 106: pt2 = parse_pt2_from_pt106(pts, use_lip=use_lip) elif pts.shape[0] == 68: pt2 = parse_pt2_from_pt68(pts, use_lip=use_lip) elif pts.shape[0] == 5: pt2 = parse_pt2_from_pt5(pts, use_lip=use_lip) elif pts.shape[0] == 203: pt2 = parse_pt2_from_pt203(pts, use_lip=use_lip) elif pts.shape[0] > 101: # take the first 101 points pt2 = parse_pt2_from_pt101(pts[:101], use_lip=use_lip) elif pts.shape[0] == 9: pt2 = parse_pt2_from_pt9(pts, use_lip=use_lip) else: raise Exception(f'Unknow shape: {pts.shape}') if not use_lip: # NOTE: to compile with the latter code, need to rotate the pt2 90 degrees clockwise manually v = pt2[1] - pt2[0] pt2[1, 0] = pt2[0, 0] - v[1] pt2[1, 1] = pt2[0, 1] + v[0] return pt2 def parse_rect_from_landmark( pts, scale=1.5, need_square=True, vx_ratio=0, vy_ratio=0, use_deg_flag=False, **kwargs ): """parsing center, size, angle from 101/68/5/x landmarks vx_ratio: the offset ratio along the pupil axis x-axis, multiplied by size vy_ratio: the offset ratio along the pupil axis y-axis, multiplied by size, which is used to contain more forehead area judge with pts.shape """ pt2 = parse_pt2_from_pt_x(pts, use_lip=kwargs.get('use_lip', True)) uy = pt2[1] - pt2[0] l = np.linalg.norm(uy) if l <= 1e-3: uy = np.array([0, 1], dtype=DTYPE) else: uy /= l ux = np.array((uy[1], -uy[0]), dtype=DTYPE) # the rotation degree of the x-axis, the clockwise is positive, the counterclockwise is negative (image coordinate system) # print(uy) # print(ux) angle = acos(ux[0]) if ux[1] < 0: angle = -angle # rotation matrix M = np.array([ux, uy]) # calculate the size which contains the angle degree of the bbox, and the center center0 = np.mean(pts, axis=0) rpts = (pts - center0) @ M.T # (M @ P.T).T = P @ M.T lt_pt = np.min(rpts, axis=0) rb_pt = np.max(rpts, axis=0) center1 = (lt_pt + rb_pt) / 2 size = rb_pt - lt_pt if need_square: m = max(size[0], size[1]) size[0] = m size[1] = m size *= scale # scale size center = center0 + ux * center1[0] + uy * center1[1] # counterclockwise rotation, equivalent to M.T @ center1.T center = center + ux * (vx_ratio * size) + uy * \ (vy_ratio * size) # considering the offset in vx and vy direction if use_deg_flag: angle = degrees(angle) return center, size, angle def parse_bbox_from_landmark(pts, **kwargs): center, size, angle = parse_rect_from_landmark(pts, **kwargs) cx, cy = center w, h = size # calculate the vertex positions before rotation bbox = np.array([ [cx-w/2, cy-h/2], # left, top [cx+w/2, cy-h/2], [cx+w/2, cy+h/2], # right, bottom [cx-w/2, cy+h/2] ], dtype=DTYPE) # construct rotation matrix bbox_rot = bbox.copy() R = np.array([ [np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)] ], dtype=DTYPE) # calculate the relative position of each vertex from the rotation center, then rotate these positions, and finally add the coordinates of the rotation center bbox_rot = (bbox_rot - center) @ R.T + center return { 'center': center, # 2x1 'size': size, # scalar 'angle': angle, # rad, counterclockwise 'bbox': bbox, # 4x2 'bbox_rot': bbox_rot, # 4x2 } def crop_image_by_bbox(img, bbox, lmk=None, dsize=512, angle=None, flag_rot=False, **kwargs): left, top, right, bot = bbox if int(right - left) != int(bot - top): print(f'right-left {right-left} != bot-top {bot-top}') size = right - left src_center = np.array([(left + right) / 2, (top + bot) / 2], dtype=DTYPE) tgt_center = np.array([dsize / 2, dsize / 2], dtype=DTYPE) s = dsize / size # scale if flag_rot and angle is not None: costheta, sintheta = cos(angle), sin(angle) cx, cy = src_center[0], src_center[1] # ori center tcx, tcy = tgt_center[0], tgt_center[1] # target center # need to infer M_o2c = np.array( [[s * costheta, s * sintheta, tcx - s * (costheta * cx + sintheta * cy)], [-s * sintheta, s * costheta, tcy - s * (-sintheta * cx + costheta * cy)]], dtype=DTYPE ) else: M_o2c = np.array( [[s, 0, tgt_center[0] - s * src_center[0]], [0, s, tgt_center[1] - s * src_center[1]]], dtype=DTYPE ) # if flag_rot and angle is None: # print('angle is None, but flag_rotate is True', style="bold yellow") img_crop = _transform_img(img, M_o2c, dsize=dsize, borderMode=kwargs.get('borderMode', None)) lmk_crop = _transform_pts(lmk, M_o2c) if lmk is not None else None M_o2c = np.vstack([M_o2c, np.array([0, 0, 1], dtype=DTYPE)]) M_c2o = np.linalg.inv(M_o2c) # cv2.imwrite('crop.jpg', img_crop) return { 'img_crop': img_crop, 'lmk_crop': lmk_crop, 'M_o2c': M_o2c, 'M_c2o': M_c2o, } def _estimate_similar_transform_from_pts( pts, dsize, scale=1.5, vx_ratio=0, vy_ratio=-0.1, flag_do_rot=True, **kwargs ): """ calculate the affine matrix of the cropped image from sparse points, the original image to the cropped image, the inverse is the cropped image to the original image pts: landmark, 101 or 68 points or other points, Nx2 scale: the larger scale factor, the smaller face ratio vx_ratio: x shift vy_ratio: y shift, the smaller the y shift, the lower the face region rot_flag: if it is true, conduct correction """ center, size, angle = parse_rect_from_landmark( pts, scale=scale, vx_ratio=vx_ratio, vy_ratio=vy_ratio, use_lip=kwargs.get('use_lip', True) ) s = dsize / size[0] # scale tgt_center = np.array([dsize / 2, dsize / 2], dtype=DTYPE) # center of dsize if flag_do_rot: costheta, sintheta = cos(angle), sin(angle) cx, cy = center[0], center[1] # ori center tcx, tcy = tgt_center[0], tgt_center[1] # target center # need to infer M_INV = np.array( [[s * costheta, s * sintheta, tcx - s * (costheta * cx + sintheta * cy)], [-s * sintheta, s * costheta, tcy - s * (-sintheta * cx + costheta * cy)]], dtype=DTYPE ) else: M_INV = np.array( [[s, 0, tgt_center[0] - s * center[0]], [0, s, tgt_center[1] - s * center[1]]], dtype=DTYPE ) M_INV_H = np.vstack([M_INV, np.array([0, 0, 1])]) M = np.linalg.inv(M_INV_H) # M_INV is from the original image to the cropped image, M is from the cropped image to the original image return M_INV, M[:2, ...] def crop_image(img, pts: np.ndarray, **kwargs): dsize = kwargs.get('dsize', 224) scale = kwargs.get('scale', 1.5) # 1.5 | 1.6 vy_ratio = kwargs.get('vy_ratio', -0.1) # -0.0625 | -0.1 M_INV, _ = _estimate_similar_transform_from_pts( pts, dsize=dsize, scale=scale, vy_ratio=vy_ratio, flag_do_rot=kwargs.get('flag_do_rot', True), ) img_crop = _transform_img(img, M_INV, dsize) # origin to crop pt_crop = _transform_pts(pts, M_INV) M_o2c = np.vstack([M_INV, np.array([0, 0, 1], dtype=DTYPE)]) M_c2o = np.linalg.inv(M_o2c) ret_dct = { 'M_o2c': M_o2c, # from the original image to the cropped image 3x3 'M_c2o': M_c2o, # from the cropped image to the original image 3x3 'img_crop': img_crop, # the cropped image 'pt_crop': pt_crop, # the landmarks of the cropped image } return ret_dct def average_bbox_lst(bbox_lst): if len(bbox_lst) == 0: return None bbox_arr = np.array(bbox_lst) return np.mean(bbox_arr, axis=0).tolist() def prepare_paste_back(mask_crop, crop_M_c2o, dsize): """prepare mask for later image paste back """ mask_ori = _transform_img(mask_crop, crop_M_c2o, dsize) mask_ori = mask_ori.astype(np.float32) / 255. return mask_ori def paste_back(img_crop, M_c2o, img_ori, mask_ori): """paste back the image """ dsize = (img_ori.shape[1], img_ori.shape[0]) result = _transform_img(img_crop, M_c2o, dsize=dsize) result = np.clip(mask_ori * result + (1 - mask_ori) * img_ori, 0, 255).astype(np.uint8) return result ================================================ FILE: src/utils/cropper.py ================================================ # coding: utf-8 import os.path as osp import torch import numpy as np import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) from PIL import Image from typing import List, Tuple, Union from dataclasses import dataclass, field from ..config.crop_config import CropConfig from .crop import ( average_bbox_lst, crop_image, crop_image_by_bbox, parse_bbox_from_landmark, ) from .io import contiguous from .rprint import rlog as log from .face_analysis_diy import FaceAnalysisDIY from .human_landmark_runner import LandmarkRunner as HumanLandmark def make_abs_path(fn): return osp.join(osp.dirname(osp.realpath(__file__)), fn) @dataclass class Trajectory: start: int = -1 # start frame end: int = -1 # end frame lmk_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list bbox_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # bbox list M_c2o_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # M_c2o list frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame list lmk_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list frame_rgb_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame crop list class Cropper(object): def __init__(self, **kwargs) -> None: self.crop_cfg: CropConfig = kwargs.get("crop_cfg", None) self.image_type = kwargs.get("image_type", 'human_face') device_id = kwargs.get("device_id", 0) flag_force_cpu = kwargs.get("flag_force_cpu", False) if flag_force_cpu: device = "cpu" face_analysis_wrapper_provider = ["CPUExecutionProvider"] else: try: if torch.backends.mps.is_available(): # Shape inference currently fails with CoreMLExecutionProvider # for the retinaface model device = "mps" face_analysis_wrapper_provider = ["CPUExecutionProvider"] else: device = "cuda" face_analysis_wrapper_provider = ["CUDAExecutionProvider"] except: device = "cuda" face_analysis_wrapper_provider = ["CUDAExecutionProvider"] self.face_analysis_wrapper = FaceAnalysisDIY( name="buffalo_l", root=self.crop_cfg.insightface_root, providers=face_analysis_wrapper_provider, ) self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512), det_thresh=self.crop_cfg.det_thresh) self.face_analysis_wrapper.warmup() self.human_landmark_runner = HumanLandmark( ckpt_path=self.crop_cfg.landmark_ckpt_path, onnx_provider=device, device_id=device_id, ) self.human_landmark_runner.warmup() if self.image_type == "animal_face": from .animal_landmark_runner import XPoseRunner as AnimalLandmarkRunner self.animal_landmark_runner = AnimalLandmarkRunner( model_config_path=self.crop_cfg.xpose_config_file_path, model_checkpoint_path=self.crop_cfg.xpose_ckpt_path, embeddings_cache_path=self.crop_cfg.xpose_embedding_cache_path, flag_use_half_precision=kwargs.get("flag_use_half_precision", True), ) self.animal_landmark_runner.warmup() def update_config(self, user_args): for k, v in user_args.items(): if hasattr(self.crop_cfg, k): setattr(self.crop_cfg, k, v) def crop_source_image(self, img_rgb_: np.ndarray, crop_cfg: CropConfig): # crop a source image and get neccessary information img_rgb = img_rgb_.copy() # copy it img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) if self.image_type == "human_face": src_face = self.face_analysis_wrapper.get( img_bgr, flag_do_landmark_2d_106=True, direction=crop_cfg.direction, max_face_num=crop_cfg.max_face_num, ) if len(src_face) == 0: log("No face detected in the source image.") return None elif len(src_face) > 1: log(f"More than one face detected in the image, only pick one face by rule {crop_cfg.direction}.") # NOTE: temporarily only pick the first face, to support multiple face in the future src_face = src_face[0] lmk = src_face.landmark_2d_106 # this is the 106 landmarks from insightface else: tmp_dct = { 'animal_face_9': 'animal_face', 'animal_face_68': 'face' } img_rgb_pil = Image.fromarray(img_rgb) lmk = self.animal_landmark_runner.run( img_rgb_pil, 'face', tmp_dct[crop_cfg.animal_face_type], 0, 0 ) # crop the face ret_dct = crop_image( img_rgb, # ndarray lmk, # 106x2 or Nx2 dsize=crop_cfg.dsize, scale=crop_cfg.scale, vx_ratio=crop_cfg.vx_ratio, vy_ratio=crop_cfg.vy_ratio, flag_do_rot=crop_cfg.flag_do_rot, ) # update a 256x256 version for network input ret_dct["img_crop_256x256"] = cv2.resize(ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA) if self.image_type == "human_face": lmk = self.human_landmark_runner.run(img_rgb, lmk) ret_dct["lmk_crop"] = lmk ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / crop_cfg.dsize else: # 68x2 or 9x2 ret_dct["lmk_crop"] = lmk return ret_dct def calc_lmk_from_cropped_image(self, img_rgb_, **kwargs): direction = kwargs.get("direction", "large-small") src_face = self.face_analysis_wrapper.get( contiguous(img_rgb_[..., ::-1]), # convert to BGR flag_do_landmark_2d_106=True, direction=direction, ) if len(src_face) == 0: log("No face detected in the source image.") return None elif len(src_face) > 1: log(f"More than one face detected in the image, only pick one face by rule {direction}.") src_face = src_face[0] lmk = src_face.landmark_2d_106 lmk = self.human_landmark_runner.run(img_rgb_, lmk) return lmk # TODO: support skipping frame with NO FACE def crop_source_video(self, source_rgb_lst, crop_cfg: CropConfig, **kwargs): """Tracking based landmarks/alignment and cropping""" trajectory = Trajectory() direction = kwargs.get("direction", "large-small") for idx, frame_rgb in enumerate(source_rgb_lst): if idx == 0 or trajectory.start == -1: src_face = self.face_analysis_wrapper.get( contiguous(frame_rgb[..., ::-1]), flag_do_landmark_2d_106=True, direction=crop_cfg.direction, max_face_num=crop_cfg.max_face_num, ) if len(src_face) == 0: log(f"No face detected in the frame #{idx}") continue elif len(src_face) > 1: log(f"More than one face detected in the source frame_{idx}, only pick one face by rule {direction}.") src_face = src_face[0] lmk = src_face.landmark_2d_106 lmk = self.human_landmark_runner.run(frame_rgb, lmk) trajectory.start, trajectory.end = idx, idx else: # TODO: add IOU check for tracking lmk = self.human_landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1]) trajectory.end = idx trajectory.lmk_lst.append(lmk) # crop the face ret_dct = crop_image( frame_rgb, # ndarray lmk, # 106x2 or Nx2 dsize=crop_cfg.dsize, scale=crop_cfg.scale, vx_ratio=crop_cfg.vx_ratio, vy_ratio=crop_cfg.vy_ratio, flag_do_rot=crop_cfg.flag_do_rot, ) # update a 256x256 version for network input ret_dct["img_crop_256x256"] = cv2.resize(ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA) ret_dct["lmk_crop_256x256"] = ret_dct["pt_crop"] * 256 / crop_cfg.dsize trajectory.frame_rgb_crop_lst.append(ret_dct["img_crop_256x256"]) trajectory.lmk_crop_lst.append(ret_dct["lmk_crop_256x256"]) trajectory.M_c2o_lst.append(ret_dct['M_c2o']) return { "frame_crop_lst": trajectory.frame_rgb_crop_lst, "lmk_crop_lst": trajectory.lmk_crop_lst, "M_c2o_lst": trajectory.M_c2o_lst, } def crop_driving_video(self, driving_rgb_lst, **kwargs): """Tracking based landmarks/alignment and cropping""" trajectory = Trajectory() direction = kwargs.get("direction", "large-small") for idx, frame_rgb in enumerate(driving_rgb_lst): if idx == 0 or trajectory.start == -1: src_face = self.face_analysis_wrapper.get( contiguous(frame_rgb[..., ::-1]), flag_do_landmark_2d_106=True, direction=direction, ) if len(src_face) == 0: log(f"No face detected in the frame #{idx}") continue elif len(src_face) > 1: log(f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.") src_face = src_face[0] lmk = src_face.landmark_2d_106 lmk = self.human_landmark_runner.run(frame_rgb, lmk) trajectory.start, trajectory.end = idx, idx else: lmk = self.human_landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1]) trajectory.end = idx trajectory.lmk_lst.append(lmk) ret_bbox = parse_bbox_from_landmark( lmk, scale=self.crop_cfg.scale_crop_driving_video, vx_ratio_crop_driving_video=self.crop_cfg.vx_ratio_crop_driving_video, vy_ratio=self.crop_cfg.vy_ratio_crop_driving_video, )["bbox"] bbox = [ ret_bbox[0, 0], ret_bbox[0, 1], ret_bbox[2, 0], ret_bbox[2, 1], ] # 4, trajectory.bbox_lst.append(bbox) # bbox trajectory.frame_rgb_lst.append(frame_rgb) global_bbox = average_bbox_lst(trajectory.bbox_lst) for idx, (frame_rgb, lmk) in enumerate(zip(trajectory.frame_rgb_lst, trajectory.lmk_lst)): ret_dct = crop_image_by_bbox( frame_rgb, global_bbox, lmk=lmk, dsize=kwargs.get("dsize", 512), flag_rot=False, borderValue=(0, 0, 0), ) trajectory.frame_rgb_crop_lst.append(ret_dct["img_crop"]) trajectory.lmk_crop_lst.append(ret_dct["lmk_crop"]) return { "frame_crop_lst": trajectory.frame_rgb_crop_lst, "lmk_crop_lst": trajectory.lmk_crop_lst, } def calc_lmks_from_cropped_video(self, driving_rgb_crop_lst, **kwargs): """Tracking based landmarks/alignment""" trajectory = Trajectory() direction = kwargs.get("direction", "large-small") for idx, frame_rgb_crop in enumerate(driving_rgb_crop_lst): if idx == 0 or trajectory.start == -1: src_face = self.face_analysis_wrapper.get( contiguous(frame_rgb_crop[..., ::-1]), # convert to BGR flag_do_landmark_2d_106=True, direction=direction, ) if len(src_face) == 0: log(f"No face detected in the frame #{idx}") raise Exception(f"No face detected in the frame #{idx}") elif len(src_face) > 1: log(f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.") src_face = src_face[0] lmk = src_face.landmark_2d_106 lmk = self.human_landmark_runner.run(frame_rgb_crop, lmk) trajectory.start, trajectory.end = idx, idx else: lmk = self.human_landmark_runner.run(frame_rgb_crop, trajectory.lmk_lst[-1]) trajectory.end = idx trajectory.lmk_lst.append(lmk) return trajectory.lmk_lst ================================================ FILE: src/utils/dependencies/XPose/config_model/UniPose_SwinT.py ================================================ _base_ = ['coco_transformer.py'] use_label_enc = True num_classes=2 lr = 0.0001 param_dict_type = 'default' lr_backbone = 1e-05 lr_backbone_names = ['backbone.0'] lr_linear_proj_names = ['reference_points', 'sampling_offsets'] lr_linear_proj_mult = 0.1 ddetr_lr_param = False batch_size = 2 weight_decay = 0.0001 epochs = 12 lr_drop = 11 save_checkpoint_interval = 100 clip_max_norm = 0.1 onecyclelr = False multi_step_lr = False lr_drop_list = [33, 45] modelname = 'UniPose' frozen_weights = None backbone = 'swin_T_224_1k' dilation = False position_embedding = 'sine' pe_temperatureH = 20 pe_temperatureW = 20 return_interm_indices = [1, 2, 3] backbone_freeze_keywords = None enc_layers = 6 dec_layers = 6 unic_layers = 0 pre_norm = False dim_feedforward = 2048 hidden_dim = 256 dropout = 0.0 nheads = 8 num_queries = 900 query_dim = 4 num_patterns = 0 pdetr3_bbox_embed_diff_each_layer = False pdetr3_refHW = -1 random_refpoints_xy = False fix_refpoints_hw = -1 dabdetr_yolo_like_anchor_update = False dabdetr_deformable_encoder = False dabdetr_deformable_decoder = False use_deformable_box_attn = False box_attn_type = 'roi_align' dec_layer_number = None num_feature_levels = 4 enc_n_points = 4 dec_n_points = 4 decoder_layer_noise = False dln_xy_noise = 0.2 dln_hw_noise = 0.2 add_channel_attention = False add_pos_value = False two_stage_type = 'standard' two_stage_pat_embed = 0 two_stage_add_query_num = 0 two_stage_bbox_embed_share = False two_stage_class_embed_share = False two_stage_learn_wh = False two_stage_default_hw = 0.05 two_stage_keep_all_tokens = False num_select = 50 transformer_activation = 'relu' batch_norm_type = 'FrozenBatchNorm2d' masks = False decoder_sa_type = 'sa' # ['sa', 'ca_label', 'ca_content'] matcher_type = 'HungarianMatcher' # or SimpleMinsumMatcher decoder_module_seq = ['sa', 'ca', 'ffn'] nms_iou_threshold = -1 dec_pred_bbox_embed_share = True dec_pred_class_embed_share = True use_dn = True dn_number = 100 dn_box_noise_scale = 1.0 dn_label_noise_ratio = 0.5 dn_label_coef=1.0 dn_bbox_coef=1.0 embed_init_tgt = True dn_labelbook_size = 2000 match_unstable_error = True # for ema use_ema = True ema_decay = 0.9997 ema_epoch = 0 use_detached_boxes_dec_out = False max_text_len = 256 shuffle_type = None use_text_enhancer = True use_fusion_layer = True use_checkpoint = False # True use_transformer_ckpt = True text_encoder_type = 'bert-base-uncased' use_text_cross_attention = True text_dropout = 0.0 fusion_dropout = 0.0 fusion_droppath = 0.1 num_body_points=68 binary_query_selection = False use_cdn = True ffn_extra_layernorm = False fix_size=False ================================================ FILE: src/utils/dependencies/XPose/config_model/coco_transformer.py ================================================ data_aug_scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] data_aug_max_size = 1333 data_aug_scales2_resize = [400, 500, 600] data_aug_scales2_crop = [384, 600] data_aug_scale_overlap = None ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/__init__.py ================================================ # ------------------------------------------------------------------------ # Conditional DETR # Copyright (c) 2021 Microsoft. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Copied from DETR (https://github.com/facebookresearch/detr) # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # ------------------------------------------------------------------------ from .unipose import build_unipose ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/attention.py ================================================ # ------------------------------------------------------------------------ # UniPose # url: https://github.com/IDEA-Research/UniPose # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # ED-Pose # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Conditional DETR # Copyright (c) 2021 Microsoft. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from codes in torch.nn # ------------------------------------------------------------------------ """ MultiheadAttention that support query, key, and value to have different dimensions. Query, key, and value projections are removed. Mostly copy-paste from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/activation.py#L873 and https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L4837 """ import warnings import torch from torch.nn.modules.linear import Linear from torch.nn.init import constant_ from torch.nn.modules.module import Module from torch._jit_internal import Optional, Tuple try: from torch.overrides import has_torch_function, handle_torch_function except: from torch._overrides import has_torch_function, handle_torch_function from torch.nn.functional import linear, pad, softmax, dropout Tensor = torch.Tensor class MultiheadAttention(Module): r"""Allows the model to jointly attend to information from different representation subspaces. See reference: Attention Is All You Need .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) Args: embed_dim: total dimension of the model. num_heads: parallel attention heads. dropout: a Dropout layer on attn_output_weights. Default: 0.0. bias: add bias as module parameter. Default: True. add_bias_kv: add bias to the key and value sequences at dim=0. add_zero_attn: add a new batch of zeros to the key and value sequences at dim=1. kdim: total number of features in key. Default: None. vdim: total number of features in value. Default: None. Note: if kdim and vdim are None, they will be set to embed_dim such that query, key, and value have the same number of features. Examples:: >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value) """ bias_k: Optional[torch.Tensor] bias_v: Optional[torch.Tensor] def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): super(MultiheadAttention, self).__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" vdim = vdim if vdim is not None else embed_dim self.out_proj = Linear(vdim , vdim) self.in_proj_bias = None self.in_proj_weight = None self.bias_k = self.bias_v = None self.q_proj_weight = None self.k_proj_weight = None self.v_proj_weight = None self.add_zero_attn = add_zero_attn self._reset_parameters() def _reset_parameters(self): constant_(self.out_proj.bias, 0.) def __setstate__(self, state): # Support loading old MultiheadAttention checkpoints generated by v1.1.0 if '_qkv_same_embed_dim' not in state: state['_qkv_same_embed_dim'] = True super(MultiheadAttention, self).__setstate__(state) def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None): # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] r""" Args: query, key, value: map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. key_padding_mask: if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. Shape: - Inputs: - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 3D mask :math:`(N*\text{num_heads}, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. - Outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - attn_output_weights: :math:`(N, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. """ if not self._qkv_same_embed_dim: return multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, v_proj_weight=self.v_proj_weight, out_dim=self.vdim) else: return multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, out_dim=self.vdim) def multi_head_attention_forward(query: Tensor, key: Tensor, value: Tensor, embed_dim_to_check: int, num_heads: int, in_proj_weight: Tensor, in_proj_bias: Tensor, bias_k: Optional[Tensor], bias_v: Optional[Tensor], add_zero_attn: bool, dropout_p: float, out_proj_weight: Tensor, out_proj_bias: Tensor, training: bool = True, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, use_separate_proj_weight: bool = False, q_proj_weight: Optional[Tensor] = None, k_proj_weight: Optional[Tensor] = None, v_proj_weight: Optional[Tensor] = None, static_k: Optional[Tensor] = None, static_v: Optional[Tensor] = None, out_dim: Optional[Tensor] = None ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query, key, value: map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. embed_dim_to_check: total dimension of the model. num_heads: parallel attention heads. in_proj_weight, in_proj_bias: input projection weight and bias. bias_k, bias_v: bias of the key and value sequences to be added at dim=0. add_zero_attn: add a new batch of zeros to the key and value sequences at dim=1. dropout_p: probability of an element to be zeroed. out_proj_weight, out_proj_bias: the output projection weight and bias. training: apply dropout if is ``True``. key_padding_mask: if provided, specified padding elements in the key will be ignored by the attention. This is an binary mask. When the value is True, the corresponding value on the attention layer will be filled with -inf. need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. use_separate_proj_weight: the function accept the proj. weights for query, key, and value in different forms. If false, in_proj_weight will be used, which is a combination of q_proj_weight, k_proj_weight, v_proj_weight. q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. static_k, static_v: static key and value used for attention operators. Shape: Inputs: - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. Outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - attn_output_weights: :math:`(N, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. """ if not torch.jit.is_scripting(): tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): return handle_torch_function( multi_head_attention_forward, tens_ops, query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, use_separate_proj_weight=use_separate_proj_weight, q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v) tgt_len, bsz, embed_dim = query.size() assert embed_dim == embed_dim_to_check # allow MHA to have different sizes for the feature dimension assert key.size(0) == value.size(0) and key.size(1) == value.size(1) head_dim = embed_dim // num_heads v_head_dim = out_dim // num_heads assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" scaling = float(head_dim) ** -0.5 q = query * scaling k = key v = value if attn_mask is not None: assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \ attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \ 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype) if attn_mask.dtype == torch.uint8: warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") attn_mask = attn_mask.to(torch.bool) if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: raise RuntimeError('The size of the 2D attn_mask is not correct.') elif attn_mask.dim() == 3: if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]: raise RuntimeError('The size of the 3D attn_mask is not correct.') else: raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") key_padding_mask = key_padding_mask.to(torch.bool) if bias_k is not None and bias_v is not None: if static_k is None and static_v is None: k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) if attn_mask is not None: attn_mask = pad(attn_mask, (0, 1)) if key_padding_mask is not None: key_padding_mask = pad(key_padding_mask, (0, 1)) else: assert static_k is None, "bias cannot be added to static key." assert static_v is None, "bias cannot be added to static value." else: assert bias_k is None assert bias_v is None q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) if k is not None: k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if v is not None: v = v.contiguous().view(-1, bsz * num_heads, v_head_dim).transpose(0, 1) if static_k is not None: assert static_k.size(0) == bsz * num_heads assert static_k.size(2) == head_dim k = static_k if static_v is not None: assert static_v.size(0) == bsz * num_heads assert static_v.size(2) == v_head_dim v = static_v src_len = k.size(1) if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len if add_zero_attn: src_len += 1 k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) if attn_mask is not None: attn_mask = pad(attn_mask, (0, 1)) if key_padding_mask is not None: key_padding_mask = pad(key_padding_mask, (0, 1)) attn_output_weights = torch.bmm(q, k.transpose(1, 2)) assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_output_weights.masked_fill_(attn_mask, float('-inf')) else: attn_output_weights += attn_mask if key_padding_mask is not None: attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) attn_output_weights = attn_output_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'), ) attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) # attn_output_weights = softmax( # attn_output_weights, dim=-1) attn_output_weights = softmax( attn_output_weights - attn_output_weights.max(dim=-1, keepdim=True)[0], dim=-1) attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training) attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, v_head_dim] attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, out_dim) attn_output = linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) return attn_output, attn_output_weights.sum(dim=1) / num_heads else: return attn_output, None ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/backbone.py ================================================ # ------------------------------------------------------------------------ # UniPose # url: https://github.com/IDEA-Research/UniPose # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Conditional DETR # Copyright (c) 2021 Microsoft. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Copied from DETR (https://github.com/facebookresearch/detr) # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # ------------------------------------------------------------------------ """ Backbone modules. """ import torch import torch.nn.functional as F import torchvision from torch import nn from torchvision.models._utils import IntermediateLayerGetter from typing import Dict, List from util.misc import NestedTensor, is_main_process from .position_encoding import build_position_encoding from .swin_transformer import build_swin_transformer class FrozenBatchNorm2d(torch.nn.Module): """ BatchNorm2d where the batch statistics and the affine parameters are fixed. Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than torchvision.models.resnet[18,34,50,101] produce nans. """ def __init__(self, n): super(FrozenBatchNorm2d, self).__init__() self.register_buffer("weight", torch.ones(n)) self.register_buffer("bias", torch.zeros(n)) self.register_buffer("running_mean", torch.zeros(n)) self.register_buffer("running_var", torch.ones(n)) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): num_batches_tracked_key = prefix + "num_batches_tracked" if num_batches_tracked_key in state_dict: del state_dict[num_batches_tracked_key] super(FrozenBatchNorm2d, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) def forward(self, x): # move reshapes to the beginning # to make it fuser-friendly w = self.weight.reshape(1, -1, 1, 1) b = self.bias.reshape(1, -1, 1, 1) rv = self.running_var.reshape(1, -1, 1, 1) rm = self.running_mean.reshape(1, -1, 1, 1) eps = 1e-5 scale = w * (rv + eps).rsqrt() bias = b - rm * scale return x * scale + bias class BackboneBase(nn.Module): def __init__( self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_indices: list, ): super().__init__() for name, parameter in backbone.named_parameters(): if ( not train_backbone or "layer2" not in name and "layer3" not in name and "layer4" not in name ): parameter.requires_grad_(False) return_layers = {} for idx, layer_index in enumerate(return_interm_indices): return_layers.update( {"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)} ) self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) self.num_channels = num_channels def forward(self, tensor_list: NestedTensor): xs = self.body(tensor_list.tensors) out: Dict[str, NestedTensor] = {} for name, x in xs.items(): m = tensor_list.mask assert m is not None mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] out[name] = NestedTensor(x, mask) # import ipdb; ipdb.set_trace() return out class Backbone(BackboneBase): """ResNet backbone with frozen BatchNorm.""" def __init__( self, name: str, train_backbone: bool, dilation: bool, return_interm_indices: list, batch_norm=FrozenBatchNorm2d, ): if name in ["resnet18", "resnet34", "resnet50", "resnet101"]: backbone = getattr(torchvision.models, name)( replace_stride_with_dilation=[False, False, dilation], pretrained=is_main_process(), norm_layer=batch_norm, ) else: raise NotImplementedError("Why you can get here with name {}".format(name)) # num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available." assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]] num_channels_all = [256, 512, 1024, 2048] num_channels = num_channels_all[4 - len(return_interm_indices) :] super().__init__(backbone, train_backbone, num_channels, return_interm_indices) class Joiner(nn.Sequential): def __init__(self, backbone, position_embedding): super().__init__(backbone, position_embedding) def forward(self, tensor_list: NestedTensor): xs = self[0](tensor_list) out: List[NestedTensor] = [] pos = [] for name, x in xs.items(): out.append(x) # position encoding pos.append(self[1](x).to(x.tensors.dtype)) return out, pos def build_backbone(args): """ Useful args: - backbone: backbone name - lr_backbone: - dilation - return_interm_indices: available: [0,1,2,3], [1,2,3], [3] - backbone_freeze_keywords: - use_checkpoint: for swin only for now """ position_embedding = build_position_encoding(args) train_backbone = True if not train_backbone: raise ValueError("Please set lr_backbone > 0") return_interm_indices = args.return_interm_indices assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]] args.backbone_freeze_keywords use_checkpoint = getattr(args, "use_checkpoint", False) if args.backbone in ["resnet50", "resnet101"]: backbone = Backbone( args.backbone, train_backbone, args.dilation, return_interm_indices, batch_norm=FrozenBatchNorm2d, ) bb_num_channels = backbone.num_channels elif args.backbone in [ "swin_T_224_1k", "swin_B_224_22k", "swin_B_384_22k", "swin_L_224_22k", "swin_L_384_22k", ]: pretrain_img_size = int(args.backbone.split("_")[-2]) backbone = build_swin_transformer( args.backbone, pretrain_img_size=pretrain_img_size, out_indices=tuple(return_interm_indices), dilation=False, use_checkpoint=use_checkpoint, ) bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :] else: raise NotImplementedError("Unknown backbone {}".format(args.backbone)) assert len(bb_num_channels) == len( return_interm_indices ), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}" model = Joiner(backbone, position_embedding) model.num_channels = bb_num_channels assert isinstance( bb_num_channels, List ), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels)) return model ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/deformable_transformer.py ================================================ # ------------------------------------------------------------------------ # UniPose # url: https://github.com/IDEA-Research/UniPose # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # ED-Pose # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # DINO # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from DETR (https://github.com/facebookresearch/detr) # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # ------------------------------------------------------------------------ import math import copy import torch import torch.utils.checkpoint as checkpoint from torch import nn, Tensor from typing import Optional from util.misc import inverse_sigmoid from .transformer_vanilla import TransformerEncoderLayer from .fuse_modules import BiAttentionBlock from .utils import gen_encoder_output_proposals, MLP, _get_activation_fn, gen_sineembed_for_position, get_sine_pos_embed from .ops.modules import MSDeformAttn class DeformableTransformer(nn.Module): def __init__(self, d_model=256, nhead=8, num_queries=300, num_encoder_layers=6, num_unicoder_layers=0, num_decoder_layers=6, dim_feedforward=2048, dropout=0.0, activation="relu", normalize_before=False, return_intermediate_dec=False, query_dim=4, num_patterns=0, modulate_hw_attn=False, # for deformable encoder deformable_encoder=False, deformable_decoder=False, num_feature_levels=1, enc_n_points=4, dec_n_points=4, use_deformable_box_attn=False, box_attn_type='roi_align', # init query learnable_tgt_init=False, decoder_query_perturber=None, add_channel_attention=False, add_pos_value=False, random_refpoints_xy=False, # two stage two_stage_type='no', two_stage_pat_embed=0, two_stage_add_query_num=0, two_stage_learn_wh=False, two_stage_keep_all_tokens=False, # evo of #anchors dec_layer_number=None, rm_enc_query_scale=True, rm_dec_query_scale=True, rm_self_attn_layers=None, key_aware_type=None, # layer share layer_share_type=None, # for detach rm_detach=None, decoder_sa_type='ca', module_seq=['sa', 'ca', 'ffn'], # for dn embed_init_tgt=False, use_detached_boxes_dec_out=False, use_text_enhancer=False, use_fusion_layer=False, use_checkpoint=False, use_transformer_ckpt=False, use_text_cross_attention=False, text_dropout=0.1, fusion_dropout=0.1, fusion_droppath=0.0, binary_query_selection=False, ffn_extra_layernorm=False, ): super().__init__() self.num_feature_levels = num_feature_levels self.num_encoder_layers = num_encoder_layers self.num_unicoder_layers = num_unicoder_layers self.num_decoder_layers = num_decoder_layers self.deformable_encoder = deformable_encoder self.deformable_decoder = deformable_decoder self.two_stage_keep_all_tokens = two_stage_keep_all_tokens self.num_queries = num_queries self.random_refpoints_xy = random_refpoints_xy self.use_detached_boxes_dec_out = use_detached_boxes_dec_out self.ffn_extra_layernorm = ffn_extra_layernorm assert query_dim == 4 self.binary_query_selection = binary_query_selection if self.binary_query_selection: self.binary_query_selection_layer = nn.Linear(d_model, 1) # assert not binary_query_selection, 'binary_query_selection not implemented yet' if num_feature_levels > 1: assert deformable_encoder, "only support deformable_encoder for num_feature_levels > 1" if use_deformable_box_attn: assert deformable_encoder or deformable_encoder assert layer_share_type in [None, 'encoder', 'decoder', 'both'] if layer_share_type in ['encoder', 'both']: enc_layer_share = True else: enc_layer_share = False if layer_share_type in ['decoder', 'both']: dec_layer_share = True else: dec_layer_share = False assert layer_share_type is None self.decoder_sa_type = decoder_sa_type assert decoder_sa_type in ['sa', 'ca_label', 'ca_content'] # choose encoder layer type if deformable_encoder: encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points, add_channel_attention=add_channel_attention, use_deformable_box_attn=use_deformable_box_attn, box_attn_type=box_attn_type) else: raise NotImplementedError if use_text_enhancer: text_enhance_layer = TransformerEncoderLayer( d_model=d_model, nhead=nhead // 2, dim_feedforward=dim_feedforward // 2, dropout=text_dropout ) else: text_enhance_layer = None if use_fusion_layer: feature_fusion_layer = BiAttentionBlock( v_dim=d_model, l_dim=d_model, embed_dim=dim_feedforward // 2, num_heads=nhead // 2, dropout=fusion_dropout, drop_path=fusion_droppath ) else: feature_fusion_layer = None encoder_norm = nn.LayerNorm(d_model) if normalize_before else None assert encoder_norm is None self.encoder = TransformerEncoder( encoder_layer, num_encoder_layers, d_model=d_model, num_queries=num_queries, enc_layer_share=enc_layer_share, text_enhance_layer=text_enhance_layer, feature_fusion_layer=feature_fusion_layer, use_checkpoint=use_checkpoint, use_transformer_ckpt=use_transformer_ckpt, ) # choose decoder layer type if deformable_decoder: decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, dec_n_points, use_text_cross_attention=use_text_cross_attention, ffn_extra_layernorm=ffn_extra_layernorm, ) else: raise NotImplementedError decoder_norm = nn.LayerNorm(d_model) self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec, d_model=d_model, query_dim=query_dim, modulate_hw_attn=modulate_hw_attn, num_feature_levels=num_feature_levels, deformable_decoder=deformable_decoder, decoder_query_perturber=decoder_query_perturber, dec_layer_number=dec_layer_number, rm_dec_query_scale=rm_dec_query_scale, dec_layer_share=dec_layer_share, use_detached_boxes_dec_out=use_detached_boxes_dec_out ) self.d_model = d_model self.nhead = nhead self.dec_layers = num_decoder_layers self.num_queries = num_queries # useful for single stage model only self.num_patterns = num_patterns if not isinstance(num_patterns, int): Warning("num_patterns should be int but {}".format(type(num_patterns))) self.num_patterns = 0 if num_feature_levels > 1: if self.num_encoder_layers > 0: self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) else: self.level_embed = None self.learnable_tgt_init = learnable_tgt_init assert learnable_tgt_init, "why not learnable_tgt_init" self.embed_init_tgt = embed_init_tgt if (two_stage_type != 'no' and embed_init_tgt) or (two_stage_type == 'no'): self.tgt_embed = nn.Embedding(self.num_queries, d_model) nn.init.normal_(self.tgt_embed.weight.data) else: self.tgt_embed = None # for two stage self.two_stage_type = two_stage_type self.two_stage_pat_embed = two_stage_pat_embed self.two_stage_add_query_num = two_stage_add_query_num self.two_stage_learn_wh = two_stage_learn_wh assert two_stage_type in ['no', 'standard'], "unknown param {} of two_stage_type".format(two_stage_type) if two_stage_type == 'standard': # anchor selection at the output of encoder self.enc_output = nn.Linear(d_model, d_model) self.enc_output_norm = nn.LayerNorm(d_model) if two_stage_pat_embed > 0: self.pat_embed_for_2stage = nn.Parameter(torch.Tensor(two_stage_pat_embed, d_model)) nn.init.normal_(self.pat_embed_for_2stage) if two_stage_add_query_num > 0: self.tgt_embed = nn.Embedding(self.two_stage_add_query_num, d_model) if two_stage_learn_wh: # import ipdb; ipdb.set_trace() self.two_stage_wh_embedding = nn.Embedding(1, 2) else: self.two_stage_wh_embedding = None if two_stage_type == 'no': self.init_ref_points(num_queries) # init self.refpoint_embed self.enc_out_class_embed = None self.enc_out_bbox_embed = None # evolution of anchors self.dec_layer_number = dec_layer_number if dec_layer_number is not None: if self.two_stage_type != 'no' or num_patterns == 0: assert dec_layer_number[ 0] == num_queries, f"dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries})" else: assert dec_layer_number[ 0] == num_queries * num_patterns, f"dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries}) * num_patterns({num_patterns})" self._reset_parameters() self.rm_self_attn_layers = rm_self_attn_layers if rm_self_attn_layers is not None: # assert len(rm_self_attn_layers) == num_decoder_layers print("Removing the self-attn in {} decoder layers".format(rm_self_attn_layers)) for lid, dec_layer in enumerate(self.decoder.layers): if lid in rm_self_attn_layers: dec_layer.rm_self_attn_modules() self.rm_detach = rm_detach if self.rm_detach: assert isinstance(rm_detach, list) assert any([i in ['enc_ref', 'enc_tgt', 'dec'] for i in rm_detach]) self.decoder.rm_detach = rm_detach def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MSDeformAttn): m._reset_parameters() if self.num_feature_levels > 1 and self.level_embed is not None: nn.init.normal_(self.level_embed) if self.two_stage_learn_wh: nn.init.constant_(self.two_stage_wh_embedding.weight, math.log(0.05 / (1 - 0.05))) def get_valid_ratio(self, mask): _, H, W = mask.shape valid_H = torch.sum(~mask[:, :, 0], 1) valid_W = torch.sum(~mask[:, 0, :], 1) valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) return valid_ratio def init_ref_points(self, use_num_queries): self.refpoint_embed = nn.Embedding(use_num_queries, 4) if self.random_refpoints_xy: # import ipdb; ipdb.set_trace() self.refpoint_embed.weight.data[:, :2].uniform_(0, 1) self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(self.refpoint_embed.weight.data[:, :2]) self.refpoint_embed.weight.data[:, :2].requires_grad = False def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, attn_mask2=None, text_dict=None, dn_meta=None,targets=None,kpt_embed=None): """ Input: - srcs: List of multi features [bs, ci, hi, wi] - masks: List of multi masks [bs, hi, wi] - refpoint_embed: [bs, num_dn, 4]. None in infer - pos_embeds: List of multi pos embeds [bs, ci, hi, wi] - tgt: [bs, num_dn, d_model]. None in infer """ # if self.two_stage_type != 'no' and self.two_stage_add_query_num == 0: # assert refpoint_embed is None # prepare input for encoder src_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): bs, c, h, w = src.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) src = src.flatten(2).transpose(1, 2) # bs, hw, c mask = mask.flatten(1) # bs, hw pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c if self.num_feature_levels > 1 and self.level_embed is not None: lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) else: lvl_pos_embed = pos_embed lvl_pos_embed_flatten.append(lvl_pos_embed) src_flatten.append(src) mask_flatten.append(mask) src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) # two stage enc_topk_proposals = enc_refpoint_embed = None ######################################################### # Begin Encoder ######################################################### memory, memory_text = self.encoder( src_flatten, pos=lvl_pos_embed_flatten, level_start_index=level_start_index, spatial_shapes=spatial_shapes, valid_ratios=valid_ratios, key_padding_mask=mask_flatten, memory_text=text_dict['encoded_text'], text_attention_mask=~text_dict['text_token_mask'], # we ~ the mask . False means use the token; True means pad the token position_ids=text_dict['position_ids'], text_self_attention_masks=text_dict['text_self_attention_masks'], ) ######################################################### # End Encoder # - memory: bs, \sum{hw}, c # - mask_flatten: bs, \sum{hw} # - lvl_pos_embed_flatten: bs, \sum{hw}, c # - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c) # - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c) ######################################################### text_dict['encoded_text'] = memory_text if self.two_stage_type == 'standard': if self.two_stage_learn_wh: input_hw = self.two_stage_wh_embedding.weight[0] else: input_hw = None output_memory, output_proposals = gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes, input_hw) output_memory = self.enc_output_norm(self.enc_output(output_memory)) if self.two_stage_pat_embed > 0: bs, nhw, _ = output_memory.shape # output_memory: bs, n, 256; self.pat_embed_for_2stage: k, 256 output_memory = output_memory.repeat(1, self.two_stage_pat_embed, 1) _pats = self.pat_embed_for_2stage.repeat_interleave(nhw, 0) output_memory = output_memory + _pats output_proposals = output_proposals.repeat(1, self.two_stage_pat_embed, 1) if self.two_stage_add_query_num > 0: assert refpoint_embed is not None output_memory = torch.cat((output_memory, tgt), dim=1) output_proposals = torch.cat((output_proposals, refpoint_embed), dim=1) if self.binary_query_selection: topk_logits = self.binary_query_selection_layer(output_memory).squeeze(-1) else: if text_dict is not None: enc_outputs_class_unselected = self.enc_out_class_embed(output_memory, text_dict) else: enc_outputs_class_unselected = self.enc_out_class_embed(output_memory) topk_logits = enc_outputs_class_unselected.max(-1)[0] enc_outputs_coord_unselected = self.enc_out_bbox_embed( output_memory) + output_proposals # (bs, \sum{hw}, 4) unsigmoid topk = self.num_queries topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq # gather boxes refpoint_embed_undetach = torch.gather(enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid refpoint_embed_ = refpoint_embed_undetach.detach() init_box_proposal = torch.gather(output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)).sigmoid() # sigmoid # gather tgt tgt_undetach = torch.gather(output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model)) if self.embed_init_tgt: tgt_ = self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, d_model else: tgt_ = tgt_undetach.detach() if refpoint_embed is not None: refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1) tgt = torch.cat([tgt, tgt_], dim=1) else: refpoint_embed, tgt = refpoint_embed_, tgt_ elif self.two_stage_type == 'no': tgt_ = self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, d_model refpoint_embed_ = self.refpoint_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, 4 if refpoint_embed is not None: refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1) tgt = torch.cat([tgt, tgt_], dim=1) else: refpoint_embed, tgt = refpoint_embed_, tgt_ if self.num_patterns > 0: tgt_embed = tgt.repeat(1, self.num_patterns, 1) refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1) tgt_pat = self.patterns.weight[None, :, :].repeat_interleave(self.num_queries, 1) # 1, n_q*n_pat, d_model tgt = tgt_embed + tgt_pat init_box_proposal = refpoint_embed_.sigmoid() else: raise NotImplementedError("unknown two_stage_type {}".format(self.two_stage_type)) ######################################################### # End preparing tgt # - tgt: bs, NQ, d_model # - refpoint_embed(unsigmoid): bs, NQ, d_model ######################################################### # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1': # if refpoint_embed.isnan().any() | refpoint_embed.isinf().any(): # import ipdb; ipdb.set_trace() # if tgt.isnan().any() | tgt.isinf().any(): # import ipdb; ipdb.set_trace() ######################################################### # Begin Decoder ######################################################### hs, references = self.decoder( tgt=tgt.transpose(0, 1), memory=memory.transpose(0, 1), memory_key_padding_mask=mask_flatten, pos=lvl_pos_embed_flatten.transpose(0, 1), refpoints_unsigmoid=refpoint_embed.transpose(0, 1), level_start_index=level_start_index, spatial_shapes=spatial_shapes, valid_ratios=valid_ratios, tgt_mask=attn_mask, tgt_mask2=attn_mask2, memory_text=text_dict['encoded_text'], text_attention_mask=~text_dict['text_token_mask'], text_dict=text_dict, dn_meta=dn_meta, targets=targets, kpt_embed=kpt_embed # we ~ the mask . False means use the token; True means pad the token ) ######################################################### # End Decoder # hs: n_dec, bs, nq, d_model # references: n_dec+1, bs, nq, query_dim ######################################################### ######################################################### # Begin postprocess ######################################################### if self.two_stage_type == 'standard': if self.two_stage_keep_all_tokens: hs_enc = output_memory.unsqueeze(0) ref_enc = enc_outputs_coord_unselected.unsqueeze(0) init_box_proposal = output_proposals # import ipdb; ipdb.set_trace() else: hs_enc = tgt_undetach.unsqueeze(0) ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0) else: hs_enc = ref_enc = None ######################################################### # End postprocess # hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or (n_enc, bs, nq, d_model) or None # ref_enc: (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or (n_enc, bs, nq, d_model) or None ######################################################### return hs, references, hs_enc, ref_enc, init_box_proposal # hs: (n_dec, bs, nq, d_model) # references: sigmoid coordinates. (n_dec+1, bs, bq, 4) # hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or None # ref_enc: sigmoid coordinates. \ # (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or None class TransformerEncoder(nn.Module): def __init__(self, encoder_layer, num_layers, d_model=256, num_queries=300, enc_layer_share=False, text_enhance_layer=None, feature_fusion_layer=None, use_checkpoint=False, use_transformer_ckpt=False, ): """_summary_ Args: encoder_layer (_type_): _description_ num_layers (_type_): _description_ norm (_type_, optional): _description_. Defaults to None. d_model (int, optional): _description_. Defaults to 256. num_queries (int, optional): _description_. Defaults to 300. enc_layer_share (bool, optional): _description_. Defaults to False. """ super().__init__() # prepare layers self.layers = [] self.text_layers = [] self.fusion_layers = [] if num_layers > 0: self.layers = _get_clones(encoder_layer, num_layers, layer_share=enc_layer_share) if text_enhance_layer is not None: self.text_layers = _get_clones(text_enhance_layer, num_layers, layer_share=enc_layer_share) if feature_fusion_layer is not None: self.fusion_layers = _get_clones(feature_fusion_layer, num_layers, layer_share=enc_layer_share) else: self.layers = [] del encoder_layer if text_enhance_layer is not None: self.text_layers = [] del text_enhance_layer if feature_fusion_layer is not None: self.fusion_layers = [] del feature_fusion_layer self.query_scale = None self.num_queries = num_queries self.num_layers = num_layers self.d_model = d_model self.use_checkpoint = use_checkpoint self.use_transformer_ckpt = use_transformer_ckpt @staticmethod def get_reference_points(spatial_shapes, valid_ratios, device): reference_points_list = [] for lvl, (H_, W_) in enumerate(spatial_shapes): ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),) ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) ref = torch.stack((ref_x, ref_y), -1) reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) reference_points = reference_points[:, :, None] * valid_ratios[:, None] return reference_points def forward(self, # for images src: Tensor, pos: Tensor, spatial_shapes: Tensor, level_start_index: Tensor, valid_ratios: Tensor, key_padding_mask: Tensor, # for texts memory_text: Tensor = None, text_attention_mask: Tensor = None, pos_text: Tensor = None, text_self_attention_masks: Tensor = None, position_ids: Tensor = None, ): """ Input: - src: [bs, sum(hi*wi), 256] - pos: pos embed for src. [bs, sum(hi*wi), 256] - spatial_shapes: h,w of each level [num_level, 2] - level_start_index: [num_level] start point of level in sum(hi*wi). - valid_ratios: [bs, num_level, 2] - key_padding_mask: [bs, sum(hi*wi)] - memory_text: bs, n_text, 256 - text_attention_mask: bs, n_text False for no padding; True for padding - pos_text: bs, n_text, 256 - position_ids: bs, n_text Intermedia: - reference_points: [bs, sum(hi*wi), num_level, 2] Outpus: - output: [bs, sum(hi*wi), 256] """ output = src # preparation and reshape if self.num_layers > 0: reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) if self.text_layers: # generate pos_text bs, n_text, text_dim = memory_text.shape if pos_text is None and position_ids is None: pos_text = torch.arange(n_text, device=memory_text.device).float().unsqueeze(0).unsqueeze(-1).repeat(bs, 1, 1) pos_text = get_sine_pos_embed(pos_text, num_pos_feats=256, exchange_xy=False) if position_ids is not None: pos_text = get_sine_pos_embed(position_ids[..., None], num_pos_feats=256, exchange_xy=False) # main process for layer_id, layer in enumerate(self.layers): # if output.isnan().any() or memory_text.isnan().any(): # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO': # import ipdb; ipdb.set_trace() if self.fusion_layers: if self.use_checkpoint: output, memory_text = checkpoint.checkpoint( self.fusion_layers[layer_id], output, memory_text, key_padding_mask, text_attention_mask ) else: output, memory_text = self.fusion_layers[layer_id](v=output, l=memory_text, attention_mask_v=key_padding_mask, attention_mask_l=text_attention_mask) if self.text_layers: memory_text = self.text_layers[layer_id]( src=memory_text.transpose(0, 1), src_mask=~text_self_attention_masks, # note we use ~ for mask here src_key_padding_mask=text_attention_mask, pos=(pos_text.transpose(0, 1) if pos_text is not None else None) ).transpose(0, 1) # main process if self.use_transformer_ckpt: output = checkpoint.checkpoint( layer, output, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask ) else: output = layer(src=output, pos=pos, reference_points=reference_points, spatial_shapes=spatial_shapes, level_start_index=level_start_index, key_padding_mask=key_padding_mask) return output, memory_text class TransformerDecoder(nn.Module): def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False, d_model=256, query_dim=4, modulate_hw_attn=False, num_feature_levels=1, deformable_decoder=False, decoder_query_perturber=None, dec_layer_number=None, # number of queries each layer in decoder rm_dec_query_scale=False, dec_layer_share=False, dec_layer_dropout_prob=None, use_detached_boxes_dec_out=False, num_box_decoder_layers=2, num_body_points=68, ): super().__init__() if num_layers > 0: self.layers = _get_clones(decoder_layer, num_layers, layer_share=dec_layer_share) else: self.layers = [] self.num_layers = num_layers self.norm = norm self.return_intermediate = return_intermediate assert return_intermediate, "support return_intermediate only" self.query_dim = query_dim assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim) self.num_feature_levels = num_feature_levels self.use_detached_boxes_dec_out = use_detached_boxes_dec_out self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2) if not deformable_decoder: self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2) else: self.query_pos_sine_scale = None if rm_dec_query_scale: self.query_scale = None else: raise NotImplementedError self.query_scale = MLP(d_model, d_model, d_model, 2) self.bbox_embed = None self.class_embed = None self.pose_embed = None self.pose_hw_embed = None self.d_model = d_model self.modulate_hw_attn = modulate_hw_attn self.deformable_decoder = deformable_decoder if not deformable_decoder and modulate_hw_attn: self.ref_anchor_head = MLP(d_model, d_model, 2, 2) else: self.ref_anchor_head = None self.decoder_query_perturber = decoder_query_perturber self.box_pred_damping = None self.dec_layer_number = dec_layer_number if dec_layer_number is not None: assert isinstance(dec_layer_number, list) assert len(dec_layer_number) == num_layers # assert dec_layer_number[0] == self.dec_layer_dropout_prob = dec_layer_dropout_prob if dec_layer_dropout_prob is not None: assert isinstance(dec_layer_dropout_prob, list) assert len(dec_layer_dropout_prob) == num_layers for i in dec_layer_dropout_prob: assert 0.0 <= i <= 1.0 self.rm_detach = None self.num_body_points = num_body_points self.hw = nn.Embedding(17, 2) self.num_box_decoder_layers = num_box_decoder_layers self.kpt_index = [x for x in range(50 * (self.num_body_points + 1)) if x % (self.num_body_points + 1) != 0] self.hw_append = nn.Embedding(self.num_body_points-17, 2) def forward(self, tgt, memory, tgt_mask: Optional[Tensor] = None, tgt_mask2: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2 # for memory level_start_index: Optional[Tensor] = None, # num_levels spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 valid_ratios: Optional[Tensor] = None, # for text memory_text: Optional[Tensor] = None, text_attention_mask: Optional[Tensor] = None, text_dict: Optional[Tensor] = None, dn_meta: Optional[Tensor] = None, targets: Optional[Tensor] = None, kpt_embed: Optional[Tensor] = None ): """ Input: - tgt: nq, bs, d_model - memory: hw, bs, d_model - pos: hw, bs, d_model - refpoints_unsigmoid: nq, bs, 2/4 - valid_ratios/spatial_shapes: bs, nlevel, 2 """ output = tgt output += self.hw.weight[0, 0] * 0.0 intermediate = [] reference_points = refpoints_unsigmoid.sigmoid() ref_points = [reference_points] effect_num_dn = dn_meta['pad_size'] if self.training else 0 inter_select_number = 50 for layer_id, layer in enumerate(self.layers): if reference_points.shape[-1] == 4: reference_points_input = reference_points[:, :, None] \ * torch.cat([valid_ratios, valid_ratios], -1)[None, :] # nq, bs, nlevel, 4 else: assert reference_points.shape[-1] == 2 reference_points_input = reference_points[:, :, None] * valid_ratios[None, :] query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :]) # nq, bs, 256*2 # conditional query raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256 pos_scale = self.query_scale(output) if self.query_scale is not None else 1 query_pos = pos_scale * raw_query_pos # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1': # if query_pos.isnan().any() | query_pos.isinf().any(): # import ipdb; ipdb.set_trace() # main process output = layer( tgt=output, tgt_query_pos=query_pos, tgt_query_sine_embed=query_sine_embed, tgt_key_padding_mask=tgt_key_padding_mask, tgt_reference_points=reference_points_input, memory_text=memory_text, text_attention_mask=text_attention_mask, memory=memory, memory_key_padding_mask=memory_key_padding_mask, memory_level_start_index=level_start_index, memory_spatial_shapes=spatial_shapes, memory_pos=pos, self_attn_mask=tgt_mask, cross_attn_mask=memory_mask ) if output.isnan().any() | output.isinf().any(): print(f"output layer_id {layer_id} is nan") try: num_nan = output.isnan().sum().item() num_inf = output.isinf().sum().item() print(f"num_nan {num_nan}, num_inf {num_inf}") except Exception as e: print(e) intermediate.append(self.norm(output)) # iter update if layer_id < self.num_box_decoder_layers: reference_before_sigmoid = inverse_sigmoid(reference_points) delta_unsig = self.bbox_embed[layer_id](output) outputs_unsig = delta_unsig + reference_before_sigmoid new_reference_points = outputs_unsig.sigmoid() # select # ref points as anchors if layer_id == self.num_box_decoder_layers - 1: dn_output = output[:effect_num_dn] dn_new_reference_points = new_reference_points[:effect_num_dn] class_unselected = self.class_embed[layer_id](output.transpose(0, 1), text_dict)[:, effect_num_dn:].transpose(0, 1) topk_proposals = torch.topk(class_unselected.max(-1)[0], inter_select_number, dim=0)[1] new_reference_points_for_box = torch.gather(new_reference_points[effect_num_dn:], 0, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) new_output_for_box = torch.gather(output[effect_num_dn:], 0, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model)) keypoint_embed=kpt_embed.transpose(0, 1) new_output_for_keypoint = keypoint_embed[None, :, :, :].repeat(new_output_for_box.shape[0],1,1,1) delta_xy = self.pose_embed[-1](new_output_for_keypoint)[..., :2] keypoint_xy = (inverse_sigmoid(new_reference_points_for_box[..., :2][:, None]) + delta_xy).sigmoid() num_queries, _, bs, _ = keypoint_xy.shape aa = torch.cat((self.hw.weight,self.hw_append.weight),dim=0) keypoint_wh_weight = aa.unsqueeze(0).unsqueeze(-2).repeat(num_queries, 1, bs, 1).sigmoid() keypoint_wh = keypoint_wh_weight * new_reference_points_for_box[..., 2:][:, None] new_reference_points_for_keypoint = torch.cat((keypoint_xy, keypoint_wh), dim=-1) new_reference_points = torch.cat( (new_reference_points_for_box.unsqueeze(1), new_reference_points_for_keypoint), dim=1).flatten(0, 1) output = torch.cat((new_output_for_box.unsqueeze(1), new_output_for_keypoint), dim=1).flatten(0, 1) new_reference_points = torch.cat((dn_new_reference_points, new_reference_points), dim=0) output = torch.cat((dn_output, output), dim=0) tgt_mask = tgt_mask2 if layer_id >= self.num_box_decoder_layers: reference_before_sigmoid = inverse_sigmoid(reference_points) output_bbox_dn = output[:effect_num_dn] output_bbox_norm = output[effect_num_dn:][0::(self.num_body_points + 1)] reference_before_sigmoid_bbox_dn = reference_before_sigmoid[:effect_num_dn] reference_before_sigmoid_bbox_norm = reference_before_sigmoid[effect_num_dn:][ 0::(self.num_body_points + 1)] delta_unsig_dn = self.bbox_embed[layer_id](output_bbox_dn) delta_unsig_norm = self.bbox_embed[layer_id](output_bbox_norm) outputs_unsig_dn = delta_unsig_dn + reference_before_sigmoid_bbox_dn outputs_unsig_norm = delta_unsig_norm + reference_before_sigmoid_bbox_norm new_reference_points_for_box_dn = outputs_unsig_dn.sigmoid() new_reference_points_for_box_norm = outputs_unsig_norm.sigmoid() output_kpt = output[effect_num_dn:].index_select(0, torch.tensor(self.kpt_index, device=output.device)) delta_xy_unsig = self.pose_embed[layer_id - self.num_box_decoder_layers](output_kpt) outputs_unsig = reference_before_sigmoid[effect_num_dn:].index_select(0, torch.tensor(self.kpt_index, device=output.device)).clone() ## delta_hw_unsig = self.pose_hw_embed[layer_id - self.num_box_decoder_layers](output_kpt) outputs_unsig[..., :2] += delta_xy_unsig[..., :2] outputs_unsig[..., 2:] += delta_hw_unsig new_reference_points_for_keypoint = outputs_unsig.sigmoid() bs = new_reference_points_for_box_norm.shape[1] new_reference_points_norm = torch.cat((new_reference_points_for_box_norm.unsqueeze(1), new_reference_points_for_keypoint.view(-1, self.num_body_points, bs, 4)), dim=1).flatten(0, 1) new_reference_points = torch.cat((new_reference_points_for_box_dn, new_reference_points_norm), dim=0) if self.rm_detach and 'dec' in self.rm_detach: reference_points = new_reference_points else: reference_points = new_reference_points.detach() # if layer_id != self.num_layers - 1: if self.use_detached_boxes_dec_out: ref_points.append(reference_points) else: ref_points.append(new_reference_points) return [ [itm_out.transpose(0, 1) for itm_out in intermediate], [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points] ] class DeformableTransformerEncoderLayer(nn.Module): def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_levels=4, n_heads=8, n_points=4, add_channel_attention=False, use_deformable_box_attn=False, box_attn_type='roi_align', ): super().__init__() # self attention self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) self.dropout1 = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_model) # ffn self.linear1 = nn.Linear(d_model, d_ffn) self.activation = _get_activation_fn(activation, d_model=d_ffn) self.dropout2 = nn.Dropout(dropout) self.linear2 = nn.Linear(d_ffn, d_model) self.dropout3 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(d_model) # channel attention self.add_channel_attention = add_channel_attention if add_channel_attention: self.activ_channel = _get_activation_fn('dyrelu', d_model=d_model) self.norm_channel = nn.LayerNorm(d_model) @staticmethod def with_pos_embed(tensor, pos): return tensor if pos is None else tensor + pos def forward_ffn(self, src): src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) src = src + self.dropout3(src2) src = self.norm2(src) return src def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask=None): # self attention # import ipdb; ipdb.set_trace() src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, key_padding_mask) src = src + self.dropout1(src2) src = self.norm1(src) # ffn src = self.forward_ffn(src) # channel attn if self.add_channel_attention: src = self.norm_channel(src + self.activ_channel(src)) return src class DeformableTransformerDecoderLayer(nn.Module): def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_levels=4, n_heads=8, n_points=4, use_text_feat_guide=False, use_text_cross_attention=False, ffn_extra_layernorm=False ): super().__init__() # cross attention # self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.norm1 = nn.LayerNorm(d_model) # cross attention text if use_text_cross_attention: self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.catext_norm = nn.LayerNorm(d_model) # self attention self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.norm2 = nn.LayerNorm(d_model) # ffn self.linear1 = nn.Linear(d_model, d_ffn) self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1) self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.linear2 = nn.Linear(d_ffn, d_model) self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.norm3 = nn.LayerNorm(d_model) if ffn_extra_layernorm: raise NotImplementedError('ffn_extra_layernorm not implemented') self.norm_ext = nn.LayerNorm(d_ffn) else: self.norm_ext = None self.key_aware_proj = None self.use_text_feat_guide = use_text_feat_guide assert not use_text_feat_guide self.use_text_cross_attention = use_text_cross_attention def rm_self_attn_modules(self): self.self_attn = None self.dropout2 = None self.norm2 = None @staticmethod def with_pos_embed(tensor, pos): return tensor if pos is None else tensor + pos def forward_ffn(self, tgt, ipdb_flag=False): with torch.cuda.amp.autocast(enabled=False): tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout4(tgt2) tgt = self.norm3(tgt) return tgt def forward(self, # for tgt tgt: Optional[Tensor], # nq, bs, d_model tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos)) tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos) tgt_key_padding_mask: Optional[Tensor] = None, tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4 memory_text: Optional[Tensor] = None, # bs, num_token, d_model text_attention_mask: Optional[Tensor] = None, # bs, num_token # for memory memory: Optional[Tensor] = None, # hw, bs, d_model memory_key_padding_mask: Optional[Tensor] = None, memory_level_start_index: Optional[Tensor] = None, # num_levels memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 memory_pos: Optional[Tensor] = None, # pos for memory # sa self_attn_mask: Optional[Tensor] = None, # mask used for self-attention cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention ): """ Input: - tgt/tgt_query_pos: nq, bs, d_model - """ assert cross_attn_mask is None # self attention if self.self_attn is not None: # import ipdb; ipdb.set_trace() q = k = self.with_pos_embed(tgt, tgt_query_pos) tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1': # if tgt.isnan().any() | tgt.isinf().any() : # import ipdb; ipdb.set_trace() if self.use_text_cross_attention: tgt2 = self.ca_text(self.with_pos_embed(tgt, tgt_query_pos), memory_text.transpose(0, 1), memory_text.transpose(0, 1), key_padding_mask=text_attention_mask)[0] tgt = tgt + self.catext_dropout(tgt2) tgt = self.catext_norm(tgt) # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1': # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO': # import ipdb; ipdb.set_trace() # if tgt.isnan().any() | tgt.isinf().any() : # import ipdb; ipdb.set_trace() tgt2 = self.cross_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1), tgt_reference_points.transpose(0, 1).contiguous(), memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index, memory_key_padding_mask).transpose(0, 1) tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1': # tgtk = tgt.clone() # if tgt.isnan().any() | tgt.isinf().any() : # import ipdb; ipdb.set_trace() # ffn tgt = self.forward_ffn(tgt) # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1': # if tgt.isnan().any() | tgt.isinf().any() : # tgtk = self.forward_ffn(tgtk, ipdb_flag=True) # import ipdb; ipdb.set_trace() return tgt def _get_clones(module, N, layer_share=False): # import ipdb; ipdb.set_trace() if layer_share: return nn.ModuleList([module for i in range(N)]) else: return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def build_deformable_transformer(args): decoder_query_perturber = None if args.decoder_layer_noise: from .utils import RandomBoxPerturber decoder_query_perturber = RandomBoxPerturber( x_noise_scale=args.dln_xy_noise, y_noise_scale=args.dln_xy_noise, w_noise_scale=args.dln_hw_noise, h_noise_scale=args.dln_hw_noise) use_detached_boxes_dec_out = False try: use_detached_boxes_dec_out = args.use_detached_boxes_dec_out except: use_detached_boxes_dec_out = False binary_query_selection = False try: binary_query_selection = args.binary_query_selection except: binary_query_selection = False ffn_extra_layernorm = False try: ffn_extra_layernorm = args.ffn_extra_layernorm except: print('ffn_extra_layernorm not found, set to False') ffn_extra_layernorm = False return DeformableTransformer( d_model=args.hidden_dim, dropout=args.dropout, nhead=args.nheads, num_queries=args.num_queries, dim_feedforward=args.dim_feedforward, num_encoder_layers=args.enc_layers, num_unicoder_layers=args.unic_layers, num_decoder_layers=args.dec_layers, normalize_before=args.pre_norm, return_intermediate_dec=True, query_dim=args.query_dim, activation=args.transformer_activation, num_patterns=args.num_patterns, modulate_hw_attn=True, deformable_encoder=True, deformable_decoder=True, num_feature_levels=args.num_feature_levels, enc_n_points=args.enc_n_points, dec_n_points=args.dec_n_points, use_deformable_box_attn=args.use_deformable_box_attn, box_attn_type=args.box_attn_type, learnable_tgt_init=True, decoder_query_perturber=decoder_query_perturber, add_channel_attention=args.add_channel_attention, add_pos_value=args.add_pos_value, random_refpoints_xy=args.random_refpoints_xy, # two stage two_stage_type=args.two_stage_type, # ['no', 'standard', 'early'] two_stage_pat_embed=args.two_stage_pat_embed, two_stage_add_query_num=args.two_stage_add_query_num, two_stage_learn_wh=args.two_stage_learn_wh, two_stage_keep_all_tokens=args.two_stage_keep_all_tokens, dec_layer_number=args.dec_layer_number, rm_self_attn_layers=None, key_aware_type=None, layer_share_type=None, rm_detach=None, decoder_sa_type=args.decoder_sa_type, module_seq=args.decoder_module_seq, embed_init_tgt=args.embed_init_tgt, use_detached_boxes_dec_out=use_detached_boxes_dec_out, use_text_enhancer=args.use_text_enhancer, use_fusion_layer=args.use_fusion_layer, use_checkpoint=args.use_checkpoint, use_transformer_ckpt=args.use_transformer_ckpt, use_text_cross_attention=args.use_text_cross_attention, text_dropout=args.text_dropout, fusion_dropout=args.fusion_dropout, fusion_droppath=args.fusion_droppath, binary_query_selection=binary_query_selection, ffn_extra_layernorm=ffn_extra_layernorm, ) ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/fuse_modules.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F # from timm.models.layers import DropPath from src.modules.util import DropPath class FeatureResizer(nn.Module): """ This class takes as input a set of embeddings of dimension C1 and outputs a set of embedding of dimension C2, after a linear transformation, dropout and normalization (LN). """ def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True): super().__init__() self.do_ln = do_ln # Object feature encoding self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True) self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12) self.dropout = nn.Dropout(dropout) def forward(self, encoder_features): x = self.fc(encoder_features) if self.do_ln: x = self.layer_norm(x) output = self.dropout(x) return output def l1norm(X, dim, eps=1e-8): """L1-normalize columns of X """ norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps X = torch.div(X, norm) return X def l2norm(X, dim, eps=1e-8): """L2-normalize columns of X """ norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps X = torch.div(X, norm) return X def func_attention(query, context, smooth=1, raw_feature_norm="softmax", eps=1e-8): """ query: (n_context, queryL, d) context: (n_context, sourceL, d) """ batch_size_q, queryL = query.size(0), query.size(1) batch_size, sourceL = context.size(0), context.size(1) # Get attention # --> (batch, d, queryL) queryT = torch.transpose(query, 1, 2) # (batch, sourceL, d)(batch, d, queryL) # --> (batch, sourceL, queryL) attn = torch.bmm(context, queryT) if raw_feature_norm == "softmax": # --> (batch*sourceL, queryL) attn = attn.view(batch_size * sourceL, queryL) attn = nn.Softmax()(attn) # --> (batch, sourceL, queryL) attn = attn.view(batch_size, sourceL, queryL) elif raw_feature_norm == "l2norm": attn = l2norm(attn, 2) elif raw_feature_norm == "clipped_l2norm": attn = nn.LeakyReLU(0.1)(attn) attn = l2norm(attn, 2) else: raise ValueError("unknown first norm type:", raw_feature_norm) # --> (batch, queryL, sourceL) attn = torch.transpose(attn, 1, 2).contiguous() # --> (batch*queryL, sourceL) attn = attn.view(batch_size * queryL, sourceL) attn = nn.Softmax()(attn * smooth) # --> (batch, queryL, sourceL) attn = attn.view(batch_size, queryL, sourceL) # --> (batch, sourceL, queryL) attnT = torch.transpose(attn, 1, 2).contiguous() # --> (batch, d, sourceL) contextT = torch.transpose(context, 1, 2) # (batch x d x sourceL)(batch x sourceL x queryL) # --> (batch, d, queryL) weightedContext = torch.bmm(contextT, attnT) # --> (batch, queryL, d) weightedContext = torch.transpose(weightedContext, 1, 2) return weightedContext, attnT class BiMultiHeadAttention(nn.Module): def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, cfg=None): super(BiMultiHeadAttention, self).__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.v_dim = v_dim self.l_dim = l_dim assert ( self.head_dim * self.num_heads == self.embed_dim ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." self.scale = self.head_dim ** (-0.5) self.dropout = dropout self.v_proj = nn.Linear(self.v_dim, self.embed_dim) self.l_proj = nn.Linear(self.l_dim, self.embed_dim) self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim) self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim) self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim) self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim) self.stable_softmax_2d = True self.clamp_min_for_underflow = True self.clamp_max_for_overflow = True self._reset_parameters() def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def _reset_parameters(self): nn.init.xavier_uniform_(self.v_proj.weight) self.v_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.l_proj.weight) self.l_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.values_v_proj.weight) self.values_v_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.values_l_proj.weight) self.values_l_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.out_v_proj.weight) self.out_v_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.out_l_proj.weight) self.out_l_proj.bias.data.fill_(0) def forward(self, v, l, attention_mask_v=None, attention_mask_l=None): """_summary_ Args: v (_type_): bs, n_img, dim l (_type_): bs, n_text, dim attention_mask_v (_type_, optional): _description_. bs, n_img attention_mask_l (_type_, optional): _description_. bs, n_text Returns: _type_: _description_ """ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO': # import ipdb; ipdb.set_trace() bsz, tgt_len, _ = v.size() query_states = self.v_proj(v) * self.scale key_states = self._shape(self.l_proj(l), -1, bsz) value_v_states = self._shape(self.values_v_proj(v), -1, bsz) value_l_states = self._shape(self.values_l_proj(l), -1, bsz) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_v_states = value_v_states.view(*proj_shape) value_l_states = value_l_states.view(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # bs*nhead, nimg, ntxt if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" ) if self.stable_softmax_2d: attn_weights = attn_weights - attn_weights.max() if self.clamp_min_for_underflow: attn_weights = torch.clamp(attn_weights, min=-50000) # Do not increase -50000, data type half has quite limited range if self.clamp_max_for_overflow: attn_weights = torch.clamp(attn_weights, max=50000) # Do not increase 50000, data type half has quite limited range attn_weights_T = attn_weights.transpose(1, 2) attn_weights_l = (attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[ 0]) if self.clamp_min_for_underflow: attn_weights_l = torch.clamp(attn_weights_l, min=-50000) # Do not increase -50000, data type half has quite limited range if self.clamp_max_for_overflow: attn_weights_l = torch.clamp(attn_weights_l, max=50000) # Do not increase 50000, data type half has quite limited range # mask vison for language if attention_mask_v is not None: attention_mask_v = attention_mask_v[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1) attn_weights_l.masked_fill_(attention_mask_v, float('-inf')) attn_weights_l = attn_weights_l.softmax(dim=-1) # mask language for vision if attention_mask_l is not None: attention_mask_l = attention_mask_l[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1) attn_weights.masked_fill_(attention_mask_l, float('-inf')) attn_weights_v = attn_weights.softmax(dim=-1) attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training) attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training) attn_output_v = torch.bmm(attn_probs_v, value_l_states) attn_output_l = torch.bmm(attn_probs_l, value_v_states) if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}" ) if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim): raise ValueError( f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}" ) attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output_v = attn_output_v.transpose(1, 2) attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim) attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim) attn_output_l = attn_output_l.transpose(1, 2) attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim) attn_output_v = self.out_v_proj(attn_output_v) attn_output_l = self.out_l_proj(attn_output_l) return attn_output_v, attn_output_l # Bi-Direction MHA (text->image, image->text) class BiAttentionBlock(nn.Module): def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, drop_path=.0, init_values=1e-4, cfg=None): """ Inputs: embed_dim - Dimensionality of input and attention feature vectors hidden_dim - Dimensionality of hidden layer in feed-forward network (usually 2-4x larger than embed_dim) num_heads - Number of heads to use in the Multi-Head Attention block dropout - Amount of dropout to apply in the feed-forward network """ super(BiAttentionBlock, self).__init__() # pre layer norm self.layer_norm_v = nn.LayerNorm(v_dim) self.layer_norm_l = nn.LayerNorm(l_dim) self.attn = BiMultiHeadAttention(v_dim=v_dim, l_dim=l_dim, embed_dim=embed_dim, num_heads=num_heads, dropout=dropout) # add layer scale for training stability self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=False) self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=False) def forward(self, v, l, attention_mask_v=None, attention_mask_l=None): v = self.layer_norm_v(v) l = self.layer_norm_l(l) delta_v, delta_l = self.attn(v, l, attention_mask_v=attention_mask_v, attention_mask_l=attention_mask_l) # v, l = v + delta_v, l + delta_l v = v + self.drop_path(self.gamma_v * delta_v) l = l + self.drop_path(self.gamma_l * delta_l) return v, l ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/mask_generate.py ================================================ import torch def prepare_for_mask(kpt_mask): tgt_size2 = 50 * 69 attn_mask2 = torch.ones(kpt_mask.shape[0], 8, tgt_size2, tgt_size2).to('cuda') < 0 group_bbox_kpt = 69 num_group=50 for matchj in range(num_group * group_bbox_kpt): sj = (matchj // group_bbox_kpt) * group_bbox_kpt ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt if sj > 0: attn_mask2[:,:,matchj, :sj] = True if ej < num_group * group_bbox_kpt: attn_mask2[:,:,matchj, ej:] = True bs, length = kpt_mask.shape equal_mask = kpt_mask[:, :, None] == kpt_mask[:, None, :] equal_mask= equal_mask.unsqueeze(1).repeat(1,8,1,1) for idx in range(num_group): start_idx = idx * length end_idx = (idx + 1) * length attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][equal_mask] = False attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][~equal_mask] = True input_query_label = None input_query_bbox = None attn_mask = None dn_meta = None return input_query_label, input_query_bbox, attn_mask, attn_mask2.flatten(0,1), dn_meta def post_process(outputs_class, outputs_coord, dn_meta, aux_loss, _set_aux_loss): if dn_meta and dn_meta['pad_size'] > 0: output_known_class = [outputs_class_i[:, :dn_meta['pad_size'], :] for outputs_class_i in outputs_class] output_known_coord = [outputs_coord_i[:, :dn_meta['pad_size'], :] for outputs_coord_i in outputs_coord] outputs_class = [outputs_class_i[:, dn_meta['pad_size']:, :] for outputs_class_i in outputs_class] outputs_coord = [outputs_coord_i[:, dn_meta['pad_size']:, :] for outputs_coord_i in outputs_coord] out = {'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1]} if aux_loss: out['aux_outputs'] = _set_aux_loss(output_known_class, output_known_coord) dn_meta['output_known_lbs_bboxes'] = out return outputs_class, outputs_coord ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/ops/functions/__init__.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ from .ms_deform_attn_func import MSDeformAttnFunction ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/ops/functions/ms_deform_attn_func.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ from __future__ import absolute_import from __future__ import print_function from __future__ import division import torch import torch.nn.functional as F from torch.autograd import Function from torch.autograd.function import once_differentiable import MultiScaleDeformableAttention as MSDA class MSDeformAttnFunction(Function): @staticmethod def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): ctx.im2col_step = im2col_step output = MSDA.ms_deform_attn_forward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) return output @staticmethod @once_differentiable def backward(ctx, grad_output): value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors grad_value, grad_sampling_loc, grad_attn_weight = \ MSDA.ms_deform_attn_backward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): # for debug and test only, # need to use cuda version instead N_, S_, M_, D_ = value.shape _, Lq_, M_, L_, P_, _ = sampling_locations.shape value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] for lid_, (H_, W_) in enumerate(value_spatial_shapes): # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) # N_*M_, D_, Lq_, P_ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, mode='bilinear', padding_mode='zeros', align_corners=False) sampling_value_list.append(sampling_value_l_) # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) return output.transpose(1, 2).contiguous() ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/ops/modules/__init__.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ from .ms_deform_attn import MSDeformAttn ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/ops/modules/ms_deform_attn.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ from __future__ import absolute_import from __future__ import print_function from __future__ import division import warnings import math, os import sys sys.path.append(os.path.dirname(os.path.abspath(__file__))) import torch from torch import nn import torch.nn.functional as F from torch.nn.init import xavier_uniform_, constant_ from src.utils.dependencies.XPose.models.UniPose.ops.functions.ms_deform_attn_func import MSDeformAttnFunction def _is_power_of_2(n): if (not isinstance(n, int)) or (n < 0): raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) return (n & (n-1) == 0) and n != 0 class MSDeformAttn(nn.Module): def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, use_4D_normalizer=False): """ Multi-Scale Deformable Attention Module :param d_model hidden dimension :param n_levels number of feature levels :param n_heads number of attention heads :param n_points number of sampling points per attention head per feature level """ super().__init__() if d_model % n_heads != 0: raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) _d_per_head = d_model // n_heads # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation if not _is_power_of_2(_d_per_head): warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " "which is more efficient in our CUDA implementation.") self.im2col_step = 64 self.d_model = d_model self.n_levels = n_levels self.n_heads = n_heads self.n_points = n_points self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) self.value_proj = nn.Linear(d_model, d_model) self.output_proj = nn.Linear(d_model, d_model) self.use_4D_normalizer = use_4D_normalizer self._reset_parameters() def _reset_parameters(self): constant_(self.sampling_offsets.weight.data, 0.) thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) for i in range(self.n_points): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) constant_(self.attention_weights.weight.data, 0.) constant_(self.attention_weights.bias.data, 0.) xavier_uniform_(self.value_proj.weight.data) constant_(self.value_proj.bias.data, 0.) xavier_uniform_(self.output_proj.weight.data) constant_(self.output_proj.bias.data, 0.) def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): """ :param query (N, Length_{query}, C) :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements :return output (N, Length_{query}, C) """ N, Len_q, _ = query.shape N, Len_in, _ = input_flatten.shape assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in value = self.value_proj(input_flatten) if input_padding_mask is not None: value = value.masked_fill(input_padding_mask[..., None], float(0)) value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) # N, Len_q, n_heads, n_levels, n_points, 2 # if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO': # import ipdb; ipdb.set_trace() if reference_points.shape[-1] == 2: offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) sampling_locations = reference_points[:, :, None, :, None, :] \ + sampling_offsets / offset_normalizer[None, None, None, :, None, :] elif reference_points.shape[-1] == 4: if self.use_4D_normalizer: offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) sampling_locations = reference_points[:, :, None, :, None, :2] \ + sampling_offsets / offset_normalizer[None, None, None, :, None, :] * reference_points[:, :, None, :, None, 2:] * 0.5 else: sampling_locations = reference_points[:, :, None, :, None, :2] \ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 else: raise ValueError( 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) # if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO': # import ipdb; ipdb.set_trace() # for amp if value.dtype == torch.float16: # for mixed precision output = MSDeformAttnFunction.apply( value.to(torch.float32), input_spatial_shapes, input_level_start_index, sampling_locations.to(torch.float32), attention_weights, self.im2col_step) output = output.to(torch.float16) output = self.output_proj(output) return output output = MSDeformAttnFunction.apply( value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) output = self.output_proj(output) return output ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/ops/modules/ms_deform_attn_key_aware.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ from __future__ import absolute_import from __future__ import print_function from __future__ import division import warnings import math, os import torch from torch import nn import torch.nn.functional as F from torch.nn.init import xavier_uniform_, constant_ try: from src.utils.dependencies.XPose.models.UniPose.ops.functions import MSDeformAttnFunction except: warnings.warn('Failed to import MSDeformAttnFunction.') def _is_power_of_2(n): if (not isinstance(n, int)) or (n < 0): raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) return (n & (n-1) == 0) and n != 0 class MSDeformAttn(nn.Module): def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, use_4D_normalizer=False): """ Multi-Scale Deformable Attention Module :param d_model hidden dimension :param n_levels number of feature levels :param n_heads number of attention heads :param n_points number of sampling points per attention head per feature level """ super().__init__() if d_model % n_heads != 0: raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) _d_per_head = d_model // n_heads # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation if not _is_power_of_2(_d_per_head): warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " "which is more efficient in our CUDA implementation.") self.im2col_step = 64 self.d_model = d_model self.n_levels = n_levels self.n_heads = n_heads self.n_points = n_points self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) self.value_proj = nn.Linear(d_model, d_model) self.output_proj = nn.Linear(d_model, d_model) self.use_4D_normalizer = use_4D_normalizer self._reset_parameters() def _reset_parameters(self): constant_(self.sampling_offsets.weight.data, 0.) thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) for i in range(self.n_points): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) constant_(self.attention_weights.weight.data, 0.) constant_(self.attention_weights.bias.data, 0.) xavier_uniform_(self.value_proj.weight.data) constant_(self.value_proj.bias.data, 0.) xavier_uniform_(self.output_proj.weight.data) constant_(self.output_proj.bias.data, 0.) def forward(self, query, key, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): """ :param query (N, Length_{query}, C) :param key (N, 1, C) :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements :return output (N, Length_{query}, C) """ N, Len_q, _ = query.shape N, Len_in, _ = input_flatten.shape assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in value = self.value_proj(input_flatten) if input_padding_mask is not None: value = value.masked_fill(input_padding_mask[..., None], float(0)) value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) # N, Len_q, n_heads, n_levels, n_points, 2 # if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO': # import ipdb; ipdb.set_trace() if reference_points.shape[-1] == 2: offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) sampling_locations = reference_points[:, :, None, :, None, :] \ + sampling_offsets / offset_normalizer[None, None, None, :, None, :] elif reference_points.shape[-1] == 4: if self.use_4D_normalizer: offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) sampling_locations = reference_points[:, :, None, :, None, :2] \ + sampling_offsets / offset_normalizer[None, None, None, :, None, :] * reference_points[:, :, None, :, None, 2:] * 0.5 else: sampling_locations = reference_points[:, :, None, :, None, :2] \ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 else: raise ValueError( 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) output = MSDeformAttnFunction.apply( value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) output = self.output_proj(output) return output ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/ops/setup.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ import os import glob import torch from torch.utils.cpp_extension import CUDA_HOME from torch.utils.cpp_extension import CppExtension from torch.utils.cpp_extension import CUDAExtension from setuptools import find_packages from setuptools import setup requirements = ["torch", "torchvision"] def get_extensions(): this_dir = os.path.dirname(os.path.abspath(__file__)) extensions_dir = os.path.join(this_dir, "src") main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) sources = main_file + source_cpu extension = CppExtension extra_compile_args = {"cxx": []} define_macros = [] # import ipdb; ipdb.set_trace() if torch.cuda.is_available() and CUDA_HOME is not None: extension = CUDAExtension sources += source_cuda define_macros += [("WITH_CUDA", None)] extra_compile_args["nvcc"] = [ "-DCUDA_HAS_FP16=1", "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__", "-D__CUDA_NO_HALF2_OPERATORS__", ] else: raise NotImplementedError('Cuda is not availabel') sources = [os.path.join(extensions_dir, s) for s in sources] include_dirs = [extensions_dir] ext_modules = [ extension( "MultiScaleDeformableAttention", sources, include_dirs=include_dirs, define_macros=define_macros, extra_compile_args=extra_compile_args, ) ] return ext_modules setup( name="MultiScaleDeformableAttention", version="1.0", author="Weijie Su", url="https://github.com/fundamentalvision/Deformable-DETR", description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", packages=find_packages(exclude=("configs", "tests",)), ext_modules=get_extensions(), cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, ) ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/ops/src/cpu/ms_deform_attn_cpu.cpp ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ #include #include #include at::Tensor ms_deform_attn_cpu_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step) { AT_ERROR("Not implement on cpu"); } std::vector ms_deform_attn_cpu_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step) { AT_ERROR("Not implement on cpu"); } ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/ops/src/cpu/ms_deform_attn_cpu.h ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ #pragma once #include at::Tensor ms_deform_attn_cpu_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step); std::vector ms_deform_attn_cpu_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step); ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/ops/src/cuda/ms_deform_attn_cuda.cu ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ #include #include "cuda/ms_deform_im2col_cuda.cuh" #include #include #include #include at::Tensor ms_deform_attn_cuda_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step) { AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); const int batch = value.size(0); const int spatial_size = value.size(1); const int num_heads = value.size(2); const int channels = value.size(3); const int num_levels = spatial_shapes.size(0); const int num_query = sampling_loc.size(1); const int num_point = sampling_loc.size(4); const int im2col_step_ = std::min(batch, im2col_step); AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); const int batch_n = im2col_step_; auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); auto per_value_size = spatial_size * num_heads * channels; auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; for (int n = 0; n < batch/im2col_step_; ++n) { auto columns = output_n.select(0, n); AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] { ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), value.data() + n * im2col_step_ * per_value_size, spatial_shapes.data(), level_start_index.data(), sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, attn_weight.data() + n * im2col_step_ * per_attn_weight_size, batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, columns.data()); })); } output = output.view({batch, num_query, num_heads*channels}); return output; } std::vector ms_deform_attn_cuda_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step) { AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); const int batch = value.size(0); const int spatial_size = value.size(1); const int num_heads = value.size(2); const int channels = value.size(3); const int num_levels = spatial_shapes.size(0); const int num_query = sampling_loc.size(1); const int num_point = sampling_loc.size(4); const int im2col_step_ = std::min(batch, im2col_step); AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); auto grad_value = at::zeros_like(value); auto grad_sampling_loc = at::zeros_like(sampling_loc); auto grad_attn_weight = at::zeros_like(attn_weight); const int batch_n = im2col_step_; auto per_value_size = spatial_size * num_heads * channels; auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); for (int n = 0; n < batch/im2col_step_; ++n) { auto grad_output_g = grad_output_n.select(0, n); AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] { ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), grad_output_g.data(), value.data() + n * im2col_step_ * per_value_size, spatial_shapes.data(), level_start_index.data(), sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, attn_weight.data() + n * im2col_step_ * per_attn_weight_size, batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value.data() + n * im2col_step_ * per_value_size, grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); })); } return { grad_value, grad_sampling_loc, grad_attn_weight }; } ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/ops/src/cuda/ms_deform_attn_cuda.h ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ #pragma once #include at::Tensor ms_deform_attn_cuda_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step); std::vector ms_deform_attn_cuda_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step); ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/ops/src/cuda/ms_deform_im2col_cuda.cuh ================================================ /*! ************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************** * Modified from DCN (https://github.com/msracver/Deformable-ConvNets) * Copyright (c) 2018 Microsoft ************************************************************************** */ #include #include #include #include #include #include #define CUDA_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ i < (n); \ i += blockDim.x * gridDim.x) const int CUDA_NUM_THREADS = 1024; inline int GET_BLOCKS(const int N, const int num_threads) { return (N + num_threads - 1) / num_threads; } template __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, const int &height, const int &width, const int &nheads, const int &channels, const scalar_t &h, const scalar_t &w, const int &m, const int &c) { const int h_low = floor(h); const int w_low = floor(w); const int h_high = h_low + 1; const int w_high = w_low + 1; const scalar_t lh = h - h_low; const scalar_t lw = w - w_low; const scalar_t hh = 1 - lh, hw = 1 - lw; const int w_stride = nheads * channels; const int h_stride = width * w_stride; const int h_low_ptr_offset = h_low * h_stride; const int h_high_ptr_offset = h_low_ptr_offset + h_stride; const int w_low_ptr_offset = w_low * w_stride; const int w_high_ptr_offset = w_low_ptr_offset + w_stride; const int base_ptr = m * channels + c; scalar_t v1 = 0; if (h_low >= 0 && w_low >= 0) { const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; v1 = bottom_data[ptr1]; } scalar_t v2 = 0; if (h_low >= 0 && w_high <= width - 1) { const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; v2 = bottom_data[ptr2]; } scalar_t v3 = 0; if (h_high <= height - 1 && w_low >= 0) { const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; v3 = bottom_data[ptr3]; } scalar_t v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) { const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; v4 = bottom_data[ptr4]; } const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); return val; } template __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, const int &height, const int &width, const int &nheads, const int &channels, const scalar_t &h, const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad, const scalar_t &attn_weight, scalar_t* &grad_value, scalar_t* grad_sampling_loc, scalar_t* grad_attn_weight) { const int h_low = floor(h); const int w_low = floor(w); const int h_high = h_low + 1; const int w_high = w_low + 1; const scalar_t lh = h - h_low; const scalar_t lw = w - w_low; const scalar_t hh = 1 - lh, hw = 1 - lw; const int w_stride = nheads * channels; const int h_stride = width * w_stride; const int h_low_ptr_offset = h_low * h_stride; const int h_high_ptr_offset = h_low_ptr_offset + h_stride; const int w_low_ptr_offset = w_low * w_stride; const int w_high_ptr_offset = w_low_ptr_offset + w_stride; const int base_ptr = m * channels + c; const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; const scalar_t top_grad_value = top_grad * attn_weight; scalar_t grad_h_weight = 0, grad_w_weight = 0; scalar_t v1 = 0; if (h_low >= 0 && w_low >= 0) { const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; v1 = bottom_data[ptr1]; grad_h_weight -= hw * v1; grad_w_weight -= hh * v1; atomicAdd(grad_value+ptr1, w1*top_grad_value); } scalar_t v2 = 0; if (h_low >= 0 && w_high <= width - 1) { const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; v2 = bottom_data[ptr2]; grad_h_weight -= lw * v2; grad_w_weight += hh * v2; atomicAdd(grad_value+ptr2, w2*top_grad_value); } scalar_t v3 = 0; if (h_high <= height - 1 && w_low >= 0) { const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; v3 = bottom_data[ptr3]; grad_h_weight += hw * v3; grad_w_weight -= lh * v3; atomicAdd(grad_value+ptr3, w3*top_grad_value); } scalar_t v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) { const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; v4 = bottom_data[ptr4]; grad_h_weight += lw * v4; grad_w_weight += lh * v4; atomicAdd(grad_value+ptr4, w4*top_grad_value); } const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); *grad_attn_weight = top_grad * val; *grad_sampling_loc = width * grad_w_weight * top_grad_value; *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; } template __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, const int &height, const int &width, const int &nheads, const int &channels, const scalar_t &h, const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad, const scalar_t &attn_weight, scalar_t* &grad_value, scalar_t* grad_sampling_loc, scalar_t* grad_attn_weight) { const int h_low = floor(h); const int w_low = floor(w); const int h_high = h_low + 1; const int w_high = w_low + 1; const scalar_t lh = h - h_low; const scalar_t lw = w - w_low; const scalar_t hh = 1 - lh, hw = 1 - lw; const int w_stride = nheads * channels; const int h_stride = width * w_stride; const int h_low_ptr_offset = h_low * h_stride; const int h_high_ptr_offset = h_low_ptr_offset + h_stride; const int w_low_ptr_offset = w_low * w_stride; const int w_high_ptr_offset = w_low_ptr_offset + w_stride; const int base_ptr = m * channels + c; const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; const scalar_t top_grad_value = top_grad * attn_weight; scalar_t grad_h_weight = 0, grad_w_weight = 0; scalar_t v1 = 0; if (h_low >= 0 && w_low >= 0) { const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; v1 = bottom_data[ptr1]; grad_h_weight -= hw * v1; grad_w_weight -= hh * v1; atomicAdd(grad_value+ptr1, w1*top_grad_value); } scalar_t v2 = 0; if (h_low >= 0 && w_high <= width - 1) { const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; v2 = bottom_data[ptr2]; grad_h_weight -= lw * v2; grad_w_weight += hh * v2; atomicAdd(grad_value+ptr2, w2*top_grad_value); } scalar_t v3 = 0; if (h_high <= height - 1 && w_low >= 0) { const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; v3 = bottom_data[ptr3]; grad_h_weight += hw * v3; grad_w_weight -= lh * v3; atomicAdd(grad_value+ptr3, w3*top_grad_value); } scalar_t v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) { const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; v4 = bottom_data[ptr4]; grad_h_weight += lw * v4; grad_w_weight += lh * v4; atomicAdd(grad_value+ptr4, w4*top_grad_value); } const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); atomicAdd(grad_attn_weight, top_grad * val); atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); } template __global__ void ms_deformable_im2col_gpu_kernel(const int n, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *data_col) { CUDA_KERNEL_LOOP(index, n) { int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; scalar_t *data_col_ptr = data_col + index; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; scalar_t col = 0; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; } data_weight_ptr += 1; data_loc_w_ptr += 2; } } *data_col_ptr = col; } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; __shared__ scalar_t cache_grad_attn_weight[blockSize]; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); if (tid == 0) { scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; int sid=2; for (unsigned int tid = 1; tid < blockSize; ++tid) { _grad_w += cache_grad_sampling_loc[sid]; _grad_h += cache_grad_sampling_loc[sid + 1]; _grad_a += cache_grad_attn_weight[tid]; sid += 2; } *grad_sampling_loc = _grad_w; *(grad_sampling_loc + 1) = _grad_h; *grad_attn_weight = _grad_a; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; __shared__ scalar_t cache_grad_attn_weight[blockSize]; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); for (unsigned int s=blockSize/2; s>0; s>>=1) { if (tid < s) { const unsigned int xid1 = tid << 1; const unsigned int xid2 = (tid + s) << 1; cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; } __syncthreads(); } if (tid == 0) { *grad_sampling_loc = cache_grad_sampling_loc[0]; *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; *grad_attn_weight = cache_grad_attn_weight[0]; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { extern __shared__ int _s[]; scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); if (tid == 0) { scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; int sid=2; for (unsigned int tid = 1; tid < blockDim.x; ++tid) { _grad_w += cache_grad_sampling_loc[sid]; _grad_h += cache_grad_sampling_loc[sid + 1]; _grad_a += cache_grad_attn_weight[tid]; sid += 2; } *grad_sampling_loc = _grad_w; *(grad_sampling_loc + 1) = _grad_h; *grad_attn_weight = _grad_a; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { extern __shared__ int _s[]; scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) { if (tid < s) { const unsigned int xid1 = tid << 1; const unsigned int xid2 = (tid + s) << 1; cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; if (tid + (s << 1) < spre) { cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; } } __syncthreads(); } if (tid == 0) { *grad_sampling_loc = cache_grad_sampling_loc[0]; *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; *grad_attn_weight = cache_grad_attn_weight[0]; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { extern __shared__ int _s[]; scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) { if (tid < s) { const unsigned int xid1 = tid << 1; const unsigned int xid2 = (tid + s) << 1; cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; if (tid + (s << 1) < spre) { cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; } } __syncthreads(); } if (tid == 0) { atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear_gm( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, grad_sampling_loc, grad_attn_weight); } data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t* data_value, const int64_t* data_spatial_shapes, const int64_t* data_level_start_index, const scalar_t* data_sampling_loc, const scalar_t* data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t* data_col) { const int num_kernels = batch_size * num_query * num_heads * channels; const int num_actual_kernels = batch_size * num_query * num_heads * channels; const int num_threads = CUDA_NUM_THREADS; ms_deformable_im2col_gpu_kernel <<>>( num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); } } template void ms_deformable_col2im_cuda(cudaStream_t stream, const scalar_t* grad_col, const scalar_t* data_value, const int64_t * data_spatial_shapes, const int64_t * data_level_start_index, const scalar_t * data_sampling_loc, const scalar_t * data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t* grad_value, scalar_t* grad_sampling_loc, scalar_t* grad_attn_weight) { const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; const int num_kernels = batch_size * num_query * num_heads * channels; const int num_actual_kernels = batch_size * num_query * num_heads * channels; if (channels > 1024) { if ((channels & 1023) == 0) { ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } else { ms_deformable_col2im_gpu_kernel_gm <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } } else{ switch(channels) { case 1: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 2: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 4: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 8: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 16: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 32: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 64: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 128: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 256: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 512: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 1024: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; default: if (channels < 64) { ms_deformable_col2im_gpu_kernel_shm_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } else { ms_deformable_col2im_gpu_kernel_shm_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } } } cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); } } ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/ops/src/ms_deform_attn.h ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ #pragma once #include "cpu/ms_deform_attn_cpu.h" #ifdef WITH_CUDA #include "cuda/ms_deform_attn_cuda.h" #endif at::Tensor ms_deform_attn_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step) { if (value.type().is_cuda()) { #ifdef WITH_CUDA return ms_deform_attn_cuda_forward( value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); #else AT_ERROR("Not compiled with GPU support"); #endif } AT_ERROR("Not implemented on the CPU"); } std::vector ms_deform_attn_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step) { if (value.type().is_cuda()) { #ifdef WITH_CUDA return ms_deform_attn_cuda_backward( value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); #else AT_ERROR("Not compiled with GPU support"); #endif } AT_ERROR("Not implemented on the CPU"); } ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/ops/src/vision.cpp ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ #include "ms_deform_attn.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); } ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/ops/test.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ from __future__ import absolute_import from __future__ import print_function from __future__ import division import time import torch import torch.nn as nn from torch.autograd import gradcheck from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch N, M, D = 1, 2, 2 Lq, L, P = 2, 2, 2 shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) S = sum([(H*W).item() for H, W in shapes]) torch.manual_seed(3) @torch.no_grad() def check_forward_equal_with_pytorch_double(): value = torch.rand(N, S, M, D).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) im2col_step = 2 output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() fwdok = torch.allclose(output_cuda, output_pytorch) max_abs_err = (output_cuda - output_pytorch).abs().max() max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') @torch.no_grad() def check_forward_equal_with_pytorch_float(): value = torch.rand(N, S, M, D).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) im2col_step = 2 output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) max_abs_err = (output_cuda - output_pytorch).abs().max() max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): value = torch.rand(N, S, M, channels).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) im2col_step = 2 func = MSDeformAttnFunction.apply value.requires_grad = grad_value sampling_locations.requires_grad = grad_sampling_loc attention_weights.requires_grad = grad_attn_weight gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) print(f'* {gradok} check_gradient_numerical(D={channels})') if __name__ == '__main__': check_forward_equal_with_pytorch_double() check_forward_equal_with_pytorch_float() for channels in [30, 32, 64, 71, 1025, 2048, 3096]: check_gradient_numerical(channels, True, True, True) ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/position_encoding.py ================================================ # ------------------------------------------------------------------------ # ED-Pose # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Conditional DETR # Copyright (c) 2021 Microsoft. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Copied from DETR (https://github.com/facebookresearch/detr) # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # ------------------------------------------------------------------------ """ Various positional encodings for the transformer. """ import math import torch from torch import nn from util.misc import NestedTensor class PositionEmbeddingSine(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images. """ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): super().__init__() self.num_pos_feats = num_pos_feats self.temperature = temperature self.normalize = normalize if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") if scale is None: scale = 2 * math.pi self.scale = scale def forward(self, tensor_list: NestedTensor): x = tensor_list.tensors mask = tensor_list.mask assert mask is not None not_mask = ~mask y_embed = not_mask.cumsum(1, dtype=torch.float32) x_embed = not_mask.cumsum(2, dtype=torch.float32) if self.normalize: eps = 1e-6 # if os.environ.get("SHILONG_AMP", None) == '1': # eps = 1e-4 # else: # eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) return pos class PositionEmbeddingSineHW(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images. """ def __init__(self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None): super().__init__() self.num_pos_feats = num_pos_feats self.temperatureH = temperatureH self.temperatureW = temperatureW self.normalize = normalize if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") if scale is None: scale = 2 * math.pi self.scale = scale def forward(self, tensor_list: NestedTensor): x = tensor_list.tensors mask = tensor_list.mask assert mask is not None not_mask = ~mask y_embed = not_mask.cumsum(1, dtype=torch.float32) x_embed = not_mask.cumsum(2, dtype=torch.float32) # import ipdb; ipdb.set_trace() if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_tx dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats) pos_y = y_embed[:, :, :, None] / dim_ty pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) # import ipdb; ipdb.set_trace() return pos class PositionEmbeddingLearned(nn.Module): """ Absolute pos embedding, learned. """ def __init__(self, num_pos_feats=256): super().__init__() self.row_embed = nn.Embedding(50, num_pos_feats) self.col_embed = nn.Embedding(50, num_pos_feats) self.reset_parameters() def reset_parameters(self): nn.init.uniform_(self.row_embed.weight) nn.init.uniform_(self.col_embed.weight) def forward(self, tensor_list: NestedTensor): x = tensor_list.tensors h, w = x.shape[-2:] i = torch.arange(w, device=x.device) j = torch.arange(h, device=x.device) x_emb = self.col_embed(i) y_emb = self.row_embed(j) pos = torch.cat([ x_emb.unsqueeze(0).repeat(h, 1, 1), y_emb.unsqueeze(1).repeat(1, w, 1), ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) return pos def build_position_encoding(args): N_steps = args.hidden_dim // 2 if args.position_embedding in ('v2', 'sine'): # TODO find a better way of exposing other arguments position_embedding = PositionEmbeddingSineHW( N_steps, temperatureH=args.pe_temperatureH, temperatureW=args.pe_temperatureW, normalize=True ) elif args.position_embedding in ('v3', 'learned'): position_embedding = PositionEmbeddingLearned(N_steps) else: raise ValueError(f"not supported {args.position_embedding}") return position_embedding ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/swin_transformer.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint import numpy as np from util.misc import NestedTensor # from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from src.modules.util import DropPath, to_2tuple, trunc_normal_ class Mlp(nn.Module): """ Multilayer perceptron.""" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): """ Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """ Forward function. Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x class SwinTransformerBlock(nn.Module): """ Swin Transformer Block. Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.H = None self.W = None def forward(self, x, mask_matrix): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. mask_matrix: Attention mask for cyclic shift. """ B, L, C = x.shape H, W = self.H, self.W assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # pad feature maps to multiples of window size pad_l = pad_t = 0 pad_r = (self.window_size - W % self.window_size) % self.window_size pad_b = (self.window_size - H % self.window_size) % self.window_size x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) _, Hp, Wp, _ = x.shape # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) attn_mask = mask_matrix else: shifted_x = x attn_mask = None # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x if pad_r > 0 or pad_b > 0: x = x[:, :H, :W, :].contiguous() x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PatchMerging(nn.Module): """ Patch Merging Layer Args: dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x, H, W): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) # padding pad_input = (H % 2 == 1) or (W % 2 == 1) if pad_input: x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of feature channels depth (int): Depths of this stage. num_heads (int): Number of attention head. window_size (int): Local window size. Default: 7. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, dim, depth, num_heads, window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): super().__init__() self.window_size = window_size self.shift_size = window_size // 2 self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ SwinTransformerBlock( dim=dim, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample(dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x, H, W): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ # calculate attention mask for SW-MSA Hp = int(np.ceil(H / self.window_size)) * self.window_size Wp = int(np.ceil(W / self.window_size)) * self.window_size img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) for blk in self.blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, attn_mask) else: x = blk(x, attn_mask) if self.downsample is not None: x_down = self.downsample(x, H, W) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x, H, W, x_down, Wh, Ww else: return x, H, W, x, H, W class PatchEmbed(nn.Module): """ Image to Patch Embedding Args: patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): """Forward function.""" # padding _, _, H, W = x.size() if W % self.patch_size[1] != 0: x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) if H % self.patch_size[0] != 0: x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) x = self.proj(x) # B C Wh Ww if self.norm is not None: Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.norm(x) x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) return x class SwinTransformer(nn.Module): """ Swin Transformer backbone. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Args: pretrain_img_size (int): Input image size for training the pretrained model, used in absolute postion embedding. Default 224. patch_size (int | tuple(int)): Patch size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. depths (tuple[int]): Depths of each Swin Transformer stage. num_heads (tuple[int]): Number of attention head of each stage. window_size (int): Window size. Default: 7. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. drop_rate (float): Dropout rate. attn_drop_rate (float): Attention dropout rate. Default: 0. drop_path_rate (float): Stochastic depth rate. Default: 0.2. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. patch_norm (bool): If True, add normalization after patch embedding. Default: True. out_indices (Sequence[int]): Output from which stages. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. dilation (bool): if True, the output size if 16x downsample, ow 32x downsample. """ def __init__(self, pretrain_img_size=224, patch_size=4, in_chans=3, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, out_indices=(0, 1, 2, 3), frozen_stages=-1, dilation=False, use_checkpoint=False): super().__init__() self.pretrain_img_size = pretrain_img_size self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.out_indices = out_indices self.frozen_stages = frozen_stages self.dilation = dilation # if use_checkpoint: # print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!") # split image into non-overlapping patches self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) # absolute position embedding if self.ape: pretrain_img_size = to_2tuple(pretrain_img_size) patch_size = to_2tuple(patch_size) patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) trunc_normal_(self.absolute_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() # prepare downsample list downsamplelist = [PatchMerging for i in range(self.num_layers)] downsamplelist[-1] = None num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] if self.dilation: downsamplelist[-2] = None num_features[-1] = int(embed_dim * 2 ** (self.num_layers - 1)) // 2 for i_layer in range(self.num_layers): layer = BasicLayer( # dim=int(embed_dim * 2 ** i_layer), dim=num_features[i_layer], depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, # downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, downsample=downsamplelist[i_layer], use_checkpoint=use_checkpoint) self.layers.append(layer) # num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] self.num_features = num_features # add a norm layer for each output for i_layer in out_indices: layer = norm_layer(num_features[i_layer]) layer_name = f'norm{i_layer}' self.add_module(layer_name, layer) self._freeze_stages() def _freeze_stages(self): if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False if self.frozen_stages >= 1 and self.ape: self.absolute_pos_embed.requires_grad = False if self.frozen_stages >= 2: self.pos_drop.eval() for i in range(0, self.frozen_stages - 1): m = self.layers[i] m.eval() for param in m.parameters(): param.requires_grad = False def forward_raw(self, x): """Forward function.""" x = self.patch_embed(x) Wh, Ww = x.size(2), x.size(3) if self.ape: # interpolate the position embedding to the corresponding size absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C else: x = x.flatten(2).transpose(1, 2) x = self.pos_drop(x) outs = [] for i in range(self.num_layers): layer = self.layers[i] x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) # import ipdb; ipdb.set_trace() if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') x_out = norm_layer(x_out) out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() outs.append(out) # in: # torch.Size([2, 3, 1024, 1024]) # outs: # [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \ # torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])] return tuple(outs) def forward(self, tensor_list: NestedTensor): x = tensor_list.tensors """Forward function.""" x = self.patch_embed(x) Wh, Ww = x.size(2), x.size(3) if self.ape: # interpolate the position embedding to the corresponding size absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C else: x = x.flatten(2).transpose(1, 2) x = self.pos_drop(x) outs = [] for i in range(self.num_layers): layer = self.layers[i] x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') x_out = norm_layer(x_out) out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() outs.append(out) # in: # torch.Size([2, 3, 1024, 1024]) # out: # [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \ # torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])] # collect for nesttensors outs_dict = {} for idx, out_i in enumerate(outs): m = tensor_list.mask assert m is not None mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[0] outs_dict[idx] = NestedTensor(out_i, mask) return outs_dict def train(self, mode=True): """Convert the model into training mode while keep layers freezed.""" super(SwinTransformer, self).train(mode) self._freeze_stages() def build_swin_transformer(modelname, pretrain_img_size, **kw): assert modelname in ['swin_T_224_1k', 'swin_B_224_22k', 'swin_B_384_22k', 'swin_L_224_22k', 'swin_L_384_22k'] model_para_dict = { 'swin_T_224_1k': dict( embed_dim=96, depths=[ 2, 2, 6, 2 ], num_heads=[ 3, 6, 12, 24], window_size=7 ), 'swin_B_224_22k': dict( embed_dim=128, depths=[ 2, 2, 18, 2 ], num_heads=[ 4, 8, 16, 32 ], window_size=7 ), 'swin_B_384_22k': dict( embed_dim=128, depths=[ 2, 2, 18, 2 ], num_heads=[ 4, 8, 16, 32 ], window_size=12 ), 'swin_L_224_22k': dict( embed_dim=192, depths=[ 2, 2, 18, 2 ], num_heads=[ 6, 12, 24, 48 ], window_size=7 ), 'swin_L_384_22k': dict( embed_dim=192, depths=[ 2, 2, 18, 2 ], num_heads=[ 6, 12, 24, 48 ], window_size=12 ), } kw_cgf = model_para_dict[modelname] kw_cgf.update(kw) model = SwinTransformer(pretrain_img_size=pretrain_img_size, **kw_cgf) return model if __name__ == "__main__": model = build_swin_transformer('swin_L_384_22k', 384, dilation=True) x = torch.rand(2, 3, 1024, 1024) y = model.forward_raw(x) import ipdb; ipdb.set_trace() x = torch.rand(2, 3, 384, 384) y = model.forward_raw(x) ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/transformer_deformable.py ================================================ # ------------------------------------------------------------------------ # ED-Pose # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from DETR (https://github.com/facebookresearch/detr) # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # ------------------------------------------------------------------------ import copy import math import torch from torch import nn, Tensor from torch.nn.init import xavier_uniform_, constant_, normal_ from typing import Optional from util.misc import inverse_sigmoid from .ops.modules import MSDeformAttn from .utils import MLP, _get_activation_fn, gen_sineembed_for_position class DeformableTransformer(nn.Module): def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1, activation="relu", return_intermediate_dec=False, num_feature_levels=4, dec_n_points=4, enc_n_points=4, two_stage=False, two_stage_num_proposals=300, use_dab=False, high_dim_query_update=False, no_sine_embed=False): super().__init__() self.d_model = d_model self.nhead = nhead self.two_stage = two_stage self.two_stage_num_proposals = two_stage_num_proposals self.use_dab = use_dab encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points) self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers) decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, dec_n_points) self.decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec, use_dab=use_dab, d_model=d_model, high_dim_query_update=high_dim_query_update, no_sine_embed=no_sine_embed) self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) if two_stage: self.enc_output = nn.Linear(d_model, d_model) self.enc_output_norm = nn.LayerNorm(d_model) self.pos_trans = nn.Linear(d_model * 2, d_model * 2) self.pos_trans_norm = nn.LayerNorm(d_model * 2) else: if not self.use_dab: self.reference_points = nn.Linear(d_model, 2) self.high_dim_query_update = high_dim_query_update if high_dim_query_update: assert not self.use_dab, "use_dab must be True" self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MSDeformAttn): m._reset_parameters() if not self.two_stage and not self.use_dab: xavier_uniform_(self.reference_points.weight.data, gain=1.0) constant_(self.reference_points.bias.data, 0.) normal_(self.level_embed) def get_proposal_pos_embed(self, proposals): num_pos_feats = 128 temperature = 10000 scale = 2 * math.pi dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device) dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) # N, L, 4 proposals = proposals.sigmoid() * scale # N, L, 4, 128 pos = proposals[:, :, :, None] / dim_t # N, L, 4, 64, 2 pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) return pos def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): N_, S_, C_ = memory.shape base_scale = 4.0 proposals = [] _cur = 0 for lvl, (H_, W_) in enumerate(spatial_shapes): mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1) valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device)) grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl) proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) proposals.append(proposal) _cur += (H_ * W_) output_proposals = torch.cat(proposals, 1) output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) output_proposals = torch.log(output_proposals / (1 - output_proposals)) output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf')) output_memory = memory output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) output_memory = self.enc_output_norm(self.enc_output(output_memory)) return output_memory, output_proposals def get_valid_ratio(self, mask): _, H, W = mask.shape valid_H = torch.sum(~mask[:, :, 0], 1) valid_W = torch.sum(~mask[:, 0, :], 1) valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) return valid_ratio def forward(self, srcs, masks, pos_embeds, query_embed=None): """ Input: - srcs: List([bs, c, h, w]) - masks: List([bs, h, w]) """ assert self.two_stage or query_embed is not None # prepare input for encoder src_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): bs, c, h, w = src.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) src = src.flatten(2).transpose(1, 2) # bs, hw, c mask = mask.flatten(1) # bs, hw pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) lvl_pos_embed_flatten.append(lvl_pos_embed) src_flatten.append(src) mask_flatten.append(mask) src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) # encoder memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten) # import ipdb; ipdb.set_trace() # prepare input for decoder bs, _, c = memory.shape if self.two_stage: output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes) # hack implementation for two-stage Deformable DETR enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory) enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals topk = self.two_stage_num_proposals topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1] topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) topk_coords_unact = topk_coords_unact.detach() reference_points = topk_coords_unact.sigmoid() init_reference_out = reference_points pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))) query_embed, tgt = torch.split(pos_trans_out, c, dim=2) elif self.use_dab: reference_points = query_embed[..., self.d_model:].sigmoid() tgt = query_embed[..., :self.d_model] tgt = tgt.unsqueeze(0).expand(bs, -1, -1) init_reference_out = reference_points else: query_embed, tgt = torch.split(query_embed, c, dim=1) query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1) tgt = tgt.unsqueeze(0).expand(bs, -1, -1) reference_points = self.reference_points(query_embed).sigmoid() # bs, num_quires, 2 init_reference_out = reference_points # decoder # import ipdb; ipdb.set_trace() hs, inter_references = self.decoder(tgt, reference_points, memory, spatial_shapes, level_start_index, valid_ratios, query_pos=query_embed if not self.use_dab else None, src_padding_mask=mask_flatten) inter_references_out = inter_references if self.two_stage: return hs, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact return hs, init_reference_out, inter_references_out, None, None class DeformableTransformerEncoderLayer(nn.Module): def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_levels=4, n_heads=8, n_points=4, add_channel_attention=False, use_deformable_box_attn=False, box_attn_type='roi_align', ): super().__init__() # self attention if use_deformable_box_attn: self.self_attn = MSDeformableBoxAttention(d_model, n_levels, n_heads, n_boxes=n_points, used_func=box_attn_type) else: self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) self.dropout1 = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_model) # ffn self.linear1 = nn.Linear(d_model, d_ffn) self.activation = _get_activation_fn(activation, d_model=d_ffn) self.dropout2 = nn.Dropout(dropout) self.linear2 = nn.Linear(d_ffn, d_model) self.dropout3 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(d_model) # channel attention self.add_channel_attention = add_channel_attention if add_channel_attention: self.activ_channel = _get_activation_fn('dyrelu', d_model=d_model) self.norm_channel = nn.LayerNorm(d_model) @staticmethod def with_pos_embed(tensor, pos): return tensor if pos is None else tensor + pos def forward_ffn(self, src): src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) src = src + self.dropout3(src2) src = self.norm2(src) return src def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask=None): # self attention # import ipdb; ipdb.set_trace() src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, key_padding_mask) src = src + self.dropout1(src2) src = self.norm1(src) # ffn src = self.forward_ffn(src) # channel attn if self.add_channel_attention: src = self.norm_channel(src + self.activ_channel(src)) return src class DeformableTransformerEncoder(nn.Module): def __init__(self, encoder_layer, num_layers, norm=None): super().__init__() if num_layers > 0: self.layers = _get_clones(encoder_layer, num_layers) else: self.layers = [] del encoder_layer self.num_layers = num_layers self.norm = norm @staticmethod def get_reference_points(spatial_shapes, valid_ratios, device): reference_points_list = [] for lvl, (H_, W_) in enumerate(spatial_shapes): ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) ref = torch.stack((ref_x, ref_y), -1) reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) reference_points = reference_points[:, :, None] * valid_ratios[:, None] return reference_points def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None): """ Input: - src: [bs, sum(hi*wi), 256] - spatial_shapes: h,w of each level [num_level, 2] - level_start_index: [num_level] start point of level in sum(hi*wi). - valid_ratios: [bs, num_level, 2] - pos: pos embed for src. [bs, sum(hi*wi), 256] - padding_mask: [bs, sum(hi*wi)] Intermedia: - reference_points: [bs, sum(hi*wi), num_lebel, 2] """ output = src # bs, sum(hi*wi), 256 # import ipdb; ipdb.set_trace() if self.num_layers > 0: reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) for _, layer in enumerate(self.layers): output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) if self.norm is not None: output = self.norm(output) return output class DeformableTransformerDecoderLayer(nn.Module): def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_levels=4, n_heads=8, n_points=4, use_deformable_box_attn=False, box_attn_type='roi_align', key_aware_type=None, decoder_sa_type='ca', module_seq=['sa', 'ca', 'ffn'], ): super().__init__() self.module_seq = module_seq assert sorted(module_seq) == ['ca', 'ffn', 'sa'] # cross attention # self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) if use_deformable_box_attn: self.cross_attn = MSDeformableBoxAttention(d_model, n_levels, n_heads, n_boxes=n_points, used_func=box_attn_type) else: self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) self.dropout1 = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_model) # self attention self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) self.dropout2 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(d_model) # ffn self.linear1 = nn.Linear(d_model, d_ffn) self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1) self.dropout3 = nn.Dropout(dropout) self.linear2 = nn.Linear(d_ffn, d_model) self.dropout4 = nn.Dropout(dropout) self.norm3 = nn.LayerNorm(d_model) self.key_aware_type = key_aware_type self.key_aware_proj = None self.decoder_sa_type = decoder_sa_type assert decoder_sa_type in ['sa', 'ca_label', 'ca_content'] if decoder_sa_type == 'ca_content': self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) def rm_self_attn_modules(self): self.self_attn = None self.dropout2 = None self.norm2 = None @staticmethod def with_pos_embed(tensor, pos): return tensor if pos is None else tensor + pos def forward_ffn(self, tgt): tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout4(tgt2) tgt = self.norm3(tgt) return tgt def forward_sa(self, # for tgt tgt: Optional[Tensor], # nq, bs, d_model tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos)) tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos) tgt_key_padding_mask: Optional[Tensor] = None, tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4 # for memory memory: Optional[Tensor] = None, # hw, bs, d_model memory_key_padding_mask: Optional[Tensor] = None, memory_level_start_index: Optional[Tensor] = None, # num_levels memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 memory_pos: Optional[Tensor] = None, # pos for memory # sa self_attn_mask: Optional[Tensor] = None, # mask used for self-attention cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention ): # self attention if self.self_attn is not None: # import ipdb; ipdb.set_trace() if self.decoder_sa_type == 'sa': q = k = self.with_pos_embed(tgt, tgt_query_pos) tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) elif self.decoder_sa_type == 'ca_label': # import ipdb; ipdb.set_trace() # q = self.with_pos_embed(tgt, tgt_query_pos) bs = tgt.shape[1] k = v = self.label_embedding.weight[:, None, :].repeat(1, bs, 1) tgt2 = self.self_attn(tgt, k, v, attn_mask=self_attn_mask)[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) elif self.decoder_sa_type == 'ca_content': tgt2 = self.self_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1), tgt_reference_points.transpose(0, 1).contiguous(), memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index, memory_key_padding_mask).transpose(0, 1) tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) else: raise NotImplementedError("Unknown decoder_sa_type {}".format(self.decoder_sa_type)) return tgt def forward_ca(self, # for tgt tgt: Optional[Tensor], # nq, bs, d_model tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos)) tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos) tgt_key_padding_mask: Optional[Tensor] = None, tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4 # for memory memory: Optional[Tensor] = None, # hw, bs, d_model memory_key_padding_mask: Optional[Tensor] = None, memory_level_start_index: Optional[Tensor] = None, # num_levels memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 memory_pos: Optional[Tensor] = None, # pos for memory # sa self_attn_mask: Optional[Tensor] = None, # mask used for self-attention cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention ): # cross attention # import ipdb; ipdb.set_trace() if self.key_aware_type is not None: if self.key_aware_type == 'mean': tgt = tgt + memory.mean(0, keepdim=True) elif self.key_aware_type == 'proj_mean': tgt = tgt + self.key_aware_proj(memory).mean(0, keepdim=True) else: raise NotImplementedError("Unknown key_aware_type: {}".format(self.key_aware_type)) tgt2 = self.cross_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1), tgt_reference_points.transpose(0, 1).contiguous(), memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index, memory_key_padding_mask).transpose(0, 1) tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) return tgt def forward(self, # for tgt tgt: Optional[Tensor], # nq, bs, d_model tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos)) tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos) tgt_key_padding_mask: Optional[Tensor] = None, tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4 # for memory memory: Optional[Tensor] = None, # hw, bs, d_model memory_key_padding_mask: Optional[Tensor] = None, memory_level_start_index: Optional[Tensor] = None, # num_levels memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 memory_pos: Optional[Tensor] = None, # pos for memory # sa self_attn_mask: Optional[Tensor] = None, # mask used for self-attention cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention ): for funcname in self.module_seq: # if os.environ.get('IPDB_DEBUG_SHILONG') == 'INFO': # import ipdb; ipdb.set_trace() if funcname == 'ffn': tgt = self.forward_ffn(tgt) elif funcname == 'ca': tgt = self.forward_ca(tgt, tgt_query_pos, tgt_query_sine_embed, \ tgt_key_padding_mask, tgt_reference_points, \ memory, memory_key_padding_mask, memory_level_start_index, \ memory_spatial_shapes, memory_pos, self_attn_mask, cross_attn_mask) elif funcname == 'sa': tgt = self.forward_sa(tgt, tgt_query_pos, tgt_query_sine_embed, \ tgt_key_padding_mask, tgt_reference_points, \ memory, memory_key_padding_mask, memory_level_start_index, \ memory_spatial_shapes, memory_pos, self_attn_mask, cross_attn_mask) else: raise ValueError('unknown funcname {}'.format(funcname)) return tgt class DeformableTransformerDecoder(nn.Module): def __init__(self, decoder_layer, num_layers, return_intermediate=False, use_dab=False, d_model=256, query_dim=4): super().__init__() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers self.return_intermediate = return_intermediate assert return_intermediate # hack implementation for iterative bounding box refinement and two-stage Deformable DETR self.bbox_embed = None self.class_embed = None self.use_dab = use_dab self.d_model = d_model self.query_dim = query_dim if use_dab: self.query_scale = MLP(d_model, d_model, d_model, 2) self.ref_point_head = MLP(2 * d_model, d_model, d_model, 2) def forward(self, tgt, reference_points, src, src_spatial_shapes, src_level_start_index, src_valid_ratios, query_pos=None, src_padding_mask=None): output = tgt if self.use_dab: assert query_pos is None intermediate = [] intermediate_reference_points = [reference_points] for layer_id, layer in enumerate(self.layers): # import ipdb; ipdb.set_trace() if reference_points.shape[-1] == 4: reference_points_input = reference_points[:, :, None] \ * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None] # bs, nq, 4, 4 else: assert reference_points.shape[-1] == 2 reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None] if self.use_dab: # import ipdb; ipdb.set_trace() query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :]) # bs, nq, 256*2 raw_query_pos = self.ref_point_head(query_sine_embed) # bs, nq, 256 pos_scale = self.query_scale(output) if layer_id != 0 else 1 query_pos = pos_scale * raw_query_pos output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask) # hack implementation for iterative bounding box refinement if self.bbox_embed is not None: box_holder = self.bbox_embed(output) box_holder[..., :self.query_dim] += inverse_sigmoid(reference_points) new_reference_points = box_holder[..., :self.query_dim].sigmoid() reference_points = new_reference_points.detach() if layer_id != self.num_layers - 1: intermediate_reference_points.append(new_reference_points) intermediate.append(output) return torch.stack(intermediate), torch.stack(intermediate_reference_points) def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def build_deforamble_transformer(args): return DeformableTransformer( d_model=args.hidden_dim, nhead=args.nheads, num_encoder_layers=args.enc_layers, num_decoder_layers=args.dec_layers, dim_feedforward=args.dim_feedforward, dropout=args.dropout, activation="relu", return_intermediate_dec=True, num_feature_levels=args.ddetr_num_feature_levels, dec_n_points=args.ddetr_dec_n_points, enc_n_points=args.ddetr_enc_n_points, two_stage=args.ddetr_two_stage, two_stage_num_proposals=args.num_queries, use_dab=args.ddetr_use_dab, high_dim_query_update=args.ddetr_high_dim_query_update, no_sine_embed=args.ddetr_no_sine_embed) ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/transformer_vanilla.py ================================================ # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ DETR Transformer class. Copy-paste from torch.nn.Transformer with modifications: * positional encodings are passed in MHattention * extra LN at the end of encoder is removed * decoder returns a stack of activations from all decoding layers """ import torch from torch import Tensor, nn from typing import List, Optional from .utils import _get_activation_fn, _get_clones class TextTransformer(nn.Module): def __init__(self, num_layers, d_model=256, nheads=8, dim_feedforward=2048, dropout=0.1): super().__init__() self.num_layers = num_layers self.d_model = d_model self.nheads = nheads self.dim_feedforward = dim_feedforward self.norm = None single_encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nheads, dim_feedforward=dim_feedforward, dropout=dropout) self.layers = _get_clones(single_encoder_layer, num_layers) def forward(self, memory_text:torch.Tensor, text_attention_mask:torch.Tensor): """ Args: text_attention_mask: bs, num_token memory_text: bs, num_token, d_model Raises: RuntimeError: _description_ Returns: output: bs, num_token, d_model """ output = memory_text.transpose(0, 1) for layer in self.layers: output = layer(output, src_key_padding_mask=text_attention_mask) if self.norm is not None: output = self.norm(output) return output.transpose(0, 1) class TransformerEncoderLayer(nn.Module): def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self.nhead = nhead def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): # repeat attn mask if src_mask.dim() == 3 and src_mask.shape[0] == src.shape[1]: # bs, num_q, num_k src_mask = src_mask.repeat(self.nhead, 1, 1) q = k = self.with_pos_embed(src, pos) src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)[0] # src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] src = src + self.dropout1(src2) src = self.norm1(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) src = src + self.dropout2(src2) src = self.norm2(src) return src ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/unipose.py ================================================ # ------------------------------------------------------------------------ # ED-Pose # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) # Copyright (c) 2020 SenseTime. All Rights Reserved. # ------------------------------------------------------------------------ import os import copy import torch import torch.nn.functional as F from torch import nn from typing import List from util.keypoint_ops import keypoint_xyzxyz_to_xyxyzz from util.misc import NestedTensor, nested_tensor_from_tensor_list,inverse_sigmoid from .utils import MLP from .backbone import build_backbone from ..registry import MODULE_BUILD_FUNCS from .mask_generate import prepare_for_mask, post_process from .deformable_transformer import build_deformable_transformer class UniPose(nn.Module): """ This is the Cross-Attention Detector module that performs object detection """ def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False, iter_update=False, query_dim=2, random_refpoints_xy=False, fix_refpoints_hw=-1, num_feature_levels=1, nheads=8, # two stage two_stage_type='no', # ['no', 'standard'] two_stage_add_query_num=0, dec_pred_class_embed_share=True, dec_pred_bbox_embed_share=True, two_stage_class_embed_share=True, two_stage_bbox_embed_share=True, decoder_sa_type='sa', num_patterns=0, dn_number=100, dn_box_noise_scale=0.4, dn_label_noise_ratio=0.5, dn_labelbook_size=100, use_label_enc=True, text_encoder_type='bert-base-uncased', binary_query_selection=False, use_cdn=True, sub_sentence_present=True, num_body_points=68, num_box_decoder_layers=2, ): """ Initializes the model. Parameters: backbone: torch module of the backbone to be used. See backbone.py transformer: torch module of the transformer architecture. See transformer.py num_classes: number of object classes num_queries: number of object queries, ie detection slot. This is the maximal number of objects Conditional DETR can detect in a single image. For COCO, we recommend 100 queries. aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. fix_refpoints_hw: -1(default): learn w and h for each box seperately >0 : given fixed number -2 : learn a shared w and h """ super().__init__() self.num_queries = num_queries self.transformer = transformer self.num_classes = num_classes self.hidden_dim = hidden_dim = transformer.d_model self.num_feature_levels = num_feature_levels self.nheads = nheads self.use_label_enc = use_label_enc if use_label_enc: self.label_enc = nn.Embedding(dn_labelbook_size + 1, hidden_dim) else: raise NotImplementedError self.label_enc = None self.max_text_len = 256 self.binary_query_selection = binary_query_selection self.sub_sentence_present = sub_sentence_present # setting query dim self.query_dim = query_dim assert query_dim == 4 self.random_refpoints_xy = random_refpoints_xy self.fix_refpoints_hw = fix_refpoints_hw # for dn training self.num_patterns = num_patterns self.dn_number = dn_number self.dn_box_noise_scale = dn_box_noise_scale self.dn_label_noise_ratio = dn_label_noise_ratio self.dn_labelbook_size = dn_labelbook_size self.use_cdn = use_cdn self.projection = MLP(512, hidden_dim, hidden_dim, 3) self.projection_kpt = MLP(512, hidden_dim, hidden_dim, 3) device = "cuda" if torch.cuda.is_available() else "cpu" # model, _ = clip.load("ViT-B/32", device=device) # self.clip_model = model # visual_parameters = list(self.clip_model.visual.parameters()) # # # for param in visual_parameters: # param.requires_grad = False self.pos_proj = nn.Linear(hidden_dim, 768) self.padding = nn.Embedding(1, 768) # prepare input projection layers if num_feature_levels > 1: num_backbone_outs = len(backbone.num_channels) input_proj_list = [] for _ in range(num_backbone_outs): in_channels = backbone.num_channels[_] input_proj_list.append(nn.Sequential( nn.Conv2d(in_channels, hidden_dim, kernel_size=1), nn.GroupNorm(32, hidden_dim), )) for _ in range(num_feature_levels - num_backbone_outs): input_proj_list.append(nn.Sequential( nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), nn.GroupNorm(32, hidden_dim), )) in_channels = hidden_dim self.input_proj = nn.ModuleList(input_proj_list) else: assert two_stage_type == 'no', "two_stage_type should be no if num_feature_levels=1 !!!" self.input_proj = nn.ModuleList([ nn.Sequential( nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1), nn.GroupNorm(32, hidden_dim), )]) self.backbone = backbone self.aux_loss = aux_loss self.box_pred_damping = box_pred_damping = None self.iter_update = iter_update assert iter_update, "Why not iter_update?" # prepare pred layers self.dec_pred_class_embed_share = dec_pred_class_embed_share self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share # prepare class & box embed _class_embed = ContrastiveAssign() _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) _pose_embed = MLP(hidden_dim, hidden_dim, 2, 3) _pose_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) nn.init.constant_(_pose_embed.layers[-1].weight.data, 0) nn.init.constant_(_pose_embed.layers[-1].bias.data, 0) if dec_pred_bbox_embed_share: box_embed_layerlist = [_bbox_embed for i in range(transformer.num_decoder_layers)] else: box_embed_layerlist = [copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers)] if dec_pred_class_embed_share: class_embed_layerlist = [_class_embed for i in range(transformer.num_decoder_layers)] else: class_embed_layerlist = [copy.deepcopy(_class_embed) for i in range(transformer.num_decoder_layers)] if dec_pred_bbox_embed_share: pose_embed_layerlist = [_pose_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers + 1)] else: pose_embed_layerlist = [copy.deepcopy(_pose_embed) for i in range(transformer.num_decoder_layers - num_box_decoder_layers + 1)] pose_hw_embed_layerlist = [_pose_hw_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers)] self.num_box_decoder_layers = num_box_decoder_layers self.bbox_embed = nn.ModuleList(box_embed_layerlist) self.class_embed = nn.ModuleList(class_embed_layerlist) self.num_body_points = num_body_points self.pose_embed = nn.ModuleList(pose_embed_layerlist) self.pose_hw_embed = nn.ModuleList(pose_hw_embed_layerlist) self.transformer.decoder.bbox_embed = self.bbox_embed self.transformer.decoder.class_embed = self.class_embed self.transformer.decoder.pose_embed = self.pose_embed self.transformer.decoder.pose_hw_embed = self.pose_hw_embed self.transformer.decoder.num_body_points = num_body_points # two stage self.two_stage_type = two_stage_type self.two_stage_add_query_num = two_stage_add_query_num assert two_stage_type in ['no', 'standard'], "unknown param {} of two_stage_type".format(two_stage_type) if two_stage_type != 'no': if two_stage_bbox_embed_share: assert dec_pred_class_embed_share and dec_pred_bbox_embed_share self.transformer.enc_out_bbox_embed = _bbox_embed else: self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed) if two_stage_class_embed_share: assert dec_pred_class_embed_share and dec_pred_bbox_embed_share self.transformer.enc_out_class_embed = _class_embed else: self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed) self.refpoint_embed = None if self.two_stage_add_query_num > 0: self.init_ref_points(two_stage_add_query_num) self.decoder_sa_type = decoder_sa_type assert decoder_sa_type in ['sa', 'ca_label', 'ca_content'] # self.replace_sa_with_double_ca = replace_sa_with_double_ca if decoder_sa_type == 'ca_label': self.label_embedding = nn.Embedding(num_classes, hidden_dim) for layer in self.transformer.decoder.layers: layer.label_embedding = self.label_embedding else: for layer in self.transformer.decoder.layers: layer.label_embedding = None self.label_embedding = None self._reset_parameters() def open_set_transfer_init(self): for name, param in self.named_parameters(): if 'fusion_layers' in name: continue if 'ca_text' in name: continue if 'catext_norm' in name: continue if 'catext_dropout' in name: continue if "text_layers" in name: continue if 'bert' in name: continue if 'bbox_embed' in name: continue if 'label_enc.weight' in name: continue if 'feat_map' in name: continue if 'enc_output' in name: continue param.requires_grad_(False) # import ipdb; ipdb.set_trace() def _reset_parameters(self): # init input_proj for proj in self.input_proj: nn.init.xavier_uniform_(proj[0].weight, gain=1) nn.init.constant_(proj[0].bias, 0) def init_ref_points(self, use_num_queries): self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim) if self.random_refpoints_xy: # import ipdb; ipdb.set_trace() self.refpoint_embed.weight.data[:, :2].uniform_(0, 1) self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(self.refpoint_embed.weight.data[:, :2]) self.refpoint_embed.weight.data[:, :2].requires_grad = False if self.fix_refpoints_hw > 0: print("fix_refpoints_hw: {}".format(self.fix_refpoints_hw)) assert self.random_refpoints_xy self.refpoint_embed.weight.data[:, 2:] = self.fix_refpoints_hw self.refpoint_embed.weight.data[:, 2:] = inverse_sigmoid(self.refpoint_embed.weight.data[:, 2:]) self.refpoint_embed.weight.data[:, 2:].requires_grad = False elif int(self.fix_refpoints_hw) == -1: pass elif int(self.fix_refpoints_hw) == -2: print('learn a shared h and w') assert self.random_refpoints_xy self.refpoint_embed = nn.Embedding(use_num_queries, 2) self.refpoint_embed.weight.data[:, :2].uniform_(0, 1) self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(self.refpoint_embed.weight.data[:, :2]) self.refpoint_embed.weight.data[:, :2].requires_grad = False self.hw_embed = nn.Embedding(1, 1) else: raise NotImplementedError('Unknown fix_refpoints_hw {}'.format(self.fix_refpoints_hw)) def forward(self, samples: NestedTensor, targets: List = None, **kw): """ The forward expects a NestedTensor, which consists of: - samples.tensor: batched images, of shape [batch_size x 3 x H x W] - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels It returns a dict with the following elements: - "pred_logits": the classification logits (including no-object) for all queries. Shape= [batch_size x num_queries x num_classes] - "pred_boxes": The normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These values are normalized in [0, 1], relative to the size of each individual image (disregarding possible padding). See PostProcess for information on how to retrieve the unnormalized bounding box. - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of dictionnaries containing the two above keys for each decoder layer. """ captions = [t['instance_text_prompt'] for t in targets] bs=len(captions) tensor_list = [tgt["object_embeddings_text"] for tgt in targets] max_size = 350 padded_tensors = [torch.cat([tensor, torch.zeros(max_size - tensor.size(0), tensor.size(1),device=tensor.device)]) if tensor.size(0) < max_size else tensor for tensor in tensor_list] object_embeddings_text = torch.stack(padded_tensors) kpts_embeddings_text = torch.stack([tgt["kpts_embeddings_text"] for tgt in targets])[:, :self.num_body_points] encoded_text=self.projection(object_embeddings_text) # bs, 81, 101, 256 kpt_embeddings_specific=self.projection_kpt(kpts_embeddings_text) # bs, 81, 101, 256 kpt_vis = torch.stack([tgt["kpt_vis_text"] for tgt in targets])[:, :self.num_body_points] kpt_mask = torch.cat((torch.ones_like(kpt_vis, device=kpt_vis.device)[..., 0].unsqueeze(-1), kpt_vis), dim=-1) num_classes = encoded_text.shape[1] # bs, 81, 101, 256 text_self_attention_masks = torch.eye(num_classes).unsqueeze(0).expand(bs, -1, -1).bool().to(samples.device) text_token_mask = torch.zeros(samples.shape[0],num_classes).to(samples.device)>0 for i in range(bs): text_token_mask[i,:len(captions[i])]=True position_ids = torch.zeros(samples.shape[0], num_classes).to(samples.device) for i in range(bs): position_ids[i,:len(captions[i])]= 1 text_dict = { 'encoded_text': encoded_text, # bs, 195, d_model 'text_token_mask': text_token_mask, # bs, 195 'position_ids': position_ids, # bs, 195 'text_self_attention_masks': text_self_attention_masks # bs, 195,195 } # import ipdb; ipdb.set_trace() if isinstance(samples, (list, torch.Tensor)): samples = nested_tensor_from_tensor_list(samples) features, poss = self.backbone(samples) if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1': import ipdb; ipdb.set_trace() srcs = [] masks = [] for l, feat in enumerate(features): src, mask = feat.decompose() srcs.append(self.input_proj[l](src)) masks.append(mask) assert mask is not None if self.num_feature_levels > len(srcs): _len_srcs = len(srcs) for l in range(_len_srcs, self.num_feature_levels): if l == _len_srcs: src = self.input_proj[l](features[-1].tensors) else: src = self.input_proj[l](srcs[-1]) m = samples.mask mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) srcs.append(src) masks.append(mask) poss.append(pos_l) if self.label_enc is not None: label_enc = self.label_enc else: raise NotImplementedError label_enc = encoded_text if self.dn_number > 0 or targets is not None: input_query_label, input_query_bbox, attn_mask, attn_mask2, dn_meta = \ prepare_for_mask(kpt_mask=kpt_mask) else: assert targets is None input_query_bbox = input_query_label = attn_mask = attn_mask2 = dn_meta = None hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, attn_mask2, text_dict, dn_meta,targets,kpt_embeddings_specific) # In case num object=0 if self.label_enc is not None: hs[0] += self.label_enc.weight[0, 0] * 0.0 hs[0] += self.pos_proj.weight[0, 0] * 0.0 hs[0] += self.pos_proj.bias[0] * 0.0 hs[0] += self.padding.weight[0, 0] * 0.0 num_group = 50 effective_dn_number = dn_meta['pad_size'] if self.training else 0 outputs_coord_list = [] outputs_class = [] for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_cls_embed, layer_hs) in enumerate( zip(reference[:-1], self.bbox_embed, self.class_embed, hs)): if dec_lid < self.num_box_decoder_layers: layer_delta_unsig = layer_bbox_embed(layer_hs) layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig) layer_outputs_unsig = layer_outputs_unsig.sigmoid() layer_cls = layer_cls_embed(layer_hs, text_dict) outputs_coord_list.append(layer_outputs_unsig) outputs_class.append(layer_cls) else: layer_hs_bbox_dn = layer_hs[:, :effective_dn_number, :] layer_hs_bbox_norm = layer_hs[:, effective_dn_number:, :][:, 0::(self.num_body_points + 1), :] bs = layer_ref_sig.shape[0] reference_before_sigmoid_bbox_dn = layer_ref_sig[:, :effective_dn_number, :] reference_before_sigmoid_bbox_norm = layer_ref_sig[:, effective_dn_number:, :][:, 0::(self.num_body_points + 1), :] layer_delta_unsig_dn = layer_bbox_embed(layer_hs_bbox_dn) layer_delta_unsig_norm = layer_bbox_embed(layer_hs_bbox_norm) layer_outputs_unsig_dn = layer_delta_unsig_dn + inverse_sigmoid(reference_before_sigmoid_bbox_dn) layer_outputs_unsig_dn = layer_outputs_unsig_dn.sigmoid() layer_outputs_unsig_norm = layer_delta_unsig_norm + inverse_sigmoid(reference_before_sigmoid_bbox_norm) layer_outputs_unsig_norm = layer_outputs_unsig_norm.sigmoid() layer_outputs_unsig = torch.cat((layer_outputs_unsig_dn, layer_outputs_unsig_norm), dim=1) layer_cls_dn = layer_cls_embed(layer_hs_bbox_dn, text_dict) layer_cls_norm = layer_cls_embed(layer_hs_bbox_norm, text_dict) layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1) outputs_class.append(layer_cls) outputs_coord_list.append(layer_outputs_unsig) # update keypoints outputs_keypoints_list = [] outputs_keypoints_hw = [] kpt_index = [x for x in range(num_group * (self.num_body_points + 1)) if x % (self.num_body_points + 1) != 0] for dec_lid, (layer_ref_sig, layer_hs) in enumerate(zip(reference[:-1], hs)): if dec_lid < self.num_box_decoder_layers: assert isinstance(layer_hs, torch.Tensor) bs = layer_hs.shape[0] layer_res = layer_hs.new_zeros((bs, self.num_queries, self.num_body_points * 3)) outputs_keypoints_list.append(layer_res) else: bs = layer_ref_sig.shape[0] layer_hs_kpt = layer_hs[:, effective_dn_number:, :].index_select(1, torch.tensor(kpt_index, device=layer_hs.device)) delta_xy_unsig = self.pose_embed[dec_lid - self.num_box_decoder_layers](layer_hs_kpt) layer_ref_sig_kpt = layer_ref_sig[:, effective_dn_number:, :].index_select(1, torch.tensor(kpt_index, device=layer_hs.device)) layer_outputs_unsig_keypoints = delta_xy_unsig + inverse_sigmoid(layer_ref_sig_kpt[..., :2]) vis_xy_unsig = torch.ones_like(layer_outputs_unsig_keypoints, device=layer_outputs_unsig_keypoints.device) xyv = torch.cat((layer_outputs_unsig_keypoints, vis_xy_unsig[:, :, 0].unsqueeze(-1)), dim=-1) xyv = xyv.sigmoid() layer_res = xyv.reshape((bs, num_group, self.num_body_points, 3)).flatten(2, 3) layer_hw = layer_ref_sig_kpt[..., 2:].reshape(bs, num_group, self.num_body_points, 2).flatten(2, 3) layer_res = keypoint_xyzxyz_to_xyxyzz(layer_res) outputs_keypoints_list.append(layer_res) outputs_keypoints_hw.append(layer_hw) if self.dn_number > 0 and dn_meta is not None: outputs_class, outputs_coord_list = \ post_process(outputs_class, outputs_coord_list, dn_meta, self.aux_loss, self._set_aux_loss) out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord_list[-1], 'pred_keypoints': outputs_keypoints_list[-1]} return out @MODULE_BUILD_FUNCS.registe_with_name(module_name='UniPose') def build_unipose(args): num_classes = args.num_classes device = torch.device(args.device) backbone = build_backbone(args) transformer = build_deformable_transformer(args) try: match_unstable_error = args.match_unstable_error dn_labelbook_size = args.dn_labelbook_size except: match_unstable_error = True dn_labelbook_size = num_classes try: dec_pred_class_embed_share = args.dec_pred_class_embed_share except: dec_pred_class_embed_share = True try: dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share except: dec_pred_bbox_embed_share = True binary_query_selection = False try: binary_query_selection = args.binary_query_selection except: binary_query_selection = False use_cdn = True try: use_cdn = args.use_cdn except: use_cdn = True sub_sentence_present = True try: sub_sentence_present = args.sub_sentence_present except: sub_sentence_present = True # print('********* sub_sentence_present', sub_sentence_present) model = UniPose( backbone, transformer, num_classes=num_classes, num_queries=args.num_queries, aux_loss=True, iter_update=True, query_dim=4, random_refpoints_xy=args.random_refpoints_xy, fix_refpoints_hw=args.fix_refpoints_hw, num_feature_levels=args.num_feature_levels, nheads=args.nheads, dec_pred_class_embed_share=dec_pred_class_embed_share, dec_pred_bbox_embed_share=dec_pred_bbox_embed_share, # two stage two_stage_type=args.two_stage_type, # box_share two_stage_bbox_embed_share=args.two_stage_bbox_embed_share, two_stage_class_embed_share=args.two_stage_class_embed_share, decoder_sa_type=args.decoder_sa_type, num_patterns=args.num_patterns, dn_number=args.dn_number if args.use_dn else 0, dn_box_noise_scale=args.dn_box_noise_scale, dn_label_noise_ratio=args.dn_label_noise_ratio, dn_labelbook_size=dn_labelbook_size, use_label_enc=args.use_label_enc, text_encoder_type=args.text_encoder_type, binary_query_selection=binary_query_selection, use_cdn=use_cdn, sub_sentence_present=sub_sentence_present ) return model class ContrastiveAssign(nn.Module): def __init__(self, project=False, cal_bias=None, max_text_len=256): """ :param x: query :param y: text embed :param proj: :return: """ super().__init__() self.project = project self.cal_bias = cal_bias self.max_text_len = max_text_len def forward(self, x, text_dict): """_summary_ Args: x (_type_): _description_ text_dict (_type_): _description_ { 'encoded_text': encoded_text, # bs, 195, d_model 'text_token_mask': text_token_mask, # bs, 195 # True for used tokens. False for padding tokens } Returns: _type_: _description_ """ assert isinstance(text_dict, dict) y = text_dict['encoded_text'] max_text_len = y.shape[1] text_token_mask = text_dict['text_token_mask'] if self.cal_bias is not None: raise NotImplementedError return x @ y.transpose(-1, -2) + self.cal_bias.weight.repeat(x.shape[0], x.shape[1], 1) res = x @ y.transpose(-1, -2) res.masked_fill_(~text_token_mask[:, None, :], float('-inf')) # padding to max_text_len new_res = torch.full((*res.shape[:-1], max_text_len), float('-inf'), device=res.device) new_res[..., :res.shape[-1]] = res return new_res ================================================ FILE: src/utils/dependencies/XPose/models/UniPose/utils.py ================================================ # ------------------------------------------------------------------------ # ED-Pose # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ import copy import torch import random from torch import nn, Tensor import os import numpy as np import math import torch.nn.functional as F from torch import nn def _get_clones(module, N, layer_share=False): # import ipdb; ipdb.set_trace() if layer_share: return nn.ModuleList([module for i in range(N)]) else: return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def get_sine_pos_embed( pos_tensor: torch.Tensor, num_pos_feats: int = 128, temperature: int = 10000, exchange_xy: bool = True, ): """generate sine position embedding from a position tensor Args: pos_tensor (torch.Tensor): shape: [..., n]. num_pos_feats (int): projected shape for each float in the tensor. temperature (int): temperature in the sine/cosine function. exchange_xy (bool, optional): exchange pos x and pos y. \ For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. Defaults to True. Returns: pos_embed (torch.Tensor): shape: [..., n*num_pos_feats]. """ scale = 2 * math.pi dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device) dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats) def sine_func(x: torch.Tensor): sin_x = x * scale / dim_t sin_x = torch.stack((sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3).flatten(2) return sin_x pos_res = [sine_func(x) for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1)] if exchange_xy: pos_res[0], pos_res[1] = pos_res[1], pos_res[0] pos_res = torch.cat(pos_res, dim=-1) return pos_res def gen_encoder_output_proposals(memory: Tensor, memory_padding_mask: Tensor, spatial_shapes: Tensor, learnedwh=None): """ Input: - memory: bs, \sum{hw}, d_model - memory_padding_mask: bs, \sum{hw} - spatial_shapes: nlevel, 2 - learnedwh: 2 Output: - output_memory: bs, \sum{hw}, d_model - output_proposals: bs, \sum{hw}, 4 """ N_, S_, C_ = memory.shape base_scale = 4.0 proposals = [] _cur = 0 for lvl, (H_, W_) in enumerate(spatial_shapes): mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1) valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) # import ipdb; ipdb.set_trace() grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device)) grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2 scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale if learnedwh is not None: # import ipdb; ipdb.set_trace() wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0 ** lvl) else: wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl) # scale = torch.cat([W_[None].unsqueeze(-1), H_[None].unsqueeze(-1)], 1).view(1, 1, 1, 2).repeat(N_, 1, 1, 1) # grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale # wh = torch.ones_like(grid) / scale proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) proposals.append(proposal) _cur += (H_ * W_) # import ipdb; ipdb.set_trace() output_proposals = torch.cat(proposals, 1) output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf')) output_memory = memory output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) # output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) # output_memory = output_memory.masked_fill(~output_proposals_valid, float('inf')) return output_memory, output_proposals class RandomBoxPerturber(): def __init__(self, x_noise_scale=0.2, y_noise_scale=0.2, w_noise_scale=0.2, h_noise_scale=0.2) -> None: self.noise_scale = torch.Tensor([x_noise_scale, y_noise_scale, w_noise_scale, h_noise_scale]) def __call__(self, refanchors: Tensor) -> Tensor: nq, bs, query_dim = refanchors.shape device = refanchors.device noise_raw = torch.rand_like(refanchors) noise_scale = self.noise_scale.to(device)[:query_dim] new_refanchors = refanchors * (1 + (noise_raw - 0.5) * noise_scale) return new_refanchors.clamp_(0, 1) def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, no_reduction=False): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). alpha: (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default = -1 (no weighting). gamma: Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Returns: Loss tensor """ prob = inputs.sigmoid() ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") p_t = prob * targets + (1 - prob) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss if no_reduction: return loss return loss.mean(1).sum() / num_boxes class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x def _get_activation_fn(activation, d_model=256, batch_dim=0): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu if activation == "prelu": return nn.PReLU() if activation == "selu": return F.selu raise RuntimeError(F"activation should be relu/gelu, not {activation}.") def gen_sineembed_for_position(pos_tensor): # n_query, bs, _ = pos_tensor.size() # sineembed_tensor = torch.zeros(n_query, bs, 256) scale = 2 * math.pi dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device) dim_t = 10000 ** (2 * (dim_t // 2) / 128) x_embed = pos_tensor[:, :, 0] * scale y_embed = pos_tensor[:, :, 1] * scale pos_x = x_embed[:, :, None] / dim_t pos_y = y_embed[:, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) if pos_tensor.size(-1) == 2: pos = torch.cat((pos_y, pos_x), dim=2) elif pos_tensor.size(-1) == 4: w_embed = pos_tensor[:, :, 2] * scale pos_w = w_embed[:, :, None] / dim_t pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) h_embed = pos_tensor[:, :, 3] * scale pos_h = h_embed[:, :, None] / dim_t pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) else: raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) return pos def oks_overlaps(kpt_preds, kpt_gts, kpt_valids, kpt_areas, sigmas): sigmas = kpt_preds.new_tensor(sigmas) variances = (sigmas * 2) ** 2 assert kpt_preds.size(0) == kpt_gts.size(0) kpt_preds = kpt_preds.reshape(-1, kpt_preds.size(-1) // 2, 2) kpt_gts = kpt_gts.reshape(-1, kpt_gts.size(-1) // 2, 2) squared_distance = (kpt_preds[:, :, 0] - kpt_gts[:, :, 0]) ** 2 + \ (kpt_preds[:, :, 1] - kpt_gts[:, :, 1]) ** 2 # import pdb # pdb.set_trace() # assert (kpt_valids.sum(-1) > 0).all() squared_distance0 = squared_distance / (kpt_areas[:, None] * variances[None, :] * 2) squared_distance1 = torch.exp(-squared_distance0) squared_distance1 = squared_distance1 * kpt_valids oks = squared_distance1.sum(dim=1) / (kpt_valids.sum(dim=1) + 1e-6) return oks def oks_loss(pred, target, valid=None, area=None, linear=False, sigmas=None, eps=1e-6): """Oks loss. Computing the oks loss between a set of predicted poses and target poses. The loss is calculated as negative log of oks. Args: pred (torch.Tensor): Predicted poses of format (x1, y1, x2, y2, ...), shape (n, 2K). target (torch.Tensor): Corresponding gt poses, shape (n, 2K). linear (bool, optional): If True, use linear scale of loss instead of log scale. Default: False. eps (float): Eps to avoid log(0). Return: torch.Tensor: Loss tensor. """ oks = oks_overlaps(pred, target, valid, area, sigmas).clamp(min=eps) if linear: loss = 1 - oks else: loss = -oks.log() return loss class OKSLoss(nn.Module): """IoULoss. Computing the oks loss between a set of predicted poses and target poses. Args: linear (bool): If True, use linear scale of loss instead of log scale. Default: False. eps (float): Eps to avoid log(0). reduction (str): Options are "none", "mean" and "sum". loss_weight (float): Weight of loss. """ def __init__(self, linear=False, num_keypoints=17, eps=1e-6, reduction='mean', loss_weight=1.0): super(OKSLoss, self).__init__() self.linear = linear self.eps = eps self.reduction = reduction self.loss_weight = loss_weight if num_keypoints == 68: self.sigmas = np.array([ .26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, .25, ], dtype=np.float32) / 10.0 else: raise ValueError(f'Unsupported keypoints number {num_keypoints}') def forward(self, pred, target, valid, area, weight=None, avg_factor=None, reduction_override=None): """Forward function. Args: pred (torch.Tensor): The prediction. target (torch.Tensor): The learning target of the prediction. valid (torch.Tensor): The visible flag of the target pose. area (torch.Tensor): The area of the target pose. weight (torch.Tensor, optional): The weight of loss for each prediction. Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None. Options are "none", "mean" and "sum". """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) if (weight is not None) and (not torch.any(weight > 0)) and ( reduction != 'none'): if pred.dim() == weight.dim() + 1: weight = weight.unsqueeze(1) return (pred * weight).sum() # 0 if weight is not None and weight.dim() > 1: # TODO: remove this in the future # reduce the weight of shape (n, 4) to (n,) to match the # iou_loss of shape (n,) assert weight.shape == pred.shape weight = weight.mean(-1) loss = self.loss_weight * oks_loss( pred, target, valid=valid, area=area, linear=self.linear, sigmas=self.sigmas, eps=self.eps) return loss ================================================ FILE: src/utils/dependencies/XPose/models/__init__.py ================================================ # ------------------------------------------------------------------------ # ED-Pose # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from .UniPose.unipose import build_unipose def build_model(args): # we use register to maintain models from catdet6 on. from .registry import MODULE_BUILD_FUNCS assert args.modelname in MODULE_BUILD_FUNCS._module_dict build_func = MODULE_BUILD_FUNCS.get(args.modelname) model = build_func(args) return model ================================================ FILE: src/utils/dependencies/XPose/models/registry.py ================================================ # -*- coding: utf-8 -*- # @Author: Yihao Chen # @Date: 2021-08-16 16:03:17 # @Last Modified by: Shilong Liu # @Last Modified time: 2022-01-23 15:26 # modified from mmcv import inspect from functools import partial class Registry(object): def __init__(self, name): self._name = name self._module_dict = dict() def __repr__(self): format_str = self.__class__.__name__ + '(name={}, items={})'.format( self._name, list(self._module_dict.keys())) return format_str def __len__(self): return len(self._module_dict) @property def name(self): return self._name @property def module_dict(self): return self._module_dict def get(self, key): return self._module_dict.get(key, None) def registe_with_name(self, module_name=None, force=False): return partial(self.register, module_name=module_name, force=force) def register(self, module_build_function, module_name=None, force=False): """Register a module build function. Args: module (:obj:`nn.Module`): Module to be registered. """ if not inspect.isfunction(module_build_function): raise TypeError('module_build_function must be a function, but got {}'.format( type(module_build_function))) if module_name is None: module_name = module_build_function.__name__ if not force and module_name in self._module_dict: raise KeyError('{} is already registered in {}'.format( module_name, self.name)) self._module_dict[module_name] = module_build_function return module_build_function MODULE_BUILD_FUNCS = Registry('model build functions') ================================================ FILE: src/utils/dependencies/XPose/predefined_keypoints.py ================================================ person = {"keypoints":['nose', 'left eye', 'right eye', 'left ear', 'right ear', 'left shoulder', 'right shoulder', 'left elbow', 'right elbow', 'left wrist', 'right wrist', 'left hip', 'right hip', 'left knee', 'right knee', 'left ankle', 'right ankle'],"skeleton": [[16,14],[14,12],[17,15],[15,13],[12,13],[6,12],[7,13],[6,7],[6,8],[7,9],[8,10],[9,11],[2,3],[1,2],[1,3],[2,4],[3,5],[4,6],[5,7]]} face = {"keypoints": ['right cheekbone 1', 'right cheekbone 2', 'right cheek 1', 'right cheek 2', 'right cheek 3', 'right cheek 4', 'right cheek 5', 'right chin', 'chin center', 'left chin', 'left cheek 5', 'left cheek 4', 'left cheek 3', 'left cheek 2', 'left cheek 1', 'left cheekbone 2', 'left cheekbone 1', 'right eyebrow 1', 'right eyebrow 2', 'right eyebrow 3', 'right eyebrow 4', 'right eyebrow 5', 'left eyebrow 1', 'left eyebrow 2', 'left eyebrow 3', 'left eyebrow 4', 'left eyebrow 5', 'nasal bridge 1', 'nasal bridge 2', 'nasal bridge 3', 'nasal bridge 4', 'right nasal wing 1', 'right nasal wing 2', 'nasal wing center', 'left nasal wing 1', 'left nasal wing 2', 'right eye eye corner 1', 'right eye upper eyelid 1', 'right eye upper eyelid 2', 'right eye eye corner 2', 'right eye lower eyelid 2', 'right eye lower eyelid 1', 'left eye eye corner 1', 'left eye upper eyelid 1', 'left eye upper eyelid 2', 'left eye eye corner 2', 'left eye lower eyelid 2', 'left eye lower eyelid 1', 'right mouth corner', 'upper lip outer edge 1', 'upper lip outer edge 2', 'upper lip outer edge 3', 'upper lip outer edge 4', 'upper lip outer edge 5', 'left mouth corner', 'lower lip outer edge 5', 'lower lip outer edge 4', 'lower lip outer edge 3', 'lower lip outer edge 2', 'lower lip outer edge 1', 'upper lip inter edge 1', 'upper lip inter edge 2', 'upper lip inter edge 3', 'upper lip inter edge 4', 'upper lip inter edge 5', 'lower lip inter edge 3', 'lower lip inter edge 2', 'lower lip inter edge 1'], "skeleton": []} hand = {"keypoints":['wrist', 'thumb root', "thumb's third knuckle", "thumb's second knuckle", 'thumb’s first knuckle', "forefinger's root", "forefinger's third knuckle", "forefinger's second knuckle", "forefinger's first knuckle", "middle finger's root", "middle finger's third knuckle", "middle finger's second knuckle", "middle finger's first knuckle", "ring finger's root", "ring finger's third knuckle", "ring finger's second knuckle", "ring finger's first knuckle", "pinky finger's root", "pinky finger's third knuckle", "pinky finger's second knuckle", "pinky finger's first knuckle"],"skeleton": []} animal_in_AnimalKindom = {"keypoints":['head mid top', 'eye left', 'eye right', 'mouth front top', 'mouth back left', 'mouth back right', 'mouth front bottom', 'shoulder left', 'shoulder right', 'elbow left', 'elbow right', 'wrist left', 'wrist right', 'torso mid back', 'hip left', 'hip right', 'knee left', 'knee right', 'ankle left ', 'ankle right', 'tail top back', 'tail mid back', 'tail end back'],"skeleton": [[1, 0], [2, 0], [3, 4], [3, 5], [4, 6], [5, 6], [0, 7], [0, 8], [7, 9], [8, 10], [9, 11], [10, 12], [0, 13], [13, 20], [20, 14], [20, 15], [14, 16], [15, 17], [16, 18], [17, 19], [20, 21], [21, 22]]} animal_in_AP10K = {"keypoints": ['left eye', 'right eye', 'nose', 'neck', 'root of tail', 'left shoulder', 'left elbow', 'left front paw', 'right shoulder', 'right elbow', 'right front paw', 'left hip', 'left knee', 'left back paw', 'right hip', 'right knee', 'right back paw'], "skeleton": [[1, 2], [1, 3], [2, 3], [3, 4], [4, 5], [4, 6], [6, 7], [7, 8], [4, 9], [9, 10], [10, 11], [5, 12], [12, 13], [13, 14], [5, 15], [15, 16], [16, 17]]} animal= {"keypoints": ['left eye', 'right eye', 'nose', 'neck', 'root of tail', 'left shoulder', 'left elbow', 'left front paw', 'right shoulder', 'right elbow', 'right front paw', 'left hip', 'left knee', 'left back paw', 'right hip', 'right knee', 'right back paw'], "skeleton": [[1, 2], [1, 3], [2, 3], [3, 4], [4, 5], [4, 6], [6, 7], [7, 8], [4, 9], [9, 10], [10, 11], [5, 12], [12, 13], [13, 14], [5, 15], [15, 16], [16, 17]]} animal_face = {"keypoints": ['right eye right', 'right eye left', 'left eye right', 'left eye left', 'nose tip', 'lip right', 'lip left', 'upper lip', 'lower lip'], "skeleton": []} fly = {"keypoints": ['head', 'eye left', 'eye right', 'neck', 'thorax', 'abdomen', 'foreleg right base', 'foreleg right first segment', 'foreleg right second segment', 'foreleg right tip', 'midleg right base', 'midleg right first segment', 'midleg right second segment', 'midleg right tip', 'hindleg right base', 'hindleg right first segment', 'hindleg right second segment', 'hindleg right tip', 'foreleg left base', 'foreleg left first segment', 'foreleg left second segment', 'foreleg left tip', 'midleg left base', 'midleg left first segment', 'midleg left second segment', 'midleg left tip', 'hindleg left base', 'hindleg left first segment', 'hindleg left second segment', 'hindleg left tip', 'wing left', 'wing right'], "skeleton": [[2, 1], [3, 1], [4, 1], [5, 4], [6, 5], [8, 7], [9, 8], [10, 9], [12, 11], [13, 12], [14, 13], [16, 15], [17, 16], [18, 17], [20, 19], [21, 20], [22, 21], [24, 23], [25, 24], [26, 25], [28, 27], [29, 28], [30, 29], [31, 4], [32, 4]]} locust = {"keypoints": ['head', 'neck', 'thorax', 'abdomen1', 'abdomen2', 'anttip left', 'antbase left', 'eye left', 'foreleg left base', 'foreleg left first segment', 'foreleg left second segment', 'foreleg left tip', 'midleg left base', 'midleg left first segment', 'midleg left second segment', 'midleg left tip', 'hindleg left base', 'hindleg left first segment', 'hindleg left second segment', 'hindleg left tip', 'anttip right', 'antbase right', 'eye right', 'foreleg right base', 'foreleg right first segment', 'foreleg right second segment', 'foreleg right tip', 'midleg right base', 'midleg right first segment', 'midleg right second segment', 'midleg right tip', 'hindleg right base', 'hindleg right first segment', 'hindleg right second segment', 'hindleg right tip'],"skeleton": [[2, 1], [3, 2], [4, 3], [5, 4], [7, 6], [8, 7], [10, 9], [11, 10], [12, 11], [14, 13], [15, 14],[16, 15], [18, 17], [19, 18], [20, 19], [22, 21], [23, 22], [25, 24], [26, 25], [27, 26],[29, 28], [30, 29], [31, 30], [33, 32], [34, 33], [35, 34]]} car ={"keypoints": ['right front wheel center', 'left front wheel center', 'right rear wheel center', 'left rear wheel center', 'front right', 'front left', 'back right', 'back left', 'none', 'roof front right', 'roof front left', 'roof back right', 'roof back left', 'none'],"skeleton": [[0, 2], [1, 3], [0, 1], [2, 3], [9, 11], [10, 12], [9, 10], [11, 12], [4, 0], [4, 9], [4, 5], [5, 1], [5, 10], [6, 2], [6, 11], [7, 3], [7, 12], [6, 7]]} short_sleeved_shirt = {'keypoints': ['upper center neckline', 'upper right neckline', 'lower right neckline', 'lower center neckline', 'lower left neckline', 'upper left neckline', 'right sleeve outside 1', 'right sleeve outside 2', 'right cuff outside', 'right cuff inside', 'right sleeve inside 2', 'right sleeve inside 1', 'right side 1', 'right side 2', 'right side 3', 'center hem', 'left side 3', 'left side 2', 'left side 1', 'left sleeve inside 1', 'left sleeve inside 2', 'left cuff inside', 'left cuff outside', 'left sleeve outside 2', 'left sleeve outside 1'], 'skeleton': []} long_sleeved_outwear={'keypoints': ['upper center neckline', 'lower right center neckline', 'lower right neckline', 'upper right neckline', 'lower left neckline', 'upper left neckline', 'right sleeve outside 1', 'right sleeve outside 2', 'right sleeve outside 3', 'right sleeve outside 4', 'right cuff outside', 'right cuff inside', 'right sleeve inside 1', 'right sleeve inside 2', 'right sleeve inside 3', 'right sleeve inside 4', 'right side outside 1', 'right side outside 2', 'right side outside 3', 'right side inside 3', 'left side outside 3', 'left side outside 2', 'left side outside 1', 'left sleeve inside 4', 'left sleeve inside 3', 'left sleeve inside 2', 'left sleeve inside 1', 'left cuff inside', 'left cuff outside', 'left sleeve outside 4', 'left sleeve outside 3', 'left sleeve outside 2', 'left sleeve outside 1', 'lower left center neckline', 'left side inside 1', 'left side inside 2', 'left side inside 3', 'right side inside 1', 'right side inside 2'], 'skeleton': []} short_sleeved_outwear={'keypoints': ['upper center neckline', 'lower right center neckline', 'lower right neckline', 'upper right neckline', 'lower left neckline', 'upper left neckline', 'right sleeve outside 1', 'right sleeve outside 2', 'right cuff outside', 'right cuff inside', 'right sleeve inside 2', 'right sleeve inside 1', 'right side outside 1', 'right side outside 2', 'right side outside 3', 'right side inside 3', 'left side outside 3', 'left side outside 2', 'left side outside 1', 'left sleeve inside 1', 'left sleeve inside 2', 'left cuff inside', 'left cuff outside', 'left sleeve outside 2', 'left sleeve outside 1', 'lower left center neckline', 'left side inside 1', 'left side inside 2', 'left side inside 3', 'right side inside 1', 'right side inside 2'], 'skeleton': []} sling={'keypoints': ['upper center neckline', 'upper right neckline', 'lower right neckline', 'lower center neckline', 'lower left neckline', 'upper left neckline', 'right sleeve', 'right side 1', 'right side 2', 'right side 3', 'center hem', 'left side 3', 'left side 2', 'left side 1', 'left sleeve'], 'skeleton': []} vest = {'keypoints': ['upper center neckline', 'upper right neckline', 'lower right neckline', 'lower center neckline', 'lower left neckline', 'upper left neckline', 'right sleeve', 'right side 1', 'right side 2', 'right side 3', 'center hem', 'left side 3', 'left side 2', 'left side 1', 'left sleeve'], 'skeleton': []} long_sleeved_dress={'keypoints': ['upper center neckline', 'upper right neckline', 'lower right neckline', 'lower center neckline', 'lower left neckline', 'upper left neckline', 'right sleeve outside 1', 'right sleeve outside 2', 'right sleeve outside 3', 'right sleeve outside 4', 'right cuff outside', 'right cuff inside', 'right sleeve inside 4', 'right sleeve inside 3', 'right sleeve inside 2', 'right sleeve inside 1', 'right side 1', 'right side 2', 'right side 3', 'right side 4', 'right side 5', 'center hem', 'left side 5', 'left side 4', 'left side 3', 'left side 2', 'left side 1', 'left sleeve inside 1', 'left sleeve inside 2', 'left sleeve inside 3', 'left sleeve inside 4', 'left cuff inside', 'left cuff outside', 'left sleeve outside 4', 'left sleeve outside 3', 'left sleeve outside 2', 'left sleeve outside 1'], 'skeleton': []} long_sleeved_shirt = {'keypoints': ['upper center neckline', 'upper right neckline', 'lower right neckline', 'lower center neckline', 'lower left neckline', 'upper left neckline', 'right sleeve outside 1', 'right sleeve outside 2', 'right sleeve outside 3', 'right sleeve outside 4', 'right cuff outside', 'right cuff inside', 'right sleeve inside 4', 'right sleeve inside 3', 'right sleeve inside 2', 'right sleeve inside 1', 'right side 1', 'right side 2', 'right side 3', 'center hem', 'left side 3', 'left side 2', 'left side 1', 'left sleeve inside 1', 'left sleeve inside 2', 'left sleeve inside 3', 'left sleeve inside 4', 'left cuff inside', 'left cuff outside', 'left sleeve outside 4', 'left sleeve outside 3', 'left sleeve outside 2', 'left sleeve outside 1'], 'skeleton': []} trousers = {'keypoints': ['right side outside 1', 'upper center', 'left side outside 1', 'right side outside 2', 'right side outside 3', 'right cuff outside', 'right cuff inside', 'right side inside 1', 'crotch', 'left side inside 1', 'left cuff inside', 'left cuff outside', 'left side outside 3', 'left side outside 2'], 'skeleton': []} sling_dress = {'keypoints': ['upper center neckline', 'upper right neckline', 'lower right neckline', 'lower center neckline', 'lower left neckline', 'upper left neckline', 'right side 1', 'right side 2', 'right side 3', 'right side 4', 'right side 5', 'right side 6', 'center hem', 'left side 6', 'left side 5', 'left side 4', 'left side 3', 'left side 2', 'left side 1'], 'skeleton': []} vest_dress = {'keypoints': ['upper center neckline', 'upper right neckline', 'lower right neckline', 'lower center neckline', 'lower left neckline', 'upper left neckline', 'right side 1', 'right side 2', 'right side 3', 'right side 4', 'right side 5', 'right side 6', 'center hem', 'left side 6', 'left side 5', 'left side 4', 'left side 3', 'left side 2', 'left side 1'], 'skeleton': []} skirt = {'keypoints': ['right side 1', 'upper center', 'left side 1', 'right side 2', 'right side 3', 'center hem', 'left side 3', 'left side 2'], 'skeleton': []} short_sleeved_dress = {'keypoints': ['upper center neckline', 'upper right neckline', 'lower right neckline', 'lower center neckline', 'lower left neckline', 'upper left neckline', 'right sleeve outside 1', 'right sleeve outside 2', 'right cuff outside', 'right cuff inside', 'right sleeve inside 1', 'right sleeve inside 2', 'left side 1', 'left side 2', 'left side 3', 'left side 4', 'left side 5', 'center hem', 'right side 5', 'right side 4', 'right side 3', 'right side 2', 'right side 1', 'left sleeve inside 2', 'left sleeve inside 1', 'left cuff inside', 'left cuff outside', 'left sleeve outside 2', 'left sleeve outside 1'], 'skeleton': []} shorts = {'keypoints': ['right side outside 1', 'upper center', 'left side outside 1', 'right side outside 2', 'right cuff outside', 'right cuff inside', 'crotch', 'left cuff inside', 'left cuff outside', 'left side outside 2'], 'skeleton': []} table = {'keypoints': ['desktop corner 1', 'desktop corner 2', 'desktop corner 3', 'desktop corner 4', 'table leg 1', 'table leg 2', 'table leg 3', 'table leg 4'], 'skeleton': []} chair = {'keypoints': ['legs righttopcorner', 'legs lefttopcorner', 'legs leftbottomcorner', 'legs rightbottomcorner', 'base righttop', 'base lefttop', 'base leftbottom', 'base rightbottom', 'headboard righttop', 'headboard lefttop'], 'skeleton': []} bed = {'keypoints': ['legs rightbottomcorner', 'legs righttopcorner', 'base rightbottom', 'base righttop', 'backrest righttop', 'legs leftbottomcorner', 'legs lefttopcorner', 'base leftbottom', 'base lefttop', 'backrest lefttop'], 'skeleton': []} sofa = {'keypoints': ['legs rightbottomcorner', 'legs righttopcorner', 'base rightbottom', 'base righttop', 'armrests rightbottomcorner', 'armrests righttopcorner', 'backrest righttop', 'legs leftbottomcorner', 'legs lefttopcorner', 'base leftbottom', 'base lefttop', 'armrests leftbottomcorner', 'armrests lefttopcorner', 'backrest lefttop'], 'skeleton': []} swivelchair = {'keypoints': ['rotatingbase 1', 'rotatingbase 2', 'rotatingbase 3', 'rotatingbase 4', 'rotatingbase 5', 'rotatingbase center', 'base center', 'base righttop', 'base lefttop', 'base leftbottom', 'base rightbottom', 'backrest righttop', 'backrest lefttop'], 'skeleton': []} ================================================ FILE: src/utils/dependencies/XPose/transforms.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ Transforms and data augmentation for both image + bbox. """ import os import sys import random import PIL import torch import torchvision.transforms as T import torchvision.transforms.functional as F sys.path.append(os.path.dirname(os.path.abspath(__file__))) from util.box_ops import box_xyxy_to_cxcywh from util.misc import interpolate def crop(image, target, region): cropped_image = F.crop(image, *region) if target is not None: target = target.copy() i, j, h, w = region id2catname = target["id2catname"] caption_list = target["caption_list"] target["size"] = torch.tensor([h, w]) fields = ["labels", "area", "iscrowd", "positive_map","keypoints"] if "boxes" in target: boxes = target["boxes"] max_size = torch.as_tensor([w, h], dtype=torch.float32) cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) cropped_boxes = cropped_boxes.clamp(min=0) area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) target["boxes"] = cropped_boxes.reshape(-1, 4) target["area"] = area fields.append("boxes") if "masks" in target: # FIXME should we update the area here if there are no boxes? target['masks'] = target['masks'][:, i:i + h, j:j + w] fields.append("masks") # remove elements for which the boxes or masks that have zero area if "boxes" in target or "masks" in target: # favor boxes selection when defining which elements to keep # this is compatible with previous implementation if "boxes" in target: cropped_boxes = target['boxes'].reshape(-1, 2, 2) keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) else: keep = target['masks'].flatten(1).any(1) for field in fields: if field in target: target[field] = target[field][keep] if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO': # for debug and visualization only. if 'strings_positive' in target: target['strings_positive'] = [_i for _i, _j in zip(target['strings_positive'], keep) if _j] if "keypoints" in target: max_size = torch.as_tensor([w, h], dtype=torch.float32) keypoints = target["keypoints"] cropped_keypoints = keypoints.view(-1, 3)[:,:2] - torch.as_tensor([j, i]) cropped_keypoints = torch.min(cropped_keypoints, max_size) cropped_keypoints = cropped_keypoints.clamp(min=0) cropped_keypoints = torch.cat([cropped_keypoints, keypoints.view(-1, 3)[:,2].unsqueeze(1)], dim=1) target["keypoints"] = cropped_keypoints.view(target["keypoints"].shape[0], target["keypoints"].shape[1], 3) target["id2catname"] = id2catname target["caption_list"] = caption_list return cropped_image, target def hflip(image, target): flipped_image = F.hflip(image) w, h = image.size if target is not None: target = target.copy() if "boxes" in target: boxes = target["boxes"] boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) target["boxes"] = boxes if "masks" in target: target['masks'] = target['masks'].flip(-1) if "keypoints" in target: dataset_name=target["dataset_name"] if dataset_name == "coco_person" or dataset_name == "macaque": flip_pairs = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]] elif dataset_name=="animalkindom_ak_P1_animal": flip_pairs = [[1, 2], [4, 5],[7,8],[9,10],[11,12],[14,15],[16,17],[18,19]] elif dataset_name=="animalweb_animal": flip_pairs = [[0, 3], [1, 2], [5, 6]] elif dataset_name=="face": flip_pairs = [ [0, 16], [1, 15], [2, 14], [3, 13], [4, 12], [5, 11], [6, 10], [7, 9], [17, 26], [18, 25], [19, 24], [20, 23], [21, 22], [31, 35], [32, 34], [36, 45], [37, 44], [38, 43], [39, 42], [40, 47], [41, 46], [48, 54], [49, 53], [50, 52], [55, 59], [56, 58], [60, 64], [61, 63], [65, 67] ] elif dataset_name=="hand": flip_pairs = [] elif dataset_name=="foot": flip_pairs = [] elif dataset_name=="locust": flip_pairs = [[5, 20], [6, 21], [7, 22], [8, 23], [9, 24], [10, 25], [11, 26], [12, 27], [13, 28], [14, 29], [15, 30], [16, 31], [17, 32], [18, 33], [19, 34]] elif dataset_name=="fly": flip_pairs = [[1, 2], [6, 18], [7, 19], [8, 20], [9, 21], [10, 22], [11, 23], [12, 24], [13, 25], [14, 26], [15, 27], [16, 28], [17, 29], [30, 31]] elif dataset_name == "ap_36k_animal" or dataset_name == "ap_10k_animal": flip_pairs = [[0, 1],[5, 8], [6, 9], [7, 10], [11, 14], [12, 15], [13, 16]] keypoints = target["keypoints"] keypoints[:,:,0] = w - keypoints[:,:, 0]-1 for pair in flip_pairs: keypoints[:,pair[0], :], keypoints[:,pair[1], :] = keypoints[:,pair[1], :], keypoints[:,pair[0], :].clone() target["keypoints"] = keypoints return flipped_image, target def resize(image, target, size, max_size=None): # size can be min_size (scalar) or (w, h) tuple def get_size_with_aspect_ratio(image_size, size, max_size=None): w, h = image_size if max_size is not None: min_original_size = float(min((w, h))) max_original_size = float(max((w, h))) if max_original_size / min_original_size * size > max_size: size = int(round(max_size * min_original_size / max_original_size)) if (w <= h and w == size) or (h <= w and h == size): return (h, w) if w < h: ow = size oh = int(size * h / w) else: oh = size ow = int(size * w / h) return (oh, ow) def get_size(image_size, size, max_size=None): if isinstance(size, (list, tuple)): return size[::-1] else: return get_size_with_aspect_ratio(image_size, size, max_size) size = get_size(image.size, size, max_size) rescaled_image = F.resize(image, size) if target is None: return rescaled_image, None ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) ratio_width, ratio_height = ratios target = target.copy() if "boxes" in target: boxes = target["boxes"] scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) target["boxes"] = scaled_boxes if "area" in target: area = target["area"] scaled_area = area * (ratio_width * ratio_height) target["area"] = scaled_area if "keypoints" in target: keypoints = target["keypoints"] scaled_keypoints = keypoints * torch.as_tensor([ratio_width, ratio_height, 1]) target["keypoints"] = scaled_keypoints h, w = size target["size"] = torch.tensor([h, w]) if "masks" in target: target['masks'] = interpolate( target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 return rescaled_image, target def pad(image, target, padding): # assumes that we only pad on the bottom right corners padded_image = F.pad(image, (0, 0, padding[0], padding[1])) if target is None: return padded_image, None target = target.copy() # should we do something wrt the original size? target["size"] = torch.tensor(padded_image.size[::-1]) if "masks" in target: target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) return padded_image, target class ResizeDebug(object): def __init__(self, size): self.size = size def __call__(self, img, target): return resize(img, target, self.size) class RandomCrop(object): def __init__(self, size): self.size = size def __call__(self, img, target): region = T.RandomCrop.get_params(img, self.size) return crop(img, target, region) class RandomSizeCrop(object): def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False): # respect_boxes: True to keep all boxes # False to tolerence box filter self.min_size = min_size self.max_size = max_size self.respect_boxes = respect_boxes def __call__(self, img: PIL.Image.Image, target: dict): init_boxes = len(target["boxes"]) if (target is not None and "boxes" in target) else 0 max_patience = 10 for i in range(max_patience): w = random.randint(self.min_size, min(img.width, self.max_size)) h = random.randint(self.min_size, min(img.height, self.max_size)) region = T.RandomCrop.get_params(img, [h, w]) result_img, result_target = crop(img, target, region) if target is not None: if not self.respect_boxes or len(result_target["boxes"]) == init_boxes or i == max_patience - 1: return result_img, result_target return result_img, result_target class CenterCrop(object): def __init__(self, size): self.size = size def __call__(self, img, target): image_width, image_height = img.size crop_height, crop_width = self.size crop_top = int(round((image_height - crop_height) / 2.)) crop_left = int(round((image_width - crop_width) / 2.)) return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) class RandomHorizontalFlip(object): def __init__(self, p=0.5): self.p = p def __call__(self, img, target): if random.random() < self.p: return hflip(img, target) return img, target class RandomResize(object): def __init__(self, sizes, max_size=None): assert isinstance(sizes, (list, tuple)) self.sizes = sizes self.max_size = max_size def __call__(self, img, target=None): size = random.choice(self.sizes) return resize(img, target, size, self.max_size) class RandomPad(object): def __init__(self, max_pad): self.max_pad = max_pad def __call__(self, img, target): pad_x = random.randint(0, self.max_pad) pad_y = random.randint(0, self.max_pad) return pad(img, target, (pad_x, pad_y)) class RandomSelect(object): """ Randomly selects between transforms1 and transforms2, with probability p for transforms1 and (1 - p) for transforms2 """ def __init__(self, transforms1, transforms2, p=0.5): self.transforms1 = transforms1 self.transforms2 = transforms2 self.p = p def __call__(self, img, target): if random.random() < self.p: return self.transforms1(img, target) return self.transforms2(img, target) class ToTensor(object): def __call__(self, img, target): return F.to_tensor(img), target class RandomErasing(object): def __init__(self, *args, **kwargs): self.eraser = T.RandomErasing(*args, **kwargs) def __call__(self, img, target): return self.eraser(img), target class Normalize(object): def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, image, target=None): image = F.normalize(image, mean=self.mean, std=self.std) if target is None: return image, None target = target.copy() h, w = image.shape[-2:] if "boxes" in target: boxes = target["boxes"] boxes = box_xyxy_to_cxcywh(boxes) boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) target["boxes"] = boxes if "area" in target: area = target["area"] area = area / (torch.tensor(w, dtype=torch.float32)*torch.tensor(h, dtype=torch.float32)) target["area"] = area if "keypoints" in target: keypoints = target["keypoints"] V = keypoints[:, :, 2] V[V == 2] = 1 Z=keypoints[:, :, :2] Z = Z.contiguous().view(-1, 2 * V.shape[-1]) Z = Z / torch.tensor([w, h] * V.shape[-1], dtype=torch.float32) target["valid_kpt_num"] = V.shape[1] Z_pad = torch.zeros(Z.shape[0],68 * 2 - Z.shape[1]) V_pad = torch.zeros(V.shape[0],68 - V.shape[1]) V=torch.cat([V, V_pad], dim=1) Z=torch.cat([Z, Z_pad], dim=1) all_keypoints = torch.cat([Z, V], dim=1) target["keypoints"] = all_keypoints return image, target class Compose(object): def __init__(self, transforms): self.transforms = transforms def __call__(self, image, target): for t in self.transforms: image, target = t(image, target) return image, target def __repr__(self): format_string = self.__class__.__name__ + "(" for t in self.transforms: format_string += "\n" format_string += " {0}".format(t) format_string += "\n)" return format_string ================================================ FILE: src/utils/dependencies/XPose/util/addict.py ================================================ import copy class Dict(dict): def __init__(__self, *args, **kwargs): object.__setattr__(__self, '__parent', kwargs.pop('__parent', None)) object.__setattr__(__self, '__key', kwargs.pop('__key', None)) object.__setattr__(__self, '__frozen', False) for arg in args: if not arg: continue elif isinstance(arg, dict): for key, val in arg.items(): __self[key] = __self._hook(val) elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)): __self[arg[0]] = __self._hook(arg[1]) else: for key, val in iter(arg): __self[key] = __self._hook(val) for key, val in kwargs.items(): __self[key] = __self._hook(val) def __setattr__(self, name, value): if hasattr(self.__class__, name): raise AttributeError("'Dict' object attribute " "'{0}' is read-only".format(name)) else: self[name] = value def __setitem__(self, name, value): isFrozen = (hasattr(self, '__frozen') and object.__getattribute__(self, '__frozen')) if isFrozen and name not in super(Dict, self).keys(): raise KeyError(name) super(Dict, self).__setitem__(name, value) try: p = object.__getattribute__(self, '__parent') key = object.__getattribute__(self, '__key') except AttributeError: p = None key = None if p is not None: p[key] = self object.__delattr__(self, '__parent') object.__delattr__(self, '__key') def __add__(self, other): if not self.keys(): return other else: self_type = type(self).__name__ other_type = type(other).__name__ msg = "unsupported operand type(s) for +: '{}' and '{}'" raise TypeError(msg.format(self_type, other_type)) @classmethod def _hook(cls, item): if isinstance(item, dict): return cls(item) elif isinstance(item, (list, tuple)): return type(item)(cls._hook(elem) for elem in item) return item def __getattr__(self, item): return self.__getitem__(item) def __missing__(self, name): if object.__getattribute__(self, '__frozen'): raise KeyError(name) return self.__class__(__parent=self, __key=name) def __delattr__(self, name): del self[name] def to_dict(self): base = {} for key, value in self.items(): if isinstance(value, type(self)): base[key] = value.to_dict() elif isinstance(value, (list, tuple)): base[key] = type(value)( item.to_dict() if isinstance(item, type(self)) else item for item in value) else: base[key] = value return base def copy(self): return copy.copy(self) def deepcopy(self): return copy.deepcopy(self) def __deepcopy__(self, memo): other = self.__class__() memo[id(self)] = other for key, value in self.items(): other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo) return other def update(self, *args, **kwargs): other = {} if args: if len(args) > 1: raise TypeError() other.update(args[0]) other.update(kwargs) for k, v in other.items(): if ((k not in self) or (not isinstance(self[k], dict)) or (not isinstance(v, dict))): self[k] = v else: self[k].update(v) def __getnewargs__(self): return tuple(self.items()) def __getstate__(self): return self def __setstate__(self, state): self.update(state) def __or__(self, other): if not isinstance(other, (Dict, dict)): return NotImplemented new = Dict(self) new.update(other) return new def __ror__(self, other): if not isinstance(other, (Dict, dict)): return NotImplemented new = Dict(other) new.update(self) return new def __ior__(self, other): self.update(other) return self def setdefault(self, key, default=None): if key in self: return self[key] else: self[key] = default return default def freeze(self, shouldFreeze=True): object.__setattr__(self, '__frozen', shouldFreeze) for key, val in self.items(): if isinstance(val, Dict): val.freeze(shouldFreeze) def unfreeze(self): self.freeze(False) ================================================ FILE: src/utils/dependencies/XPose/util/box_ops.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ Utilities for bounding box manipulation and GIoU. """ import torch, os from torchvision.ops.boxes import box_area def box_cxcywh_to_xyxy(x): x_c, y_c, w, h = x.unbind(-1) b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] return torch.stack(b, dim=-1) def box_xyxy_to_cxcywh(x): x0, y0, x1, y1 = x.unbind(-1) b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] return torch.stack(b, dim=-1) # modified from torchvision to also return the union def box_iou(boxes1, boxes2): area1 = box_area(boxes1) area2 = box_area(boxes2) # import ipdb; ipdb.set_trace() lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] wh = (rb - lt).clamp(min=0) # [N,M,2] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] union = area1[:, None] + area2 - inter iou = inter / (union + 1e-6) return iou, union def generalized_box_iou(boxes1, boxes2): """ Generalized IoU from https://giou.stanford.edu/ The boxes should be in [x0, y0, x1, y1] format Returns a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) """ # degenerate boxes gives inf / nan results # so do an early check assert (boxes1[:, 2:] >= boxes1[:, :2]).all() assert (boxes2[:, 2:] >= boxes2[:, :2]).all() # except: # import ipdb; ipdb.set_trace() iou, union = box_iou(boxes1, boxes2) lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) wh = (rb - lt).clamp(min=0) # [N,M,2] area = wh[:, :, 0] * wh[:, :, 1] return iou - (area - union) / (area + 1e-6) # modified from torchvision to also return the union def box_iou_pairwise(boxes1, boxes2): area1 = box_area(boxes1) area2 = box_area(boxes2) lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2] rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2] wh = (rb - lt).clamp(min=0) # [N,2] inter = wh[:, 0] * wh[:, 1] # [N] union = area1 + area2 - inter iou = inter / union return iou, union def generalized_box_iou_pairwise(boxes1, boxes2): """ Generalized IoU from https://giou.stanford.edu/ Input: - boxes1, boxes2: N,4 Output: - giou: N, 4 """ # degenerate boxes gives inf / nan results # so do an early check assert (boxes1[:, 2:] >= boxes1[:, :2]).all() assert (boxes2[:, 2:] >= boxes2[:, :2]).all() assert boxes1.shape == boxes2.shape iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4 lt = torch.min(boxes1[:, :2], boxes2[:, :2]) rb = torch.max(boxes1[:, 2:], boxes2[:, 2:]) wh = (rb - lt).clamp(min=0) # [N,2] area = wh[:, 0] * wh[:, 1] return iou - (area - union) / area def masks_to_boxes(masks): """Compute the bounding boxes around the provided masks The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. Returns a [N, 4] tensors, with the boxes in xyxy format """ if masks.numel() == 0: return torch.zeros((0, 4), device=masks.device) h, w = masks.shape[-2:] y = torch.arange(0, h, dtype=torch.float) x = torch.arange(0, w, dtype=torch.float) y, x = torch.meshgrid(y, x) x_mask = (masks * x.unsqueeze(0)) x_max = x_mask.flatten(1).max(-1)[0] x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] y_mask = (masks * y.unsqueeze(0)) y_max = y_mask.flatten(1).max(-1)[0] y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] return torch.stack([x_min, y_min, x_max, y_max], 1) if __name__ == '__main__': x = torch.rand(5, 4) y = torch.rand(3, 4) iou, union = box_iou(x, y) import ipdb; ipdb.set_trace() ================================================ FILE: src/utils/dependencies/XPose/util/config.py ================================================ # ========================================================== # Modified from mmcv # ========================================================== import sys import os.path as osp import ast import tempfile import shutil from importlib import import_module from argparse import Action from .addict import Dict BASE_KEY = '_base_' DELETE_KEY = '_delete_' RESERVED_KEYS = ['filename', 'text', 'pretty_text', 'get', 'dump', 'merge_from_dict'] def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): if not osp.isfile(filename): raise FileNotFoundError(msg_tmpl.format(filename)) class ConfigDict(Dict): def __missing__(self, name): raise KeyError(name) def __getattr__(self, name): try: value = super(ConfigDict, self).__getattr__(name) except KeyError: ex = AttributeError(f"'{self.__class__.__name__}' object has no " f"attribute '{name}'") except Exception as e: ex = e else: return value raise ex class Config(object): """ config files. only support .py file as config now. ref: mmcv.utils.config Example: >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) >>> cfg.a 1 >>> cfg.b {'b1': [0, 1]} >>> cfg.b.b1 [0, 1] >>> cfg = Config.fromfile('tests/data/config/a.py') >>> cfg.filename "/home/kchen/projects/mmcv/tests/data/config/a.py" >>> cfg.item4 'test' >>> cfg "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: " "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" """ @staticmethod def _validate_py_syntax(filename): with open(filename) as f: content = f.read() try: ast.parse(content) except SyntaxError: raise SyntaxError('There are syntax errors in config ' f'file {filename}') @staticmethod def _file2dict(filename): filename = osp.abspath(osp.expanduser(filename)) check_file_exist(filename) if filename.lower().endswith('.py'): with tempfile.TemporaryDirectory() as temp_config_dir: temp_config_file = tempfile.NamedTemporaryFile( dir=temp_config_dir, suffix='.py') temp_config_name = osp.basename(temp_config_file.name) # close temp file before copy temp_config_file.close() shutil.copyfile(filename, osp.join(temp_config_dir, temp_config_name)) temp_module_name = osp.splitext(temp_config_name)[0] sys.path.insert(0, temp_config_dir) Config._validate_py_syntax(filename) mod = import_module(temp_module_name) sys.path.pop(0) cfg_dict = { name: value for name, value in mod.__dict__.items() if not name.startswith('__') } # delete imported module del sys.modules[temp_module_name] elif filename.lower().endswith(('.yml', '.yaml', '.json')): from .slio import slload cfg_dict = slload(filename) else: raise IOError('Only py/yml/yaml/json type are supported now!') cfg_text = filename + '\n' with open(filename, 'r') as f: cfg_text += f.read() # parse the base file if BASE_KEY in cfg_dict: cfg_dir = osp.dirname(filename) base_filename = cfg_dict.pop(BASE_KEY) base_filename = base_filename if isinstance( base_filename, list) else [base_filename] cfg_dict_list = list() cfg_text_list = list() for f in base_filename: _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f)) cfg_dict_list.append(_cfg_dict) cfg_text_list.append(_cfg_text) base_cfg_dict = dict() for c in cfg_dict_list: if len(base_cfg_dict.keys() & c.keys()) > 0: raise KeyError('Duplicate key is not allowed among bases') # TODO Allow the duplicate key while warnning user base_cfg_dict.update(c) base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) cfg_dict = base_cfg_dict # merge cfg_text cfg_text_list.append(cfg_text) cfg_text = '\n'.join(cfg_text_list) return cfg_dict, cfg_text @staticmethod def _merge_a_into_b(a, b): """merge dict `a` into dict `b` (non-inplace). values in `a` will overwrite `b`. copy first to avoid inplace modification Args: a ([type]): [description] b ([type]): [description] Returns: [dict]: [description] """ # import ipdb; ipdb.set_trace() if not isinstance(a, dict): return a b = b.copy() for k, v in a.items(): if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False): if not isinstance(b[k], dict) and not isinstance(b[k], list): # if : # import ipdb; ipdb.set_trace() raise TypeError( f'{k}={v} in child config cannot inherit from base ' f'because {k} is a dict in the child config but is of ' f'type {type(b[k])} in base config. You may set ' f'`{DELETE_KEY}=True` to ignore the base config') b[k] = Config._merge_a_into_b(v, b[k]) elif isinstance(b, list): try: _ = int(k) except: raise TypeError( f'b is a list, ' f'index {k} should be an int when input but {type(k)}' ) b[int(k)] = Config._merge_a_into_b(v, b[int(k)]) else: b[k] = v return b @staticmethod def fromfile(filename): cfg_dict, cfg_text = Config._file2dict(filename) return Config(cfg_dict, cfg_text=cfg_text, filename=filename) def __init__(self, cfg_dict=None, cfg_text=None, filename=None): if cfg_dict is None: cfg_dict = dict() elif not isinstance(cfg_dict, dict): raise TypeError('cfg_dict must be a dict, but ' f'got {type(cfg_dict)}') for key in cfg_dict: if key in RESERVED_KEYS: raise KeyError(f'{key} is reserved for config file') super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict)) super(Config, self).__setattr__('_filename', filename) if cfg_text: text = cfg_text elif filename: with open(filename, 'r') as f: text = f.read() else: text = '' super(Config, self).__setattr__('_text', text) @property def filename(self): return self._filename @property def text(self): return self._text @property def pretty_text(self): indent = 4 def _indent(s_, num_spaces): s = s_.split('\n') if len(s) == 1: return s_ first = s.pop(0) s = [(num_spaces * ' ') + line for line in s] s = '\n'.join(s) s = first + '\n' + s return s def _format_basic_types(k, v, use_mapping=False): if isinstance(v, str): v_str = f"'{v}'" else: v_str = str(v) if use_mapping: k_str = f"'{k}'" if isinstance(k, str) else str(k) attr_str = f'{k_str}: {v_str}' else: attr_str = f'{str(k)}={v_str}' attr_str = _indent(attr_str, indent) return attr_str def _format_list(k, v, use_mapping=False): # check if all items in the list are dict if all(isinstance(_, dict) for _ in v): v_str = '[\n' v_str += '\n'.join( f'dict({_indent(_format_dict(v_), indent)}),' for v_ in v).rstrip(',') if use_mapping: k_str = f"'{k}'" if isinstance(k, str) else str(k) attr_str = f'{k_str}: {v_str}' else: attr_str = f'{str(k)}={v_str}' attr_str = _indent(attr_str, indent) + ']' else: attr_str = _format_basic_types(k, v, use_mapping) return attr_str def _contain_invalid_identifier(dict_str): contain_invalid_identifier = False for key_name in dict_str: contain_invalid_identifier |= \ (not str(key_name).isidentifier()) return contain_invalid_identifier def _format_dict(input_dict, outest_level=False): r = '' s = [] use_mapping = _contain_invalid_identifier(input_dict) if use_mapping: r += '{' for idx, (k, v) in enumerate(input_dict.items()): is_last = idx >= len(input_dict) - 1 end = '' if outest_level or is_last else ',' if isinstance(v, dict): v_str = '\n' + _format_dict(v) if use_mapping: k_str = f"'{k}'" if isinstance(k, str) else str(k) attr_str = f'{k_str}: dict({v_str}' else: attr_str = f'{str(k)}=dict({v_str}' attr_str = _indent(attr_str, indent) + ')' + end elif isinstance(v, list): attr_str = _format_list(k, v, use_mapping) + end else: attr_str = _format_basic_types(k, v, use_mapping) + end s.append(attr_str) r += '\n'.join(s) if use_mapping: r += '}' return r cfg_dict = self._cfg_dict.to_dict() text = _format_dict(cfg_dict, outest_level=True) return text def __repr__(self): return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' def __len__(self): return len(self._cfg_dict) def __getattr__(self, name): # # debug # print('+'*15) # print('name=%s' % name) # print("addr:", id(self)) # # print('type(self):', type(self)) # print(self.__dict__) # print('+'*15) # if self.__dict__ == {}: # raise ValueError return getattr(self._cfg_dict, name) def __getitem__(self, name): return self._cfg_dict.__getitem__(name) def __setattr__(self, name, value): if isinstance(value, dict): value = ConfigDict(value) self._cfg_dict.__setattr__(name, value) def __setitem__(self, name, value): if isinstance(value, dict): value = ConfigDict(value) self._cfg_dict.__setitem__(name, value) def __iter__(self): return iter(self._cfg_dict) def dump(self, file=None): # import ipdb; ipdb.set_trace() if file is None: return self.pretty_text else: with open(file, 'w') as f: f.write(self.pretty_text) def merge_from_dict(self, options): """Merge list into cfg_dict Merge the dict parsed by MultipleKVAction into this cfg. Examples: >>> options = {'model.backbone.depth': 50, ... 'model.backbone.with_cp':True} >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) >>> cfg.merge_from_dict(options) >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') >>> assert cfg_dict == dict( ... model=dict(backbone=dict(depth=50, with_cp=True))) Args: options (dict): dict of configs to merge from. """ option_cfg_dict = {} for full_key, v in options.items(): d = option_cfg_dict key_list = full_key.split('.') for subkey in key_list[:-1]: d.setdefault(subkey, ConfigDict()) d = d[subkey] subkey = key_list[-1] d[subkey] = v cfg_dict = super(Config, self).__getattribute__('_cfg_dict') super(Config, self).__setattr__( '_cfg_dict', Config._merge_a_into_b(option_cfg_dict, cfg_dict)) # for multiprocess def __setstate__(self, state): self.__init__(state) def copy(self): return Config(self._cfg_dict.copy()) def deepcopy(self): return Config(self._cfg_dict.deepcopy()) class DictAction(Action): """ argparse action to split an argument into KEY=VALUE form on the first = and append to a dictionary. List options should be passed as comma separated values, i.e KEY=V1,V2,V3 """ @staticmethod def _parse_int_float_bool(val): try: return int(val) except ValueError: pass try: return float(val) except ValueError: pass if val.lower() in ['true', 'false']: return True if val.lower() == 'true' else False if val.lower() in ['none', 'null']: return None return val def __call__(self, parser, namespace, values, option_string=None): options = {} for kv in values: key, val = kv.split('=', maxsplit=1) val = [self._parse_int_float_bool(v) for v in val.split(',')] if len(val) == 1: val = val[0] options[key] = val setattr(namespace, self.dest, options) ================================================ FILE: src/utils/dependencies/XPose/util/keypoint_ops.py ================================================ import torch, os def keypoint_xyxyzz_to_xyzxyz(keypoints: torch.Tensor): """_summary_ Args: keypoints (torch.Tensor): ..., 51 """ res = torch.zeros_like(keypoints) num_points = keypoints.shape[-1] // 3 Z = keypoints[..., :2*num_points] V = keypoints[..., 2*num_points:] res[...,0::3] = Z[..., 0::2] res[...,1::3] = Z[..., 1::2] res[...,2::3] = V[...] return res def keypoint_xyzxyz_to_xyxyzz(keypoints: torch.Tensor): """_summary_ Args: keypoints (torch.Tensor): ..., 51 """ res = torch.zeros_like(keypoints) num_points = keypoints.shape[-1] // 3 res[...,0:2*num_points:2] = keypoints[..., 0::3] res[...,1:2*num_points:2] = keypoints[..., 1::3] res[...,2*num_points:] = keypoints[..., 2::3] return res ================================================ FILE: src/utils/dependencies/XPose/util/misc.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ Misc functions, including distributed helpers. Mostly copy-paste from torchvision references. """ import functools import io import os import random import subprocess import time from collections import OrderedDict, defaultdict, deque import datetime import pickle from typing import Optional, List import json, time import numpy as np import torch import torch.distributed as dist from torch import Tensor import colorsys # needed due to empty tensor bug in pytorch and torchvision 0.5 import torchvision __torchvision_need_compat_flag = float(torchvision.__version__.split('.')[1]) < 7 if __torchvision_need_compat_flag: from torchvision.ops import _new_empty_tensor from torchvision.ops.misc import _output_size class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a window or the global series average. """ def __init__(self, window_size=20, fmt=None): if fmt is None: fmt = "{median:.4f} ({global_avg:.4f})" self.deque = deque(maxlen=window_size) self.total = 0.0 self.count = 0 self.fmt = fmt def update(self, value, n=1): self.deque.append(value) self.count += n self.total += value * n def synchronize_between_processes(self): """ Warning: does not synchronize the deque! """ if not is_dist_avail_and_initialized(): return t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') dist.barrier() dist.all_reduce(t) t = t.tolist() self.count = int(t[0]) self.total = t[1] @property def median(self): d = torch.tensor(list(self.deque)) if d.shape[0] == 0: return 0 return d.median().item() @property def avg(self): d = torch.tensor(list(self.deque), dtype=torch.float32) return d.mean().item() @property def global_avg(self): if os.environ.get("SHILONG_AMP", None) == '1': eps = 1e-4 else: eps = 1e-6 return self.total / (self.count + eps) @property def max(self): return max(self.deque) @property def value(self): return self.deque[-1] def __str__(self): return self.fmt.format( median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value) @functools.lru_cache() def _get_global_gloo_group(): """ Return a process group based on gloo backend, containing all the ranks The result is cached. """ if dist.get_backend() == "nccl": return dist.new_group(backend="gloo") return dist.group.WORLD def all_gather_cpu(data): """ Run all_gather on arbitrary picklable data (not necessarily tensors) Args: data: any picklable object Returns: list[data]: list of data gathered from each rank """ world_size = get_world_size() if world_size == 1: return [data] cpu_group = _get_global_gloo_group() buffer = io.BytesIO() torch.save(data, buffer) data_view = buffer.getbuffer() device = "cuda" if cpu_group is None else "cpu" tensor = torch.ByteTensor(data_view).to(device) # obtain Tensor size of each rank local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long) size_list = [torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)] if cpu_group is None: dist.all_gather(size_list, local_size) else: print("gathering on cpu") dist.all_gather(size_list, local_size, group=cpu_group) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) assert isinstance(local_size.item(), int) local_size = int(local_size.item()) # receiving Tensor from all ranks # we pad the tensor because torch all_gather does not support # gathering tensors of different shapes tensor_list = [] for _ in size_list: tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device)) if local_size != max_size: padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device=device) tensor = torch.cat((tensor, padding), dim=0) if cpu_group is None: dist.all_gather(tensor_list, tensor) else: dist.all_gather(tensor_list, tensor, group=cpu_group) data_list = [] for size, tensor in zip(size_list, tensor_list): tensor = torch.split(tensor, [size, max_size - size], dim=0)[0] buffer = io.BytesIO(tensor.cpu().numpy()) obj = torch.load(buffer) data_list.append(obj) return data_list def all_gather(data): """ Run all_gather on arbitrary picklable data (not necessarily tensors) Args: data: any picklable object Returns: list[data]: list of data gathered from each rank """ if os.getenv("CPU_REDUCE") == "1": return all_gather_cpu(data) world_size = get_world_size() if world_size == 1: return [data] # serialized to a Tensor buffer = pickle.dumps(data) storage = torch.ByteStorage.from_buffer(buffer) tensor = torch.ByteTensor(storage).to("cuda") # obtain Tensor size of each rank local_size = torch.tensor([tensor.numel()], device="cuda") size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] dist.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) # receiving Tensor from all ranks # we pad the tensor because torch all_gather does not support # gathering tensors of different shapes tensor_list = [] for _ in size_list: tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) if local_size != max_size: padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") tensor = torch.cat((tensor, padding), dim=0) dist.all_gather(tensor_list, tensor) data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list def reduce_dict(input_dict, average=True): """ Args: input_dict (dict): all the values will be reduced average (bool): whether to do average or sum Reduce the values in the dictionary from all processes so that all processes have the averaged results. Returns a dict with the same fields as input_dict, after reduction. """ world_size = get_world_size() if world_size < 2: return input_dict with torch.no_grad(): names = [] values = [] # sort the keys so that they are consistent across processes for k in sorted(input_dict.keys()): names.append(k) values.append(input_dict[k]) values = torch.stack(values, dim=0) dist.all_reduce(values) if average: values /= world_size reduced_dict = {k: v for k, v in zip(names, values)} return reduced_dict class MetricLogger(object): def __init__(self, delimiter="\t"): self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter def update(self, **kwargs): for k, v in kwargs.items(): if isinstance(v, torch.Tensor): v = v.item() assert isinstance(v, (float, int)) self.meters[k].update(v) def __getattr__(self, attr): if attr in self.meters: return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] raise AttributeError("'{}' object has no attribute '{}'".format( type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): # print(name, str(meter)) # import ipdb;ipdb.set_trace() if meter.count > 0: loss_str.append( "{}: {}".format(name, str(meter)) ) return self.delimiter.join(loss_str) def synchronize_between_processes(self): for meter in self.meters.values(): meter.synchronize_between_processes() def add_meter(self, name, meter): self.meters[name] = meter def log_every(self, iterable, print_freq, header=None, logger=None): if logger is None: print_func = print else: print_func = logger.info i = 0 if not header: header = '' start_time = time.time() end = time.time() iter_time = SmoothedValue(fmt='{avg:.4f}') data_time = SmoothedValue(fmt='{avg:.4f}') space_fmt = ':' + str(len(str(len(iterable)))) + 'd' if torch.cuda.is_available(): log_msg = self.delimiter.join([ header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 'time: {time}', 'data: {data}', 'max mem: {memory:.0f}' ]) else: log_msg = self.delimiter.join([ header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 'time: {time}', 'data: {data}' ]) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) yield obj # import ipdb; ipdb.set_trace() iter_time.update(time.time() - end) if i % print_freq == 0 or i == len(iterable) - 1: eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): print_func(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time), memory=torch.cuda.max_memory_allocated() / MB)) else: print_func(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time))) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print_func('{} Total time: {} ({:.4f} s / it)'.format( header, total_time_str, total_time / len(iterable))) def get_sha(): cwd = os.path.dirname(os.path.abspath(__file__)) def _run(command): return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() sha = 'N/A' diff = "clean" branch = 'N/A' try: sha = _run(['git', 'rev-parse', 'HEAD']) subprocess.check_output(['git', 'diff'], cwd=cwd) diff = _run(['git', 'diff-index', 'HEAD']) diff = "has uncommited changes" if diff else "clean" branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) except Exception: pass message = f"sha: {sha}, status: {diff}, branch: {branch}" return message def collate_fn(batch): # import ipdb; ipdb.set_trace() batch = list(zip(*batch)) batch[0] = nested_tensor_from_tensor_list(batch[0]) return tuple(batch) def _max_by_axis(the_list): # type: (List[List[int]]) -> List[int] maxes = the_list[0] for sublist in the_list[1:]: for index, item in enumerate(sublist): maxes[index] = max(maxes[index], item) return maxes class NestedTensor(object): def __init__(self, tensors, mask: Optional[Tensor]): self.tensors = tensors self.mask = mask if mask == 'auto': self.mask = torch.zeros_like(tensors).to(tensors.device) if self.mask.dim() == 3: self.mask = self.mask.sum(0).to(bool) elif self.mask.dim() == 4: self.mask = self.mask.sum(1).to(bool) else: raise ValueError("tensors dim must be 3 or 4 but {}({})".format(self.tensors.dim(), self.tensors.shape)) def imgsize(self): res = [] for i in range(self.tensors.shape[0]): mask = self.mask[i] maxH = (~mask).sum(0).max() maxW = (~mask).sum(1).max() res.append(torch.Tensor([maxH, maxW])) return res def to(self, device): # type: (Device) -> NestedTensor # noqa cast_tensor = self.tensors.to(device) mask = self.mask if mask is not None: assert mask is not None cast_mask = mask.to(device) else: cast_mask = None return NestedTensor(cast_tensor, cast_mask) def to_img_list_single(self, tensor, mask): assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(tensor.dim()) maxH = (~mask).sum(0).max() maxW = (~mask).sum(1).max() img = tensor[:, :maxH, :maxW] return img def to_img_list(self): """remove the padding and convert to img list Returns: [type]: [description] """ if self.tensors.dim() == 3: return self.to_img_list_single(self.tensors, self.mask) else: res = [] for i in range(self.tensors.shape[0]): tensor_i = self.tensors[i] mask_i = self.mask[i] res.append(self.to_img_list_single(tensor_i, mask_i)) return res @property def device(self): return self.tensors.device def decompose(self): return self.tensors, self.mask def __repr__(self): return str(self.tensors) @property def shape(self): return { 'tensors.shape': self.tensors.shape, 'mask.shape': self.mask.shape } def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): # TODO make this more general if tensor_list[0].ndim == 3: if torchvision._is_tracing(): # nested_tensor_from_tensor_list() does not export well to ONNX # call _onnx_nested_tensor_from_tensor_list() instead return _onnx_nested_tensor_from_tensor_list(tensor_list) # TODO make it support different-sized images max_size = _max_by_axis([list(img.shape) for img in tensor_list]) # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) batch_shape = [len(tensor_list)] + max_size b, c, h, w = batch_shape dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) mask = torch.ones((b, h, w), dtype=torch.bool, device=device) for img, pad_img, m in zip(tensor_list, tensor, mask): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) m[: img.shape[1], :img.shape[2]] = False else: raise ValueError('not supported') return NestedTensor(tensor, mask) # _onnx_nested_tensor_from_tensor_list() is an implementation of # nested_tensor_from_tensor_list() that is supported by ONNX tracing. @torch.jit.unused def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: max_size = [] for i in range(tensor_list[0].dim()): max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) max_size.append(max_size_i) max_size = tuple(max_size) # work around for # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) # m[: img.shape[1], :img.shape[2]] = False # which is not yet supported in onnx padded_imgs = [] padded_masks = [] for img in tensor_list: padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) padded_imgs.append(padded_img) m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) padded_masks.append(padded_mask.to(torch.bool)) tensor = torch.stack(padded_imgs) mask = torch.stack(padded_masks) return NestedTensor(tensor, mask=mask) def setup_for_distributed(is_master): """ This function disables printing when not in master process """ import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): force = kwargs.pop('force', False) if is_master or force: builtin_print(*args, **kwargs) __builtin__.print = print def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def is_main_process(): return get_rank() == 0 def save_on_master(*args, **kwargs): if is_main_process(): torch.save(*args, **kwargs) def init_distributed_mode(args): if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != '': # 'RANK' in os.environ and args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ['WORLD_SIZE']) args.gpu = args.local_rank = int(os.environ['LOCAL_RANK']) # launch by torch.distributed.launch # Single node # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 1 --rank 0 ... # Multi nodes # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 0 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ... # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 1 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ... # args.rank = int(os.environ.get('OMPI_COMM_WORLD_RANK')) # local_world_size = int(os.environ['GPU_PER_NODE_COUNT']) # args.world_size = args.world_size * local_world_size # args.gpu = args.local_rank = int(os.environ['LOCAL_RANK']) # args.rank = args.rank * local_world_size + args.local_rank print('world size: {}, rank: {}, local rank: {}'.format(args.world_size, args.rank, args.local_rank)) print(json.dumps(dict(os.environ), indent=2)) elif 'SLURM_PROCID' in os.environ: args.rank = int(os.environ['SLURM_PROCID']) args.gpu = args.local_rank = int(os.environ['SLURM_LOCALID']) args.world_size = int(os.environ['SLURM_NPROCS']) if os.environ.get('HAND_DEFINE_DIST_URL', 0) == '1': pass else: import util.hostlist as uh nodenames = uh.parse_nodelist(os.environ['SLURM_JOB_NODELIST']) gpu_ids = [int(node[3:]) for node in nodenames] fixid = int(os.environ.get('FIX_DISTRIBUTED_PORT_NUMBER', 0)) # fixid += random.randint(0, 300) port = str(3137 + int(min(gpu_ids)) + fixid) args.dist_url = "tcp://{ip}:{port}".format(ip=uh.nodename_to_ip(nodenames[0]), port=port) print('world size: {}, world rank: {}, local rank: {}, device_count: {}'.format(args.world_size, args.rank, args.local_rank, torch.cuda.device_count())) else: print('Not using distributed mode') args.distributed = False args.world_size = 1 args.rank = 0 args.local_rank = 0 return print("world_size:{} rank:{} local_rank:{}".format(args.world_size, args.rank, args.local_rank)) args.distributed = True torch.cuda.set_device(args.local_rank) args.dist_backend = 'nccl' print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True) torch.distributed.init_process_group( backend=args.dist_backend, world_size=args.world_size, rank=args.rank, init_method=args.dist_url, ) print("Before torch.distributed.barrier()") torch.distributed.barrier() print("End torch.distributed.barrier()") setup_for_distributed(args.rank == 0) @torch.no_grad() def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" if target.numel() == 0: return [torch.zeros([], device=output.device)] maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res @torch.no_grad() def accuracy_onehot(pred, gt): """_summary_ Args: pred (_type_): n, c gt (_type_): n, c """ tp = ((pred - gt).abs().sum(-1) < 1e-4).float().sum() acc = tp / gt.shape[0] * 100 return acc def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor """ Equivalent to nn.functional.interpolate, but with support for empty batch sizes. This will eventually be supported natively by PyTorch, and this class can go away. """ if __torchvision_need_compat_flag < 0.7: if input.numel() > 0: return torch.nn.functional.interpolate( input, size, scale_factor, mode, align_corners ) output_shape = _output_size(2, input, size, scale_factor) output_shape = list(input.shape[:-2]) + list(output_shape) return _new_empty_tensor(input, output_shape) else: return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) class color_sys(): def __init__(self, num_colors) -> None: self.num_colors = num_colors colors=[] for i in np.arange(0., 360., 360. / num_colors): hue = i/360. lightness = (50 + np.random.rand() * 10)/100. saturation = (90 + np.random.rand() * 10)/100. colors.append(tuple([int(j*255) for j in colorsys.hls_to_rgb(hue, lightness, saturation)])) self.colors = colors def __call__(self, idx): return self.colors[idx] def inverse_sigmoid(x, eps=1e-3): x = x.clamp(min=0, max=1) x1 = x.clamp(min=eps) x2 = (1 - x).clamp(min=eps) return torch.log(x1/x2) def clean_state_dict(state_dict): new_state_dict = OrderedDict() for k, v in state_dict.items(): if k[:7] == 'module.': k = k[7:] # remove `module.` new_state_dict[k] = v return new_state_dict ================================================ FILE: src/utils/dependencies/insightface/__init__.py ================================================ # coding: utf-8 # pylint: disable=wrong-import-position """InsightFace: A Face Analysis Toolkit.""" from __future__ import absolute_import try: #import mxnet as mx import onnxruntime except ImportError: raise ImportError( "Unable to import dependency onnxruntime. " ) __version__ = '0.7.3' from . import model_zoo from . import utils from . import app from . import data ================================================ FILE: src/utils/dependencies/insightface/app/__init__.py ================================================ from .face_analysis import * ================================================ FILE: src/utils/dependencies/insightface/app/common.py ================================================ import numpy as np from numpy.linalg import norm as l2norm #from easydict import EasyDict class Face(dict): def __init__(self, d=None, **kwargs): if d is None: d = {} if kwargs: d.update(**kwargs) for k, v in d.items(): setattr(self, k, v) # Class attributes #for k in self.__class__.__dict__.keys(): # if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'): # setattr(self, k, getattr(self, k)) def __setattr__(self, name, value): if isinstance(value, (list, tuple)): value = [self.__class__(x) if isinstance(x, dict) else x for x in value] elif isinstance(value, dict) and not isinstance(value, self.__class__): value = self.__class__(value) super(Face, self).__setattr__(name, value) super(Face, self).__setitem__(name, value) __setitem__ = __setattr__ def __getattr__(self, name): return None @property def embedding_norm(self): if self.embedding is None: return None return l2norm(self.embedding) @property def normed_embedding(self): if self.embedding is None: return None return self.embedding / self.embedding_norm @property def sex(self): if self.gender is None: return None return 'M' if self.gender==1 else 'F' ================================================ FILE: src/utils/dependencies/insightface/app/face_analysis.py ================================================ # -*- coding: utf-8 -*- # @Organization : insightface.ai # @Author : Jia Guo # @Time : 2021-05-04 # @Function : from __future__ import division import glob import os.path as osp import numpy as np import onnxruntime from numpy.linalg import norm from ..model_zoo import model_zoo from ..utils import ensure_available from .common import Face DEFAULT_MP_NAME = 'buffalo_l' __all__ = ['FaceAnalysis'] class FaceAnalysis: def __init__(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=None, **kwargs): onnxruntime.set_default_logger_severity(3) self.models = {} self.model_dir = ensure_available('models', name, root=root) onnx_files = glob.glob(osp.join(self.model_dir, '*.onnx')) onnx_files = sorted(onnx_files) for onnx_file in onnx_files: model = model_zoo.get_model(onnx_file, **kwargs) if model is None: print('model not recognized:', onnx_file) elif allowed_modules is not None and model.taskname not in allowed_modules: print('model ignore:', onnx_file, model.taskname) del model elif model.taskname not in self.models and (allowed_modules is None or model.taskname in allowed_modules): # print('find model:', onnx_file, model.taskname, model.input_shape, model.input_mean, model.input_std) self.models[model.taskname] = model else: print('duplicated model task type, ignore:', onnx_file, model.taskname) del model assert 'detection' in self.models self.det_model = self.models['detection'] def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)): self.det_thresh = det_thresh assert det_size is not None # print('set det-size:', det_size) self.det_size = det_size for taskname, model in self.models.items(): if taskname=='detection': model.prepare(ctx_id, input_size=det_size, det_thresh=det_thresh) else: model.prepare(ctx_id) def get(self, img, max_num=0): bboxes, kpss = self.det_model.detect(img, max_num=max_num, metric='default') if bboxes.shape[0] == 0: return [] ret = [] for i in range(bboxes.shape[0]): bbox = bboxes[i, 0:4] det_score = bboxes[i, 4] kps = None if kpss is not None: kps = kpss[i] face = Face(bbox=bbox, kps=kps, det_score=det_score) for taskname, model in self.models.items(): if taskname=='detection': continue model.get(img, face) ret.append(face) return ret def draw_on(self, img, faces): import cv2 dimg = img.copy() for i in range(len(faces)): face = faces[i] box = face.bbox.astype(np.int) color = (0, 0, 255) cv2.rectangle(dimg, (box[0], box[1]), (box[2], box[3]), color, 2) if face.kps is not None: kps = face.kps.astype(np.int) #print(landmark.shape) for l in range(kps.shape[0]): color = (0, 0, 255) if l == 0 or l == 3: color = (0, 255, 0) cv2.circle(dimg, (kps[l][0], kps[l][1]), 1, color, 2) if face.gender is not None and face.age is not None: cv2.putText(dimg,'%s,%d'%(face.sex,face.age), (box[0]-1, box[1]-4),cv2.FONT_HERSHEY_COMPLEX,0.7,(0,255,0),1) #for key, value in face.items(): # if key.startswith('landmark_3d'): # print(key, value.shape) # print(value[0:10,:]) # lmk = np.round(value).astype(np.int) # for l in range(lmk.shape[0]): # color = (255, 0, 0) # cv2.circle(dimg, (lmk[l][0], lmk[l][1]), 1, color, # 2) return dimg ================================================ FILE: src/utils/dependencies/insightface/data/__init__.py ================================================ from .image import get_image from .pickle_object import get_object ================================================ FILE: src/utils/dependencies/insightface/data/image.py ================================================ import cv2 import os import os.path as osp from pathlib import Path class ImageCache: data = {} def get_image(name, to_rgb=False): key = (name, to_rgb) if key in ImageCache.data: return ImageCache.data[key] images_dir = osp.join(Path(__file__).parent.absolute(), 'images') ext_names = ['.jpg', '.png', '.jpeg'] image_file = None for ext_name in ext_names: _image_file = osp.join(images_dir, "%s%s"%(name, ext_name)) if osp.exists(_image_file): image_file = _image_file break assert image_file is not None, '%s not found'%name img = cv2.imread(image_file) if to_rgb: img = img[:,:,::-1] ImageCache.data[key] = img return img ================================================ FILE: src/utils/dependencies/insightface/data/pickle_object.py ================================================ import cv2 import os import os.path as osp from pathlib import Path import pickle def get_object(name): objects_dir = osp.join(Path(__file__).parent.absolute(), 'objects') if not name.endswith('.pkl'): name = name+".pkl" filepath = osp.join(objects_dir, name) if not osp.exists(filepath): return None with open(filepath, 'rb') as f: obj = pickle.load(f) return obj ================================================ FILE: src/utils/dependencies/insightface/data/rec_builder.py ================================================ import pickle import numpy as np import os import os.path as osp import sys import mxnet as mx class RecBuilder(): def __init__(self, path, image_size=(112, 112)): self.path = path self.image_size = image_size self.widx = 0 self.wlabel = 0 self.max_label = -1 assert not osp.exists(path), '%s exists' % path os.makedirs(path) self.writer = mx.recordio.MXIndexedRecordIO(os.path.join(path, 'train.idx'), os.path.join(path, 'train.rec'), 'w') self.meta = [] def add(self, imgs): #!!! img should be BGR!!!! #assert label >= 0 #assert label > self.last_label assert len(imgs) > 0 label = self.wlabel for img in imgs: idx = self.widx image_meta = {'image_index': idx, 'image_classes': [label]} header = mx.recordio.IRHeader(0, label, idx, 0) if isinstance(img, np.ndarray): s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg') else: s = mx.recordio.pack(header, img) self.writer.write_idx(idx, s) self.meta.append(image_meta) self.widx += 1 self.max_label = label self.wlabel += 1 def add_image(self, img, label): #!!! img should be BGR!!!! #assert label >= 0 #assert label > self.last_label idx = self.widx header = mx.recordio.IRHeader(0, label, idx, 0) if isinstance(label, list): idlabel = label[0] else: idlabel = label image_meta = {'image_index': idx, 'image_classes': [idlabel]} if isinstance(img, np.ndarray): s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg') else: s = mx.recordio.pack(header, img) self.writer.write_idx(idx, s) self.meta.append(image_meta) self.widx += 1 self.max_label = max(self.max_label, idlabel) def close(self): with open(osp.join(self.path, 'train.meta'), 'wb') as pfile: pickle.dump(self.meta, pfile, protocol=pickle.HIGHEST_PROTOCOL) print('stat:', self.widx, self.wlabel) with open(os.path.join(self.path, 'property'), 'w') as f: f.write("%d,%d,%d\n" % (self.max_label+1, self.image_size[0], self.image_size[1])) f.write("%d\n" % (self.widx)) ================================================ FILE: src/utils/dependencies/insightface/model_zoo/__init__.py ================================================ from .model_zoo import get_model from .arcface_onnx import ArcFaceONNX from .retinaface import RetinaFace from .scrfd import SCRFD from .landmark import Landmark from .attribute import Attribute ================================================ FILE: src/utils/dependencies/insightface/model_zoo/arcface_onnx.py ================================================ # -*- coding: utf-8 -*- # @Organization : insightface.ai # @Author : Jia Guo # @Time : 2021-05-04 # @Function : from __future__ import division import numpy as np import cv2 import onnx import onnxruntime from ..utils import face_align __all__ = [ 'ArcFaceONNX', ] class ArcFaceONNX: def __init__(self, model_file=None, session=None): assert model_file is not None self.model_file = model_file self.session = session self.taskname = 'recognition' find_sub = False find_mul = False model = onnx.load(self.model_file) graph = model.graph for nid, node in enumerate(graph.node[:8]): #print(nid, node.name) if node.name.startswith('Sub') or node.name.startswith('_minus'): find_sub = True if node.name.startswith('Mul') or node.name.startswith('_mul'): find_mul = True if find_sub and find_mul: #mxnet arcface model input_mean = 0.0 input_std = 1.0 else: input_mean = 127.5 input_std = 127.5 self.input_mean = input_mean self.input_std = input_std #print('input mean and std:', self.input_mean, self.input_std) if self.session is None: self.session = onnxruntime.InferenceSession(self.model_file, None) input_cfg = self.session.get_inputs()[0] input_shape = input_cfg.shape input_name = input_cfg.name self.input_size = tuple(input_shape[2:4][::-1]) self.input_shape = input_shape outputs = self.session.get_outputs() output_names = [] for out in outputs: output_names.append(out.name) self.input_name = input_name self.output_names = output_names assert len(self.output_names)==1 self.output_shape = outputs[0].shape def prepare(self, ctx_id, **kwargs): if ctx_id<0: self.session.set_providers(['CPUExecutionProvider']) def get(self, img, face): aimg = face_align.norm_crop(img, landmark=face.kps, image_size=self.input_size[0]) face.embedding = self.get_feat(aimg).flatten() return face.embedding def compute_sim(self, feat1, feat2): from numpy.linalg import norm feat1 = feat1.ravel() feat2 = feat2.ravel() sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2)) return sim def get_feat(self, imgs): if not isinstance(imgs, list): imgs = [imgs] input_size = self.input_size blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) net_out = self.session.run(self.output_names, {self.input_name: blob})[0] return net_out def forward(self, batch_data): blob = (batch_data - self.input_mean) / self.input_std net_out = self.session.run(self.output_names, {self.input_name: blob})[0] return net_out ================================================ FILE: src/utils/dependencies/insightface/model_zoo/attribute.py ================================================ # -*- coding: utf-8 -*- # @Organization : insightface.ai # @Author : Jia Guo # @Time : 2021-06-19 # @Function : from __future__ import division import numpy as np import cv2 import onnx import onnxruntime from ..utils import face_align __all__ = [ 'Attribute', ] class Attribute: def __init__(self, model_file=None, session=None): assert model_file is not None self.model_file = model_file self.session = session find_sub = False find_mul = False model = onnx.load(self.model_file) graph = model.graph for nid, node in enumerate(graph.node[:8]): #print(nid, node.name) if node.name.startswith('Sub') or node.name.startswith('_minus'): find_sub = True if node.name.startswith('Mul') or node.name.startswith('_mul'): find_mul = True if nid<3 and node.name=='bn_data': find_sub = True find_mul = True if find_sub and find_mul: #mxnet arcface model input_mean = 0.0 input_std = 1.0 else: input_mean = 127.5 input_std = 128.0 self.input_mean = input_mean self.input_std = input_std #print('input mean and std:', model_file, self.input_mean, self.input_std) if self.session is None: self.session = onnxruntime.InferenceSession(self.model_file, None) input_cfg = self.session.get_inputs()[0] input_shape = input_cfg.shape input_name = input_cfg.name self.input_size = tuple(input_shape[2:4][::-1]) self.input_shape = input_shape outputs = self.session.get_outputs() output_names = [] for out in outputs: output_names.append(out.name) self.input_name = input_name self.output_names = output_names assert len(self.output_names)==1 output_shape = outputs[0].shape #print('init output_shape:', output_shape) if output_shape[1]==3: self.taskname = 'genderage' else: self.taskname = 'attribute_%d'%output_shape[1] def prepare(self, ctx_id, **kwargs): if ctx_id<0: self.session.set_providers(['CPUExecutionProvider']) def get(self, img, face): bbox = face.bbox w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1]) center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2 rotate = 0 _scale = self.input_size[0] / (max(w, h)*1.5) #print('param:', img.shape, bbox, center, self.input_size, _scale, rotate) aimg, M = face_align.transform(img, center, self.input_size[0], _scale, rotate) input_size = tuple(aimg.shape[0:2][::-1]) #assert input_size==self.input_size blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) pred = self.session.run(self.output_names, {self.input_name : blob})[0][0] if self.taskname=='genderage': assert len(pred)==3 gender = np.argmax(pred[:2]) age = int(np.round(pred[2]*100)) face['gender'] = gender face['age'] = age return gender, age else: return pred ================================================ FILE: src/utils/dependencies/insightface/model_zoo/inswapper.py ================================================ import time import numpy as np import onnxruntime import cv2 import onnx from onnx import numpy_helper from ..utils import face_align class INSwapper(): def __init__(self, model_file=None, session=None): self.model_file = model_file self.session = session model = onnx.load(self.model_file) graph = model.graph self.emap = numpy_helper.to_array(graph.initializer[-1]) self.input_mean = 0.0 self.input_std = 255.0 #print('input mean and std:', model_file, self.input_mean, self.input_std) if self.session is None: self.session = onnxruntime.InferenceSession(self.model_file, None) inputs = self.session.get_inputs() self.input_names = [] for inp in inputs: self.input_names.append(inp.name) outputs = self.session.get_outputs() output_names = [] for out in outputs: output_names.append(out.name) self.output_names = output_names assert len(self.output_names)==1 output_shape = outputs[0].shape input_cfg = inputs[0] input_shape = input_cfg.shape self.input_shape = input_shape # print('inswapper-shape:', self.input_shape) self.input_size = tuple(input_shape[2:4][::-1]) def forward(self, img, latent): img = (img - self.input_mean) / self.input_std pred = self.session.run(self.output_names, {self.input_names[0]: img, self.input_names[1]: latent})[0] return pred def get(self, img, target_face, source_face, paste_back=True): face_mask = np.zeros((img.shape[0], img.shape[1]), np.uint8) cv2.fillPoly(face_mask, np.array([target_face.landmark_2d_106[[1,9,10,11,12,13,14,15,16,2,3,4,5,6,7,8,0,24,23,22,21,20,19,18,32,31,30,29,28,27,26,25,17,101,105,104,103,51,49,48,43]].astype('int64')]), 1) aimg, M = face_align.norm_crop2(img, target_face.kps, self.input_size[0]) blob = cv2.dnn.blobFromImage(aimg, 1.0 / self.input_std, self.input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) latent = source_face.normed_embedding.reshape((1,-1)) latent = np.dot(latent, self.emap) latent /= np.linalg.norm(latent) pred = self.session.run(self.output_names, {self.input_names[0]: blob, self.input_names[1]: latent})[0] #print(latent.shape, latent.dtype, pred.shape) img_fake = pred.transpose((0,2,3,1))[0] bgr_fake = np.clip(255 * img_fake, 0, 255).astype(np.uint8)[:,:,::-1] if not paste_back: return bgr_fake, M else: target_img = img fake_diff = bgr_fake.astype(np.float32) - aimg.astype(np.float32) fake_diff = np.abs(fake_diff).mean(axis=2) fake_diff[:2,:] = 0 fake_diff[-2:,:] = 0 fake_diff[:,:2] = 0 fake_diff[:,-2:] = 0 IM = cv2.invertAffineTransform(M) img_white = np.full((aimg.shape[0],aimg.shape[1]), 255, dtype=np.float32) bgr_fake = cv2.warpAffine(bgr_fake, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) img_white = cv2.warpAffine(img_white, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) fake_diff = cv2.warpAffine(fake_diff, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) img_white[img_white>20] = 255 fthresh = 10 fake_diff[fake_diff=fthresh] = 255 img_mask = img_white mask_h_inds, mask_w_inds = np.where(img_mask==255) mask_h = np.max(mask_h_inds) - np.min(mask_h_inds) mask_w = np.max(mask_w_inds) - np.min(mask_w_inds) mask_size = int(np.sqrt(mask_h*mask_w)) k = max(mask_size//10, 10) #k = max(mask_size//20, 6) #k = 6 kernel = np.ones((k,k),np.uint8) img_mask = cv2.erode(img_mask,kernel,iterations = 1) kernel = np.ones((2,2),np.uint8) fake_diff = cv2.dilate(fake_diff,kernel,iterations = 1) face_mask = cv2.erode(face_mask,np.ones((11,11),np.uint8),iterations = 1) fake_diff[face_mask==1] = 255 k = max(mask_size//20, 5) #k = 3 #k = 3 kernel_size = (k, k) blur_size = tuple(2*i+1 for i in kernel_size) img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) k = 5 kernel_size = (k, k) blur_size = tuple(2*i+1 for i in kernel_size) fake_diff = cv2.blur(fake_diff, (11,11), 0) ##fake_diff = cv2.GaussianBlur(fake_diff, blur_size, 0) # print('blur_size: ', blur_size) # fake_diff = cv2.blur(fake_diff, (21, 21), 0) # blur_size img_mask /= 255 fake_diff /= 255 # img_mask = fake_diff img_mask = img_mask*fake_diff img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) fake_merged = img_mask * bgr_fake + (1-img_mask) * target_img.astype(np.float32) fake_merged = fake_merged.astype(np.uint8) return fake_merged ================================================ FILE: src/utils/dependencies/insightface/model_zoo/landmark.py ================================================ # -*- coding: utf-8 -*- # @Organization : insightface.ai # @Author : Jia Guo # @Time : 2021-05-04 # @Function : from __future__ import division import numpy as np import cv2 import onnx import onnxruntime from ..utils import face_align from ..utils import transform from ..data import get_object __all__ = [ 'Landmark', ] class Landmark: def __init__(self, model_file=None, session=None): assert model_file is not None self.model_file = model_file self.session = session find_sub = False find_mul = False model = onnx.load(self.model_file) graph = model.graph for nid, node in enumerate(graph.node[:8]): #print(nid, node.name) if node.name.startswith('Sub') or node.name.startswith('_minus'): find_sub = True if node.name.startswith('Mul') or node.name.startswith('_mul'): find_mul = True if nid<3 and node.name=='bn_data': find_sub = True find_mul = True if find_sub and find_mul: #mxnet arcface model input_mean = 0.0 input_std = 1.0 else: input_mean = 127.5 input_std = 128.0 self.input_mean = input_mean self.input_std = input_std #print('input mean and std:', model_file, self.input_mean, self.input_std) if self.session is None: self.session = onnxruntime.InferenceSession(self.model_file, None) input_cfg = self.session.get_inputs()[0] input_shape = input_cfg.shape input_name = input_cfg.name self.input_size = tuple(input_shape[2:4][::-1]) self.input_shape = input_shape outputs = self.session.get_outputs() output_names = [] for out in outputs: output_names.append(out.name) self.input_name = input_name self.output_names = output_names assert len(self.output_names)==1 output_shape = outputs[0].shape self.require_pose = False #print('init output_shape:', output_shape) if output_shape[1]==3309: self.lmk_dim = 3 self.lmk_num = 68 self.mean_lmk = get_object('meanshape_68.pkl') self.require_pose = True else: self.lmk_dim = 2 self.lmk_num = output_shape[1]//self.lmk_dim self.taskname = 'landmark_%dd_%d'%(self.lmk_dim, self.lmk_num) def prepare(self, ctx_id, **kwargs): if ctx_id<0: self.session.set_providers(['CPUExecutionProvider']) def get(self, img, face): bbox = face.bbox w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1]) center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2 rotate = 0 _scale = self.input_size[0] / (max(w, h)*1.5) #print('param:', img.shape, bbox, center, self.input_size, _scale, rotate) aimg, M = face_align.transform(img, center, self.input_size[0], _scale, rotate) input_size = tuple(aimg.shape[0:2][::-1]) #assert input_size==self.input_size blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) pred = self.session.run(self.output_names, {self.input_name : blob})[0][0] if pred.shape[0] >= 3000: pred = pred.reshape((-1, 3)) else: pred = pred.reshape((-1, 2)) if self.lmk_num < pred.shape[0]: pred = pred[self.lmk_num*-1:,:] pred[:, 0:2] += 1 pred[:, 0:2] *= (self.input_size[0] // 2) if pred.shape[1] == 3: pred[:, 2] *= (self.input_size[0] // 2) IM = cv2.invertAffineTransform(M) pred = face_align.trans_points(pred, IM) face[self.taskname] = pred if self.require_pose: P = transform.estimate_affine_matrix_3d23d(self.mean_lmk, pred) s, R, t = transform.P2sRt(P) rx, ry, rz = transform.matrix2angle(R) pose = np.array( [rx, ry, rz], dtype=np.float32 ) face['pose'] = pose #pitch, yaw, roll return pred ================================================ FILE: src/utils/dependencies/insightface/model_zoo/model_store.py ================================================ """ This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/model_store.py """ from __future__ import print_function __all__ = ['get_model_file'] import os import zipfile import glob from ..utils import download, check_sha1 _model_sha1 = { name: checksum for checksum, name in [ ('95be21b58e29e9c1237f229dae534bd854009ce0', 'arcface_r100_v1'), ('', 'arcface_mfn_v1'), ('39fd1e087a2a2ed70a154ac01fecaa86c315d01b', 'retinaface_r50_v1'), ('2c9de8116d1f448fd1d4661f90308faae34c990a', 'retinaface_mnet025_v1'), ('0db1d07921d005e6c9a5b38e059452fc5645e5a4', 'retinaface_mnet025_v2'), ('7dd8111652b7aac2490c5dcddeb268e53ac643e6', 'genderage_v1'), ] } base_repo_url = 'https://insightface.ai/files/' _url_format = '{repo_url}models/{file_name}.zip' def short_hash(name): if name not in _model_sha1: raise ValueError( 'Pretrained model for {name} is not available.'.format(name=name)) return _model_sha1[name][:8] def find_params_file(dir_path): if not os.path.exists(dir_path): return None paths = glob.glob("%s/*.params" % dir_path) if len(paths) == 0: return None paths = sorted(paths) return paths[-1] def get_model_file(name, root=os.path.join('~', '.insightface', 'models')): r"""Return location for the pretrained on local file system. This function will download from online model zoo when model cannot be found or has mismatch. The root directory will be created if it doesn't exist. Parameters ---------- name : str Name of the model. root : str, default '~/.mxnet/models' Location for keeping the model parameters. Returns ------- file_path Path to the requested pretrained model file. """ file_name = name root = os.path.expanduser(root) dir_path = os.path.join(root, name) file_path = find_params_file(dir_path) #file_path = os.path.join(root, file_name + '.params') sha1_hash = _model_sha1[name] if file_path is not None: if check_sha1(file_path, sha1_hash): return file_path else: print( 'Mismatch in the content of model file detected. Downloading again.' ) else: print('Model file is not found. Downloading.') if not os.path.exists(root): os.makedirs(root) if not os.path.exists(dir_path): os.makedirs(dir_path) zip_file_path = os.path.join(root, file_name + '.zip') repo_url = base_repo_url if repo_url[-1] != '/': repo_url = repo_url + '/' download(_url_format.format(repo_url=repo_url, file_name=file_name), path=zip_file_path, overwrite=True) with zipfile.ZipFile(zip_file_path) as zf: zf.extractall(dir_path) os.remove(zip_file_path) file_path = find_params_file(dir_path) if check_sha1(file_path, sha1_hash): return file_path else: raise ValueError( 'Downloaded file has different hash. Please try again.') ================================================ FILE: src/utils/dependencies/insightface/model_zoo/model_zoo.py ================================================ # -*- coding: utf-8 -*- # @Organization : insightface.ai # @Author : Jia Guo # @Time : 2021-05-04 # @Function : import os import os.path as osp import glob import onnxruntime from .arcface_onnx import * from .retinaface import * #from .scrfd import * from .landmark import * from .attribute import Attribute from .inswapper import INSwapper from ..utils import download_onnx __all__ = ['get_model'] class PickableInferenceSession(onnxruntime.InferenceSession): # This is a wrapper to make the current InferenceSession class pickable. def __init__(self, model_path, **kwargs): super().__init__(model_path, **kwargs) self.model_path = model_path def __getstate__(self): return {'model_path': self.model_path} def __setstate__(self, values): model_path = values['model_path'] self.__init__(model_path) class ModelRouter: def __init__(self, onnx_file): self.onnx_file = onnx_file def get_model(self, **kwargs): session = PickableInferenceSession(self.onnx_file, **kwargs) # print(f'Applied providers: {session._providers}, with options: {session._provider_options}') inputs = session.get_inputs() input_cfg = inputs[0] input_shape = input_cfg.shape outputs = session.get_outputs() if len(outputs)>=5: return RetinaFace(model_file=self.onnx_file, session=session) elif input_shape[2]==192 and input_shape[3]==192: return Landmark(model_file=self.onnx_file, session=session) elif input_shape[2]==96 and input_shape[3]==96: return Attribute(model_file=self.onnx_file, session=session) elif len(inputs)==2 and input_shape[2]==128 and input_shape[3]==128: return INSwapper(model_file=self.onnx_file, session=session) elif input_shape[2]==input_shape[3] and input_shape[2]>=112 and input_shape[2]%16==0: return ArcFaceONNX(model_file=self.onnx_file, session=session) else: #raise RuntimeError('error on model routing') return None def find_onnx_file(dir_path): if not os.path.exists(dir_path): return None paths = glob.glob("%s/*.onnx" % dir_path) if len(paths) == 0: return None paths = sorted(paths) return paths[-1] def get_default_providers(): return ['CUDAExecutionProvider', 'CoreMLExecutionProvider', 'CPUExecutionProvider'] def get_default_provider_options(): return None def get_model(name, **kwargs): root = kwargs.get('root', '~/.insightface') root = os.path.expanduser(root) model_root = osp.join(root, 'models') allow_download = kwargs.get('download', False) download_zip = kwargs.get('download_zip', False) if not name.endswith('.onnx'): model_dir = os.path.join(model_root, name) model_file = find_onnx_file(model_dir) if model_file is None: return None else: model_file = name if not osp.exists(model_file) and allow_download: model_file = download_onnx('models', model_file, root=root, download_zip=download_zip) assert osp.exists(model_file), 'model_file %s should exist'%model_file assert osp.isfile(model_file), 'model_file %s should be a file'%model_file router = ModelRouter(model_file) providers = kwargs.get('providers', get_default_providers()) provider_options = kwargs.get('provider_options', get_default_provider_options()) model = router.get_model(providers=providers, provider_options=provider_options) return model ================================================ FILE: src/utils/dependencies/insightface/model_zoo/retinaface.py ================================================ # -*- coding: utf-8 -*- # @Organization : insightface.ai # @Author : Jia Guo # @Time : 2021-09-18 # @Function : from __future__ import division import datetime import numpy as np import onnx import onnxruntime import os import os.path as osp import cv2 import sys def softmax(z): assert len(z.shape) == 2 s = np.max(z, axis=1) s = s[:, np.newaxis] # necessary step to do broadcasting e_x = np.exp(z - s) div = np.sum(e_x, axis=1) div = div[:, np.newaxis] # dito return e_x / div def distance2bbox(points, distance, max_shape=None): """Decode distance prediction to bounding box. Args: points (Tensor): Shape (n, 2), [x, y]. distance (Tensor): Distance from the given point to 4 boundaries (left, top, right, bottom). max_shape (tuple): Shape of the image. Returns: Tensor: Decoded bboxes. """ x1 = points[:, 0] - distance[:, 0] y1 = points[:, 1] - distance[:, 1] x2 = points[:, 0] + distance[:, 2] y2 = points[:, 1] + distance[:, 3] if max_shape is not None: x1 = x1.clamp(min=0, max=max_shape[1]) y1 = y1.clamp(min=0, max=max_shape[0]) x2 = x2.clamp(min=0, max=max_shape[1]) y2 = y2.clamp(min=0, max=max_shape[0]) return np.stack([x1, y1, x2, y2], axis=-1) def distance2kps(points, distance, max_shape=None): """Decode distance prediction to bounding box. Args: points (Tensor): Shape (n, 2), [x, y]. distance (Tensor): Distance from the given point to 4 boundaries (left, top, right, bottom). max_shape (tuple): Shape of the image. Returns: Tensor: Decoded bboxes. """ preds = [] for i in range(0, distance.shape[1], 2): px = points[:, i%2] + distance[:, i] py = points[:, i%2+1] + distance[:, i+1] if max_shape is not None: px = px.clamp(min=0, max=max_shape[1]) py = py.clamp(min=0, max=max_shape[0]) preds.append(px) preds.append(py) return np.stack(preds, axis=-1) class RetinaFace: def __init__(self, model_file=None, session=None): import onnxruntime self.model_file = model_file self.session = session self.taskname = 'detection' if self.session is None: assert self.model_file is not None assert osp.exists(self.model_file) self.session = onnxruntime.InferenceSession(self.model_file, None) self.center_cache = {} self.nms_thresh = 0.4 self.det_thresh = 0.5 self._init_vars() def _init_vars(self): input_cfg = self.session.get_inputs()[0] input_shape = input_cfg.shape #print(input_shape) if isinstance(input_shape[2], str): self.input_size = None else: self.input_size = tuple(input_shape[2:4][::-1]) #print('image_size:', self.image_size) input_name = input_cfg.name self.input_shape = input_shape outputs = self.session.get_outputs() output_names = [] for o in outputs: output_names.append(o.name) self.input_name = input_name self.output_names = output_names self.input_mean = 127.5 self.input_std = 128.0 #print(self.output_names) #assert len(outputs)==10 or len(outputs)==15 self.use_kps = False self._anchor_ratio = 1.0 self._num_anchors = 1 if len(outputs)==6: self.fmc = 3 self._feat_stride_fpn = [8, 16, 32] self._num_anchors = 2 elif len(outputs)==9: self.fmc = 3 self._feat_stride_fpn = [8, 16, 32] self._num_anchors = 2 self.use_kps = True elif len(outputs)==10: self.fmc = 5 self._feat_stride_fpn = [8, 16, 32, 64, 128] self._num_anchors = 1 elif len(outputs)==15: self.fmc = 5 self._feat_stride_fpn = [8, 16, 32, 64, 128] self._num_anchors = 1 self.use_kps = True def prepare(self, ctx_id, **kwargs): if ctx_id<0: self.session.set_providers(['CPUExecutionProvider']) nms_thresh = kwargs.get('nms_thresh', None) if nms_thresh is not None: self.nms_thresh = nms_thresh det_thresh = kwargs.get('det_thresh', None) if det_thresh is not None: self.det_thresh = det_thresh input_size = kwargs.get('input_size', None) if input_size is not None: if self.input_size is not None: print('warning: det_size is already set in detection model, ignore') else: self.input_size = input_size def forward(self, img, threshold): scores_list = [] bboxes_list = [] kpss_list = [] input_size = tuple(img.shape[0:2][::-1]) blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) net_outs = self.session.run(self.output_names, {self.input_name : blob}) input_height = blob.shape[2] input_width = blob.shape[3] fmc = self.fmc for idx, stride in enumerate(self._feat_stride_fpn): scores = net_outs[idx] bbox_preds = net_outs[idx+fmc] bbox_preds = bbox_preds * stride if self.use_kps: kps_preds = net_outs[idx+fmc*2] * stride height = input_height // stride width = input_width // stride K = height * width key = (height, width, stride) if key in self.center_cache: anchor_centers = self.center_cache[key] else: #solution-1, c style: #anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 ) #for i in range(height): # anchor_centers[i, :, 1] = i #for i in range(width): # anchor_centers[:, i, 0] = i #solution-2: #ax = np.arange(width, dtype=np.float32) #ay = np.arange(height, dtype=np.float32) #xv, yv = np.meshgrid(np.arange(width), np.arange(height)) #anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32) #solution-3: anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32) #print(anchor_centers.shape) anchor_centers = (anchor_centers * stride).reshape( (-1, 2) ) if self._num_anchors>1: anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) ) if len(self.center_cache)<100: self.center_cache[key] = anchor_centers pos_inds = np.where(scores>=threshold)[0] bboxes = distance2bbox(anchor_centers, bbox_preds) pos_scores = scores[pos_inds] pos_bboxes = bboxes[pos_inds] scores_list.append(pos_scores) bboxes_list.append(pos_bboxes) if self.use_kps: kpss = distance2kps(anchor_centers, kps_preds) #kpss = kps_preds kpss = kpss.reshape( (kpss.shape[0], -1, 2) ) pos_kpss = kpss[pos_inds] kpss_list.append(pos_kpss) return scores_list, bboxes_list, kpss_list def detect(self, img, input_size = None, max_num=0, metric='default'): assert input_size is not None or self.input_size is not None input_size = self.input_size if input_size is None else input_size im_ratio = float(img.shape[0]) / img.shape[1] model_ratio = float(input_size[1]) / input_size[0] if im_ratio>model_ratio: new_height = input_size[1] new_width = int(new_height / im_ratio) else: new_width = input_size[0] new_height = int(new_width * im_ratio) det_scale = float(new_height) / img.shape[0] resized_img = cv2.resize(img, (new_width, new_height)) det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 ) det_img[:new_height, :new_width, :] = resized_img scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh) scores = np.vstack(scores_list) scores_ravel = scores.ravel() order = scores_ravel.argsort()[::-1] bboxes = np.vstack(bboxes_list) / det_scale if self.use_kps: kpss = np.vstack(kpss_list) / det_scale pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False) pre_det = pre_det[order, :] keep = self.nms(pre_det) det = pre_det[keep, :] if self.use_kps: kpss = kpss[order,:,:] kpss = kpss[keep,:,:] else: kpss = None if max_num > 0 and det.shape[0] > max_num: area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1]) img_center = img.shape[0] // 2, img.shape[1] // 2 offsets = np.vstack([ (det[:, 0] + det[:, 2]) / 2 - img_center[1], (det[:, 1] + det[:, 3]) / 2 - img_center[0] ]) offset_dist_squared = np.sum(np.power(offsets, 2.0), 0) if metric=='max': values = area else: values = area - offset_dist_squared * 2.0 # some extra weight on the centering bindex = np.argsort( values)[::-1] # some extra weight on the centering bindex = bindex[0:max_num] det = det[bindex, :] if kpss is not None: kpss = kpss[bindex, :] return det, kpss def nms(self, dets): thresh = self.nms_thresh x1 = dets[:, 0] y1 = dets[:, 1] x2 = dets[:, 2] y2 = dets[:, 3] scores = dets[:, 4] areas = (x2 - x1 + 1) * (y2 - y1 + 1) order = scores.argsort()[::-1] keep = [] while order.size > 0: i = order[0] keep.append(i) xx1 = np.maximum(x1[i], x1[order[1:]]) yy1 = np.maximum(y1[i], y1[order[1:]]) xx2 = np.minimum(x2[i], x2[order[1:]]) yy2 = np.minimum(y2[i], y2[order[1:]]) w = np.maximum(0.0, xx2 - xx1 + 1) h = np.maximum(0.0, yy2 - yy1 + 1) inter = w * h ovr = inter / (areas[i] + areas[order[1:]] - inter) inds = np.where(ovr <= thresh)[0] order = order[inds + 1] return keep def get_retinaface(name, download=False, root='~/.insightface/models', **kwargs): if not download: assert os.path.exists(name) return RetinaFace(name) else: from .model_store import get_model_file _file = get_model_file("retinaface_%s" % name, root=root) return retinaface(_file) ================================================ FILE: src/utils/dependencies/insightface/model_zoo/scrfd.py ================================================ # -*- coding: utf-8 -*- # @Organization : insightface.ai # @Author : Jia Guo # @Time : 2021-05-04 # @Function : from __future__ import division import datetime import numpy as np import onnx import onnxruntime import os import os.path as osp import cv2 import sys def softmax(z): assert len(z.shape) == 2 s = np.max(z, axis=1) s = s[:, np.newaxis] # necessary step to do broadcasting e_x = np.exp(z - s) div = np.sum(e_x, axis=1) div = div[:, np.newaxis] # dito return e_x / div def distance2bbox(points, distance, max_shape=None): """Decode distance prediction to bounding box. Args: points (Tensor): Shape (n, 2), [x, y]. distance (Tensor): Distance from the given point to 4 boundaries (left, top, right, bottom). max_shape (tuple): Shape of the image. Returns: Tensor: Decoded bboxes. """ x1 = points[:, 0] - distance[:, 0] y1 = points[:, 1] - distance[:, 1] x2 = points[:, 0] + distance[:, 2] y2 = points[:, 1] + distance[:, 3] if max_shape is not None: x1 = x1.clamp(min=0, max=max_shape[1]) y1 = y1.clamp(min=0, max=max_shape[0]) x2 = x2.clamp(min=0, max=max_shape[1]) y2 = y2.clamp(min=0, max=max_shape[0]) return np.stack([x1, y1, x2, y2], axis=-1) def distance2kps(points, distance, max_shape=None): """Decode distance prediction to bounding box. Args: points (Tensor): Shape (n, 2), [x, y]. distance (Tensor): Distance from the given point to 4 boundaries (left, top, right, bottom). max_shape (tuple): Shape of the image. Returns: Tensor: Decoded bboxes. """ preds = [] for i in range(0, distance.shape[1], 2): px = points[:, i%2] + distance[:, i] py = points[:, i%2+1] + distance[:, i+1] if max_shape is not None: px = px.clamp(min=0, max=max_shape[1]) py = py.clamp(min=0, max=max_shape[0]) preds.append(px) preds.append(py) return np.stack(preds, axis=-1) class SCRFD: def __init__(self, model_file=None, session=None): import onnxruntime self.model_file = model_file self.session = session self.taskname = 'detection' self.batched = False if self.session is None: assert self.model_file is not None assert osp.exists(self.model_file) self.session = onnxruntime.InferenceSession(self.model_file, None) self.center_cache = {} self.nms_thresh = 0.4 self.det_thresh = 0.5 self._init_vars() def _init_vars(self): input_cfg = self.session.get_inputs()[0] input_shape = input_cfg.shape #print(input_shape) if isinstance(input_shape[2], str): self.input_size = None else: self.input_size = tuple(input_shape[2:4][::-1]) #print('image_size:', self.image_size) input_name = input_cfg.name self.input_shape = input_shape outputs = self.session.get_outputs() if len(outputs[0].shape) == 3: self.batched = True output_names = [] for o in outputs: output_names.append(o.name) self.input_name = input_name self.output_names = output_names self.input_mean = 127.5 self.input_std = 128.0 #print(self.output_names) #assert len(outputs)==10 or len(outputs)==15 self.use_kps = False self._anchor_ratio = 1.0 self._num_anchors = 1 if len(outputs)==6: self.fmc = 3 self._feat_stride_fpn = [8, 16, 32] self._num_anchors = 2 elif len(outputs)==9: self.fmc = 3 self._feat_stride_fpn = [8, 16, 32] self._num_anchors = 2 self.use_kps = True elif len(outputs)==10: self.fmc = 5 self._feat_stride_fpn = [8, 16, 32, 64, 128] self._num_anchors = 1 elif len(outputs)==15: self.fmc = 5 self._feat_stride_fpn = [8, 16, 32, 64, 128] self._num_anchors = 1 self.use_kps = True def prepare(self, ctx_id, **kwargs): if ctx_id<0: self.session.set_providers(['CPUExecutionProvider']) nms_thresh = kwargs.get('nms_thresh', None) if nms_thresh is not None: self.nms_thresh = nms_thresh det_thresh = kwargs.get('det_thresh', None) if det_thresh is not None: self.det_thresh = det_thresh input_size = kwargs.get('input_size', None) if input_size is not None: if self.input_size is not None: print('warning: det_size is already set in scrfd model, ignore') else: self.input_size = input_size def forward(self, img, threshold): scores_list = [] bboxes_list = [] kpss_list = [] input_size = tuple(img.shape[0:2][::-1]) blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) net_outs = self.session.run(self.output_names, {self.input_name : blob}) input_height = blob.shape[2] input_width = blob.shape[3] fmc = self.fmc for idx, stride in enumerate(self._feat_stride_fpn): # If model support batch dim, take first output if self.batched: scores = net_outs[idx][0] bbox_preds = net_outs[idx + fmc][0] bbox_preds = bbox_preds * stride if self.use_kps: kps_preds = net_outs[idx + fmc * 2][0] * stride # If model doesn't support batching take output as is else: scores = net_outs[idx] bbox_preds = net_outs[idx + fmc] bbox_preds = bbox_preds * stride if self.use_kps: kps_preds = net_outs[idx + fmc * 2] * stride height = input_height // stride width = input_width // stride K = height * width key = (height, width, stride) if key in self.center_cache: anchor_centers = self.center_cache[key] else: #solution-1, c style: #anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 ) #for i in range(height): # anchor_centers[i, :, 1] = i #for i in range(width): # anchor_centers[:, i, 0] = i #solution-2: #ax = np.arange(width, dtype=np.float32) #ay = np.arange(height, dtype=np.float32) #xv, yv = np.meshgrid(np.arange(width), np.arange(height)) #anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32) #solution-3: anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32) #print(anchor_centers.shape) anchor_centers = (anchor_centers * stride).reshape( (-1, 2) ) if self._num_anchors>1: anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) ) if len(self.center_cache)<100: self.center_cache[key] = anchor_centers pos_inds = np.where(scores>=threshold)[0] bboxes = distance2bbox(anchor_centers, bbox_preds) pos_scores = scores[pos_inds] pos_bboxes = bboxes[pos_inds] scores_list.append(pos_scores) bboxes_list.append(pos_bboxes) if self.use_kps: kpss = distance2kps(anchor_centers, kps_preds) #kpss = kps_preds kpss = kpss.reshape( (kpss.shape[0], -1, 2) ) pos_kpss = kpss[pos_inds] kpss_list.append(pos_kpss) return scores_list, bboxes_list, kpss_list def detect(self, img, input_size = None, max_num=0, metric='default'): assert input_size is not None or self.input_size is not None input_size = self.input_size if input_size is None else input_size im_ratio = float(img.shape[0]) / img.shape[1] model_ratio = float(input_size[1]) / input_size[0] if im_ratio>model_ratio: new_height = input_size[1] new_width = int(new_height / im_ratio) else: new_width = input_size[0] new_height = int(new_width * im_ratio) det_scale = float(new_height) / img.shape[0] resized_img = cv2.resize(img, (new_width, new_height)) det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 ) det_img[:new_height, :new_width, :] = resized_img scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh) scores = np.vstack(scores_list) scores_ravel = scores.ravel() order = scores_ravel.argsort()[::-1] bboxes = np.vstack(bboxes_list) / det_scale if self.use_kps: kpss = np.vstack(kpss_list) / det_scale pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False) pre_det = pre_det[order, :] keep = self.nms(pre_det) det = pre_det[keep, :] if self.use_kps: kpss = kpss[order,:,:] kpss = kpss[keep,:,:] else: kpss = None if max_num > 0 and det.shape[0] > max_num: area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1]) img_center = img.shape[0] // 2, img.shape[1] // 2 offsets = np.vstack([ (det[:, 0] + det[:, 2]) / 2 - img_center[1], (det[:, 1] + det[:, 3]) / 2 - img_center[0] ]) offset_dist_squared = np.sum(np.power(offsets, 2.0), 0) if metric=='max': values = area else: values = area - offset_dist_squared * 2.0 # some extra weight on the centering bindex = np.argsort( values)[::-1] # some extra weight on the centering bindex = bindex[0:max_num] det = det[bindex, :] if kpss is not None: kpss = kpss[bindex, :] return det, kpss def nms(self, dets): thresh = self.nms_thresh x1 = dets[:, 0] y1 = dets[:, 1] x2 = dets[:, 2] y2 = dets[:, 3] scores = dets[:, 4] areas = (x2 - x1 + 1) * (y2 - y1 + 1) order = scores.argsort()[::-1] keep = [] while order.size > 0: i = order[0] keep.append(i) xx1 = np.maximum(x1[i], x1[order[1:]]) yy1 = np.maximum(y1[i], y1[order[1:]]) xx2 = np.minimum(x2[i], x2[order[1:]]) yy2 = np.minimum(y2[i], y2[order[1:]]) w = np.maximum(0.0, xx2 - xx1 + 1) h = np.maximum(0.0, yy2 - yy1 + 1) inter = w * h ovr = inter / (areas[i] + areas[order[1:]] - inter) inds = np.where(ovr <= thresh)[0] order = order[inds + 1] return keep def get_scrfd(name, download=False, root='~/.insightface/models', **kwargs): if not download: assert os.path.exists(name) return SCRFD(name) else: from .model_store import get_model_file _file = get_model_file("scrfd_%s" % name, root=root) return SCRFD(_file) def scrfd_2p5gkps(**kwargs): return get_scrfd("2p5gkps", download=True, **kwargs) if __name__ == '__main__': import glob detector = SCRFD(model_file='./det.onnx') detector.prepare(-1) img_paths = ['tests/data/t1.jpg'] for img_path in img_paths: img = cv2.imread(img_path) for _ in range(1): ta = datetime.datetime.now() #bboxes, kpss = detector.detect(img, 0.5, input_size = (640, 640)) bboxes, kpss = detector.detect(img, 0.5) tb = datetime.datetime.now() print('all cost:', (tb-ta).total_seconds()*1000) print(img_path, bboxes.shape) if kpss is not None: print(kpss.shape) for i in range(bboxes.shape[0]): bbox = bboxes[i] x1,y1,x2,y2,score = bbox.astype(np.int) cv2.rectangle(img, (x1,y1) , (x2,y2) , (255,0,0) , 2) if kpss is not None: kps = kpss[i] for kp in kps: kp = kp.astype(np.int) cv2.circle(img, tuple(kp) , 1, (0,0,255) , 2) filename = img_path.split('/')[-1] print('output:', filename) cv2.imwrite('./outputs/%s'%filename, img) ================================================ FILE: src/utils/dependencies/insightface/utils/__init__.py ================================================ from __future__ import absolute_import from .storage import download, ensure_available, download_onnx from .filesystem import get_model_dir from .filesystem import makedirs, try_import_dali from .constant import * ================================================ FILE: src/utils/dependencies/insightface/utils/constant.py ================================================ DEFAULT_MP_NAME = 'buffalo_l' ================================================ FILE: src/utils/dependencies/insightface/utils/download.py ================================================ """ This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/utils/download.py """ import os import hashlib import requests from tqdm import tqdm def check_sha1(filename, sha1_hash): """Check whether the sha1 hash of the file content matches the expected hash. Parameters ---------- filename : str Path to the file. sha1_hash : str Expected sha1 hash in hexadecimal digits. Returns ------- bool Whether the file content matches the expected hash. """ sha1 = hashlib.sha1() with open(filename, 'rb') as f: while True: data = f.read(1048576) if not data: break sha1.update(data) sha1_file = sha1.hexdigest() l = min(len(sha1_file), len(sha1_hash)) return sha1.hexdigest()[0:l] == sha1_hash[0:l] def download_file(url, path=None, overwrite=False, sha1_hash=None): """Download an given URL Parameters ---------- url : str URL to download path : str, optional Destination path to store downloaded file. By default stores to the current directory with same name as in url. overwrite : bool, optional Whether to overwrite destination file if already exists. sha1_hash : str, optional Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified but doesn't match. Returns ------- str The file path of the downloaded file. """ if path is None: fname = url.split('/')[-1] else: path = os.path.expanduser(path) if os.path.isdir(path): fname = os.path.join(path, url.split('/')[-1]) else: fname = path if overwrite or not os.path.exists(fname) or ( sha1_hash and not check_sha1(fname, sha1_hash)): dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) if not os.path.exists(dirname): os.makedirs(dirname) print('Downloading %s from %s...' % (fname, url)) r = requests.get(url, stream=True) if r.status_code != 200: raise RuntimeError("Failed downloading url %s" % url) total_length = r.headers.get('content-length') with open(fname, 'wb') as f: if total_length is None: # no content length header for chunk in r.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks f.write(chunk) else: total_length = int(total_length) for chunk in tqdm(r.iter_content(chunk_size=1024), total=int(total_length / 1024. + 0.5), unit='KB', unit_scale=False, dynamic_ncols=True): f.write(chunk) if sha1_hash and not check_sha1(fname, sha1_hash): raise UserWarning('File {} is downloaded but the content hash does not match. ' \ 'The repo may be outdated or download may be incomplete. ' \ 'If the "repo_url" is overridden, consider switching to ' \ 'the default repo.'.format(fname)) return fname ================================================ FILE: src/utils/dependencies/insightface/utils/face_align.py ================================================ import cv2 import numpy as np from skimage import transform as trans arcface_dst = np.array( [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], [41.5493, 92.3655], [70.7299, 92.2041]], dtype=np.float32) def estimate_norm(lmk, image_size=112,mode='arcface'): assert lmk.shape == (5, 2) assert image_size%112==0 or image_size%128==0 if image_size%112==0: ratio = float(image_size)/112.0 diff_x = 0 else: ratio = float(image_size)/128.0 diff_x = 8.0*ratio dst = arcface_dst * ratio dst[:,0] += diff_x tform = trans.SimilarityTransform() tform.estimate(lmk, dst) M = tform.params[0:2, :] return M def norm_crop(img, landmark, image_size=112, mode='arcface'): M = estimate_norm(landmark, image_size, mode) warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0) return warped def norm_crop2(img, landmark, image_size=112, mode='arcface'): M = estimate_norm(landmark, image_size, mode) warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0) return warped, M def square_crop(im, S): if im.shape[0] > im.shape[1]: height = S width = int(float(im.shape[1]) / im.shape[0] * S) scale = float(S) / im.shape[0] else: width = S height = int(float(im.shape[0]) / im.shape[1] * S) scale = float(S) / im.shape[1] resized_im = cv2.resize(im, (width, height)) det_im = np.zeros((S, S, 3), dtype=np.uint8) det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im return det_im, scale def transform(data, center, output_size, scale, rotation): scale_ratio = scale rot = float(rotation) * np.pi / 180.0 #translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio) t1 = trans.SimilarityTransform(scale=scale_ratio) cx = center[0] * scale_ratio cy = center[1] * scale_ratio t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy)) t3 = trans.SimilarityTransform(rotation=rot) t4 = trans.SimilarityTransform(translation=(output_size / 2, output_size / 2)) t = t1 + t2 + t3 + t4 M = t.params[0:2] cropped = cv2.warpAffine(data, M, (output_size, output_size), borderValue=0.0) return cropped, M def trans_points2d(pts, M): new_pts = np.zeros(shape=pts.shape, dtype=np.float32) for i in range(pts.shape[0]): pt = pts[i] new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) new_pt = np.dot(M, new_pt) #print('new_pt', new_pt.shape, new_pt) new_pts[i] = new_pt[0:2] return new_pts def trans_points3d(pts, M): scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1]) #print(scale) new_pts = np.zeros(shape=pts.shape, dtype=np.float32) for i in range(pts.shape[0]): pt = pts[i] new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) new_pt = np.dot(M, new_pt) #print('new_pt', new_pt.shape, new_pt) new_pts[i][0:2] = new_pt[0:2] new_pts[i][2] = pts[i][2] * scale return new_pts def trans_points(pts, M): if pts.shape[1] == 2: return trans_points2d(pts, M) else: return trans_points3d(pts, M) ================================================ FILE: src/utils/dependencies/insightface/utils/filesystem.py ================================================ """ This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/utils/filesystem.py """ import os import os.path as osp import errno def get_model_dir(name, root='~/.insightface'): root = os.path.expanduser(root) model_dir = osp.join(root, 'models', name) return model_dir def makedirs(path): """Create directory recursively if not exists. Similar to `makedir -p`, you can skip checking existence before this function. Parameters ---------- path : str Path of the desired dir """ try: os.makedirs(path) except OSError as exc: if exc.errno != errno.EEXIST: raise def try_import(package, message=None): """Try import specified package, with custom message support. Parameters ---------- package : str The name of the targeting package. message : str, default is None If not None, this function will raise customized error message when import error is found. Returns ------- module if found, raise ImportError otherwise """ try: return __import__(package) except ImportError as e: if not message: raise e raise ImportError(message) def try_import_cv2(): """Try import cv2 at runtime. Returns ------- cv2 module if found. Raise ImportError otherwise """ msg = "cv2 is required, you can install by package manager, e.g. 'apt-get', \ or `pip install opencv-python --user` (note that this is unofficial PYPI package)." return try_import('cv2', msg) def try_import_mmcv(): """Try import mmcv at runtime. Returns ------- mmcv module if found. Raise ImportError otherwise """ msg = "mmcv is required, you can install by first `pip install Cython --user` \ and then `pip install mmcv --user` (note that this is unofficial PYPI package)." return try_import('mmcv', msg) def try_import_rarfile(): """Try import rarfile at runtime. Returns ------- rarfile module if found. Raise ImportError otherwise """ msg = "rarfile is required, you can install by first `sudo apt-get install unrar` \ and then `pip install rarfile --user` (note that this is unofficial PYPI package)." return try_import('rarfile', msg) def import_try_install(package, extern_url=None): """Try import the specified package. If the package not installed, try use pip to install and import if success. Parameters ---------- package : str The name of the package trying to import. extern_url : str or None, optional The external url if package is not hosted on PyPI. For example, you can install a package using: "pip install git+http://github.com/user/repo/tarball/master/egginfo=xxx". In this case, you can pass the url to the extern_url. Returns ------- The imported python module. """ try: return __import__(package) except ImportError: try: from pip import main as pipmain except ImportError: from pip._internal import main as pipmain # trying to install package url = package if extern_url is None else extern_url pipmain(['install', '--user', url]) # will raise SystemExit Error if fails # trying to load again try: return __import__(package) except ImportError: import sys import site user_site = site.getusersitepackages() if user_site not in sys.path: sys.path.append(user_site) return __import__(package) return __import__(package) def try_import_dali(): """Try import NVIDIA DALI at runtime. """ try: dali = __import__('nvidia.dali', fromlist=['pipeline', 'ops', 'types']) dali.Pipeline = dali.pipeline.Pipeline except ImportError: class dali: class Pipeline: def __init__(self): raise NotImplementedError( "DALI not found, please check if you installed it correctly." ) return dali ================================================ FILE: src/utils/dependencies/insightface/utils/storage.py ================================================ import os import os.path as osp import zipfile from .download import download_file BASE_REPO_URL = 'https://github.com/deepinsight/insightface/releases/download/v0.7' def download(sub_dir, name, force=False, root='~/.insightface'): _root = os.path.expanduser(root) dir_path = os.path.join(_root, sub_dir, name) if osp.exists(dir_path) and not force: return dir_path print('download_path:', dir_path) zip_file_path = os.path.join(_root, sub_dir, name + '.zip') model_url = "%s/%s.zip"%(BASE_REPO_URL, name) download_file(model_url, path=zip_file_path, overwrite=True) if not os.path.exists(dir_path): os.makedirs(dir_path) with zipfile.ZipFile(zip_file_path) as zf: zf.extractall(dir_path) #os.remove(zip_file_path) return dir_path def ensure_available(sub_dir, name, root='~/.insightface'): return download(sub_dir, name, force=False, root=root) def download_onnx(sub_dir, model_file, force=False, root='~/.insightface', download_zip=False): _root = os.path.expanduser(root) model_root = osp.join(_root, sub_dir) new_model_file = osp.join(model_root, model_file) if osp.exists(new_model_file) and not force: return new_model_file if not osp.exists(model_root): os.makedirs(model_root) print('download_path:', new_model_file) if not download_zip: model_url = "%s/%s"%(BASE_REPO_URL, model_file) download_file(model_url, path=new_model_file, overwrite=True) else: model_url = "%s/%s.zip"%(BASE_REPO_URL, model_file) zip_file_path = new_model_file+".zip" download_file(model_url, path=zip_file_path, overwrite=True) with zipfile.ZipFile(zip_file_path) as zf: zf.extractall(model_root) return new_model_file ================================================ FILE: src/utils/dependencies/insightface/utils/transform.py ================================================ import cv2 import math import numpy as np from skimage import transform as trans def transform(data, center, output_size, scale, rotation): scale_ratio = scale rot = float(rotation) * np.pi / 180.0 #translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio) t1 = trans.SimilarityTransform(scale=scale_ratio) cx = center[0] * scale_ratio cy = center[1] * scale_ratio t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy)) t3 = trans.SimilarityTransform(rotation=rot) t4 = trans.SimilarityTransform(translation=(output_size / 2, output_size / 2)) t = t1 + t2 + t3 + t4 M = t.params[0:2] cropped = cv2.warpAffine(data, M, (output_size, output_size), borderValue=0.0) return cropped, M def trans_points2d(pts, M): new_pts = np.zeros(shape=pts.shape, dtype=np.float32) for i in range(pts.shape[0]): pt = pts[i] new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) new_pt = np.dot(M, new_pt) #print('new_pt', new_pt.shape, new_pt) new_pts[i] = new_pt[0:2] return new_pts def trans_points3d(pts, M): scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1]) #print(scale) new_pts = np.zeros(shape=pts.shape, dtype=np.float32) for i in range(pts.shape[0]): pt = pts[i] new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) new_pt = np.dot(M, new_pt) #print('new_pt', new_pt.shape, new_pt) new_pts[i][0:2] = new_pt[0:2] new_pts[i][2] = pts[i][2] * scale return new_pts def trans_points(pts, M): if pts.shape[1] == 2: return trans_points2d(pts, M) else: return trans_points3d(pts, M) def estimate_affine_matrix_3d23d(X, Y): ''' Using least-squares solution Args: X: [n, 3]. 3d points(fixed) Y: [n, 3]. corresponding 3d points(moving). Y = PX Returns: P_Affine: (3, 4). Affine camera matrix (the third row is [0, 0, 0, 1]). ''' X_homo = np.hstack((X, np.ones([X.shape[0],1]))) #n x 4 P = np.linalg.lstsq(X_homo, Y)[0].T # Affine matrix. 3 x 4 return P def P2sRt(P): ''' decompositing camera matrix P Args: P: (3, 4). Affine Camera Matrix. Returns: s: scale factor. R: (3, 3). rotation matrix. t: (3,). translation. ''' t = P[:, 3] R1 = P[0:1, :3] R2 = P[1:2, :3] s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2.0 r1 = R1/np.linalg.norm(R1) r2 = R2/np.linalg.norm(R2) r3 = np.cross(r1, r2) R = np.concatenate((r1, r2, r3), 0) return s, R, t def matrix2angle(R): ''' get three Euler angles from Rotation Matrix Args: R: (3,3). rotation matrix Returns: x: pitch y: yaw z: roll ''' sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0]) singular = sy < 1e-6 if not singular : x = math.atan2(R[2,1] , R[2,2]) y = math.atan2(-R[2,0], sy) z = math.atan2(R[1,0], R[0,0]) else : x = math.atan2(-R[1,2], R[1,1]) y = math.atan2(-R[2,0], sy) z = 0 # rx, ry, rz = np.rad2deg(x), np.rad2deg(y), np.rad2deg(z) rx, ry, rz = x*180/np.pi, y*180/np.pi, z*180/np.pi return rx, ry, rz ================================================ FILE: src/utils/face_analysis_diy.py ================================================ # coding: utf-8 """ face detectoin and alignment using InsightFace """ import numpy as np from .rprint import rlog as log from .dependencies.insightface.app import FaceAnalysis from .dependencies.insightface.app.common import Face from .timer import Timer def sort_by_direction(faces, direction: str = 'large-small', face_center=None): if len(faces) <= 0: return faces if direction == 'left-right': return sorted(faces, key=lambda face: face['bbox'][0]) if direction == 'right-left': return sorted(faces, key=lambda face: face['bbox'][0], reverse=True) if direction == 'top-bottom': return sorted(faces, key=lambda face: face['bbox'][1]) if direction == 'bottom-top': return sorted(faces, key=lambda face: face['bbox'][1], reverse=True) if direction == 'small-large': return sorted(faces, key=lambda face: (face['bbox'][2] - face['bbox'][0]) * (face['bbox'][3] - face['bbox'][1])) if direction == 'large-small': return sorted(faces, key=lambda face: (face['bbox'][2] - face['bbox'][0]) * (face['bbox'][3] - face['bbox'][1]), reverse=True) if direction == 'distance-from-retarget-face': return sorted(faces, key=lambda face: (((face['bbox'][2]+face['bbox'][0])/2-face_center[0])**2+((face['bbox'][3]+face['bbox'][1])/2-face_center[1])**2)**0.5) return faces class FaceAnalysisDIY(FaceAnalysis): def __init__(self, name='buffalo_l', root='~/.insightface', allowed_modules=None, **kwargs): super().__init__(name=name, root=root, allowed_modules=allowed_modules, **kwargs) self.timer = Timer() def get(self, img_bgr, **kwargs): max_num = kwargs.get('max_face_num', 0) # the number of the detected faces, 0 means no limit flag_do_landmark_2d_106 = kwargs.get('flag_do_landmark_2d_106', True) # whether to do 106-point detection direction = kwargs.get('direction', 'large-small') # sorting direction face_center = None bboxes, kpss = self.det_model.detect(img_bgr, max_num=max_num, metric='default') if bboxes.shape[0] == 0: return [] ret = [] for i in range(bboxes.shape[0]): bbox = bboxes[i, 0:4] det_score = bboxes[i, 4] kps = None if kpss is not None: kps = kpss[i] face = Face(bbox=bbox, kps=kps, det_score=det_score) for taskname, model in self.models.items(): if taskname == 'detection': continue if (not flag_do_landmark_2d_106) and taskname == 'landmark_2d_106': continue # print(f'taskname: {taskname}') model.get(img_bgr, face) ret.append(face) ret = sort_by_direction(ret, direction, face_center) return ret def warmup(self): self.timer.tic() img_bgr = np.zeros((512, 512, 3), dtype=np.uint8) self.get(img_bgr) elapse = self.timer.toc() log(f'FaceAnalysisDIY warmup time: {elapse:.3f}s') ================================================ FILE: src/utils/filter.py ================================================ # coding: utf-8 import torch import numpy as np from pykalman import KalmanFilter def smooth(x_d_lst, shape, device, observation_variance=3e-7, process_variance=1e-5): x_d_lst_reshape = [x.reshape(-1) for x in x_d_lst] x_d_stacked = np.vstack(x_d_lst_reshape) kf = KalmanFilter( initial_state_mean=x_d_stacked[0], n_dim_obs=x_d_stacked.shape[1], transition_covariance=process_variance * np.eye(x_d_stacked.shape[1]), observation_covariance=observation_variance * np.eye(x_d_stacked.shape[1]) ) smoothed_state_means, _ = kf.smooth(x_d_stacked) x_d_lst_smooth = [torch.tensor(state_mean.reshape(shape[-2:]), dtype=torch.float32, device=device) for state_mean in smoothed_state_means] return x_d_lst_smooth ================================================ FILE: src/utils/helper.py ================================================ # coding: utf-8 """ utility functions and classes to handle feature extraction and model loading """ import os import os.path as osp import torch from collections import OrderedDict import numpy as np from scipy.spatial import ConvexHull # pylint: disable=E0401,E0611 from typing import Union import cv2 from ..modules.spade_generator import SPADEDecoder from ..modules.warping_network import WarpingNetwork from ..modules.motion_extractor import MotionExtractor from ..modules.appearance_feature_extractor import AppearanceFeatureExtractor from ..modules.stitching_retargeting_network import StitchingRetargetingNetwork def tensor_to_numpy(data: Union[np.ndarray, torch.Tensor]) -> np.ndarray: """transform torch.Tensor into numpy.ndarray""" if isinstance(data, torch.Tensor): return data.data.cpu().numpy() return data def calc_motion_multiplier( kp_source: Union[np.ndarray, torch.Tensor], kp_driving_initial: Union[np.ndarray, torch.Tensor] ) -> float: """calculate motion_multiplier based on the source image and the first driving frame""" kp_source_np = tensor_to_numpy(kp_source) kp_driving_initial_np = tensor_to_numpy(kp_driving_initial) source_area = ConvexHull(kp_source_np.squeeze(0)).volume driving_area = ConvexHull(kp_driving_initial_np.squeeze(0)).volume motion_multiplier = np.sqrt(source_area) / np.sqrt(driving_area) # motion_multiplier = np.cbrt(source_area) / np.cbrt(driving_area) return motion_multiplier def suffix(filename): """a.jpg -> jpg""" pos = filename.rfind(".") if pos == -1: return "" return filename[pos + 1:] def prefix(filename): """a.jpg -> a""" pos = filename.rfind(".") if pos == -1: return filename return filename[:pos] def basename(filename): """a/b/c.jpg -> c""" return prefix(osp.basename(filename)) def remove_suffix(filepath): """a/b/c.jpg -> a/b/c""" return osp.join(osp.dirname(filepath), basename(filepath)) def is_image(file_path): image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp') return file_path.lower().endswith(image_extensions) def is_video(file_path): if file_path.lower().endswith((".mp4", ".mov", ".avi", ".webm")) or osp.isdir(file_path): return True return False def is_template(file_path): if file_path.endswith(".pkl"): return True return False def mkdir(d, log=False): # return self-assined `d`, for one line code if not osp.exists(d): os.makedirs(d, exist_ok=True) if log: print(f"Make dir: {d}") return d def squeeze_tensor_to_numpy(tensor): out = tensor.data.squeeze(0).cpu().numpy() return out def dct2device(dct: dict, device): for key in dct: if isinstance(dct[key], torch.Tensor): dct[key] = dct[key].to(device) else: dct[key] = torch.tensor(dct[key]).to(device) return dct def concat_feat(kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: """ kp_source: (bs, k, 3) kp_driving: (bs, k, 3) Return: (bs, 2k*3) """ bs_src = kp_source.shape[0] bs_dri = kp_driving.shape[0] assert bs_src == bs_dri, 'batch size must be equal' feat = torch.cat([kp_source.view(bs_src, -1), kp_driving.view(bs_dri, -1)], dim=1) return feat def remove_ddp_dumplicate_key(state_dict): state_dict_new = OrderedDict() for key in state_dict.keys(): state_dict_new[key.replace('module.', '')] = state_dict[key] return state_dict_new def load_model(ckpt_path, model_config, device, model_type): model_params = model_config['model_params'][f'{model_type}_params'] if model_type == 'appearance_feature_extractor': model = AppearanceFeatureExtractor(**model_params).to(device) elif model_type == 'motion_extractor': model = MotionExtractor(**model_params).to(device) elif model_type == 'warping_module': model = WarpingNetwork(**model_params).to(device) elif model_type == 'spade_generator': model = SPADEDecoder(**model_params).to(device) elif model_type == 'stitching_retargeting_module': # Special handling for stitching and retargeting module config = model_config['model_params']['stitching_retargeting_module_params'] checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage) stitcher = StitchingRetargetingNetwork(**config.get('stitching')) stitcher.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_shoulder'])) stitcher = stitcher.to(device) stitcher.eval() retargetor_lip = StitchingRetargetingNetwork(**config.get('lip')) retargetor_lip.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_mouth'])) retargetor_lip = retargetor_lip.to(device) retargetor_lip.eval() retargetor_eye = StitchingRetargetingNetwork(**config.get('eye')) retargetor_eye.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_eye'])) retargetor_eye = retargetor_eye.to(device) retargetor_eye.eval() return { 'stitching': stitcher, 'lip': retargetor_lip, 'eye': retargetor_eye } else: raise ValueError(f"Unknown model type: {model_type}") model.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage)) model.eval() return model def load_description(fp): with open(fp, 'r', encoding='utf-8') as f: content = f.read() return content def is_square_video(video_path): video = cv2.VideoCapture(video_path) width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) video.release() # if width != height: # gr.Info(f"Uploaded video is not square, force do crop (driving) to be True") return width == height def clean_state_dict(state_dict): new_state_dict = OrderedDict() for k, v in state_dict.items(): if k[:7] == 'module.': k = k[7:] # remove `module.` new_state_dict[k] = v return new_state_dict ================================================ FILE: src/utils/human_landmark_runner.py ================================================ # coding: utf-8 import os.path as osp import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) import torch import numpy as np import onnxruntime from .timer import Timer from .rprint import rlog from .crop import crop_image, _transform_pts def make_abs_path(fn): return osp.join(osp.dirname(osp.realpath(__file__)), fn) def to_ndarray(obj): if isinstance(obj, torch.Tensor): return obj.cpu().numpy() elif isinstance(obj, np.ndarray): return obj else: return np.array(obj) class LandmarkRunner(object): """landmark runner""" def __init__(self, **kwargs): ckpt_path = kwargs.get('ckpt_path') onnx_provider = kwargs.get('onnx_provider', 'cuda') # 默认用cuda device_id = kwargs.get('device_id', 0) self.dsize = kwargs.get('dsize', 224) self.timer = Timer() if onnx_provider.lower() == 'cuda': self.session = onnxruntime.InferenceSession( ckpt_path, providers=[ ('CUDAExecutionProvider', {'device_id': device_id}) ] ) elif onnx_provider.lower() == 'mps': self.session = onnxruntime.InferenceSession( ckpt_path, providers=[ 'CoreMLExecutionProvider' ] ) else: opts = onnxruntime.SessionOptions() opts.intra_op_num_threads = 4 # 默认线程数为 4 self.session = onnxruntime.InferenceSession( ckpt_path, providers=['CPUExecutionProvider'], sess_options=opts ) def _run(self, inp): out = self.session.run(None, {'input': inp}) return out def run(self, img_rgb: np.ndarray, lmk=None): if lmk is not None: crop_dct = crop_image(img_rgb, lmk, dsize=self.dsize, scale=1.5, vy_ratio=-0.1) img_crop_rgb = crop_dct['img_crop'] else: # NOTE: force resize to 224x224, NOT RECOMMEND! img_crop_rgb = cv2.resize(img_rgb, (self.dsize, self.dsize)) scale = max(img_rgb.shape[:2]) / self.dsize crop_dct = { 'M_c2o': np.array([ [scale, 0., 0.], [0., scale, 0.], [0., 0., 1.], ], dtype=np.float32), } inp = (img_crop_rgb.astype(np.float32) / 255.).transpose(2, 0, 1)[None, ...] # HxWx3 (BGR) -> 1x3xHxW (RGB!) out_lst = self._run(inp) out_pts = out_lst[2] # 2d landmarks 203 points lmk = to_ndarray(out_pts[0]).reshape(-1, 2) * self.dsize # scale to 0-224 lmk = _transform_pts(lmk, M=crop_dct['M_c2o']) return lmk def warmup(self): self.timer.tic() dummy_image = np.zeros((1, 3, self.dsize, self.dsize), dtype=np.float32) _ = self._run(dummy_image) elapse = self.timer.toc() rlog(f'LandmarkRunner warmup time: {elapse:.3f}s') ================================================ FILE: src/utils/io.py ================================================ # coding: utf-8 import os.path as osp import imageio import numpy as np import pickle import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) from .helper import mkdir, suffix def load_image_rgb(image_path: str): if not osp.exists(image_path): raise FileNotFoundError(f"Image not found: {image_path}") img = cv2.imread(image_path, cv2.IMREAD_COLOR) return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) def load_video(video_info, n_frames=-1): reader = imageio.get_reader(video_info, "ffmpeg") ret = [] for idx, frame_rgb in enumerate(reader): if n_frames > 0 and idx >= n_frames: break ret.append(frame_rgb) reader.close() return ret def contiguous(obj): if not obj.flags.c_contiguous: obj = obj.copy(order="C") return obj def resize_to_limit(img: np.ndarray, max_dim=1920, division=2): """ ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n. :param img: the image to be processed. :param max_dim: the maximum dimension constraint. :param n: the number that needs to be multiples of. :return: the adjusted image. """ h, w = img.shape[:2] # ajust the size of the image according to the maximum dimension if max_dim > 0 and max(h, w) > max_dim: if h > w: new_h = max_dim new_w = int(w * (max_dim / h)) else: new_w = max_dim new_h = int(h * (max_dim / w)) img = cv2.resize(img, (new_w, new_h)) # ensure that the image dimensions are multiples of n division = max(division, 1) new_h = img.shape[0] - (img.shape[0] % division) new_w = img.shape[1] - (img.shape[1] % division) if new_h == 0 or new_w == 0: # when the width or height is less than n, no need to process return img if new_h != img.shape[0] or new_w != img.shape[1]: img = img[:new_h, :new_w] return img def load_img_online(obj, mode="bgr", **kwargs): max_dim = kwargs.get("max_dim", 1920) n = kwargs.get("n", 2) if isinstance(obj, str): if mode.lower() == "gray": img = cv2.imread(obj, cv2.IMREAD_GRAYSCALE) else: img = cv2.imread(obj, cv2.IMREAD_COLOR) else: img = obj # Resize image to satisfy constraints img = resize_to_limit(img, max_dim=max_dim, division=n) if mode.lower() == "bgr": return contiguous(img) elif mode.lower() == "rgb": return contiguous(img[..., ::-1]) else: raise Exception(f"Unknown mode {mode}") def load(fp): suffix_ = suffix(fp) if suffix_ == "npy": return np.load(fp) elif suffix_ == "pkl": return pickle.load(open(fp, "rb")) else: raise Exception(f"Unknown type: {suffix}") def dump(wfp, obj): wd = osp.split(wfp)[0] if wd != "" and not osp.exists(wd): mkdir(wd) _suffix = suffix(wfp) if _suffix == "npy": np.save(wfp, obj) elif _suffix == "pkl": pickle.dump(obj, open(wfp, "wb")) else: raise Exception("Unknown type: {}".format(_suffix)) ================================================ FILE: src/utils/retargeting_utils.py ================================================ """ Functions to compute distance ratios between specific pairs of facial landmarks """ import numpy as np def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray: return (np.linalg.norm(lmk[:, idx1] - lmk[:, idx2], axis=1, keepdims=True) / (np.linalg.norm(lmk[:, idx3] - lmk[:, idx4], axis=1, keepdims=True) + eps)) def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -> np.ndarray: lefteye_close_ratio = calculate_distance_ratio(lmk, 6, 18, 0, 12) righteye_close_ratio = calculate_distance_ratio(lmk, 30, 42, 24, 36) if target_eye_ratio is not None: return np.concatenate([lefteye_close_ratio, righteye_close_ratio, target_eye_ratio], axis=1) else: return np.concatenate([lefteye_close_ratio, righteye_close_ratio], axis=1) def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray: return calculate_distance_ratio(lmk, 90, 102, 48, 66) ================================================ FILE: src/utils/rprint.py ================================================ # coding: utf-8 """ custom print and log functions """ __all__ = ['rprint', 'rlog'] try: from rich.console import Console console = Console() rprint = console.print rlog = console.log except: rprint = print rlog = print ================================================ FILE: src/utils/timer.py ================================================ # coding: utf-8 """ tools to measure elapsed time """ import time class Timer(object): """A simple timer.""" def __init__(self): self.total_time = 0. self.calls = 0 self.start_time = 0. self.diff = 0. def tic(self): # using time.time instead of time.clock because time time.clock # does not normalize for multithreading self.start_time = time.time() def toc(self, average=True): self.diff = time.time() - self.start_time return self.diff def clear(self): self.start_time = 0. self.diff = 0. ================================================ FILE: src/utils/video.py ================================================ # coding: utf-8 """ Functions for processing video ATTENTION: you need to install ffmpeg and ffprobe in your env! """ import os.path as osp import numpy as np import subprocess import imageio import cv2 from rich.progress import track from .rprint import rlog as log from .rprint import rprint as print from .helper import prefix def exec_cmd(cmd): return subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) def images2video(images, wfp, **kwargs): fps = kwargs.get('fps', 30) video_format = kwargs.get('format', 'mp4') # default is mp4 format codec = kwargs.get('codec', 'libx264') # default is libx264 encoding quality = kwargs.get('quality') # video quality pixelformat = kwargs.get('pixelformat', 'yuv420p') # video pixel format image_mode = kwargs.get('image_mode', 'rgb') macro_block_size = kwargs.get('macro_block_size', 2) ffmpeg_params = ['-crf', str(kwargs.get('crf', 18))] writer = imageio.get_writer( wfp, fps=fps, format=video_format, codec=codec, quality=quality, ffmpeg_params=ffmpeg_params, pixelformat=pixelformat, macro_block_size=macro_block_size ) n = len(images) for i in track(range(n), description='Writing', transient=True): if image_mode.lower() == 'bgr': writer.append_data(images[i][..., ::-1]) else: writer.append_data(images[i]) writer.close() def video2gif(video_fp, fps=30, size=256): if osp.exists(video_fp): d = osp.split(video_fp)[0] fn = prefix(osp.basename(video_fp)) palette_wfp = osp.join(d, 'palette.png') gif_wfp = osp.join(d, f'{fn}.gif') # generate the palette cmd = f'ffmpeg -i "{video_fp}" -vf "fps={fps},scale={size}:-1:flags=lanczos,palettegen" "{palette_wfp}" -y' exec_cmd(cmd) # use the palette to generate the gif cmd = f'ffmpeg -i "{video_fp}" -i "{palette_wfp}" -filter_complex "fps={fps},scale={size}:-1:flags=lanczos[x];[x][1:v]paletteuse" "{gif_wfp}" -y' exec_cmd(cmd) return gif_wfp else: raise FileNotFoundError(f"video_fp: {video_fp} not exists!") def merge_audio_video(video_fp, audio_fp, wfp): if osp.exists(video_fp) and osp.exists(audio_fp): cmd = f'ffmpeg -i "{video_fp}" -i "{audio_fp}" -c:v copy -c:a aac "{wfp}" -y' exec_cmd(cmd) print(f'merge {video_fp} and {audio_fp} to {wfp}') else: print(f'video_fp: {video_fp} or audio_fp: {audio_fp} not exists!') def blend(img: np.ndarray, mask: np.ndarray, background_color=(255, 255, 255)): mask_float = mask.astype(np.float32) / 255. background_color = np.array(background_color).reshape([1, 1, 3]) bg = np.ones_like(img) * background_color img = np.clip(mask_float * img + (1 - mask_float) * bg, 0, 255).astype(np.uint8) return img def concat_frames(driving_image_lst, source_image_lst, I_p_lst): # TODO: add more concat style, e.g., left-down corner driving out_lst = [] h, w, _ = I_p_lst[0].shape source_image_resized_lst = [cv2.resize(img, (w, h)) for img in source_image_lst] for idx, _ in track(enumerate(I_p_lst), total=len(I_p_lst), description='Concatenating result...'): I_p = I_p_lst[idx] source_image_resized = source_image_resized_lst[idx] if len(source_image_lst) > 1 else source_image_resized_lst[0] if driving_image_lst is None: out = np.hstack((source_image_resized, I_p)) else: driving_image = driving_image_lst[idx] driving_image_resized = cv2.resize(driving_image, (w, h)) out = np.hstack((driving_image_resized, source_image_resized, I_p)) out_lst.append(out) return out_lst class VideoWriter: def __init__(self, **kwargs): self.fps = kwargs.get('fps', 30) self.wfp = kwargs.get('wfp', 'video.mp4') self.video_format = kwargs.get('format', 'mp4') self.codec = kwargs.get('codec', 'libx264') self.quality = kwargs.get('quality') self.pixelformat = kwargs.get('pixelformat', 'yuv420p') self.image_mode = kwargs.get('image_mode', 'rgb') self.ffmpeg_params = kwargs.get('ffmpeg_params') self.writer = imageio.get_writer( self.wfp, fps=self.fps, format=self.video_format, codec=self.codec, quality=self.quality, ffmpeg_params=self.ffmpeg_params, pixelformat=self.pixelformat ) def write(self, image): if self.image_mode.lower() == 'bgr': self.writer.append_data(image[..., ::-1]) else: self.writer.append_data(image) def close(self): if self.writer is not None: self.writer.close() def change_video_fps(input_file, output_file, fps=20, codec='libx264', crf=12): cmd = f'ffmpeg -i "{input_file}" -c:v {codec} -crf {crf} -r {fps} "{output_file}" -y' exec_cmd(cmd) def get_fps(filepath, default_fps=25): try: fps = cv2.VideoCapture(filepath).get(cv2.CAP_PROP_FPS) if fps in (0, None): fps = default_fps except Exception as e: log(e) fps = default_fps return fps def has_audio_stream(video_path: str) -> bool: """ Check if the video file contains an audio stream. :param video_path: Path to the video file :return: True if the video contains an audio stream, False otherwise """ if osp.isdir(video_path): return False cmd = [ 'ffprobe', '-v', 'error', '-select_streams', 'a', '-show_entries', 'stream=codec_type', '-of', 'default=noprint_wrappers=1:nokey=1', f'"{video_path}"' ] try: # result = subprocess.run(cmd, capture_output=True, text=True) result = exec_cmd(' '.join(cmd)) if result.returncode != 0: log(f"Error occurred while probing video: {result.stderr}") return False # Check if there is any output from ffprobe command return bool(result.stdout.strip()) except Exception as e: log( f"Error occurred while probing video: {video_path}, " "you may need to install ffprobe! (https://ffmpeg.org/download.html) " "Now set audio to false!", style="bold red" ) return False def add_audio_to_video(silent_video_path: str, audio_video_path: str, output_video_path: str): cmd = [ 'ffmpeg', '-y', '-i', f'"{silent_video_path}"', '-i', f'"{audio_video_path}"', '-map', '0:v', '-map', '1:a', '-c:v', 'copy', '-shortest', f'"{output_video_path}"' ] try: exec_cmd(' '.join(cmd)) log(f"Video with audio generated successfully: {output_video_path}") except subprocess.CalledProcessError as e: log(f"Error occurred: {e}") def bb_intersection_over_union(boxA, boxB): xA = max(boxA[0], boxB[0]) yA = max(boxA[1], boxB[1]) xB = min(boxA[2], boxB[2]) yB = min(boxA[3], boxB[3]) interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) iou = interArea / float(boxAArea + boxBArea - interArea) return iou ================================================ FILE: src/utils/viz.py ================================================ # coding: utf-8 import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) def viz_lmk(img_, vps, **kwargs): """可视化点""" lineType = kwargs.get("lineType", cv2.LINE_8) # cv2.LINE_AA img_for_viz = img_.copy() for pt in vps: cv2.circle( img_for_viz, (int(pt[0]), int(pt[1])), radius=kwargs.get("radius", 1), color=(0, 255, 0), thickness=kwargs.get("thickness", 1), lineType=lineType, ) return img_for_viz