Repository: KwaiVGI/3DTrajMaster
Branch: main
Commit: 0af2932ffe51
Files: 276
Total size: 6.9 MB
Directory structure:
gitextract_mycvvgja/
├── CogVideo/
│ ├── .github/
│ │ ├── ISSUE_TEMPLATE/
│ │ │ ├── bug_report.yaml
│ │ │ └── feature-request.yaml
│ │ └── PULL_REQUEST_TEMPLATE/
│ │ └── pr_template.md
│ ├── .gitignore
│ ├── LICENSE
│ ├── MODEL_LICENSE
│ ├── README.md
│ ├── README_ja.md
│ ├── README_zh.md
│ ├── download.sh
│ ├── finetune/
│ │ ├── README.md
│ │ ├── README_ja.md
│ │ ├── README_zh.md
│ │ ├── accelerate_config_machine_single.yaml
│ │ ├── accelerate_config_machine_single_debug.yaml
│ │ ├── finetune_single_rank_injector.sh
│ │ ├── finetune_single_rank_lora.sh
│ │ ├── hostfile.txt
│ │ ├── models/
│ │ │ ├── attention.py
│ │ │ ├── attention_processor.py
│ │ │ ├── cogvideox_transformer_3d.py
│ │ │ ├── embeddings.py
│ │ │ ├── pipeline_cogvideox.py
│ │ │ ├── pipeline_output.py
│ │ │ └── utils.py
│ │ ├── train_cogvideox_injector.py
│ │ └── train_cogvideox_lora.py
│ ├── inference/
│ │ ├── 3dtrajmaster_inference.py
│ │ ├── entity_zoo.txt
│ │ └── location_zoo.txt
│ ├── pyproject.toml
│ ├── requirements.txt
│ ├── tools/
│ │ ├── caption/
│ │ │ ├── README.md
│ │ │ ├── README_ja.md
│ │ │ ├── README_zh.md
│ │ │ ├── requirements.txt
│ │ │ └── video_caption.py
│ │ ├── convert_weight_sat2hf.py
│ │ ├── export_sat_lora_weight.py
│ │ ├── llm_flux_cogvideox/
│ │ │ ├── generate.sh
│ │ │ ├── gradio_page.py
│ │ │ └── llm_flux_cogvideox.py
│ │ ├── load_cogvideox_lora.py
│ │ ├── parallel_inference/
│ │ │ ├── parallel_inference_xdit.py
│ │ │ └── run.sh
│ │ ├── replicate/
│ │ │ ├── cog.yaml
│ │ │ ├── predict_i2v.py
│ │ │ └── predict_t2v.py
│ │ └── venhancer/
│ │ ├── README.md
│ │ ├── README_ja.md
│ │ └── README_zh.md
│ └── weights/
│ └── put weights here.txt
├── README.md
├── dataset/
│ ├── load_dataset.py
│ ├── traj_vis/
│ │ ├── D_loc1_61_t3n13_003d_Hemi12_1.json
│ │ ├── Hemi12_transforms.json
│ │ └── location_data_desert.json
│ ├── utils.py
│ └── vis_trajectory.py
└── eval/
├── GVHMR/
│ ├── .gitignore
│ ├── .gitmodules
│ ├── LICENSE
│ ├── README.md
│ ├── docs/
│ │ └── INSTALL.md
│ ├── download_eval_pose.sh
│ ├── eval.sh
│ ├── hmr4d/
│ │ ├── __init__.py
│ │ ├── build_gvhmr.py
│ │ ├── configs/
│ │ │ ├── __init__.py
│ │ │ ├── data/
│ │ │ │ └── mocap/
│ │ │ │ ├── testY.yaml
│ │ │ │ └── trainX_testY.yaml
│ │ │ ├── demo.yaml
│ │ │ ├── exp/
│ │ │ │ └── gvhmr/
│ │ │ │ └── mixed/
│ │ │ │ └── mixed.yaml
│ │ │ ├── global/
│ │ │ │ ├── debug/
│ │ │ │ │ ├── debug_train.yaml
│ │ │ │ │ └── debug_train_limit_data.yaml
│ │ │ │ └── task/
│ │ │ │ └── gvhmr/
│ │ │ │ ├── test_3dpw.yaml
│ │ │ │ ├── test_3dpw_emdb_rich.yaml
│ │ │ │ ├── test_emdb.yaml
│ │ │ │ └── test_rich.yaml
│ │ │ ├── hydra/
│ │ │ │ └── default.yaml
│ │ │ ├── siga24_release.yaml
│ │ │ ├── store_gvhmr.py
│ │ │ └── train.yaml
│ │ ├── datamodule/
│ │ │ └── mocap_trainX_testY.py
│ │ ├── dataset/
│ │ │ ├── bedlam/
│ │ │ │ ├── bedlam.py
│ │ │ │ ├── resource/
│ │ │ │ │ └── vname2lwh.pt
│ │ │ │ └── utils.py
│ │ │ ├── emdb/
│ │ │ │ ├── emdb_motion_test.py
│ │ │ │ └── utils.py
│ │ │ ├── h36m/
│ │ │ │ ├── camera-parameters.json
│ │ │ │ ├── h36m.py
│ │ │ │ └── utils.py
│ │ │ ├── imgfeat_motion/
│ │ │ │ └── base_dataset.py
│ │ │ ├── pure_motion/
│ │ │ │ ├── amass.py
│ │ │ │ ├── base_dataset.py
│ │ │ │ ├── cam_traj_utils.py
│ │ │ │ └── utils.py
│ │ │ ├── rich/
│ │ │ │ ├── resource/
│ │ │ │ │ ├── cam2params.pt
│ │ │ │ │ ├── seqname2imgrange.json
│ │ │ │ │ ├── test.txt
│ │ │ │ │ ├── train.txt
│ │ │ │ │ ├── val.txt
│ │ │ │ │ └── w2az_sahmr.json
│ │ │ │ ├── rich_motion_test.py
│ │ │ │ └── rich_utils.py
│ │ │ └── threedpw/
│ │ │ ├── threedpw_motion_test.py
│ │ │ ├── threedpw_motion_train.py
│ │ │ └── utils.py
│ │ ├── model/
│ │ │ ├── common_utils/
│ │ │ │ ├── optimizer.py
│ │ │ │ ├── scheduler.py
│ │ │ │ └── scheduler_cfg.py
│ │ │ └── gvhmr/
│ │ │ ├── callbacks/
│ │ │ │ ├── metric_3dpw.py
│ │ │ │ ├── metric_emdb.py
│ │ │ │ └── metric_rich.py
│ │ │ ├── gvhmr_pl.py
│ │ │ ├── gvhmr_pl_demo.py
│ │ │ ├── pipeline/
│ │ │ │ └── gvhmr_pipeline.py
│ │ │ └── utils/
│ │ │ ├── endecoder.py
│ │ │ ├── postprocess.py
│ │ │ └── stats_compose.py
│ │ ├── network/
│ │ │ ├── base_arch/
│ │ │ │ ├── embeddings/
│ │ │ │ │ └── rotary_embedding.py
│ │ │ │ └── transformer/
│ │ │ │ ├── encoder_rope.py
│ │ │ │ └── layer.py
│ │ │ ├── gvhmr/
│ │ │ │ └── relative_transformer.py
│ │ │ └── hmr2/
│ │ │ ├── __init__.py
│ │ │ ├── components/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── pose_transformer.py
│ │ │ │ └── t_cond_mlp.py
│ │ │ ├── configs/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── model_config.yaml
│ │ │ │ └── smpl_mean_params.npz
│ │ │ ├── hmr2.py
│ │ │ ├── smpl_head.py
│ │ │ ├── utils/
│ │ │ │ ├── geometry.py
│ │ │ │ ├── preproc.py
│ │ │ │ └── smpl_wrapper.py
│ │ │ └── vit.py
│ │ └── utils/
│ │ ├── body_model/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── body_model.py
│ │ │ ├── body_model_smplh.py
│ │ │ ├── body_model_smplx.py
│ │ │ ├── coco_aug_dict.pth
│ │ │ ├── min_lbs.py
│ │ │ ├── seg_part_info.npy
│ │ │ ├── smpl_3dpw14_J_regressor_sparse.pt
│ │ │ ├── smpl_coco17_J_regressor.pt
│ │ │ ├── smpl_lite.py
│ │ │ ├── smpl_neutral_J_regressor.pt
│ │ │ ├── smpl_vert_segmentation.json
│ │ │ ├── smplx2smpl_sparse.pt
│ │ │ ├── smplx_lite.py
│ │ │ ├── smplx_verts437.pt
│ │ │ └── utils.py
│ │ ├── callbacks/
│ │ │ ├── lr_monitor.py
│ │ │ ├── prog_bar.py
│ │ │ ├── simple_ckpt_saver.py
│ │ │ └── train_speed_timer.py
│ │ ├── comm/
│ │ │ └── gather.py
│ │ ├── eval/
│ │ │ └── eval_utils.py
│ │ ├── geo/
│ │ │ ├── augment_noisy_pose.py
│ │ │ ├── flip_utils.py
│ │ │ ├── hmr_cam.py
│ │ │ ├── hmr_global.py
│ │ │ ├── quaternion.py
│ │ │ └── transforms.py
│ │ ├── geo_transform.py
│ │ ├── ik/
│ │ │ └── ccd_ik.py
│ │ ├── kpts/
│ │ │ └── kp2d_utils.py
│ │ ├── matrix.py
│ │ ├── net_utils.py
│ │ ├── preproc/
│ │ │ ├── __init__.py
│ │ │ ├── slam.py
│ │ │ ├── tracker.py
│ │ │ ├── vitfeat_extractor.py
│ │ │ ├── vitpose.py
│ │ │ └── vitpose_pytorch/
│ │ │ ├── __init__.py
│ │ │ └── src/
│ │ │ └── vitpose_infer/
│ │ │ ├── __init__.py
│ │ │ ├── builder/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── backbones/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── alexnet.py
│ │ │ │ │ ├── cpm.py
│ │ │ │ │ ├── hourglass.py
│ │ │ │ │ ├── hourglass_ae.py
│ │ │ │ │ ├── hrformer.py
│ │ │ │ │ ├── litehrnet.py
│ │ │ │ │ ├── mobilenet_v2.py
│ │ │ │ │ ├── mobilenet_v3.py
│ │ │ │ │ ├── mspn.py
│ │ │ │ │ ├── regnet.py
│ │ │ │ │ ├── resnest.py
│ │ │ │ │ ├── resnext.py
│ │ │ │ │ ├── rsn.py
│ │ │ │ │ ├── scnet.py
│ │ │ │ │ ├── seresnet.py
│ │ │ │ │ ├── seresnext.py
│ │ │ │ │ ├── shufflenet_v1.py
│ │ │ │ │ ├── shufflenet_v2.py
│ │ │ │ │ ├── tcn.py
│ │ │ │ │ ├── test_torch.py
│ │ │ │ │ ├── utils/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── channel_shuffle.py
│ │ │ │ │ │ ├── inverted_residual.py
│ │ │ │ │ │ ├── make_divisible.py
│ │ │ │ │ │ ├── se_layer.py
│ │ │ │ │ │ └── utils.py
│ │ │ │ │ ├── vgg.py
│ │ │ │ │ ├── vipnas_mbv3.py
│ │ │ │ │ ├── vipnas_resnet.py
│ │ │ │ │ └── vit.py
│ │ │ │ ├── configs/
│ │ │ │ │ └── coco/
│ │ │ │ │ ├── ViTPose_base_coco_256x192.py
│ │ │ │ │ ├── ViTPose_base_simple_coco_256x192.py
│ │ │ │ │ ├── ViTPose_huge_coco_256x192.py
│ │ │ │ │ ├── ViTPose_huge_simple_coco_256x192.py
│ │ │ │ │ ├── ViTPose_large_coco_256x192.py
│ │ │ │ │ ├── ViTPose_large_simple_coco_256x192.py
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── heads/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── deconv_head.py
│ │ │ │ │ ├── deeppose_regression_head.py
│ │ │ │ │ ├── hmr_head.py
│ │ │ │ │ ├── interhand_3d_head.py
│ │ │ │ │ ├── temporal_regression_head.py
│ │ │ │ │ ├── topdown_heatmap_base_head.py
│ │ │ │ │ ├── topdown_heatmap_multi_stage_head.py
│ │ │ │ │ ├── topdown_heatmap_simple_head.py
│ │ │ │ │ ├── vipnas_heatmap_simple_head.py
│ │ │ │ │ └── voxelpose_head.py
│ │ │ │ └── model_builder.py
│ │ │ ├── model_builder.py
│ │ │ └── pose_utils/
│ │ │ ├── ViTPose_trt.py
│ │ │ ├── __init__.py
│ │ │ ├── convert_to_trt.py
│ │ │ ├── general_utils.py
│ │ │ ├── inference_test.py
│ │ │ ├── logger_helper.py
│ │ │ ├── pose_utils.py
│ │ │ ├── pose_viz.py
│ │ │ ├── timerr.py
│ │ │ └── visualizer.py
│ │ ├── pylogger.py
│ │ ├── seq_utils.py
│ │ ├── smplx_utils.py
│ │ ├── video_io_utils.py
│ │ ├── vis/
│ │ │ ├── README.md
│ │ │ ├── cv2_utils.py
│ │ │ ├── renderer.py
│ │ │ ├── renderer_tools.py
│ │ │ ├── renderer_utils.py
│ │ │ └── rich_logger.py
│ │ └── wis3d_utils.py
│ ├── pyproject.toml
│ ├── pyrightconfig.json
│ ├── requirements.txt
│ ├── setup.py
│ └── tools/
│ ├── demo/
│ │ ├── colab_demo.ipynb
│ │ ├── demo.py
│ │ └── demo_folder.py
│ ├── eval_pose.py
│ ├── train.py
│ ├── unitest/
│ │ ├── make_hydra_cfg.py
│ │ └── run_dataset.py
│ └── video/
│ ├── merge_folder.py
│ ├── merge_horizontal.py
│ └── merge_vertical.py
└── common_metrics_on_video_quality/
├── .gitignore
├── README.md
├── calculate_clip.py
├── calculate_fvd.py
├── calculate_fvd_styleganv.py
├── calculate_lpips.py
├── calculate_psnr.py
├── calculate_ssim.py
├── download_eval_visual.sh
├── eval_prompts.json
└── eval_visual.sh
================================================
FILE CONTENTS
================================================
================================================
FILE: CogVideo/.github/ISSUE_TEMPLATE/bug_report.yaml
================================================
name: "\U0001F41B Bug Report"
description: Submit a bug report to help us improve CogVideoX / 提交一个 Bug 问题报告来帮助我们改进 CogVideoX 开源模型
body:
- type: textarea
id: system-info
attributes:
label: System Info / 系統信息
description: Your operating environment / 您的运行环境信息
placeholder: Includes Cuda version, Diffusers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本,Diffusers,Python版本,操作系统,硬件信息(如果您怀疑是硬件方面的问题)...
validations:
required: true
- type: checkboxes
id: information-scripts-examples
attributes:
label: Information / 问题信息
description: 'The problem arises when using: / 问题出现在'
options:
- label: "The official example scripts / 官方的示例脚本"
- label: "My own modified scripts / 我自己修改的脚本和任务"
- type: textarea
id: reproduction
validations:
required: true
attributes:
label: Reproduction / 复现过程
description: |
Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit.
If you have code snippets, error messages, stack traces, please provide them here as well.
Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code.
请提供能重现您遇到的问题的代码示例,最好是最小复现单元。
如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。
请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。
placeholder: |
Steps to reproduce the behavior/复现Bug的步骤:
1.
2.
3.
- type: textarea
id: expected-behavior
validations:
required: true
attributes:
label: Expected behavior / 期待表现
description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。"
================================================
FILE: CogVideo/.github/ISSUE_TEMPLATE/feature-request.yaml
================================================
name: "\U0001F680 Feature request"
description: Submit a request for a new CogVideoX feature / 提交一个新的 CogVideoX开源模型的功能建议
labels: [ "feature" ]
body:
- type: textarea
id: feature-request
validations:
required: true
attributes:
label: Feature request / 功能建议
description: |
A brief description of the functional proposal. Links to corresponding papers and code are desirable.
对功能建议的简述。最好提供对应的论文和代码链接。
- type: textarea
id: motivation
validations:
required: true
attributes:
label: Motivation / 动机
description: |
Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here.
您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。
- type: textarea
id: contribution
validations:
required: true
attributes:
label: Your contribution / 您的贡献
description: |
Your PR link or any other link you can help with.
您的PR链接或者其他您能提供帮助的链接。
================================================
FILE: CogVideo/.github/PULL_REQUEST_TEMPLATE/pr_template.md
================================================
# Raise valuable PR / 提出有价值的PR
## Caution / 注意事项:
Users should keep the following points in mind when submitting PRs:
1. Ensure that your code meets the requirements in the [specification](../../resources/contribute.md).
2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs.
用户在提交PR时候应该注意以下几点:
1. 确保您的代码符合 [规范](../../resources/contribute_zh.md) 中的要求。
2. 提出的PR应该具有针对性,如果具有多个不同的想法和优化方案,应该分配到不同的PR中。
## 不应该提出的PR / PRs that should not be proposed
If a developer proposes a PR about any of the following, it may be closed or Rejected.
1. those that don't describe improvement options.
2. multiple issues of different types combined in one PR.
3. The proposed PR is highly duplicative of already existing PRs.
如果开发者提出关于以下方面的PR,则可能会被直接关闭或拒绝通过。
1. 没有说明改进方案的。
2. 多个不同类型的问题合并在一个PR中的。
3. 提出的PR与已经存在的PR高度重复的。
# 检查您的PR
- [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分?
- [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。
- [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。
- [ ] Did you write new required tests? / 您是否编写了新的必要测试?
- [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题
================================================
FILE: CogVideo/.gitignore
================================================
*__pycache__/
samples*/
runs/
checkpoints/
master_ip
logs/
*.DS_Store
.idea
output*
test*
================================================
FILE: CogVideo/LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2024 CogVideo Model Team @ Zhipu AI
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: CogVideo/MODEL_LICENSE
================================================
The CogVideoX License
1. Definitions
“Licensor” means the CogVideoX Model Team that distributes its Software.
“Software” means the CogVideoX model parameters made available under this license.
2. License Grant
Under the terms and conditions of this license, the licensor hereby grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license. The intellectual property rights of the generated content belong to the user to the extent permitted by applicable local laws.
This license allows you to freely use all open-source models in this repository for academic research. Users who wish to use the models for commercial purposes must register and obtain a basic commercial license in https://open.bigmodel.cn/mla/form .
Users who have registered and obtained the basic commercial license can use the models for commercial activities for free, but must comply with all terms and conditions of this license. Additionally, the number of service users (visits) for your commercial activities must not exceed 1 million visits per month.
If the number of service users (visits) for your commercial activities exceeds 1 million visits per month, you need to contact our business team to obtain more commercial licenses.
The above copyright statement and this license statement should be included in all copies or significant portions of this software.
3. Restriction
You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any military, or illegal purposes.
You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
4. Disclaimer
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
5. Limitation of Liability
EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
6. Dispute Resolution
This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at license@zhipuai.cn.
1. 定义
“许可方”是指分发其软件的 CogVideoX 模型团队。
“软件”是指根据本许可提供的 CogVideoX 模型参数。
2. 许可授予
根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。生成内容的知识产权所属,可根据适用当地法律的规定,在法律允许的范围内由用户享有生成内容的知识产权或其他权利。
本许可允许您免费使用本仓库中的所有开源模型进行学术研究。对于希望将模型用于商业目的的用户,需在 https://open.bigmodel.cn/mla/form 完成登记并获得基础商用授权。
经过登记并获得基础商用授权的用户可以免费使用本模型进行商业活动,但必须遵守本许可的所有条款和条件。
在本许可证下,您的商业活动的服务用户数量(访问量)不得超过100万人次访问 / 每月。如果超过,您需要与我们的商业团队联系以获得更多的商业许可。
上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。
3.限制
您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。
您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。
4.免责声明
本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。
在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。
5. 责任限制
除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。
6.争议解决
本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。
请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 与我们联系。
================================================
FILE: CogVideo/README.md
================================================
# CogVideo & CogVideoX
[中文阅读](./README_zh.md)
[日本語で読む](./README_ja.md)
📍 Visit QingYing and API Platform to experience larger-scale commercial video generation models.
## Project Updates
- 🔥🔥 **News**: ```2024/10/13```: A more cost-effective fine-tuning framework for `CogVideoX-5B` that works with a single
4090 GPU, [cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory), has been released. It supports
fine-tuning with multiple resolutions. Feel free to use it!
- 🔥 **News**: ```2024/10/10```: We have updated our technical report. Please
click [here](https://arxiv.org/pdf/2408.06072) to view it. More training details and a demo have been added. To see
the demo, click [here](https://yzy-thu.github.io/CogVideoX-demo/).- 🔥 **News**: ```2024/10/09```: We have publicly
released the [technical documentation](https://zhipu-ai.feishu.cn/wiki/DHCjw1TrJiTyeukfc9RceoSRnCh) for CogVideoX
fine-tuning on Feishu, further increasing distribution flexibility. All examples in the public documentation can be
fully reproduced.
- 🔥 **News**: ```2024/9/19```: We have open-sourced the CogVideoX series image-to-video model **CogVideoX-5B-I2V**.
This model can take an image as a background input and generate a video combined with prompt words, offering greater
controllability. With this, the CogVideoX series models now support three tasks: text-to-video generation, video
continuation, and image-to-video generation. Welcome to try it online
at [Experience](https://huggingface.co/spaces/THUDM/CogVideoX-5B-Space).
- 🔥 ```2024/9/19```: The Caption
model [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption), used in the training process of
CogVideoX to convert video data into text descriptions, has been open-sourced. Welcome to download and use it.
- 🔥 ```2024/8/27```: We have open-sourced a larger model in the CogVideoX series, **CogVideoX-5B**. We have
significantly optimized the model's inference performance, greatly lowering the inference threshold. You can run *
*CogVideoX-2B** on older GPUs like `GTX 1080TI`, and **CogVideoX-5B** on desktop GPUs like `RTX 3060`. Please strictly
follow the [requirements](requirements.txt) to update and install dependencies, and refer
to [cli_demo](inference/cli_demo.py) for inference code. Additionally, the open-source license for the **CogVideoX-2B
** model has been changed to the **Apache 2.0 License**.
- 🔥 ```2024/8/6```: We have open-sourced **3D Causal VAE**, used for **CogVideoX-2B**, which can reconstruct videos with
almost no loss.
- 🔥 ```2024/8/6```: We have open-sourced the first model of the CogVideoX series video generation models, **CogVideoX-2B
**.
- 🌱 **Source**: ```2022/5/19```: We have open-sourced the CogVideo video generation model (now you can see it in
the `CogVideo` branch). This is the first open-source large Transformer-based text-to-video generation model. You can
access the [ICLR'23 paper](https://arxiv.org/abs/2205.15868) for technical details.
## Table of Contents
Jump to a specific section:
- [Quick Start](#Quick-Start)
- [SAT](#sat)
- [Diffusers](#Diffusers)
- [CogVideoX-2B Video Works](#cogvideox-2b-gallery)
- [Introduction to the CogVideoX Model](#Model-Introduction)
- [Full Project Structure](#project-structure)
- [Inference](#inference)
- [SAT](#sat)
- [Tools](#tools)
- [Introduction to CogVideo(ICLR'23) Model](#cogvideoiclr23)
- [Citations](#Citation)
- [Open Source Project Plan](#Open-Source-Project-Plan)
- [Model License](#Model-License)
## Quick Start
### Prompt Optimization
Before running the model, please refer to [this guide](inference/convert_demo.py) to see how we use large models like
GLM-4 (or other comparable products, such as GPT-4) to optimize the model. This is crucial because the model is trained
with long prompts, and a good prompt directly impacts the quality of the video generation.
### SAT
**Please make sure your Python version is between 3.10 and 3.12, inclusive of both 3.10 and 3.12.**
Follow instructions in [sat_demo](sat/README.md): Contains the inference code and fine-tuning code of SAT weights. It is
recommended to improve based on the CogVideoX model structure. Innovative researchers use this code to better perform
rapid stacking and development.
### Diffusers
**Please make sure your Python version is between 3.10 and 3.12, inclusive of both 3.10 and 3.12.**
```
pip install -r requirements.txt
```
Then follow [diffusers_demo](inference/cli_demo.py): A more detailed explanation of the inference code, mentioning the
significance of common parameters.
For more details on quantized inference, please refer
to [diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao/). With Diffusers and TorchAO, quantized inference
is also possible leading to memory-efficient inference as well as speedup in some cases when compiled. A full list of
memory and time benchmarks with various settings on A100 and H100 has been published
at [diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao).
## Gallery
### CogVideoX-5B
### CogVideoX-2B
To view the corresponding prompt words for the gallery, please click [here](resources/galary_prompt.md)
## Model Introduction
CogVideoX is an open-source version of the video generation model originating
from [QingYing](https://chatglm.cn/video?lang=en?fr=osm_cogvideo). The table below displays the list of video generation
models we currently offer, along with their foundational information.
Model Name
CogVideoX-2B
CogVideoX-5B
CogVideoX-5B-I2V
Model Description
Entry-level model, balancing compatibility. Low cost for running and secondary development.
Larger model with higher video generation quality and better visual effects.
CogVideoX-5B image-to-video version.
Inference Precision
FP16*(recommended), BF16, FP32, FP8*, INT8, not supported: INT4
BF16 (recommended), FP16, FP32, FP8*, INT8, not supported: INT4
Single GPU Memory Usage
SAT FP16: 18GB diffusers FP16: from 4GB* diffusers INT8 (torchao): from 3.6GB*
SAT BF16: 26GB diffusers BF16: from 5GB* diffusers INT8 (torchao): from 4.4GB*
Multi-GPU Inference Memory Usage
FP16: 10GB* using diffusers
BF16: 15GB* using diffusers
Inference Speed (Step = 50, FP/BF16)
Single A100: ~90 seconds Single H100: ~45 seconds
Single A100: ~180 seconds Single H100: ~90 seconds
**Data Explanation**
+ While testing using the diffusers library, all optimizations included in the diffusers library were enabled. This
scheme has not been tested for actual memory usage on devices outside of **NVIDIA A100 / H100** architectures.
Generally, this scheme can be adapted to all **NVIDIA Ampere architecture** and above devices. If optimizations are
disabled, memory consumption will multiply, with peak memory usage being about 3 times the value in the table.
However, speed will increase by about 3-4 times. You can selectively disable some optimizations, including:
```
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
```
+ For multi-GPU inference, the `enable_sequential_cpu_offload()` optimization needs to be disabled.
+ Using INT8 models will slow down inference, which is done to accommodate lower-memory GPUs while maintaining minimal
video quality loss, though inference speed will significantly decrease.
+ The CogVideoX-2B model was trained in `FP16` precision, and all CogVideoX-5B models were trained in `BF16` precision.
We recommend using the precision in which the model was trained for inference.
+ [PytorchAO](https://github.com/pytorch/ao) and [Optimum-quanto](https://github.com/huggingface/optimum-quanto/) can be
used to quantize the text encoder, transformer, and VAE modules to reduce the memory requirements of CogVideoX. This
allows the model to run on free T4 Colabs or GPUs with smaller memory! Also, note that TorchAO quantization is fully
compatible with `torch.compile`, which can significantly improve inference speed. FP8 precision must be used on
devices with NVIDIA H100 and above, requiring source installation of `torch`, `torchao`, `diffusers`, and `accelerate`
Python packages. CUDA 12.4 is recommended.
+ The inference speed tests also used the above memory optimization scheme. Without memory optimization, inference speed
increases by about 10%. Only the `diffusers` version of the model supports quantization.
+ The model only supports English input; other languages can be translated into English for use via large model
refinement.
+ The memory usage of model fine-tuning is tested in an `8 * H100` environment, and the program automatically
uses `Zero 2` optimization. If a specific number of GPUs is marked in the table, that number or more GPUs must be used
for fine-tuning.
## Friendly Links
We highly welcome contributions from the community and actively contribute to the open-source community. The following
works have already been adapted for CogVideoX, and we invite everyone to use them:
+ [CogVideoX-Fun](https://github.com/aigc-apps/CogVideoX-Fun): CogVideoX-Fun is a modified pipeline based on the
CogVideoX architecture, supporting flexible resolutions and multiple launch methods.
+ [CogStudio](https://github.com/pinokiofactory/cogstudio): A separate repository for CogVideo's Gradio Web UI, which
supports more functional Web UIs.
+ [Xorbits Inference](https://github.com/xorbitsai/inference): A powerful and comprehensive distributed inference
framework, allowing you to easily deploy your own models or the latest cutting-edge open-source models with just one
click.
+ [ComfyUI-CogVideoXWrapper](https://github.com/kijai/ComfyUI-CogVideoXWrapper) Use the ComfyUI framework to integrate
CogVideoX into your workflow.
+ [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys): VideoSys provides a user-friendly, high-performance
infrastructure for video generation, with full pipeline support and continuous integration of the latest models and
techniques.
+ [AutoDL Space](https://www.codewithgpu.com/i/THUDM/CogVideo/CogVideoX-5b-demo): A one-click deployment Huggingface
Space image provided by community members.
+ [Interior Design Fine-Tuning Model](https://huggingface.co/collections/bertjiazheng/koolcogvideox-66e4762f53287b7f39f8f3ba):
is a fine-tuned model based on CogVideoX, specifically designed for interior design.
+ [xDiT](https://github.com/xdit-project/xDiT): xDiT is a scalable inference engine for Diffusion Transformers (DiTs)
on multiple GPU Clusters. xDiT supports real-time image and video generations services.
[cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory): A cost-effective
fine-tuning framework for CogVideoX, compatible with the `diffusers` version model. Supports more resolutions, and
fine-tuning CogVideoX-5B can be done with a single 4090 GPU.
+ [CogVideoX-Interpolation](https://github.com/feizc/CogvideX-Interpolation): A pipeline based on the modified CogVideoX
structure, aimed at providing greater flexibility for keyframe interpolation generation.
+ [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): DiffSynth Studio is a diffusion engine. It has
restructured the architecture, including text encoders, UNet, VAE, etc., enhancing computational performance while
maintaining compatibility with open-source community models. The framework has been adapted for CogVideoX.
## Project Structure
This open-source repository will guide developers to quickly get started with the basic usage and fine-tuning examples
of the **CogVideoX** open-source model.
### Quick Start with Colab
Here provide three projects that can be run directly on free Colab T4 instances:
+ [CogVideoX-5B-T2V-Colab.ipynb](https://colab.research.google.com/drive/1pCe5s0bC_xuXbBlpvIH1z0kfdTLQPzCS?usp=sharing):
CogVideoX-5B Text-to-Video Colab code.
+ [CogVideoX-5B-T2V-Int8-Colab.ipynb](https://colab.research.google.com/drive/1DUffhcjrU-uz7_cpuJO3E_D4BaJT7OPa?usp=sharing):
CogVideoX-5B Quantized Text-to-Video Inference Colab code, which takes about 30 minutes per run.
+ [CogVideoX-5B-I2V-Colab.ipynb](https://colab.research.google.com/drive/17CqYCqSwz39nZAX2YyonDxosVKUZGzcX?usp=sharing):
CogVideoX-5B Image-to-Video Colab code.
+ [CogVideoX-5B-V2V-Colab.ipynb](https://colab.research.google.com/drive/1comfGAUJnChl5NwPuO8Ox5_6WCy4kbNN?usp=sharing):
CogVideoX-5B Video-to-Video Colab code.
### Inference
+ [dcli_demo](inference/cli_demo.py): A more detailed inference code explanation, including the significance of
common parameters. All of this is covered here.
+ [cli_demo_quantization](inference/cli_demo_quantization.py):
Quantized model inference code that can run on devices with lower memory. You can also modify this code to support
running CogVideoX models in FP8 precision.
+ [diffusers_vae_demo](inference/cli_vae_demo.py): Code for running VAE inference separately.
+ [space demo](inference/gradio_composite_demo): The same GUI code as used in the Huggingface Space, with frame
interpolation and super-resolution tools integrated.
+ [convert_demo](inference/convert_demo.py): How to convert user input into long-form input suitable for CogVideoX.
Since CogVideoX is trained on long texts, we need to transform the input text distribution to match the training data
using an LLM. The script defaults to using GLM-4, but it can be replaced with GPT, Gemini, or any other large language
model.
+ [gradio_web_demo](inference/gradio_composite_demo): A simple Gradio web application demonstrating how to use the
CogVideoX-2B / 5B model to generate videos. Similar to our Huggingface Space, you can use this script to run a simple
web application for video generation.
### finetune
+ [finetune_demo](finetune/README.md): Fine-tuning scheme and details of the diffusers version of the CogVideoX model.
### sat
+ [sat_demo](sat/README.md): Contains the inference code and fine-tuning code of SAT weights. It is recommended to
improve based on the CogVideoX model structure. Innovative researchers use this code to better perform rapid stacking
and development.
### Tools
This folder contains some tools for model conversion / caption generation, etc.
+ [convert_weight_sat2hf](tools/convert_weight_sat2hf.py): Converts SAT model weights to Huggingface model weights.
+ [caption_demo](tools/caption/README.md): Caption tool, a model that understands videos and outputs descriptions in
text.
+ [export_sat_lora_weight](tools/export_sat_lora_weight.py): SAT fine-tuning model export tool, exports the SAT Lora
Adapter in diffusers format.
+ [load_cogvideox_lora](tools/load_cogvideox_lora.py): Tool code for loading the diffusers version of fine-tuned Lora
Adapter.
+ [llm_flux_cogvideox](tools/llm_flux_cogvideox/llm_flux_cogvideox.py): Automatically generate videos using an
open-source local large language model + Flux + CogVideoX.
+ [parallel_inference_xdit](tools/parallel_inference/parallel_inference_xdit.py):
Supported by [xDiT](https://github.com/xdit-project/xDiT), parallelize the
video generation process on multiple GPUs.
## CogVideo(ICLR'23)
The official repo for the
paper: [CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers](https://arxiv.org/abs/2205.15868)
is on the [CogVideo branch](https://github.com/THUDM/CogVideo/tree/CogVideo)
**CogVideo is able to generate relatively high-frame-rate videos.**
A 4-second clip of 32 frames is shown below.


The demo for CogVideo is at [https://models.aminer.cn/cogvideo](https://models.aminer.cn/cogvideo/), where you can get
hands-on practice on text-to-video generation. *The original input is in Chinese.*
## Citation
🌟 If you find our work helpful, please leave us a star and cite our paper.
```
@article{yang2024cogvideox,
title={CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer},
author={Yang, Zhuoyi and Teng, Jiayan and Zheng, Wendi and Ding, Ming and Huang, Shiyu and Xu, Jiazheng and Yang, Yuanming and Hong, Wenyi and Zhang, Xiaohan and Feng, Guanyu and others},
journal={arXiv preprint arXiv:2408.06072},
year={2024}
}
@article{hong2022cogvideo,
title={CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers},
author={Hong, Wenyi and Ding, Ming and Zheng, Wendi and Liu, Xinghan and Tang, Jie},
journal={arXiv preprint arXiv:2205.15868},
year={2022}
}
```
We welcome your contributions! You can click [here](resources/contribute.md) for more information.
## License Agreement
The code in this repository is released under the [Apache 2.0 License](LICENSE).
The CogVideoX-2B model (including its corresponding Transformers module and VAE module) is released under
the [Apache 2.0 License](LICENSE).
The CogVideoX-5B model (Transformers module, include I2V and T2V) is released under
the [CogVideoX LICENSE](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE).
================================================
FILE: CogVideo/README_ja.md
================================================
# CogVideo & CogVideoX
[Read this in English](./README_zh.md)
[中文阅读](./README_zh.md)
CogVideo的demo网站在[https://models.aminer.cn/cogvideo](https://models.aminer.cn/cogvideo/)。您可以在这里体验文本到视频生成。
*原始输入为中文。*
## 引用
🌟 如果您发现我们的工作有所帮助,欢迎引用我们的文章,留下宝贵的stars
```
@article{yang2024cogvideox,
title={CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer},
author={Yang, Zhuoyi and Teng, Jiayan and Zheng, Wendi and Ding, Ming and Huang, Shiyu and Xu, Jiazheng and Yang, Yuanming and Hong, Wenyi and Zhang, Xiaohan and Feng, Guanyu and others},
journal={arXiv preprint arXiv:2408.06072},
year={2024}
}
@article{hong2022cogvideo,
title={CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers},
author={Hong, Wenyi and Ding, Ming and Zheng, Wendi and Liu, Xinghan and Tang, Jie},
journal={arXiv preprint arXiv:2205.15868},
year={2022}
}
```
我们欢迎您的贡献,您可以点击[这里](resources/contribute_zh.md)查看更多信息。
## 模型协议
本仓库代码使用 [Apache 2.0 协议](LICENSE) 发布。
CogVideoX-2B 模型 (包括其对应的Transformers模块,VAE模块) 根据 [Apache 2.0 协议](LICENSE) 许可证发布。
CogVideoX-5B 模型 (Transformers 模块,包括图生视频,文生视频版本)
根据 [CogVideoX LICENSE](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE)
许可证发布。
================================================
FILE: CogVideo/download.sh
================================================
mkdir CogVideoX-2b-sat
cd CogVideoX-2b-sat
wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
mv 'index.html?dl=1' vae.zip
unzip vae.zip
wget https://cloud.tsinghua.edu.cn/f/556a3e1329e74f1bac45/?dl=1
mv 'index.html?dl=1' transformer.zip
unzip transformer.zip
================================================
FILE: CogVideo/finetune/README.md
================================================
# CogVideoX diffusers Fine-tuning Guide
[中文阅读](./README_zh.md)
[日本語で読む](./README_ja.md)
This feature is not fully complete yet. If you want to check the fine-tuning for the SAT version, please
see [here](../sat/README_zh.md). The dataset format is different from this version.
## Hardware Requirements
+ CogVideoX-2B / 5B LoRA: 1 * A100 (5B need to use `--use_8bit_adam`)
+ CogVideoX-2B SFT: 8 * A100 (Working)
+ CogVideoX-5B-I2V is not supported yet.
## Install Dependencies
Since the related code has not been merged into the diffusers release, you need to base your fine-tuning on the
diffusers branch. Please follow the steps below to install dependencies:
```shell
git clone https://github.com/huggingface/diffusers.git
cd diffusers # Now in Main branch
pip install -e .
```
## Prepare the Dataset
First, you need to prepare the dataset. The dataset format should be as follows, with `videos.txt` containing the list
of videos in the `videos` directory:
```
.
├── prompts.txt
├── videos
└── videos.txt
```
You can download
the [Disney Steamboat Willie](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset) dataset from
here.
This video fine-tuning dataset is used as a test for fine-tuning.
## Configuration Files and Execution
The `accelerate` configuration files are as follows:
+ `accelerate_config_machine_multi.yaml`: Suitable for multi-GPU use
+ `accelerate_config_machine_single.yaml`: Suitable for single-GPU use
The configuration for the `finetune` script is as follows:
```
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \ # Use accelerate to launch multi-GPU training with the config file accelerate_config_machine_single.yaml
train_cogvideox_lora.py \ # Training script train_cogvideox_lora.py for LoRA fine-tuning on CogVideoX model
--gradient_checkpointing \ # Enable gradient checkpointing to reduce memory usage
--pretrained_model_name_or_path $MODEL_PATH \ # Path to the pretrained model, specified by $MODEL_PATH
--cache_dir $CACHE_PATH \ # Cache directory for model files, specified by $CACHE_PATH
--enable_tiling \ # Enable tiling technique to process videos in chunks, saving memory
--enable_slicing \ # Enable slicing to further optimize memory by slicing inputs
--instance_data_root $DATASET_PATH \ # Dataset path specified by $DATASET_PATH
--caption_column prompts.txt \ # Specify the file prompts.txt for video descriptions used in training
--video_column videos.txt \ # Specify the file videos.txt for video paths used in training
--validation_prompt "" \ # Prompt used for generating validation videos during training
--validation_prompt_separator ::: \ # Set ::: as the separator for validation prompts
--num_validation_videos 1 \ # Generate 1 validation video per validation round
--validation_epochs 100 \ # Perform validation every 100 training epochs
--seed 42 \ # Set random seed to 42 for reproducibility
--rank 128 \ # Set the rank for LoRA parameters to 128
--lora_alpha 64 \ # Set the alpha parameter for LoRA to 64, adjusting LoRA learning rate
--mixed_precision bf16 \ # Use bf16 mixed precision for training to save memory
--output_dir $OUTPUT_PATH \ # Specify the output directory for the model, defined by $OUTPUT_PATH
--height 480 \ # Set video height to 480 pixels
--width 720 \ # Set video width to 720 pixels
--fps 8 \ # Set video frame rate to 8 frames per second
--max_num_frames 49 \ # Set the maximum number of frames per video to 49
--skip_frames_start 0 \ # Skip 0 frames at the start of the video
--skip_frames_end 0 \ # Skip 0 frames at the end of the video
--train_batch_size 4 \ # Set training batch size to 4
--num_train_epochs 30 \ # Total number of training epochs set to 30
--checkpointing_steps 1000 \ # Save model checkpoint every 1000 steps
--gradient_accumulation_steps 1 \ # Accumulate gradients for 1 step, updating after each batch
--learning_rate 1e-3 \ # Set learning rate to 0.001
--lr_scheduler cosine_with_restarts \ # Use cosine learning rate scheduler with restarts
--lr_warmup_steps 200 \ # Warm up the learning rate for the first 200 steps
--lr_num_cycles 1 \ # Set the number of learning rate cycles to 1
--optimizer AdamW \ # Use the AdamW optimizer
--adam_beta1 0.9 \ # Set Adam optimizer beta1 parameter to 0.9
--adam_beta2 0.95 \ # Set Adam optimizer beta2 parameter to 0.95
--max_grad_norm 1.0 \ # Set maximum gradient clipping value to 1.0
--allow_tf32 \ # Enable TF32 to speed up training
--report_to wandb # Use Weights and Biases (wandb) for logging and monitoring the training
```
## Running the Script to Start Fine-tuning
Single Node (One GPU or Multi GPU) fine-tuning:
```shell
bash finetune_single_rank.sh
```
Multi-Node fine-tuning:
```shell
bash finetune_multi_rank.sh # Needs to be run on each node
```
## Loading the Fine-tuned Model
+ Please refer to [cli_demo.py](../inference/cli_demo.py) for how to load the fine-tuned model.
## Best Practices
+ Includes 70 training videos with a resolution of `200 x 480 x 720` (frames x height x width). By skipping frames in
the data preprocessing, we created two smaller datasets with 49 and 16 frames to speed up experimentation, as the
maximum frame limit recommended by the CogVideoX team is 49 frames. We split the 70 videos into three groups of 10,
25, and 50 videos, with similar conceptual nature.
+ Using 25 or more videos works best when training new concepts and styles.
+ It works better to train using identifier tokens specified with `--id_token`. This is similar to Dreambooth training,
but regular fine-tuning without such tokens also works.
+ The original repository used `lora_alpha` set to 1. We found this value ineffective across multiple runs, likely due
to differences in the backend and training setup. Our recommendation is to set `lora_alpha` equal to rank or rank //
2.
+ We recommend using a rank of 64 or higher.
================================================
FILE: CogVideo/finetune/README_ja.md
================================================
# CogVideoX diffusers 微調整方法
[Read this in English.](./README_zh)
[中文阅读](./README_zh.md)
この機能はまだ完全に完成していません。SATバージョンの微調整を確認したい場合は、[こちら](../sat/README_ja.md)を参照してください。本バージョンとは異なるデータセット形式を使用しています。
## ハードウェア要件
+ CogVideoX-2B / 5B T2V LORA: 1 * A100 (5B need to use `--use_8bit_adam`)
+ CogVideoX-2B SFT: 8 * A100 (動作確認済み)
+ CogVideoX-5B-I2V まだサポートしていません
## 依存関係のインストール
関連コードはまだdiffusersのリリース版に統合されていないため、diffusersブランチを使用して微調整を行う必要があります。以下の手順に従って依存関係をインストールしてください:
```shell
git clone https://github.com/huggingface/diffusers.git
cd diffusers # Now in Main branch
pip install -e .
```
## データセットの準備
まず、データセットを準備する必要があります。データセットの形式は以下のようになります。
```
.
├── prompts.txt
├── videos
└── videos.txt
```
[ディズニースチームボートウィリー](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset)をここからダウンロードできます。
ビデオ微調整データセットはテスト用として使用されます。
## 設定ファイルと実行
`accelerate` 設定ファイルは以下の通りです:
+ accelerate_config_machine_multi.yaml 複数GPU向け
+ accelerate_config_machine_single.yaml 単一GPU向け
`finetune` スクリプト設定ファイルの例:
```
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \ # accelerateを使用してmulti-GPUトレーニングを起動、設定ファイルはaccelerate_config_machine_single.yaml
train_cogvideox_lora.py \ # LoRAの微調整用のトレーニングスクリプトtrain_cogvideox_lora.pyを実行
--gradient_checkpointing \ # メモリ使用量を減らすためにgradient checkpointingを有効化
--pretrained_model_name_or_path $MODEL_PATH \ # 事前学習済みモデルのパスを$MODEL_PATHで指定
--cache_dir $CACHE_PATH \ # モデルファイルのキャッシュディレクトリを$CACHE_PATHで指定
--enable_tiling \ # メモリ節約のためにタイル処理を有効化し、動画をチャンク分けして処理
--enable_slicing \ # 入力をスライスしてさらにメモリ最適化
--instance_data_root $DATASET_PATH \ # データセットのパスを$DATASET_PATHで指定
--caption_column prompts.txt \ # トレーニングで使用する動画の説明ファイルをprompts.txtで指定
--video_column videos.txt \ # トレーニングで使用する動画のパスファイルをvideos.txtで指定
--validation_prompt "" \ # トレーニング中に検証用の動画を生成する際のプロンプト
--validation_prompt_separator ::: \ # 検証プロンプトの区切り文字を:::に設定
--num_validation_videos 1 \ # 各検証ラウンドで1本の動画を生成
--validation_epochs 100 \ # 100エポックごとに検証を実施
--seed 42 \ # 再現性を保証するためにランダムシードを42に設定
--rank 128 \ # LoRAのパラメータのランクを128に設定
--lora_alpha 64 \ # LoRAのalphaパラメータを64に設定し、LoRAの学習率を調整
--mixed_precision bf16 \ # bf16混合精度でトレーニングし、メモリを節約
--output_dir $OUTPUT_PATH \ # モデルの出力ディレクトリを$OUTPUT_PATHで指定
--height 480 \ # 動画の高さを480ピクセルに設定
--width 720 \ # 動画の幅を720ピクセルに設定
--fps 8 \ # 動画のフレームレートを1秒あたり8フレームに設定
--max_num_frames 49 \ # 各動画の最大フレーム数を49に設定
--skip_frames_start 0 \ # 動画の最初のフレームを0スキップ
--skip_frames_end 0 \ # 動画の最後のフレームを0スキップ
--train_batch_size 4 \ # トレーニングのバッチサイズを4に設定
--num_train_epochs 30 \ # 総トレーニングエポック数を30に設定
--checkpointing_steps 1000 \ # 1000ステップごとにモデルのチェックポイントを保存
--gradient_accumulation_steps 1 \ # 1ステップの勾配累積を行い、各バッチ後に更新
--learning_rate 1e-3 \ # 学習率を0.001に設定
--lr_scheduler cosine_with_restarts \ # リスタート付きのコサイン学習率スケジューラを使用
--lr_warmup_steps 200 \ # トレーニングの最初の200ステップで学習率をウォームアップ
--lr_num_cycles 1 \ # 学習率のサイクル数を1に設定
--optimizer AdamW \ # AdamWオプティマイザーを使用
--adam_beta1 0.9 \ # Adamオプティマイザーのbeta1パラメータを0.9に設定
--adam_beta2 0.95 \ # Adamオプティマイザーのbeta2パラメータを0.95に設定
--max_grad_norm 1.0 \ # 勾配クリッピングの最大値を1.0に設定
--allow_tf32 \ # トレーニングを高速化するためにTF32を有効化
--report_to wandb # Weights and Biasesを使用してトレーニングの記録とモニタリングを行う
```
## 微調整を開始
単一マシン (シングルGPU、マルチGPU) での微調整:
```shell
bash finetune_single_rank.sh
```
複数マシン・マルチGPUでの微調整:
```shell
bash finetune_multi_rank.sh # 各ノードで実行する必要があります。
```
## 微調整済みモデルのロード
+ 微調整済みのモデルをロードする方法については、[cli_demo.py](../inference/cli_demo.py) を参照してください。
## ベストプラクティス
+ 解像度が `200 x 480 x 720`(フレーム数 x 高さ x 幅)のトレーニングビデオが70本含まれています。データ前処理でフレームをスキップすることで、49フレームと16フレームの小さなデータセットを作成しました。これは実験を加速するためのもので、CogVideoXチームが推奨する最大フレーム数制限は49フレームです。
+ 25本以上のビデオが新しい概念やスタイルのトレーニングに最適です。
+ 現在、`--id_token` を指定して識別トークンを使用してトレーニングする方が効果的です。これはDreamboothトレーニングに似ていますが、通常の微調整でも機能します。
+ 元のリポジトリでは `lora_alpha` を1に設定していましたが、複数の実行でこの値が効果的でないことがわかりました。モデルのバックエンドやトレーニング設定によるかもしれません。私たちの提案は、lora_alphaをrankと同じか、rank // 2に設定することです。
+ Rank 64以上の設定を推奨します。
================================================
FILE: CogVideo/finetune/README_zh.md
================================================
# CogVideoX diffusers 微调方案
[Read this in English](./README_zh.md)
[日本語で読む](./README_ja.md)
本功能尚未完全完善,如果您想查看SAT版本微调,请查看[这里](../sat/README_zh.md)。其数据集格式与本版本不同。
## 硬件要求
+ CogVideoX-2B / 5B T2V LORA: 1 * A100 (5B need to use `--use_8bit_adam`)
+ CogVideoX-2B SFT: 8 * A100 (制作中)
+ CogVideoX-5B-I2V 暂未支持
## 安装依赖
由于相关代码还没有被合并到diffusers发行版,你需要基于diffusers分支进行微调。请按照以下步骤安装依赖:
```shell
git clone https://github.com/huggingface/diffusers.git
cd diffusers # Now in Main branch
pip install -e .
```
## 准备数据集
首先,你需要准备数据集,数据集格式如下,其中,videos.txt 存放 videos 中的视频。
```
.
├── prompts.txt
├── videos
└── videos.txt
```
你可以从这里下载 [迪士尼汽船威利号](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset)
视频微调数据集作为测试微调。
## 配置文件和运行
`accelerate` 配置文件如下:
+ accelerate_config_machine_multi.yaml 适合多GPU使用
+ accelerate_config_machine_single.yaml 适合单GPU使用
`finetune` 脚本配置文件如下:
```shell
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \ # 使用 accelerate 启动多GPU训练,配置文件为 accelerate_config_machine_single.yaml
train_cogvideox_lora.py \ # 运行的训练脚本为 train_cogvideox_lora.py,用于在 CogVideoX 模型上进行 LoRA 微调
--gradient_checkpointing \ # 启用梯度检查点功能,以减少显存使用
--pretrained_model_name_or_path $MODEL_PATH \ # 预训练模型路径,通过 $MODEL_PATH 指定
--cache_dir $CACHE_PATH \ # 模型缓存路径,由 $CACHE_PATH 指定
--enable_tiling \ # 启用tiling技术,以分片处理视频,节省显存
--enable_slicing \ # 启用slicing技术,将输入切片,以进一步优化内存
--instance_data_root $DATASET_PATH \ # 数据集路径,由 $DATASET_PATH 指定
--caption_column prompts.txt \ # 指定用于训练的视频描述文件,文件名为 prompts.txt
--video_column videos.txt \ # 指定用于训练的视频路径文件,文件名为 videos.txt
--validation_prompt "" \ # 验证集的提示语 (prompt),用于在训练期间生成验证视频
--validation_prompt_separator ::: \ # 设置验证提示语的分隔符为 :::
--num_validation_videos 1 \ # 每个验证回合生成 1 个视频
--validation_epochs 100 \ # 每 100 个训练epoch进行一次验证
--seed 42 \ # 设置随机种子为 42,以保证结果的可复现性
--rank 128 \ # 设置 LoRA 参数的秩 (rank) 为 128
--lora_alpha 64 \ # 设置 LoRA 的 alpha 参数为 64,用于调整LoRA的学习率
--mixed_precision bf16 \ # 使用 bf16 混合精度进行训练,减少显存使用
--output_dir $OUTPUT_PATH \ # 指定模型输出目录,由 $OUTPUT_PATH 定义
--height 480 \ # 视频高度为 480 像素
--width 720 \ # 视频宽度为 720 像素
--fps 8 \ # 视频帧率设置为 8 帧每秒
--max_num_frames 49 \ # 每个视频的最大帧数为 49 帧
--skip_frames_start 0 \ # 跳过视频开头的帧数为 0
--skip_frames_end 0 \ # 跳过视频结尾的帧数为 0
--train_batch_size 4 \ # 训练时的 batch size 设置为 4
--num_train_epochs 30 \ # 总训练epoch数为 30
--checkpointing_steps 1000 \ # 每 1000 步保存一次模型检查点
--gradient_accumulation_steps 1 \ # 梯度累计步数为 1,即每个 batch 后都会更新梯度
--learning_rate 1e-3 \ # 学习率设置为 0.001
--lr_scheduler cosine_with_restarts \ # 使用带重启的余弦学习率调度器
--lr_warmup_steps 200 \ # 在训练的前 200 步进行学习率预热
--lr_num_cycles 1 \ # 学习率周期设置为 1
--optimizer AdamW \ # 使用 AdamW 优化器
--adam_beta1 0.9 \ # 设置 Adam 优化器的 beta1 参数为 0.9
--adam_beta2 0.95 \ # 设置 Adam 优化器的 beta2 参数为 0.95
--max_grad_norm 1.0 \ # 最大梯度裁剪值设置为 1.0
--allow_tf32 \ # 启用 TF32 以加速训练
--report_to wandb # 使用 Weights and Biases 进行训练记录与监控
```
## 运行脚本,开始微调
单机(单卡,多卡)微调:
```shell
bash finetune_single_rank.sh
```
多机多卡微调:
```shell
bash finetune_multi_rank.sh #需要在每个节点运行
```
## 载入微调的模型
+ 请关注[cli_demo.py](../inference/cli_demo.py) 以了解如何加载微调的模型。
## 最佳实践
+ 包含70个分辨率为 `200 x 480 x 720`(帧数 x 高 x
宽)的训练视频。通过数据预处理中的帧跳过,我们创建了两个较小的49帧和16帧数据集,以加快实验速度,因为CogVideoX团队建议的最大帧数限制是49帧。我们将70个视频分成三组,分别为10、25和50个视频。这些视频的概念性质相似。
+ 25个及以上的视频在训练新概念和风格时效果最佳。
+ 现使用可以通过 `--id_token` 指定的标识符token进行训练效果更好。这类似于 Dreambooth 训练,但不使用这种token的常规微调也可以工作。
+ 原始仓库使用 `lora_alpha` 设置为 1。我们发现这个值在多次运行中效果不佳,可能是因为模型后端和训练设置的不同。我们的建议是将
lora_alpha 设置为与 rank 相同或 rank // 2。
+ 建议使用 rank 为 64 及以上的设置。
================================================
FILE: CogVideo/finetune/accelerate_config_machine_single.yaml
================================================
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
dynamo_backend: 'no'
mixed_precision: 'no'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
================================================
FILE: CogVideo/finetune/accelerate_config_machine_single_debug.yaml
================================================
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
dynamo_backend: 'no'
mixed_precision: 'no'
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
================================================
FILE: CogVideo/finetune/finetune_single_rank_injector.sh
================================================
#!/bin/bash
export MODEL_PATH="/m2v_intern/fuxiao/CogVideo-release/weights/cogvideox-5b" # Change it to CogVideoX-5B path
export TRANSFORMER_PATH="" # Resume from pretrained injector checkpoint
export LORA_PATH="/m2v_intern/fuxiao/CogVideo-release/weights/lora" # Change it to pretrained lora path
export CACHE_PATH="~/.cache"
export DATASET_PATH="/ytech_m2v2_hdd/fuxiao/360Motion-Dataset" # Change it to 360-Motion Dataset path
export OUTPUT_PATH="injector"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7,"
# if you are not using wth 8 gus, change `accelerate_config_machine_single_debug.yaml` num_processes as your gpu number
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \
train_cogvideox_injector.py \
--gradient_checkpointing \
--pretrained_model_name_or_path $MODEL_PATH \
--lora_path $LORA_PATH \
--cache_dir $CACHE_PATH \
--enable_tiling \
--enable_slicing \
--finetune_init \
--instance_data_root $DATASET_PATH \
--validation_prompt "a woman with short black wavy hair, lean figure, a green and yellow plaid shirt, dark brown pants, and black suede shoes and a robotic gazelle with a sturdy aluminum frame, an agile build, articulated legs and curved, metallic horns are moving in the city" \
--validation_prompt_separator ::: \
--num_validation_videos 1 \
--validation_epochs 1 \
--block_interval 2 \
--seed 42 \
--lora_scale 1.0 \
--mixed_precision bf16 \
--output_dir $OUTPUT_PATH \
--height 480 \
--width 720 \
--fps 8 \
--max_num_frames 49 \
--skip_frames_start 0 \
--skip_frames_end 0 \
--train_batch_size 1 \
--num_train_epochs 1000 \
--checkpointing_steps 4000 \
--gradient_accumulation_steps 1 \
--learning_rate 1e-4 \
--lr_scheduler cosine_with_restarts \
--lr_warmup_steps 200 \
--lr_num_cycles 1 \
--enable_slicing \
--enable_tiling \
--gradient_checkpointing \
--optimizer AdamW \
--adam_beta1 0.9 \
--adam_beta2 0.95 \
--max_grad_norm 1.0 \
--allow_tf32 \
--report_to wandb
# --resume_from_checkpoint $TRANSFORMER_PATH \
================================================
FILE: CogVideo/finetune/finetune_single_rank_lora.sh
================================================
#!/bin/bash
export MODEL_PATH="/m2v_intern/fuxiao/CogVideo-release/weights/cogvideox-5b" # Change it to CogVideoX-5B path
export CACHE_PATH="~/.cache"
export DATASET_PATH="/ytech_m2v2_hdd/fuxiao/360Motion-Dataset" # Change it to 360-Motion Dataset path
export OUTPUT_PATH="lora"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7,"
# if you are not using wth 1 gpu, change `accelerate_config_machine_single_debug.yaml` num_processes as your gpu number
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \
train_cogvideox_lora.py \
--gradient_checkpointing \
--pretrained_model_name_or_path $MODEL_PATH \
--cache_dir $CACHE_PATH \
--enable_tiling \
--enable_slicing \
--instance_data_root $DATASET_PATH \
--validation_prompt "a woman with short black wavy hair, lean figure, a green and yellow plaid shirt, dark brown pants, and black suede shoes and a robotic gazelle with a sturdy aluminum frame, an agile build, articulated legs and curved, metallic horns are moving in the city" \
--validation_prompt_separator ::: \
--num_validation_videos 1 \
--validation_epochs 1 \
--seed 42 \
--rank 32 \
--lora_alpha 32 \
--mixed_precision bf16 \
--output_dir $OUTPUT_PATH \
--height 480 \
--width 720 \
--fps 8 \
--max_num_frames 49 \
--skip_frames_start 0 \
--skip_frames_end 0 \
--train_batch_size 2 \
--num_train_epochs 1000 \
--checkpointing_steps 1000 \
--gradient_accumulation_steps 1 \
--learning_rate 3e-4 \
--lr_scheduler cosine_with_restarts \
--lr_warmup_steps 200 \
--lr_num_cycles 1 \
--enable_slicing \
--enable_tiling \
--gradient_checkpointing \
--optimizer AdamW \
--adam_beta1 0.9 \
--adam_beta2 0.95 \
--max_grad_norm 1.0 \
--allow_tf32 \
--report_to wandb
================================================
FILE: CogVideo/finetune/hostfile.txt
================================================
node1 slots=8
node2 slots=8
================================================
FILE: CogVideo/finetune/models/attention.py
================================================
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.utils import deprecate, logging
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
from models.attention_processor import Attention, JointAttnProcessor2_0
from diffusers.models.embeddings import SinusoidalPositionalEmbedding
from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
logger = logging.get_logger(__name__)
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
ff_output = torch.cat(
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
return ff_output
@maybe_allow_in_graph
class GatedSelfAttentionDense(nn.Module):
r"""
A gated self-attention dense layer that combines visual features and object features.
Parameters:
query_dim (`int`): The number of channels in the query.
context_dim (`int`): The number of channels in the context.
n_heads (`int`): The number of heads to use for attention.
d_head (`int`): The number of channels in each head.
"""
def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
super().__init__()
# we need a linear projection since we need cat visual feature and obj feature
self.linear = nn.Linear(context_dim, query_dim)
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
self.ff = FeedForward(query_dim, activation_fn="geglu")
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
self.enabled = True
def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
if not self.enabled:
return x
n_visual = x.shape[1]
objs = self.linear(objs)
x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
return x
@maybe_allow_in_graph
class JointTransformerBlock(nn.Module):
r"""
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
processing of `context` conditions.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
context_pre_only: bool = False,
qk_norm: Optional[str] = None,
use_dual_attention: bool = False,
):
super().__init__()
self.use_dual_attention = use_dual_attention
self.context_pre_only = context_pre_only
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
if use_dual_attention:
self.norm1 = SD35AdaLayerNormZeroX(dim)
else:
self.norm1 = AdaLayerNormZero(dim)
if context_norm_type == "ada_norm_continous":
self.norm1_context = AdaLayerNormContinuous(
dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
)
elif context_norm_type == "ada_norm_zero":
self.norm1_context = AdaLayerNormZero(dim)
else:
raise ValueError(
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
)
if hasattr(F, "scaled_dot_product_attention"):
processor = JointAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=context_pre_only,
bias=True,
processor=processor,
qk_norm=qk_norm,
eps=1e-6,
)
if use_dual_attention:
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm=qk_norm,
eps=1e-6,
)
else:
self.attn2 = None
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
if not context_pre_only:
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
else:
self.norm2_context = None
self.ff_context = None
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
):
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
hidden_states, emb=temb
)
else:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
if self.context_pre_only:
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
else:
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
)
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
if self.use_dual_attention:
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
hidden_states = hidden_states + attn_output2
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
# Process attention outputs for the `encoder_hidden_states`.
if self.context_pre_only:
encoder_hidden_states = None
else:
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
context_ff_output = _chunked_feed_forward(
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
)
else:
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
return encoder_hidden_states, hidden_states
@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
ada_norm_bias: Optional[int] = None,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.dropout = dropout
self.cross_attention_dim = cross_attention_dim
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.double_self_attention = double_self_attention
self.norm_elementwise_affine = norm_elementwise_affine
self.positional_embeddings = positional_embeddings
self.num_positional_embeddings = num_positional_embeddings
self.only_cross_attention = only_cross_attention
# We keep these boolean flags for backward-compatibility.
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
self.norm_type = norm_type
self.num_embeds_ada_norm = num_embeds_ada_norm
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if norm_type == "ada_norm":
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_zero":
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_continuous":
self.norm1 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
if norm_type == "ada_norm":
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_continuous":
self.norm2 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
) # is self-attn if encoder_hidden_states is none
else:
if norm_type == "ada_norm_single": # For Latte
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
if norm_type == "ada_norm_continuous":
self.norm3 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"layer_norm",
)
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
elif norm_type == "layer_norm_i2vgen":
self.norm3 = None
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
# 5. Scale-shift for PixArt-Alpha.
if norm_type == "ada_norm_single":
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.norm_type == "ada_norm_zero":
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
norm_hidden_states = self.norm1(hidden_states)
elif self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif self.norm_type == "ada_norm_single":
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
else:
raise ValueError("Incorrect norm used")
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
# 1. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.norm_type == "ada_norm_zero":
attn_output = gate_msa.unsqueeze(1) * attn_output
elif self.norm_type == "ada_norm_single":
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 1.2 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 3. Cross-Attention
if self.attn2 is not None:
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm2(hidden_states, timestep)
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
norm_hidden_states = self.norm2(hidden_states)
elif self.norm_type == "ada_norm_single":
# For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states
elif self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
else:
raise ValueError("Incorrect norm")
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
# i2vgen doesn't have this norm 🤷♂️
if self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif not self.norm_type == "ada_norm_single":
norm_hidden_states = self.norm3(hidden_states)
if self.norm_type == "ada_norm_zero":
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self.norm_type == "ada_norm_single":
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
if self.norm_type == "ada_norm_zero":
ff_output = gate_mlp.unsqueeze(1) * ff_output
elif self.norm_type == "ada_norm_single":
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class LuminaFeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
hidden_size (`int`):
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
hidden representations.
intermediate_size (`int`): The intermediate dimension of the feedforward layer.
multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
of this value.
ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
dimension. Defaults to None.
"""
def __init__(
self,
dim: int,
inner_dim: int,
multiple_of: Optional[int] = 256,
ffn_dim_multiplier: Optional[float] = None,
):
super().__init__()
inner_dim = int(2 * inner_dim / 3)
# custom hidden_size factor multiplier
if ffn_dim_multiplier is not None:
inner_dim = int(ffn_dim_multiplier * inner_dim)
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
self.linear_1 = nn.Linear(
dim,
inner_dim,
bias=False,
)
self.linear_2 = nn.Linear(
inner_dim,
dim,
bias=False,
)
self.linear_3 = nn.Linear(
dim,
inner_dim,
bias=False,
)
self.silu = FP32SiLU()
def forward(self, x):
return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
@maybe_allow_in_graph
class TemporalBasicTransformerBlock(nn.Module):
r"""
A basic Transformer block for video like data.
Parameters:
dim (`int`): The number of channels in the input and output.
time_mix_inner_dim (`int`): The number of channels for temporal attention.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
"""
def __init__(
self,
dim: int,
time_mix_inner_dim: int,
num_attention_heads: int,
attention_head_dim: int,
cross_attention_dim: Optional[int] = None,
):
super().__init__()
self.is_res = dim == time_mix_inner_dim
self.norm_in = nn.LayerNorm(dim)
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.ff_in = FeedForward(
dim,
dim_out=time_mix_inner_dim,
activation_fn="geglu",
)
self.norm1 = nn.LayerNorm(time_mix_inner_dim)
self.attn1 = Attention(
query_dim=time_mix_inner_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
cross_attention_dim=None,
)
# 2. Cross-Attn
if cross_attention_dim is not None:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = nn.LayerNorm(time_mix_inner_dim)
self.attn2 = Attention(
query_dim=time_mix_inner_dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(time_mix_inner_dim)
self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = None
def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
# Sets chunk feed-forward
self._chunk_size = chunk_size
# chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
self._chunk_dim = 1
def forward(
self,
hidden_states: torch.Tensor,
num_frames: int,
encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
batch_frames, seq_length, channels = hidden_states.shape
batch_size = batch_frames // num_frames
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
residual = hidden_states
hidden_states = self.norm_in(hidden_states)
if self._chunk_size is not None:
hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
else:
hidden_states = self.ff_in(hidden_states)
if self.is_res:
hidden_states = hidden_states + residual
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
hidden_states = attn_output + hidden_states
# 3. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self._chunk_size is not None:
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
if self.is_res:
hidden_states = ff_output + hidden_states
else:
hidden_states = ff_output
hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
return hidden_states
class SkipFFTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
kv_input_dim: int,
kv_input_dim_proj_use_bias: bool,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
attention_out_bias: bool = True,
):
super().__init__()
if kv_input_dim != dim:
self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
else:
self.kv_mapper = None
self.norm1 = RMSNorm(dim, 1e-06)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim,
out_bias=attention_out_bias,
)
self.norm2 = RMSNorm(dim, 1e-06)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
out_bias=attention_out_bias,
)
def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
if self.kv_mapper is not None:
encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
return hidden_states
@maybe_allow_in_graph
class FreeNoiseTransformerBlock(nn.Module):
r"""
A FreeNoise Transformer block.
Parameters:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
cross_attention_dim (`int`, *optional*):
The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to be used in feed-forward.
num_embeds_ada_norm (`int`, *optional*):
The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (`bool`, defaults to `False`):
Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, defaults to `False`):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, defaults to `False`):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, defaults to `False`):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` defaults to `False`):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
ff_inner_dim (`int`, *optional*):
Hidden dimension of feed-forward MLP.
ff_bias (`bool`, defaults to `True`):
Whether or not to use bias in feed-forward MLP.
attention_out_bias (`bool`, defaults to `True`):
Whether or not to use bias in attention output project layer.
context_length (`int`, defaults to `16`):
The maximum number of frames that the FreeNoise block processes at once.
context_stride (`int`, defaults to `4`):
The number of frames to be skipped before starting to process a new batch of `context_length` frames.
weighting_scheme (`str`, defaults to `"pyramid"`):
The weighting scheme to use for weighting averaging of processed latent frames. As described in the
Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
used.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout: float = 0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
norm_eps: float = 1e-5,
final_dropout: bool = False,
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
context_length: int = 16,
context_stride: int = 4,
weighting_scheme: str = "pyramid",
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.dropout = dropout
self.cross_attention_dim = cross_attention_dim
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.double_self_attention = double_self_attention
self.norm_elementwise_affine = norm_elementwise_affine
self.positional_embeddings = positional_embeddings
self.num_positional_embeddings = num_positional_embeddings
self.only_cross_attention = only_cross_attention
self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
# We keep these boolean flags for backward-compatibility.
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
self.norm_type = norm_type
self.num_embeds_ada_norm = num_embeds_ada_norm
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
) # is self-attn if encoder_hidden_states is none
# 3. Feed-forward
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
frame_indices = []
for i in range(0, num_frames - self.context_length + 1, self.context_stride):
window_start = i
window_end = min(num_frames, i + self.context_length)
frame_indices.append((window_start, window_end))
return frame_indices
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
if weighting_scheme == "flat":
weights = [1.0] * num_frames
elif weighting_scheme == "pyramid":
if num_frames % 2 == 0:
# num_frames = 4 => [1, 2, 2, 1]
mid = num_frames // 2
weights = list(range(1, mid + 1))
weights = weights + weights[::-1]
else:
# num_frames = 5 => [1, 2, 3, 2, 1]
mid = (num_frames + 1) // 2
weights = list(range(1, mid))
weights = weights + [mid] + weights[::-1]
elif weighting_scheme == "delayed_reverse_sawtooth":
if num_frames % 2 == 0:
# num_frames = 4 => [0.01, 2, 2, 1]
mid = num_frames // 2
weights = [0.01] * (mid - 1) + [mid]
weights = weights + list(range(mid, 0, -1))
else:
# num_frames = 5 => [0.01, 0.01, 3, 2, 1]
mid = (num_frames + 1) // 2
weights = [0.01] * mid
weights = weights + list(range(mid, 0, -1))
else:
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
return weights
def set_free_noise_properties(
self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
) -> None:
self.context_length = context_length
self.context_stride = context_stride
self.weighting_scheme = weighting_scheme
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
*args,
**kwargs,
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
# hidden_states: [B x H x W, F, C]
device = hidden_states.device
dtype = hidden_states.dtype
num_frames = hidden_states.size(1)
frame_indices = self._get_frame_indices(num_frames)
frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
# Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
# For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
# [(0, 16), (4, 20), (8, 24), (10, 26)]
if not is_last_frame_batch_complete:
if num_frames < self.context_length:
raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
last_frame_batch_length = num_frames - frame_indices[-1][1]
frame_indices.append((num_frames - self.context_length, num_frames))
num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
accumulated_values = torch.zeros_like(hidden_states)
for i, (frame_start, frame_end) in enumerate(frame_indices):
# The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
# cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
# essentially a non-multiple of `context_length`.
weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
weights *= frame_weights
hidden_states_chunk = hidden_states[:, frame_start:frame_end]
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Self-Attention
norm_hidden_states = self.norm1(hidden_states_chunk)
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
hidden_states_chunk = attn_output + hidden_states_chunk
if hidden_states_chunk.ndim == 4:
hidden_states_chunk = hidden_states_chunk.squeeze(1)
# 2. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = self.norm2(hidden_states_chunk)
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states_chunk = attn_output + hidden_states_chunk
if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
accumulated_values[:, -last_frame_batch_length:] += (
hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
)
num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
else:
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
num_times_accumulated[:, frame_start:frame_end] += weights
# TODO(aryan): Maybe this could be done in a better way.
#
# Previously, this was:
# hidden_states = torch.where(
# num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
# )
#
# The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
# spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
# from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
# looked into this deeply because other memory optimizations led to more pronounced reductions.
hidden_states = torch.cat(
[
torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
for accumulated_split, num_times_split in zip(
accumulated_values.split(self.context_length, dim=1),
num_times_accumulated.split(self.context_length, dim=1),
)
],
dim=1,
).to(dtype)
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self._chunk_size is not None:
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
inner_dim=None,
bias: bool = True,
):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim, bias=bias)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
elif activation_fn == "swiglu":
act_fn = SwiGLU(dim, inner_dim, bias=bias)
self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
================================================
FILE: CogVideo/finetune/models/attention_processor.py
================================================
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
from typing import Callable, List, Optional, Tuple, Union
from einops import rearrange, repeat
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.image_processor import IPAdapterMaskProcessor
from diffusers.utils import deprecate, logging
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_torch_npu_available():
import torch_npu
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
@maybe_allow_in_graph
class Attention(nn.Module):
r"""
A cross attention layer.
Parameters:
query_dim (`int`):
The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8):
The number of heads to use for multi-head attention.
kv_heads (`int`, *optional*, defaults to `None`):
The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
`kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
Query Attention (MQA) otherwise GQA is used.
dim_head (`int`, *optional*, defaults to 64):
The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
upcast_attention (`bool`, *optional*, defaults to False):
Set to `True` to upcast the attention computation to `float32`.
upcast_softmax (`bool`, *optional*, defaults to False):
Set to `True` to upcast the softmax computation to `float32`.
cross_attention_norm (`str`, *optional*, defaults to `None`):
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use for the group norm in the cross attention.
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
The number of channels to use for the added key and value projections. If `None`, no projection is used.
norm_num_groups (`int`, *optional*, defaults to `None`):
The number of groups to use for the group norm in the attention.
spatial_norm_dim (`int`, *optional*, defaults to `None`):
The number of channels to use for the spatial normalization.
out_bias (`bool`, *optional*, defaults to `True`):
Set to `True` to use a bias in the output linear layer.
scale_qk (`bool`, *optional*, defaults to `True`):
Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
only_cross_attention (`bool`, *optional*, defaults to `False`):
Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
`added_kv_proj_dim` is not `None`.
eps (`float`, *optional*, defaults to 1e-5):
An additional value added to the denominator in group normalization that is used for numerical stability.
rescale_output_factor (`float`, *optional*, defaults to 1.0):
A factor to rescale the output by dividing it with this value.
residual_connection (`bool`, *optional*, defaults to `False`):
Set to `True` to add the residual connection to the output.
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
Set to `True` if the attention block is loaded from a deprecated state dict.
processor (`AttnProcessor`, *optional*, defaults to `None`):
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
`AttnProcessor` otherwise.
"""
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
kv_heads: Optional[int] = None,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
cross_attention_norm: Optional[str] = None,
cross_attention_norm_num_groups: int = 32,
qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
norm_num_groups: Optional[int] = None,
spatial_norm_dim: Optional[int] = None,
out_bias: bool = True,
scale_qk: bool = True,
only_cross_attention: bool = False,
eps: float = 1e-5,
rescale_output_factor: float = 1.0,
residual_connection: bool = False,
_from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor"] = None,
out_dim: int = None,
context_pre_only=None,
pre_only=False,
elementwise_affine: bool = True,
):
super().__init__()
# To prevent circular import.
from diffusers.models.normalization import FP32LayerNorm, RMSNorm
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
self.query_dim = query_dim
self.use_bias = bias
self.is_cross_attention = cross_attention_dim is not None
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
self._from_deprecated_attn_block = _from_deprecated_attn_block
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = out_dim // dim_head if out_dim is not None else heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self.added_kv_proj_dim = added_kv_proj_dim
self.only_cross_attention = only_cross_attention
if self.added_kv_proj_dim is None and self.only_cross_attention:
raise ValueError(
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
)
if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
else:
self.group_norm = None
if spatial_norm_dim is not None:
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
else:
self.spatial_norm = None
if qk_norm is None:
self.norm_q = None
self.norm_k = None
elif qk_norm == "layer_norm":
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
elif qk_norm == "fp32_layer_norm":
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
elif qk_norm == "layer_norm_across_heads":
# Lumina applys qk norm across all heads
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
elif qk_norm == "rms_norm":
self.norm_q = RMSNorm(dim_head, eps=eps)
self.norm_k = RMSNorm(dim_head, eps=eps)
else:
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'")
if cross_attention_norm is None:
self.norm_cross = None
elif cross_attention_norm == "layer_norm":
self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
elif cross_attention_norm == "group_norm":
if self.added_kv_proj_dim is not None:
# The given `encoder_hidden_states` are initially of shape
# (batch_size, seq_len, added_kv_proj_dim) before being projected
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
# before the projection, so we need to use `added_kv_proj_dim` as
# the number of channels for the group norm.
norm_cross_num_channels = added_kv_proj_dim
else:
norm_cross_num_channels = self.cross_attention_dim
self.norm_cross = nn.GroupNorm(
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
)
else:
raise ValueError(
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
)
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
else:
self.to_k = None
self.to_v = None
self.added_proj_bias = added_proj_bias
if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
if self.context_pre_only is not None:
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
if not self.pre_only:
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
if self.context_pre_only is not None and not self.context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
if qk_norm is not None and added_kv_proj_dim is not None:
if qk_norm == "fp32_layer_norm":
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
elif qk_norm == "rms_norm":
self.norm_added_q = RMSNorm(dim_head, eps=eps)
self.norm_added_k = RMSNorm(dim_head, eps=eps)
else:
raise ValueError(
f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
)
else:
self.norm_added_q = None
self.norm_added_k = None
# set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
if processor is None:
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
)
self.set_processor(processor)
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
r"""
Set whether to use npu flash attention from `torch_npu` or not.
"""
if use_npu_flash_attention:
processor = AttnProcessorNPU()
else:
# set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
)
self.set_processor(processor)
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
) -> None:
r"""
Set whether to use memory efficient attention from `xformers` or not.
Args:
use_memory_efficient_attention_xformers (`bool`):
Whether to use memory efficient attention from `xformers` or not.
attention_op (`Callable`, *optional*):
The attention operation to use. Defaults to `None` which uses the default attention operation from
`xformers`.
"""
is_custom_diffusion = hasattr(self, "processor") and isinstance(
self.processor,
(CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
)
is_added_kv_processor = hasattr(self, "processor") and isinstance(
self.processor,
(
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
SlicedAttnAddedKVProcessor,
XFormersAttnAddedKVProcessor,
),
)
if use_memory_efficient_attention_xformers:
if is_added_kv_processor and is_custom_diffusion:
raise NotImplementedError(
f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}"
)
if not is_xformers_available():
raise ModuleNotFoundError(
(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers"
),
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
" only available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
if is_custom_diffusion:
processor = CustomDiffusionXFormersAttnProcessor(
train_kv=self.processor.train_kv,
train_q_out=self.processor.train_q_out,
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
attention_op=attention_op,
)
processor.load_state_dict(self.processor.state_dict())
if hasattr(self.processor, "to_k_custom_diffusion"):
processor.to(self.processor.to_k_custom_diffusion.weight.device)
elif is_added_kv_processor:
# TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
# which uses this type of cross attention ONLY because the attention mask of format
# [0, ..., -10.000, ..., 0, ...,] is not supported
# throw warning
logger.info(
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
)
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
else:
processor = XFormersAttnProcessor(attention_op=attention_op)
else:
if is_custom_diffusion:
attn_processor_class = (
CustomDiffusionAttnProcessor2_0
if hasattr(F, "scaled_dot_product_attention")
else CustomDiffusionAttnProcessor
)
processor = attn_processor_class(
train_kv=self.processor.train_kv,
train_q_out=self.processor.train_q_out,
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
)
processor.load_state_dict(self.processor.state_dict())
if hasattr(self.processor, "to_k_custom_diffusion"):
processor.to(self.processor.to_k_custom_diffusion.weight.device)
else:
# set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
processor = (
AttnProcessor2_0()
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
else AttnProcessor()
)
self.set_processor(processor)
def set_attention_slice(self, slice_size: int) -> None:
r"""
Set the slice size for attention computation.
Args:
slice_size (`int`):
The slice size for attention computation.
"""
if slice_size is not None and slice_size > self.sliceable_head_dim:
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
if slice_size is not None and self.added_kv_proj_dim is not None:
processor = SlicedAttnAddedKVProcessor(slice_size)
elif slice_size is not None:
processor = SlicedAttnProcessor(slice_size)
elif self.added_kv_proj_dim is not None:
processor = AttnAddedKVProcessor()
else:
# set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
)
self.set_processor(processor)
def set_processor(self, processor: "AttnProcessor") -> None:
r"""
Set the attention processor to use.
Args:
processor (`AttnProcessor`):
The attention processor to use.
"""
# if current processor is in `self._modules` and if passed `processor` is not, we need to
# pop `processor` from `self._modules`
if (
hasattr(self, "processor")
and isinstance(self.processor, torch.nn.Module)
and not isinstance(processor, torch.nn.Module)
):
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
self._modules.pop("processor")
self.processor = processor
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
r"""
Get the attention processor in use.
Args:
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
Set to `True` to return the deprecated LoRA attention processor.
Returns:
"AttentionProcessor": The attention processor in use.
"""
if not return_deprecated_lora:
return self.processor
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**cross_attention_kwargs,
) -> torch.Tensor:
r"""
The forward method of the `Attention` class.
Args:
hidden_states (`torch.Tensor`):
The hidden states of the query.
encoder_hidden_states (`torch.Tensor`, *optional*):
The hidden states of the encoder.
attention_mask (`torch.Tensor`, *optional*):
The attention mask to use. If `None`, no mask is applied.
**cross_attention_kwargs:
Additional keyword arguments to pass along to the cross attention.
Returns:
`torch.Tensor`: The output of the attention layer.
"""
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
quiet_attn_parameters = {"ip_adapter_masks"}
unused_kwargs = [
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
]
if len(unused_kwargs) > 0:
logger.warning(
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
r"""
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
is the number of heads initialized while constructing the `Attention` class.
Args:
tensor (`torch.Tensor`): The tensor to reshape.
Returns:
`torch.Tensor`: The reshaped tensor.
"""
head_size = self.heads
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
r"""
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
the number of heads initialized while constructing the `Attention` class.
Args:
tensor (`torch.Tensor`): The tensor to reshape.
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
reshaped to `[batch_size * heads, seq_len, dim // heads]`.
Returns:
`torch.Tensor`: The reshaped tensor.
"""
head_size = self.heads
if tensor.ndim == 3:
batch_size, seq_len, dim = tensor.shape
extra_dim = 1
else:
batch_size, extra_dim, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3)
if out_dim == 3:
tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
return tensor
def get_attention_scores(
self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
r"""
Compute the attention scores.
Args:
query (`torch.Tensor`): The query tensor.
key (`torch.Tensor`): The key tensor.
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
Returns:
`torch.Tensor`: The attention probabilities/scores.
"""
dtype = query.dtype
if self.upcast_attention:
query = query.float()
key = key.float()
if attention_mask is None:
baddbmm_input = torch.empty(
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
)
beta = 0
else:
baddbmm_input = attention_mask
beta = 1
attention_scores = torch.baddbmm(
baddbmm_input,
query,
key.transpose(-1, -2),
beta=beta,
alpha=self.scale,
)
del baddbmm_input
if self.upcast_softmax:
attention_scores = attention_scores.float()
attention_probs = attention_scores.softmax(dim=-1)
del attention_scores
attention_probs = attention_probs.to(dtype)
return attention_probs
def prepare_attention_mask(
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
) -> torch.Tensor:
r"""
Prepare the attention mask for the attention computation.
Args:
attention_mask (`torch.Tensor`):
The attention mask to prepare.
target_length (`int`):
The target length of the attention mask. This is the length of the attention mask after padding.
batch_size (`int`):
The batch size, which is used to repeat the attention mask.
out_dim (`int`, *optional*, defaults to `3`):
The output dimension of the attention mask. Can be either `3` or `4`.
Returns:
`torch.Tensor`: The prepared attention mask.
"""
head_size = self.heads
if attention_mask is None:
return attention_mask
current_length: int = attention_mask.shape[-1]
if current_length != target_length:
if attention_mask.device.type == "mps":
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
# Instead, we can manually construct the padding tensor.
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat([attention_mask, padding], dim=2)
else:
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
# we want to instead pad by (0, remaining_length), where remaining_length is:
# remaining_length: int = target_length - current_length
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
if out_dim == 3:
if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
elif out_dim == 4:
attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
return attention_mask
def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
r"""
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
`Attention` class.
Args:
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
Returns:
`torch.Tensor`: The normalized encoder hidden states.
"""
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
if isinstance(self.norm_cross, nn.LayerNorm):
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
elif isinstance(self.norm_cross, nn.GroupNorm):
# Group norm norms along the channels dimension and expects
# input to be in the shape of (N, C, *). In this case, we want
# to norm along the hidden dimension, so we need to move
# (batch_size, sequence_length, hidden_size) ->
# (batch_size, hidden_size, sequence_length)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
else:
assert False
return encoder_hidden_states
@torch.no_grad()
def fuse_projections(self, fuse=True):
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype
if not self.is_cross_attention:
# fetch weight matrices.
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
# create a new single projection layer and copy over the weights.
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)
if self.use_bias:
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
self.to_qkv.bias.copy_(concatenated_bias)
else:
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_kv.weight.copy_(concatenated_weights)
if self.use_bias:
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
self.to_kv.bias.copy_(concatenated_bias)
# handle added projections for SD3 and others.
if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
concatenated_weights = torch.cat(
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_added_qkv = nn.Linear(
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
)
self.to_added_qkv.weight.copy_(concatenated_weights)
if self.added_proj_bias:
concatenated_bias = torch.cat(
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
)
self.to_added_qkv.bias.copy_(concatenated_bias)
self.fused_projections = fuse
class AttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class CustomDiffusionAttnProcessor(nn.Module):
r"""
Processor for implementing attention for the Custom Diffusion method.
Args:
train_kv (`bool`, defaults to `True`):
Whether to newly train the key and value matrices corresponding to the text features.
train_q_out (`bool`, defaults to `True`):
Whether to newly train query matrices corresponding to the latent image features.
hidden_size (`int`, *optional*, defaults to `None`):
The hidden size of the attention layer.
cross_attention_dim (`int`, *optional*, defaults to `None`):
The number of channels in the `encoder_hidden_states`.
out_bias (`bool`, defaults to `True`):
Whether to include the bias parameter in `train_q_out`.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
"""
def __init__(
self,
train_kv: bool = True,
train_q_out: bool = True,
hidden_size: Optional[int] = None,
cross_attention_dim: Optional[int] = None,
out_bias: bool = True,
dropout: float = 0.0,
):
super().__init__()
self.train_kv = train_kv
self.train_q_out = train_q_out
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
# `_custom_diffusion` id for easy serialization and loading.
if self.train_kv:
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
if self.train_q_out:
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
self.to_out_custom_diffusion = nn.ModuleList([])
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if self.train_q_out:
query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
else:
query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
if encoder_hidden_states is None:
crossattn = False
encoder_hidden_states = hidden_states
else:
crossattn = True
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if self.train_kv:
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
key = key.to(attn.to_q.weight.dtype)
value = value.to(attn.to_q.weight.dtype)
else:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if crossattn:
detach = torch.ones_like(key)
detach[:, :1, :] = detach[:, :1, :] * 0.0
key = detach * key + (1 - detach) * key.detach()
value = detach * value + (1 - detach) * value.detach()
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
if self.train_q_out:
# linear proj
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
# dropout
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
else:
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class AttnAddedKVProcessor:
r"""
Processor for performing attention-related computations with extra learnable key and value matrices for the text
encoder.
"""
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
class AttnAddedKVProcessor2_0:
r"""
Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
learnable key and value matrices for the text encoder.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query, out_dim=4)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key, out_dim=4)
value = attn.head_to_batch_dim(value, out_dim=4)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
class JointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
batch_size = hidden_states.shape[0]
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# `context` projections.
if encoder_hidden_states is not None:
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
# Split the attention outputs.
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]],
hidden_states[:, residual.shape[1] :],
)
if not attn.context_pre_only:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
else:
return hidden_states
class PAGJointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
) -> torch.FloatTensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# store the length of image patch sequences to create a mask that prevents interaction between patches
# similar to making the self-attention map an identity matrix
identity_block_size = hidden_states.shape[1]
# chunk
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
encoder_hidden_states_org, encoder_hidden_states_ptb = encoder_hidden_states.chunk(2)
################## original path ##################
batch_size = encoder_hidden_states_org.shape[0]
# `sample` projections.
query_org = attn.to_q(hidden_states_org)
key_org = attn.to_k(hidden_states_org)
value_org = attn.to_v(hidden_states_org)
# `context` projections.
encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
# attention
query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
inner_dim = key_org.shape[-1]
head_dim = inner_dim // attn.heads
query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states_org = F.scaled_dot_product_attention(
query_org, key_org, value_org, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query_org.dtype)
# Split the attention outputs.
hidden_states_org, encoder_hidden_states_org = (
hidden_states_org[:, : residual.shape[1]],
hidden_states_org[:, residual.shape[1] :],
)
# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if not attn.context_pre_only:
encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
################## perturbed path ##################
batch_size = encoder_hidden_states_ptb.shape[0]
# `sample` projections.
query_ptb = attn.to_q(hidden_states_ptb)
key_ptb = attn.to_k(hidden_states_ptb)
value_ptb = attn.to_v(hidden_states_ptb)
# `context` projections.
encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
# attention
query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
inner_dim = key_ptb.shape[-1]
head_dim = inner_dim // attn.heads
query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# create a full mask with all entries set to 0
seq_len = query_ptb.size(2)
full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
# set the attention value between image patches to -inf
full_mask[:identity_block_size, :identity_block_size] = float("-inf")
# set the diagonal of the attention value between image patches to 0
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
# expand the mask to match the attention weights shape
full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
hidden_states_ptb = F.scaled_dot_product_attention(
query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
)
hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
# split the attention outputs.
hidden_states_ptb, encoder_hidden_states_ptb = (
hidden_states_ptb[:, : residual.shape[1]],
hidden_states_ptb[:, residual.shape[1] :],
)
# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if not attn.context_pre_only:
encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
################ concat ###############
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
return hidden_states, encoder_hidden_states
class PAGCFGJointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
identity_block_size = hidden_states.shape[
1
] # patch embeddings width * height (correspond to self-attention map width or height)
# chunk
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
(
encoder_hidden_states_uncond,
encoder_hidden_states_org,
encoder_hidden_states_ptb,
) = encoder_hidden_states.chunk(3)
encoder_hidden_states_org = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_org])
################## original path ##################
batch_size = encoder_hidden_states_org.shape[0]
# `sample` projections.
query_org = attn.to_q(hidden_states_org)
key_org = attn.to_k(hidden_states_org)
value_org = attn.to_v(hidden_states_org)
# `context` projections.
encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
# attention
query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
inner_dim = key_org.shape[-1]
head_dim = inner_dim // attn.heads
query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states_org = F.scaled_dot_product_attention(
query_org, key_org, value_org, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query_org.dtype)
# Split the attention outputs.
hidden_states_org, encoder_hidden_states_org = (
hidden_states_org[:, : residual.shape[1]],
hidden_states_org[:, residual.shape[1] :],
)
# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if not attn.context_pre_only:
encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
################## perturbed path ##################
batch_size = encoder_hidden_states_ptb.shape[0]
# `sample` projections.
query_ptb = attn.to_q(hidden_states_ptb)
key_ptb = attn.to_k(hidden_states_ptb)
value_ptb = attn.to_v(hidden_states_ptb)
# `context` projections.
encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
# attention
query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
inner_dim = key_ptb.shape[-1]
head_dim = inner_dim // attn.heads
query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# create a full mask with all entries set to 0
seq_len = query_ptb.size(2)
full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
# set the attention value between image patches to -inf
full_mask[:identity_block_size, :identity_block_size] = float("-inf")
# set the diagonal of the attention value between image patches to 0
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
# expand the mask to match the attention weights shape
full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
hidden_states_ptb = F.scaled_dot_product_attention(
query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
)
hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
# split the attention outputs.
hidden_states_ptb, encoder_hidden_states_ptb = (
hidden_states_ptb[:, : residual.shape[1]],
hidden_states_ptb[:, residual.shape[1] :],
)
# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if not attn.context_pre_only:
encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
################ concat ###############
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
return hidden_states, encoder_hidden_states
class FusedJointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = encoder_hidden_states.shape[0]
# `sample` projections.
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
# `context` projections.
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
split_size = encoder_qkv.shape[-1] // 3
(
encoder_hidden_states_query_proj,
encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj,
) = torch.split(encoder_qkv, split_size, dim=-1)
# attention
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Split the attention outputs.
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]],
hidden_states[:, residual.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if not attn.context_pre_only:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
return hidden_states, encoder_hidden_states
class AuraFlowAttnProcessor2_0:
"""Attention processor used typically in processing Aura Flow."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
raise ImportError(
"AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
*args,
**kwargs,
) -> torch.FloatTensor:
batch_size = hidden_states.shape[0]
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# `context` projections.
if encoder_hidden_states is not None:
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# Reshape.
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)
# Apply QK norm.
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Concatenate the projections.
if encoder_hidden_states is not None:
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Attention.
hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Split the attention outputs.
if encoder_hidden_states is not None:
hidden_states, encoder_hidden_states = (
hidden_states[:, encoder_hidden_states.shape[1] :],
hidden_states[:, : encoder_hidden_states.shape[1]],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
else:
return hidden_states
class FusedAuraFlowAttnProcessor2_0:
"""Attention processor used typically in processing Aura Flow with fused projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
raise ImportError(
"FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
*args,
**kwargs,
) -> torch.FloatTensor:
batch_size = hidden_states.shape[0]
# `sample` projections.
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
# `context` projections.
if encoder_hidden_states is not None:
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
split_size = encoder_qkv.shape[-1] // 3
(
encoder_hidden_states_query_proj,
encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj,
) = torch.split(encoder_qkv, split_size, dim=-1)
# Reshape.
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)
# Apply QK norm.
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Concatenate the projections.
if encoder_hidden_states is not None:
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Attention.
hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Split the attention outputs.
if encoder_hidden_states is not None:
hidden_states, encoder_hidden_states = (
hidden_states[:, encoder_hidden_states.shape[1] :],
hidden_states[:, : encoder_hidden_states.shape[1]],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
else:
return hidden_states
class FluxAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class FusedFluxAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
# `context` projections.
if encoder_hidden_states is not None:
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
split_size = encoder_qkv.shape[-1] // 3
(
encoder_hidden_states_query_proj,
encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj,
) = torch.split(encoder_qkv, split_size, dim=-1)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
if query.shape[-2] == 19726:
num_frames = 13
text_temp = query[:, :, :text_seq_length]
query_temp = rearrange(query[:, :, text_seq_length:], "b l (n t) d -> b l n t d",n=num_frames)[:,:,:,:-150,:].flatten(2,3)
traj_temp = rearrange(query[:, :, text_seq_length:], "b l (n t) d -> b l n t d",n=num_frames)[:,:,:,-150:,:].flatten(2,3)
query_temp = apply_rotary_emb(query_temp, image_rotary_emb)
query = torch.cat((text_temp, query_temp, traj_temp), dim=-2)
if not attn.is_cross_attention:
text_temp = key[:, :, :text_seq_length]
key_temp = rearrange(key[:, :, text_seq_length:], "b l (n t) d -> b l n t d",n=num_frames)[:,:,:,:-150,:].flatten(2,3)
traj_temp = rearrange(key[:, :, text_seq_length:], "b l (n t) d -> b l n t d",n=num_frames)[:,:,:,-150:,:].flatten(2,3)
key_temp = apply_rotary_emb(key_temp, image_rotary_emb)
key = torch.cat((text_temp, key_temp, traj_temp), dim=-2)
else:
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
class FusedCogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
class XFormersAttnAddedKVProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
Args:
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
"""
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
class XFormersAttnProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
Args:
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
"""
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, key_tokens, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
if attention_mask is not None:
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AttnProcessorNPU:
r"""
Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
not significant.
"""
def __init__(self):
if not is_torch_npu_available():
raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
if query.dtype in (torch.float16, torch.bfloat16):
hidden_states = torch_npu.npu_fusion_attention(
query,
key,
value,
attn.heads,
input_layout="BNSD",
pse=None,
atten_mask=attention_mask,
scale=1.0 / math.sqrt(query.shape[-1]),
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0,
sync=False,
inner_precise=0,
)[0]
else:
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class StableAudioAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def apply_partial_rotary_emb(
self,
x: torch.Tensor,
freqs_cis: Tuple[torch.Tensor],
) -> torch.Tensor:
from diffusers.models.embeddings import apply_rotary_emb
rot_dim = freqs_cis[0].shape[-1]
x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:]
x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2)
out = torch.cat((x_rotated, x_unrotated), dim=-1)
return out
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from diffusers.models.embeddings import apply_rotary_emb
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
head_dim = query.shape[-1] // attn.heads
kv_heads = key.shape[-1] // head_dim
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
if kv_heads != attn.heads:
# if GQA or MQA, repeat the key/value heads to reach the number of query heads.
heads_per_kv_head = attn.heads // kv_heads
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if rotary_emb is not None:
query_dtype = query.dtype
key_dtype = key.dtype
query = query.to(torch.float32)
key = key.to(torch.float32)
rot_dim = rotary_emb[0].shape[-1]
query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:]
query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
query = torch.cat((query_rotated, query_unrotated), dim=-1)
if not attn.is_cross_attention:
key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:]
key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
key = torch.cat((key_rotated, key_unrotated), dim=-1)
query = query.to(query_dtype)
key = key.to(key_dtype)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class HunyuanAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from diffusers.models.embeddings import apply_rotary_emb
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class FusedHunyuanAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FusedHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from diffusers.models.embeddings import apply_rotary_emb
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
if encoder_hidden_states is None:
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
query = attn.to_q(hidden_states)
kv = attn.to_kv(encoder_hidden_states)
split_size = kv.shape[-1] // 2
key, value = torch.split(kv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class PAGHunyuanAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from diffusers.models.embeddings import apply_rotary_emb
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# chunk
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
# 1. Original Path
batch_size, sequence_length, _ = (
hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states_org)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states_org
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query.dtype)
# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
# 2. Perturbed Path
if attn.group_norm is not None:
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
hidden_states_ptb = attn.to_v(hidden_states_ptb)
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
# cat
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class PAGCFGHunyuanAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from diffusers.models.embeddings import apply_rotary_emb
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# chunk
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
# 1. Original Path
batch_size, sequence_length, _ = (
hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states_org)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states_org
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query.dtype)
# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
# 2. Perturbed Path
if attn.group_norm is not None:
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
hidden_states_ptb = attn.to_v(hidden_states_ptb)
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
# cat
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class LuminaAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the LuminaNextDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
query_rotary_emb: Optional[torch.Tensor] = None,
key_rotary_emb: Optional[torch.Tensor] = None,
base_sequence_length: Optional[int] = None,
) -> torch.Tensor:
from diffusers.models.embeddings import apply_rotary_emb
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
# Get Query-Key-Value Pair
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query_dim = query.shape[-1]
inner_dim = key.shape[-1]
head_dim = query_dim // attn.heads
dtype = query.dtype
# Get key-value heads
kv_heads = inner_dim // head_dim
# Apply Query-Key Norm if needed
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, kv_heads, head_dim)
value = value.view(batch_size, -1, kv_heads, head_dim)
# Apply RoPE if needed
if query_rotary_emb is not None:
query = apply_rotary_emb(query, query_rotary_emb, use_real=False)
if key_rotary_emb is not None:
key = apply_rotary_emb(key, key_rotary_emb, use_real=False)
query, key = query.to(dtype), key.to(dtype)
# Apply proportional attention if true
if key_rotary_emb is None:
softmax_scale = None
else:
if base_sequence_length is not None:
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
else:
softmax_scale = attn.scale
# perform Grouped-qurey Attention (GQA)
n_rep = attn.heads // kv_heads
if n_rep >= 1:
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, scale=softmax_scale
)
hidden_states = hidden_states.transpose(1, 2).to(dtype)
return hidden_states
class FusedAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
For cross-attention modules, key and value projection matrices are fused.
This API is currently 🧪 experimental in nature and can change in future.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
if encoder_hidden_states is None:
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
query = attn.to_q(hidden_states)
kv = attn.to_kv(encoder_hidden_states)
split_size = kv.shape[-1] // 2
key, value = torch.split(kv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class CustomDiffusionXFormersAttnProcessor(nn.Module):
r"""
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
Args:
train_kv (`bool`, defaults to `True`):
Whether to newly train the key and value matrices corresponding to the text features.
train_q_out (`bool`, defaults to `True`):
Whether to newly train query matrices corresponding to the latent image features.
hidden_size (`int`, *optional*, defaults to `None`):
The hidden size of the attention layer.
cross_attention_dim (`int`, *optional*, defaults to `None`):
The number of channels in the `encoder_hidden_states`.
out_bias (`bool`, defaults to `True`):
Whether to include the bias parameter in `train_q_out`.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
"""
def __init__(
self,
train_kv: bool = True,
train_q_out: bool = False,
hidden_size: Optional[int] = None,
cross_attention_dim: Optional[int] = None,
out_bias: bool = True,
dropout: float = 0.0,
attention_op: Optional[Callable] = None,
):
super().__init__()
self.train_kv = train_kv
self.train_q_out = train_q_out
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.attention_op = attention_op
# `_custom_diffusion` id for easy serialization and loading.
if self.train_kv:
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
if self.train_q_out:
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
self.to_out_custom_diffusion = nn.ModuleList([])
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if self.train_q_out:
query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
else:
query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
if encoder_hidden_states is None:
crossattn = False
encoder_hidden_states = hidden_states
else:
crossattn = True
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if self.train_kv:
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
key = key.to(attn.to_q.weight.dtype)
value = value.to(attn.to_q.weight.dtype)
else:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if crossattn:
detach = torch.ones_like(key)
detach[:, :1, :] = detach[:, :1, :] * 0.0
key = detach * key + (1 - detach) * key.detach()
value = detach * value + (1 - detach) * value.detach()
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
if self.train_q_out:
# linear proj
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
# dropout
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
else:
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class CustomDiffusionAttnProcessor2_0(nn.Module):
r"""
Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled
dot-product attention.
Args:
train_kv (`bool`, defaults to `True`):
Whether to newly train the key and value matrices corresponding to the text features.
train_q_out (`bool`, defaults to `True`):
Whether to newly train query matrices corresponding to the latent image features.
hidden_size (`int`, *optional*, defaults to `None`):
The hidden size of the attention layer.
cross_attention_dim (`int`, *optional*, defaults to `None`):
The number of channels in the `encoder_hidden_states`.
out_bias (`bool`, defaults to `True`):
Whether to include the bias parameter in `train_q_out`.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
"""
def __init__(
self,
train_kv: bool = True,
train_q_out: bool = True,
hidden_size: Optional[int] = None,
cross_attention_dim: Optional[int] = None,
out_bias: bool = True,
dropout: float = 0.0,
):
super().__init__()
self.train_kv = train_kv
self.train_q_out = train_q_out
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
# `_custom_diffusion` id for easy serialization and loading.
if self.train_kv:
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
if self.train_q_out:
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
self.to_out_custom_diffusion = nn.ModuleList([])
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if self.train_q_out:
query = self.to_q_custom_diffusion(hidden_states)
else:
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
crossattn = False
encoder_hidden_states = hidden_states
else:
crossattn = True
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if self.train_kv:
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
key = key.to(attn.to_q.weight.dtype)
value = value.to(attn.to_q.weight.dtype)
else:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if crossattn:
detach = torch.ones_like(key)
detach[:, :1, :] = detach[:, :1, :] * 0.0
key = detach * key + (1 - detach) * key.detach()
value = detach * value + (1 - detach) * value.detach()
inner_dim = hidden_states.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if self.train_q_out:
# linear proj
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
# dropout
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
else:
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class SlicedAttnProcessor:
r"""
Processor for implementing sliced attention.
Args:
slice_size (`int`, *optional*):
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
`attention_head_dim` must be a multiple of the `slice_size`.
"""
def __init__(self, slice_size: int):
self.slice_size = slice_size
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
batch_size_attention, query_tokens, _ = query.shape
hidden_states = torch.zeros(
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
for i in range((batch_size_attention - 1) // self.slice_size + 1):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class SlicedAttnAddedKVProcessor:
r"""
Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
Args:
slice_size (`int`, *optional*):
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
`attention_head_dim` must be a multiple of the `slice_size`.
"""
def __init__(self, slice_size):
self.slice_size = slice_size
def __call__(
self,
attn: "Attention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
batch_size_attention, query_tokens, _ = query.shape
hidden_states = torch.zeros(
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
for i in range((batch_size_attention - 1) // self.slice_size + 1):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
class SpatialNorm(nn.Module):
"""
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
Args:
f_channels (`int`):
The number of channels for input to group normalization layer, and output of the spatial norm layer.
zq_channels (`int`):
The number of channels for the quantized vector as described in the paper.
"""
def __init__(
self,
f_channels: int,
zq_channels: int,
):
super().__init__()
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
f_size = f.shape[-2:]
zq = F.interpolate(zq, size=f_size, mode="nearest")
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
class IPAdapterAttnProcessor(nn.Module):
r"""
Attention processor for Multiple IP-Adapters.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
The context length of the image features.
scale (`float` or List[`float`], defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
self.num_tokens = num_tokens
if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale
self.to_k_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
self.to_v_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
scale: float = 1.0,
ip_adapter_masks: Optional[torch.Tensor] = None,
):
residual = hidden_states
# separate ip_hidden_states from encoder_hidden_states
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, tuple):
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
else:
deprecation_message = (
"You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
)
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
[encoder_hidden_states[:, end_pos:, :]],
)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, List):
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
raise ValueError(
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
f"({len(ip_hidden_states)})"
)
else:
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
raise ValueError(
"Each element of the ip_adapter_masks array should be a tensor with shape "
"[1, num_images_for_ip_adapter, height, width]."
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
)
if mask.shape[1] != ip_state.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of ip images ({ip_state.shape[1]}) at index {index}"
)
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of scales ({len(scale)}) at index {index}"
)
else:
ip_adapter_masks = [None] * len(self.scale)
# for ip-adapter
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
skip = False
if isinstance(scale, list):
if all(s == 0 for s in scale):
skip = True
elif scale == 0:
skip = True
if not skip:
if mask is not None:
if not isinstance(scale, list):
scale = [scale] * mask.shape[1]
current_num_images = mask.shape[1]
for i in range(current_num_images):
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
_current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
batch_size,
_current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
else:
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
hidden_states = hidden_states + scale * current_ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class IPAdapterAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapter for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
The context length of the image features.
scale (`float` or `List[float]`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
self.num_tokens = num_tokens
if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale
self.to_k_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
self.to_v_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
scale: float = 1.0,
ip_adapter_masks: Optional[torch.Tensor] = None,
):
residual = hidden_states
# separate ip_hidden_states from encoder_hidden_states
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, tuple):
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
else:
deprecation_message = (
"You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
)
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
[encoder_hidden_states[:, end_pos:, :]],
)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, List):
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
raise ValueError(
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
f"({len(ip_hidden_states)})"
)
else:
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
raise ValueError(
"Each element of the ip_adapter_masks array should be a tensor with shape "
"[1, num_images_for_ip_adapter, height, width]."
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
)
if mask.shape[1] != ip_state.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of ip images ({ip_state.shape[1]}) at index {index}"
)
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of scales ({len(scale)}) at index {index}"
)
else:
ip_adapter_masks = [None] * len(self.scale)
# for ip-adapter
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
skip = False
if isinstance(scale, list):
if all(s == 0 for s in scale):
skip = True
elif scale == 0:
skip = True
if not skip:
if mask is not None:
if not isinstance(scale, list):
scale = [scale] * mask.shape[1]
current_num_images = mask.shape[1]
for i in range(current_num_images):
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
_current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
_current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
batch_size,
_current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
else:
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + scale * current_ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class PAGIdentitySelfAttnProcessor2_0:
r"""
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
PAG reference: https://arxiv.org/abs/2403.17377
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# chunk
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
# original path
batch_size, sequence_length, _ = hidden_states_org.shape
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states_org)
key = attn.to_k(hidden_states_org)
value = attn.to_v(hidden_states_org)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query.dtype)
# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
# perturbed path (identity attention)
batch_size, sequence_length, _ = hidden_states_ptb.shape
if attn.group_norm is not None:
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
hidden_states_ptb = attn.to_v(hidden_states_ptb)
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
# cat
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class PAGCFGIdentitySelfAttnProcessor2_0:
r"""
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
PAG reference: https://arxiv.org/abs/2403.17377
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# chunk
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
# original path
batch_size, sequence_length, _ = hidden_states_org.shape
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states_org)
key = attn.to_k(hidden_states_org)
value = attn.to_v(hidden_states_org)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query.dtype)
# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
# perturbed path (identity attention)
batch_size, sequence_length, _ = hidden_states_ptb.shape
if attn.group_norm is not None:
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
value = attn.to_v(hidden_states_ptb)
hidden_states_ptb = value
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
# cat
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class LoRAAttnProcessor:
def __init__(self):
pass
class LoRAAttnProcessor2_0:
def __init__(self):
pass
class LoRAXFormersAttnProcessor:
def __init__(self):
pass
class LoRAAttnAddedKVProcessor:
def __init__(self):
pass
class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
super().__init__()
ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
XFormersAttnAddedKVProcessor,
)
CROSS_ATTENTION_PROCESSORS = (
AttnProcessor,
AttnProcessor2_0,
XFormersAttnProcessor,
SlicedAttnProcessor,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
AttentionProcessor = Union[
AttnProcessor,
AttnProcessor2_0,
FusedAttnProcessor2_0,
XFormersAttnProcessor,
SlicedAttnProcessor,
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
XFormersAttnAddedKVProcessor,
CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0,
PAGCFGIdentitySelfAttnProcessor2_0,
PAGIdentitySelfAttnProcessor2_0,
PAGCFGHunyuanAttnProcessor2_0,
PAGHunyuanAttnProcessor2_0,
]
================================================
FILE: CogVideo/finetune/models/cogvideox_transformer_3d.py
================================================
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
from einops import rearrange, repeat
import torch
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import PeftAdapterMixin
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from diffusers.utils.torch_utils import maybe_allow_in_graph
from models.attention import Attention, FeedForward
from models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
from models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph
class CogVideoXBlock(nn.Module):
r"""
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
Parameters:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
time_embed_dim (`int`):
The number of channels in timestep embedding.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to be used in feed-forward.
attention_bias (`bool`, defaults to `False`):
Whether or not to use bias in attention projection layers.
qk_norm (`bool`, defaults to `True`):
Whether or not to use normalization after query and key projections in Attention.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_eps (`float`, defaults to `1e-5`):
Epsilon value for normalization layers.
final_dropout (`bool` defaults to `False`):
Whether to apply a final dropout after the last feed-forward layer.
ff_inner_dim (`int`, *optional*, defaults to `None`):
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
ff_bias (`bool`, defaults to `True`):
Whether or not to use bias in Feed-forward layer.
attention_out_bias (`bool`, defaults to `True`):
Whether or not to use bias in Attention output projection layer.
"""
def __init__(
self,
block_idx: int,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
time_embed_dim: int,
block_interval: int = 2,
dropout: float = 0.0,
activation_fn: str = "gelu-approximate",
attention_bias: bool = False,
qk_norm: bool = True,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
final_dropout: bool = True,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
finetune_init: bool = False,
):
super().__init__()
# 1. Self Attention
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.attn1 = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=attention_bias,
out_bias=attention_out_bias,
processor=CogVideoXAttnProcessor2_0(),
)
# if False: ## for finetuning stage w/o loading pretrained checkpoints
if not finetune_init and (block_idx%block_interval==0):
self.attn_injector = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=attention_bias,
out_bias=attention_out_bias,
processor=CogVideoXAttnProcessor2_0(),
)
self.pose_fuse_layer = nn.Linear(12, dim)
self.attn_null_feature = nn.Parameter(torch.zeros([dim]))
# 2. Feed Forward
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
empty_encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
pose_embeds: Optional[torch.FloatTensor] = None,
prompt_entities_embeds: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
# attention
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
if (pose_embeds is not None) and hasattr(self, "attn_injector") :
# 1. norm & modulate
norm_hidden_states, norm_empty_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(hidden_states, empty_encoder_hidden_states, temb)
bz, N_visual, dim = norm_hidden_states.shape
max_entity_num = 3
_, entity_num, num_frames, _ = pose_embeds.shape
# 2. pair-wise fusion of trajectory and entity
attn_input = self.attn_null_feature.repeat(bz, max_entity_num, 50, num_frames, 1)
pose_embeds = self.pose_fuse_layer(pose_embeds)
attn_input[:,:entity_num,:,:,:] = pose_embeds.unsqueeze(-3) + prompt_entities_embeds.unsqueeze(-2)
attn_input = torch.cat((
rearrange(norm_hidden_states, "b (n t) d -> b n t d",n=num_frames),
rearrange(attn_input, "b n t f d -> b f (n t) d")),
dim=2
).flatten(1,2)
# 3. gated self-attention
attn_hidden_states, attn_encoder_hidden_states = self.attn_injector(
hidden_states=attn_input,
encoder_hidden_states=norm_empty_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
attn_hidden_states = attn_hidden_states[:,:N_visual,:]
hidden_states = hidden_states + gate_msa * attn_hidden_states
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
hidden_states, encoder_hidden_states, temb
)
# feed-forward
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
return hidden_states, encoder_hidden_states
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
Parameters:
num_attention_heads (`int`, defaults to `30`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `64`):
The number of channels in each head.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `16`):
The number of channels in the output.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
time_embed_dim (`int`, defaults to `512`):
Output dimension of timestep embeddings.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
num_layers (`int`, defaults to `30`):
The number of layers of Transformer blocks to use.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
attention_bias (`bool`, defaults to `True`):
Whether or not to use bias in the attention projection layers.
sample_width (`int`, defaults to `90`):
The width of the input latents.
sample_height (`int`, defaults to `60`):
The height of the input latents.
sample_frames (`int`, defaults to `49`):
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
patch_size (`int`, defaults to `2`):
The size of the patches to use in the patch embedding layer.
temporal_compression_ratio (`int`, defaults to `4`):
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
max_text_seq_length (`int`, defaults to `226`):
The maximum sequence length of the input text embeddings.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to use in feed-forward.
timestep_activation_fn (`str`, defaults to `"silu"`):
Activation function to use when generating the timestep embeddings.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether or not to use elementwise affine in normalization layers.
norm_eps (`float`, defaults to `1e-5`):
The epsilon value to use in normalization layers.
spatial_interpolation_scale (`float`, defaults to `1.875`):
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
temporal_interpolation_scale (`float`, defaults to `1.0`):
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 30,
attention_head_dim: int = 64,
in_channels: int = 16,
out_channels: Optional[int] = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
text_embed_dim: int = 4096,
num_layers: int = 30,
dropout: float = 0.0,
attention_bias: bool = True,
sample_width: int = 90,
sample_height: int = 60,
sample_frames: int = 49,
patch_size: int = 2,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
use_rotary_positional_embeddings: bool = False,
use_learned_positional_embeddings: bool = False,
finetune_init: bool = False,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
raise ValueError(
"There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
"embeddings. If you're using a custom model and/or believe this should be supported, please open an "
"issue at https://github.com/huggingface/diffusers/issues."
)
# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
text_embed_dim=text_embed_dim,
bias=True,
sample_width=sample_width,
sample_height=sample_height,
sample_frames=sample_frames,
temporal_compression_ratio=temporal_compression_ratio,
max_text_seq_length=max_text_seq_length,
spatial_interpolation_scale=spatial_interpolation_scale,
temporal_interpolation_scale=temporal_interpolation_scale,
use_positional_embeddings=not use_rotary_positional_embeddings,
use_learned_positional_embeddings=use_learned_positional_embeddings,
)
self.embedding_dropout = nn.Dropout(dropout)
# 2. Time embeddings
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
# 3. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
CogVideoXBlock(
block_idx=idx,
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
finetune_init=finetune_init,
)
for idx in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
# 4. Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
This API is 🧪 experimental.
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
empty_encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
pose_embeds: Optional[torch.FloatTensor] = None,
prompt_entities_embeds: Optional[torch.FloatTensor] = None,
prompt_entities_attention_mask: Optional[torch.FloatTensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_frames, channels, height, width = hidden_states.shape
# 1. Time embedding
timesteps = timestep
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
# 2. Patch embedding
hidden_states, prompt_entities_embeds, empty_encoder_hidden_states = self.patch_embed(empty_encoder_hidden_states, encoder_hidden_states, hidden_states, prompt_entities_embeds, prompt_entities_attention_mask)
hidden_states = self.embedding_dropout(hidden_states)
text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
empty_encoder_hidden_states,
emb,
pose_embeds,
prompt_entities_embeds,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
empty_encoder_hidden_states=empty_encoder_hidden_states,
temb=emb,
pose_embeds=pose_embeds,
prompt_entities_embeds=prompt_entities_embeds,
image_rotary_emb=image_rotary_emb,
)
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 4. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 5. Unpatchify
# Note: we use `-1` instead of `channels`:
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
================================================
FILE: CogVideo/finetune/models/embeddings.py
================================================
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.utils import deprecate
from diffusers.models.activations import FP32SiLU, get_activation
from diffusers.models.attention_processor import Attention
from einops import rearrange, repeat
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
Args
timesteps (torch.Tensor):
a 1-D Tensor of N indices, one per batch element. These may be fractional.
embedding_dim (int):
the dimension of the output.
flip_sin_to_cos (bool):
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
downscale_freq_shift (float):
Controls the delta between frequencies between dimensions
scale (float):
Scaling factor applied to the embeddings.
max_period (int):
Controls the maximum frequency of the embeddings
Returns
torch.Tensor: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def get_3d_sincos_pos_embed(
embed_dim: int,
spatial_size: Union[int, Tuple[int, int]],
temporal_size: int,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
) -> np.ndarray:
r"""
Args:
embed_dim (`int`):
spatial_size (`int` or `Tuple[int, int]`):
temporal_size (`int`):
spatial_interpolation_scale (`float`, defaults to 1.0):
temporal_interpolation_scale (`float`, defaults to 1.0):
"""
if embed_dim % 4 != 0:
raise ValueError("`embed_dim` must be divisible by 4")
if isinstance(spatial_size, int):
spatial_size = (spatial_size, spatial_size)
embed_dim_spatial = 3 * embed_dim // 4
embed_dim_temporal = embed_dim // 4
# 1. Spatial
grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
# 2. Temporal
grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
# 3. Concat
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3]
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4]
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D]
return pos_embed
def get_2d_sincos_pos_embed(
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
):
"""
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_size = (grid_size, grid_size)
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding with support for SD3 cropping."""
def __init__(
self,
height=224,
width=224,
patch_size=16,
in_channels=3,
embed_dim=768,
layer_norm=False,
flatten=True,
bias=True,
interpolation_scale=1,
pos_embed_type="sincos",
pos_embed_max_size=None, # For SD3 cropping
):
super().__init__()
num_patches = (height // patch_size) * (width // patch_size)
self.flatten = flatten
self.layer_norm = layer_norm
self.pos_embed_max_size = pos_embed_max_size
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
if layer_norm:
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
else:
self.norm = None
self.patch_size = patch_size
self.height, self.width = height // patch_size, width // patch_size
self.base_size = height // patch_size
self.interpolation_scale = interpolation_scale
# Calculate positional embeddings based on max size or default
if pos_embed_max_size:
grid_size = pos_embed_max_size
else:
grid_size = int(num_patches**0.5)
if pos_embed_type is None:
self.pos_embed = None
elif pos_embed_type == "sincos":
pos_embed = get_2d_sincos_pos_embed(
embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
)
persistent = True if pos_embed_max_size else False
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
else:
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
def cropped_pos_embed(self, height, width):
"""Crops positional embeddings for SD3 compatibility."""
if self.pos_embed_max_size is None:
raise ValueError("`pos_embed_max_size` must be set for cropping.")
height = height // self.patch_size
width = width // self.patch_size
if height > self.pos_embed_max_size:
raise ValueError(
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
)
if width > self.pos_embed_max_size:
raise ValueError(
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
)
top = (self.pos_embed_max_size - height) // 2
left = (self.pos_embed_max_size - width) // 2
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
return spatial_pos_embed
def forward(self, latent):
if self.pos_embed_max_size is not None:
height, width = latent.shape[-2:]
else:
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
latent = self.proj(latent)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
if self.pos_embed is None:
return latent.to(latent.dtype)
# Interpolate or crop positional embeddings as needed
if self.pos_embed_max_size:
pos_embed = self.cropped_pos_embed(height, width)
else:
if self.height != height or self.width != width:
pos_embed = get_2d_sincos_pos_embed(
embed_dim=self.pos_embed.shape[-1],
grid_size=(height, width),
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
else:
pos_embed = self.pos_embed
return (latent + pos_embed).to(latent.dtype)
class LuminaPatchEmbed(nn.Module):
"""2D Image to Patch Embedding with support for Lumina-T2X"""
def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Linear(
in_features=patch_size * patch_size * in_channels,
out_features=embed_dim,
bias=bias,
)
def forward(self, x, freqs_cis):
"""
Patchifies and embeds the input tensor(s).
Args:
x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded.
Returns:
Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified
and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the
frequency tensor(s).
"""
freqs_cis = freqs_cis.to(x[0].device)
patch_height = patch_width = self.patch_size
batch_size, channel, height, width = x.size()
height_tokens, width_tokens = height // patch_height, width // patch_width
x = x.view(batch_size, channel, height_tokens, patch_height, width_tokens, patch_width).permute(
0, 2, 4, 1, 3, 5
)
x = x.flatten(3)
x = self.proj(x)
x = x.flatten(1, 2)
mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device)
return (
x,
mask,
[(height, width)] * batch_size,
freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0),
)
class CogVideoXPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
embed_dim: int = 1920,
text_embed_dim: int = 4096,
bias: bool = True,
sample_width: int = 90,
sample_height: int = 60,
sample_frames: int = 49,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
use_positional_embeddings: bool = True,
use_learned_positional_embeddings: bool = True,
) -> None:
super().__init__()
self.patch_size = patch_size
self.embed_dim = embed_dim
self.sample_height = sample_height
self.sample_width = sample_width
self.sample_frames = sample_frames
self.temporal_compression_ratio = temporal_compression_ratio
self.max_text_seq_length = max_text_seq_length
self.spatial_interpolation_scale = spatial_interpolation_scale
self.temporal_interpolation_scale = temporal_interpolation_scale
self.use_positional_embeddings = use_positional_embeddings
self.use_learned_positional_embeddings = use_learned_positional_embeddings
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
if use_positional_embeddings or use_learned_positional_embeddings:
persistent = use_learned_positional_embeddings
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
post_patch_height = sample_height // self.patch_size
post_patch_width = sample_width // self.patch_size
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
pos_embedding = get_3d_sincos_pos_embed(
self.embed_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
self.spatial_interpolation_scale,
self.temporal_interpolation_scale,
)
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
joint_pos_embedding = torch.zeros(
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
)
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
return joint_pos_embedding
def forward(self, empty_text_embeds: torch.Tensor, text_embeds: torch.Tensor, image_embeds: torch.Tensor, prompt_entities_embeds: torch.Tensor, prompt_entities_attention_mask: torch.Tensor):
r"""
Args:
text_embeds (`torch.Tensor`):
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
image_embeds (`torch.Tensor`):
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
"""
batch, num_frames, channels, height, width = image_embeds.shape
empty_text_embeds = self.text_proj(empty_text_embeds)
text_embeds = self.text_proj(text_embeds)
prompt_entities_embeds = self.text_proj(prompt_entities_embeds)
prompt_entities_embeds[prompt_entities_attention_mask<1.] = 0.
prompt_entities_embeds = prompt_entities_embeds[:,:,:50,:]
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
embeds = torch.cat(
[text_embeds, image_embeds], dim=1
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
raise ValueError(
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
"If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
)
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
if (
self.sample_height != height
or self.sample_width != width
or self.sample_frames != pre_time_compression_frames
):
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
else:
pos_embedding = self.pos_embedding
embeds = embeds + pos_embedding
return embeds, prompt_entities_embeds, empty_text_embeds
class CogView3PlusPatchEmbed(nn.Module):
def __init__(
self,
in_channels: int = 16,
hidden_size: int = 2560,
patch_size: int = 2,
text_hidden_size: int = 4096,
pos_embed_max_size: int = 128,
):
super().__init__()
self.in_channels = in_channels
self.hidden_size = hidden_size
self.patch_size = patch_size
self.text_hidden_size = text_hidden_size
self.pos_embed_max_size = pos_embed_max_size
# Linear projection for image patches
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
# Linear projection for text embeddings
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, channel, height, width = hidden_states.shape
if height % self.patch_size != 0 or width % self.patch_size != 0:
raise ValueError("Height and width must be divisible by patch size")
height = height // self.patch_size
width = width // self.patch_size
hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)
# Project the patches
hidden_states = self.proj(hidden_states)
encoder_hidden_states = self.text_proj(encoder_hidden_states)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# Calculate text_length
text_length = encoder_hidden_states.shape[1]
image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
text_pos_embed = torch.zeros(
(text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
)
pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...]
return (hidden_states + pos_embed).to(hidden_states.dtype)
def get_3d_rotary_pos_embed(
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
RoPE for video tokens with 3D structure.
Args:
embed_dim: (`int`):
The embedding dimension size, corresponding to hidden_size_head.
crops_coords (`Tuple[int]`):
The top-left and bottom-right coordinates of the crop.
grid_size (`Tuple[int]`):
The grid size of the spatial positional embedding (height, width).
temporal_size (`int`):
The size of the temporal dimension.
theta (`float`):
Scaling factor for frequency computation.
Returns:
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
"""
if use_real is not True:
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
start, stop = crops_coords
grid_size_h, grid_size_w = grid_size
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
# Compute dimensions for each axis
dim_t = embed_dim // 4
dim_h = embed_dim // 8 * 3
dim_w = embed_dim // 8 * 3
# Temporal frequencies
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
# Spatial frequencies for height and width
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
freqs_t = freqs_t[:, None, None, :].expand(
-1, grid_size_h, grid_size_w, -1
) # temporal_size, grid_size_h, grid_size_w, dim_t
freqs_h = freqs_h[None, :, None, :].expand(
temporal_size, -1, grid_size_w, -1
) # temporal_size, grid_size_h, grid_size_2, dim_h
freqs_w = freqs_w[None, None, :, :].expand(
temporal_size, grid_size_h, -1, -1
) # temporal_size, grid_size_h, grid_size_2, dim_w
freqs = torch.cat(
[freqs_t, freqs_h, freqs_w], dim=-1
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
freqs = freqs.view(
temporal_size * grid_size_h * grid_size_w, -1
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
return freqs
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
cos = combine_time_height_width(t_cos, h_cos, w_cos)
sin = combine_time_height_width(t_sin, h_sin, w_sin)
return cos, sin
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
"""
RoPE for image tokens with 2d structure.
Args:
embed_dim: (`int`):
The embedding dimension size
crops_coords (`Tuple[int]`)
The top-left and bottom-right coordinates of the crop.
grid_size (`Tuple[int]`):
The grid size of the positional embedding.
use_real (`bool`):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Returns:
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
"""
start, stop = crops_coords
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0) # [2, W, H]
grid = grid.reshape([2, 1, *grid.shape[1:]])
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
return pos_embed
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
assert embed_dim % 4 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_rotary_pos_embed(
embed_dim // 2, grid[0].reshape(-1), use_real=use_real
) # (H*W, D/2) if use_real else (H*W, D/4)
emb_w = get_1d_rotary_pos_embed(
embed_dim // 2, grid[1].reshape(-1), use_real=use_real
) # (H*W, D/2) if use_real else (H*W, D/4)
if use_real:
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
return cos, sin
else:
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
return emb
def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
assert embed_dim % 4 == 0
emb_h = get_1d_rotary_pos_embed(
embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
) # (H, D/4)
emb_w = get_1d_rotary_pos_embed(
embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
) # (W, D/4)
emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
return emb
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
data type.
Args:
dim (`int`): Dimension of the frequency tensor.
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
theta (`float`, *optional*, defaults to 10000.0):
Scaling factor for frequency computation. Defaults to 10000.0.
use_real (`bool`, *optional*):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
linear_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the context extrapolation. Defaults to 1.0.
ntk_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
Otherwise, they are concateanted with themselves.
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
the dtype of the frequency tensor.
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
assert dim % 2 == 0
if isinstance(pos, int):
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S]
theta = theta * ntk_factor
freqs = (
1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# stable audio
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# lumina
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
def apply_rotary_emb(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True,
use_real_unbind_dim: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
if use_real:
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
# Used for flux, cogvideox, hunyuan-dit
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Used for Stable Audio
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
else:
# used for lumina
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
return x_out.type_as(x)
class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
self.act = get_activation(act_fn)
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
if post_act_fn is None:
self.post_act = None
else:
self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
)
return t_emb
class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels."""
def __init__(
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
):
super().__init__()
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
self.log = log
self.flip_sin_to_cos = flip_sin_to_cos
if set_W_to_weight:
# to delete later
del self.weight
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
self.weight = self.W
del self.W
def forward(self, x):
if self.log:
x = torch.log(x)
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
if self.flip_sin_to_cos:
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
else:
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
return out
class SinusoidalPositionalEmbedding(nn.Module):
"""Apply positional information to a sequence of embeddings.
Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
them
Args:
embed_dim: (int): Dimension of the positional embedding.
max_seq_length: Maximum sequence length to apply positional embeddings
"""
def __init__(self, embed_dim: int, max_seq_length: int = 32):
super().__init__()
position = torch.arange(max_seq_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
pe = torch.zeros(1, max_seq_length, embed_dim)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
def forward(self, x):
_, seq_length, _ = x.shape
x = x + self.pe[:, :seq_length]
return x
class ImagePositionalEmbeddings(nn.Module):
"""
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
height and width of the latent space.
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
For VQ-diffusion:
Output vector embeddings are used as input for the transformer.
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
Args:
num_embed (`int`):
Number of embeddings for the latent pixels embeddings.
height (`int`):
Height of the latent image i.e. the number of height embeddings.
width (`int`):
Width of the latent image i.e. the number of width embeddings.
embed_dim (`int`):
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
"""
def __init__(
self,
num_embed: int,
height: int,
width: int,
embed_dim: int,
):
super().__init__()
self.height = height
self.width = width
self.num_embed = num_embed
self.embed_dim = embed_dim
self.emb = nn.Embedding(self.num_embed, embed_dim)
self.height_emb = nn.Embedding(self.height, embed_dim)
self.width_emb = nn.Embedding(self.width, embed_dim)
def forward(self, index):
emb = self.emb(index)
height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
# 1 x H x D -> 1 x H x 1 x D
height_emb = height_emb.unsqueeze(2)
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
# 1 x W x D -> 1 x 1 x W x D
width_emb = width_emb.unsqueeze(1)
pos_emb = height_emb + width_emb
# 1 x H x W x D -> 1 x L xD
pos_emb = pos_emb.view(1, self.height * self.width, -1)
emb = emb + pos_emb[:, : emb.shape[1], :]
return emb
class LabelEmbedding(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
Args:
num_classes (`int`): The number of classes.
hidden_size (`int`): The size of the vector embeddings.
dropout_prob (`float`): The probability of dropping a label.
"""
def __init__(self, num_classes, hidden_size, dropout_prob):
super().__init__()
use_cfg_embedding = dropout_prob > 0
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def token_drop(self, labels, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
else:
drop_ids = torch.tensor(force_drop_ids == 1)
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels: torch.LongTensor, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (self.training and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
embeddings = self.embedding_table(labels)
return embeddings
class TextImageProjection(nn.Module):
def __init__(
self,
text_embed_dim: int = 1024,
image_embed_dim: int = 768,
cross_attention_dim: int = 768,
num_image_text_embeds: int = 10,
):
super().__init__()
self.num_image_text_embeds = num_image_text_embeds
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
batch_size = text_embeds.shape[0]
# image
image_text_embeds = self.image_embeds(image_embeds)
image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
# text
text_embeds = self.text_proj(text_embeds)
return torch.cat([image_text_embeds, text_embeds], dim=1)
class ImageProjection(nn.Module):
def __init__(
self,
image_embed_dim: int = 768,
cross_attention_dim: int = 768,
num_image_text_embeds: int = 32,
):
super().__init__()
self.num_image_text_embeds = num_image_text_embeds
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
self.norm = nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds: torch.Tensor):
batch_size = image_embeds.shape[0]
# image
image_embeds = self.image_embeds(image_embeds)
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
image_embeds = self.norm(image_embeds)
return image_embeds
class IPAdapterFullImageProjection(nn.Module):
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
super().__init__()
from .attention import FeedForward
self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
self.norm = nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds: torch.Tensor):
return self.norm(self.ff(image_embeds))
class IPAdapterFaceIDImageProjection(nn.Module):
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
super().__init__()
from .attention import FeedForward
self.num_tokens = num_tokens
self.cross_attention_dim = cross_attention_dim
self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu")
self.norm = nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds: torch.Tensor):
x = self.ff(image_embeds)
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
return self.norm(x)
class CombinedTimestepLabelEmbeddings(nn.Module):
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
def forward(self, timestep, class_labels, hidden_dtype=None):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
class_labels = self.class_embedder(class_labels) # (N, D)
conditioning = timesteps_emb + class_labels # (N, D)
return conditioning
class CombinedTimestepTextProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
def forward(self, timestep, pooled_projection):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
pooled_projections = self.text_embedder(pooled_projection)
conditioning = timesteps_emb + pooled_projections
return conditioning
class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
def forward(self, timestep, guidance, pooled_projection):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
time_guidance_emb = timesteps_emb + guidance_emb
pooled_projections = self.text_embedder(pooled_projection)
conditioning = time_guidance_emb + pooled_projections
return conditioning
class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
super().__init__()
self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
def forward(
self,
timestep: torch.Tensor,
original_size: torch.Tensor,
target_size: torch.Tensor,
crop_coords: torch.Tensor,
hidden_dtype: torch.dtype,
) -> torch.Tensor:
timesteps_proj = self.time_proj(timestep)
original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1)
crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
# (B, 3 * condition_dim)
condition_proj = torch.cat([original_size_proj, crop_coords_proj, target_size_proj], dim=1)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
conditioning = timesteps_emb + condition_emb
return conditioning
class HunyuanDiTAttentionPool(nn.Module):
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.permute(1, 0, 2) # NLC -> LNC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
x, _ = F.multi_head_attention_forward(
query=x[:1],
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False,
)
return x.squeeze(0)
class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
def __init__(
self,
embedding_dim,
pooled_projection_dim=1024,
seq_len=256,
cross_attention_dim=2048,
use_style_cond_and_image_meta_size=True,
):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.pooler = HunyuanDiTAttentionPool(
seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
)
# Here we use a default learned embedder layer for future extension.
self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size
if use_style_cond_and_image_meta_size:
self.style_embedder = nn.Embedding(1, embedding_dim)
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
else:
extra_in_dim = pooled_projection_dim
self.extra_embedder = PixArtAlphaTextProjection(
in_features=extra_in_dim,
hidden_size=embedding_dim * 4,
out_features=embedding_dim,
act_fn="silu_fp32",
)
def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256)
# extra condition1: text
pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
if self.use_style_cond_and_image_meta_size:
# extra condition2: image meta size embedding
image_meta_size = self.size_proj(image_meta_size.view(-1))
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
# extra condition3: style embedding
style_embedding = self.style_embedder(style) # (N, embedding_dim)
# Concatenate all extra vectors
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
else:
extra_cond = torch.cat([pooled_projections], dim=1)
conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
return conditioning
class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256):
super().__init__()
self.time_proj = Timesteps(
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
)
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
self.caption_embedder = nn.Sequential(
nn.LayerNorm(cross_attention_dim),
nn.Linear(
cross_attention_dim,
hidden_size,
bias=True,
),
)
def forward(self, timestep, caption_feat, caption_mask):
# timestep embedding:
time_freq = self.time_proj(timestep)
time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype))
# caption condition embedding:
caption_mask_float = caption_mask.float().unsqueeze(-1)
caption_feats_pool = (caption_feat * caption_mask_float).sum(dim=1) / caption_mask_float.sum(dim=1)
caption_feats_pool = caption_feats_pool.to(caption_feat)
caption_embed = self.caption_embedder(caption_feats_pool)
conditioning = time_embed + caption_embed
return conditioning
class TextTimeEmbedding(nn.Module):
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
super().__init__()
self.norm1 = nn.LayerNorm(encoder_dim)
self.pool = AttentionPooling(num_heads, encoder_dim)
self.proj = nn.Linear(encoder_dim, time_embed_dim)
self.norm2 = nn.LayerNorm(time_embed_dim)
def forward(self, hidden_states):
hidden_states = self.norm1(hidden_states)
hidden_states = self.pool(hidden_states)
hidden_states = self.proj(hidden_states)
hidden_states = self.norm2(hidden_states)
return hidden_states
class TextImageTimeEmbedding(nn.Module):
def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
super().__init__()
self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
self.text_norm = nn.LayerNorm(time_embed_dim)
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
# text
time_text_embeds = self.text_proj(text_embeds)
time_text_embeds = self.text_norm(time_text_embeds)
# image
time_image_embeds = self.image_proj(image_embeds)
return time_image_embeds + time_text_embeds
class ImageTimeEmbedding(nn.Module):
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
super().__init__()
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
self.image_norm = nn.LayerNorm(time_embed_dim)
def forward(self, image_embeds: torch.Tensor):
# image
time_image_embeds = self.image_proj(image_embeds)
time_image_embeds = self.image_norm(time_image_embeds)
return time_image_embeds
class ImageHintTimeEmbedding(nn.Module):
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
super().__init__()
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
self.image_norm = nn.LayerNorm(time_embed_dim)
self.input_hint_block = nn.Sequential(
nn.Conv2d(3, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 32, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(32, 32, 3, padding=1),
nn.SiLU(),
nn.Conv2d(32, 96, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(96, 96, 3, padding=1),
nn.SiLU(),
nn.Conv2d(96, 256, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(256, 4, 3, padding=1),
)
def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor):
# image
time_image_embeds = self.image_proj(image_embeds)
time_image_embeds = self.image_norm(time_image_embeds)
hint = self.input_hint_block(hint)
return time_image_embeds, hint
class AttentionPooling(nn.Module):
# Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
def __init__(self, num_heads, embed_dim, dtype=None):
super().__init__()
self.dtype = dtype
self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
self.num_heads = num_heads
self.dim_per_head = embed_dim // self.num_heads
def forward(self, x):
bs, length, width = x.size()
def shape(x):
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, -1, self.num_heads, self.dim_per_head)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
# (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
x = x.transpose(1, 2)
return x
class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
# (bs*n_heads, class_token_length, dim_per_head)
q = shape(self.q_proj(class_token))
# (bs*n_heads, length+class_token_length, dim_per_head)
k = shape(self.k_proj(x))
v = shape(self.v_proj(x))
# (bs*n_heads, class_token_length, length+class_token_length):
scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
# (bs*n_heads, dim_per_head, class_token_length)
a = torch.einsum("bts,bcs->bct", weight, v)
# (bs, length+1, width)
a = a.reshape(bs, -1, 1).transpose(1, 2)
return a[:, 0, :] # cls_token
def get_fourier_embeds_from_boundingbox(embed_dim, box):
"""
Args:
embed_dim: int
box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
Returns:
[B x N x embed_dim] tensor of positional embeddings
"""
batch_size, num_boxes = box.shape[:2]
emb = 100 ** (torch.arange(embed_dim) / embed_dim)
emb = emb[None, None, None].to(device=box.device, dtype=box.dtype)
emb = emb * box.unsqueeze(-1)
emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4)
return emb
class GLIGENTextBoundingboxProjection(nn.Module):
def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8):
super().__init__()
self.positive_len = positive_len
self.out_dim = out_dim
self.fourier_embedder_dim = fourier_freqs
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
if isinstance(out_dim, tuple):
out_dim = out_dim[0]
if feature_type == "text-only":
self.linears = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
elif feature_type == "text-image":
self.linears_text = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.linears_image = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
def forward(
self,
boxes,
masks,
positive_embeddings=None,
phrases_masks=None,
image_masks=None,
phrases_embeddings=None,
image_embeddings=None,
):
masks = masks.unsqueeze(-1)
# embedding position (it may includes padding as placeholder)
xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C
# learnable null embedding
xyxy_null = self.null_position_feature.view(1, 1, -1)
# replace padding with learnable null embedding
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
# positionet with text only information
if positive_embeddings is not None:
# learnable null embedding
positive_null = self.null_positive_feature.view(1, 1, -1)
# replace padding with learnable null embedding
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
# positionet with text and image information
else:
phrases_masks = phrases_masks.unsqueeze(-1)
image_masks = image_masks.unsqueeze(-1)
# learnable null embedding
text_null = self.null_text_feature.view(1, 1, -1)
image_null = self.null_image_feature.view(1, 1, -1)
# replace padding with learnable null embedding
phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null
image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null
objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1))
objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1))
objs = torch.cat([objs_text, objs_image], dim=1)
return objs
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
"""
For PixArt-Alpha.
Reference:
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
"""
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
super().__init__()
self.outdim = size_emb_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.use_additional_conditions = use_additional_conditions
if use_additional_conditions:
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
if self.use_additional_conditions:
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
else:
conditioning = timesteps_emb
return conditioning
class PixArtAlphaTextProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
super().__init__()
if out_features is None:
out_features = hidden_size
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
if act_fn == "gelu_tanh":
self.act_1 = nn.GELU(approximate="tanh")
elif act_fn == "silu":
self.act_1 = nn.SiLU()
elif act_fn == "silu_fp32":
self.act_1 = FP32SiLU()
else:
raise ValueError(f"Unknown activation function: {act_fn}")
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class IPAdapterPlusImageProjectionBlock(nn.Module):
def __init__(
self,
embed_dims: int = 768,
dim_head: int = 64,
heads: int = 16,
ffn_ratio: float = 4,
) -> None:
super().__init__()
from .attention import FeedForward
self.ln0 = nn.LayerNorm(embed_dims)
self.ln1 = nn.LayerNorm(embed_dims)
self.attn = Attention(
query_dim=embed_dims,
dim_head=dim_head,
heads=heads,
out_bias=False,
)
self.ff = nn.Sequential(
nn.LayerNorm(embed_dims),
FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
)
def forward(self, x, latents, residual):
encoder_hidden_states = self.ln0(x)
latents = self.ln1(latents)
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
latents = self.attn(latents, encoder_hidden_states) + residual
latents = self.ff(latents) + latents
return latents
class IPAdapterPlusImageProjection(nn.Module):
"""Resampler of IP-Adapter Plus.
Args:
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
that is the same
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
hidden_dims (int):
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
Defaults to 16. num_queries (int):
The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
of feedforward network hidden
layer channels. Defaults to 4.
"""
def __init__(
self,
embed_dims: int = 768,
output_dims: int = 1024,
hidden_dims: int = 1280,
depth: int = 4,
dim_head: int = 64,
heads: int = 16,
num_queries: int = 8,
ffn_ratio: float = 4,
) -> None:
super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)
self.proj_in = nn.Linear(embed_dims, hidden_dims)
self.proj_out = nn.Linear(hidden_dims, output_dims)
self.norm_out = nn.LayerNorm(output_dims)
self.layers = nn.ModuleList(
[IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x (torch.Tensor): Input Tensor.
Returns:
torch.Tensor: Output Tensor.
"""
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
for block in self.layers:
residual = latents
latents = block(x, latents, residual)
latents = self.proj_out(latents)
return self.norm_out(latents)
class IPAdapterFaceIDPlusImageProjection(nn.Module):
"""FacePerceiverResampler of IP-Adapter Plus.
Args:
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
that is the same
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
hidden_dims (int):
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
ffproj_ratio (float): The expansion ratio of feedforward network hidden
layer channels (for ID embeddings). Defaults to 4.
"""
def __init__(
self,
embed_dims: int = 768,
output_dims: int = 768,
hidden_dims: int = 1280,
id_embeddings_dim: int = 512,
depth: int = 4,
dim_head: int = 64,
heads: int = 16,
num_tokens: int = 4,
num_queries: int = 8,
ffn_ratio: float = 4,
ffproj_ratio: int = 2,
) -> None:
super().__init__()
from .attention import FeedForward
self.num_tokens = num_tokens
self.embed_dim = embed_dims
self.clip_embeds = None
self.shortcut = False
self.shortcut_scale = 1.0
self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio)
self.norm = nn.LayerNorm(embed_dims)
self.proj_in = nn.Linear(hidden_dims, embed_dims)
self.proj_out = nn.Linear(embed_dims, output_dims)
self.norm_out = nn.LayerNorm(output_dims)
self.layers = nn.ModuleList(
[IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
)
def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
id_embeds (torch.Tensor): Input Tensor (ID embeds).
Returns:
torch.Tensor: Output Tensor.
"""
id_embeds = id_embeds.to(self.clip_embeds.dtype)
id_embeds = self.proj(id_embeds)
id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim)
id_embeds = self.norm(id_embeds)
latents = id_embeds
clip_embeds = self.proj_in(self.clip_embeds)
x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3])
for block in self.layers:
residual = latents
latents = block(x, latents, residual)
latents = self.proj_out(latents)
out = self.norm_out(latents)
if self.shortcut:
out = id_embeds + self.shortcut_scale * out
return out
class MultiIPAdapterImageProjection(nn.Module):
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
super().__init__()
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
def forward(self, image_embeds: List[torch.Tensor]):
projected_image_embeds = []
# currently, we accept `image_embeds` as
# 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim]
# 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim]
if not isinstance(image_embeds, list):
deprecation_message = (
"You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning."
)
deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False)
image_embeds = [image_embeds.unsqueeze(1)]
if len(image_embeds) != len(self.image_projection_layers):
raise ValueError(
f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}"
)
for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers):
batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
image_embed = image_projection_layer(image_embed)
image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
projected_image_embeds.append(image_embed)
return projected_image_embeds
================================================
FILE: CogVideo/finetune/models/pipeline_cogvideox.py
================================================
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import T5EncoderModel, T5Tokenizer
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.loaders import CogVideoXLoraLoaderMixin
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from diffusers.utils import logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
from models.pipeline_output import CogVideoXPipelineOutput
from models.cogvideox_transformer_3d import CogVideoXTransformer3DModel
from einops import rearrange
import diffusers
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.optimization import get_scheduler
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import CogVideoXPipeline
>>> from diffusers.utils import export_to_video
>>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
>>> prompt = (
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
... "atmosphere of this unique musical performance."
... )
>>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
>>> export_to_video(video, "output.mp4", fps=8)
```
"""
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
r"""
Pipeline for text-to-video generation using CogVideoX.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. CogVideoX uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
tokenizer (`T5Tokenizer`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`CogVideoXTransformer3DModel`]):
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
"""
_optional_components = []
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
]
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
vae: AutoencoderKLCogVideoX,
transformer: CogVideoXTransformer3DModel,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
):
super().__init__()
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_spatial = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
)
self.vae_scaling_factor_image = (
self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds, prompt_attention_mask
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds, _ = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds, _ = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, negative_prompt_embeds
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
shape = (
batch_size,
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
num_channels_latents,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / self.vae_scaling_factor_image * latents
frames = self.vae.decode(latents).sample
return frames
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def fuse_qkv_projections(self) -> None:
r"""Enables fused QKV projections."""
self.fusing_transformer = True
self.transformer.fuse_qkv_projections()
def unfuse_qkv_projections(self) -> None:
r"""Disable QKV projection fusion if enabled."""
if not self.fusing_transformer:
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
else:
self.transformer.unfuse_qkv_projections()
self.fusing_transformer = False
def _prepare_rotary_positional_embeddings(
self,
height: int,
width: int,
num_frames: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
@property
def guidance_scale(self):
return self._guidance_scale
@property
def num_timesteps(self):
return self._num_timesteps
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompts_list: List[str] = None,
pose_embeds: Optional[torch.FloatTensor] = None,
height: int = 480,
width: int = 720,
num_frames: int = 49,
annealed_sample_step: int = 15,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
use_dynamic_cfg: bool = False,
num_videos_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: str = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 226,
) -> Union[CogVideoXPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
The height in pixels of the generated image. This is set to 480 by default for the best results.
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
The width in pixels of the generated image. This is set to 720 by default for the best results.
num_frames (`int`, defaults to `48`):
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
needs to be satisfied is that of divisibility mentioned above.
annealed_sample_step (`int`, *optional*, defaults to 15):
The number of annealed sampling steps that inject paired poses and objects motions into the base T2V models.
In the latter steps, the object injector will be discarded into traditional T2V.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `226`):
Maximum sequence length in encoded prompt. Must be consistent with
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
Examples:
Returns:
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
if num_frames > 49:
raise ValueError(
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
# 2. Default call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
negative_prompt,
do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
entity_num = len(prompts_list)
prompt_entities_embeds = []
prompt_entities_attention_mask = []
for idx in range(entity_num):
prompt_entity_embeds, prompt_entity_attention_mask = self._get_t5_prompt_embeds(
prompt=prompts_list[idx],
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
prompt_entities_embeds.append(prompt_entity_embeds)
prompt_entities_attention_mask.append(prompt_entity_attention_mask)
prompt_entities_embeds = rearrange(torch.stack(prompt_entities_embeds), "n b t d -> b n t d")
prompt_entities_attention_mask = rearrange(torch.stack(prompt_entities_attention_mask), "n b t -> b n t")
empty_encoder_hidden_states, _ = self._get_t5_prompt_embeds(
prompt="",
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
self._num_timesteps = len(timesteps)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
latent_channels,
num_frames,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
# for DPM-solver++
old_pred_original_sample = None
for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output w/ annealed sampling
if i <= annealed_sample_step:
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
empty_encoder_hidden_states=empty_encoder_hidden_states,
prompt_entities_embeds=prompt_entities_embeds,
prompt_entities_attention_mask=prompt_entities_attention_mask,
pose_embeds=pose_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
else:
noise_pred = self.transformer_ori(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
empty_encoder_hidden_states=empty_encoder_hidden_states,
prompt_entities_embeds=prompt_entities_embeds,
prompt_entities_attention_mask=prompt_entities_attention_mask,
pose_embeds=None,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# perform guidance
if use_dynamic_cfg:
self._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents,
**extra_step_kwargs,
return_dict=False,
)
latents = latents.to(prompt_embeds.dtype)
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return CogVideoXPipelineOutput(frames=video)
================================================
FILE: CogVideo/finetune/models/pipeline_output.py
================================================
from dataclasses import dataclass
import torch
from diffusers.utils import BaseOutput
@dataclass
class CogVideoXPipelineOutput(BaseOutput):
r"""
Output class for CogVideo pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor
================================================
FILE: CogVideo/finetune/models/utils.py
================================================
import os
from typing import Callable, Dict, List, Optional, Union
import torch
from huggingface_hub.utils import validate_hf_hub_args
from diffusers.utils import (
USE_PEFT_BACKEND,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
deprecate,
get_adapter_name,
get_peft_kwargs,
is_peft_available,
is_peft_version,
is_torch_version,
is_transformers_available,
is_transformers_version,
logging,
scale_lora_layers,
)
from diffusers.loaders.lora_base import LoraBaseMixin
from diffusers.loaders.lora_conversion_utils import (
_convert_kohya_flux_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
_maybe_map_sgm_blocks_to_diffusers,
)
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
if is_torch_version(">=", "1.9.0"):
if (
is_peft_available()
and is_peft_version(">=", "0.13.1")
and is_transformers_available()
and is_transformers_version(">", "4.45.2")
):
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
if is_transformers_available():
from diffusers.models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
logger = logging.get_logger(__name__)
TEXT_ENCODER_NAME = "text_encoder"
UNET_NAME = "unet"
TRANSFORMER_NAME = "transformer"
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
return_alphas: bool = False,
**kwargs,
):
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = cls._fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
# TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
is_kohya = any(".lora_down.weight" in k for k in state_dict)
if is_kohya:
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
# Kohya already takes care of scaling the LoRA parameters with alpha.
return (state_dict, None) if return_alphas else state_dict
is_xlabs = any("processor" in k for k in state_dict)
if is_xlabs:
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
# xlabs doesn't use `alpha`.
return (state_dict, None) if return_alphas else state_dict
# For state dicts like
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
keys = list(state_dict.keys())
network_alphas = {}
for k in keys:
if "alpha" in k:
alpha_value = state_dict.get(k)
if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
alpha_value, float
):
network_alphas[k] = state_dict.pop(k)
else:
raise ValueError(
f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
)
if return_alphas:
return state_dict, network_alphas
else:
return state_dict
pretrained_model_name_or_path_or_dict = '/ytech_m2v2_hdd/fuxiao/CogVideo/finetune/cogvideox5b-lora-single-node-r32/checkpoint-2000/pytorch_lora_weights.safetensors'
state_dict, network_alphas = lora_state_dict(pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs)
import ipdb; ipdb.set_trace()
================================================
FILE: CogVideo/finetune/train_cogvideox_injector.py
================================================
"""
Adapted from CogVideoX-5B: https://github.com/THUDM/CogVideo by Xiao Fu (CUHK)
"""
import argparse
import logging
import math
import os
import shutil
from pathlib import Path
from typing import List, Optional, Tuple, Union
from torch import nn
import copy
import torch
import json
import numpy as np
import random
import cv2
import decord
from einops import rearrange
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
from safetensors.torch import save_file
import diffusers
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler
from models.pipeline_cogvideox import CogVideoXPipeline
from models.cogvideox_transformer_3d import CogVideoXTransformer3DModel
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.optimization import get_scheduler
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
from diffusers.training_utils import (
cast_training_params,
free_memory,
)
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
def get_args():
parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.")
# Model information
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
# Dataset information
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help=(
"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--instance_data_root",
type=str,
default=None,
help=("A folder containing the training data."),
)
parser.add_argument(
"--id_token", type=str, default=None, help="Identifier token appended to the start of each prompt if provided."
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
# Validation
parser.add_argument(
"--validation_prompt",
type=str,
default=None,
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
)
parser.add_argument(
"--validation_prompt_separator",
type=str,
default=":::",
help="String that separates multiple validation prompts",
)
parser.add_argument(
"--num_validation_videos",
type=int,
default=1,
help="Number of videos that should be generated during validation per `validation_prompt`.",
)
parser.add_argument(
"--validation_epochs",
type=int,
default=50,
help=(
"Run validation every X epochs. Validation consists of running the prompt `args.validation_prompt` multiple times: `args.num_validation_videos`."
),
)
parser.add_argument(
"--guidance_scale",
type=float,
default=6,
help="The guidance scale to use while sampling validation videos.",
)
parser.add_argument(
"--use_dynamic_cfg",
action="store_true",
default=False,
help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.",
)
# Training information
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--lora_path",
type=str,
default=None,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--lora_scale",
type=float,
default=1.0,
help=("The scaling factor to scale LoRA weight update.`"),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="cogvideox-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--height",
type=int,
default=480,
help="All input videos are resized to this height.",
)
parser.add_argument(
"--width",
type=int,
default=720,
help="All input videos are resized to this width.",
)
parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
parser.add_argument(
"--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
)
parser.add_argument(
"--skip_frames_start",
type=int,
default=0,
help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.",
)
parser.add_argument(
"--skip_frames_end",
type=int,
default=0,
help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.",
)
parser.add_argument(
"--random_flip",
action="store_true",
help="whether to randomly flip videos horizontally",
)
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.",
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--dim",
type=int,
default=3072,
help="dimension of each basic transformer block.",
)
parser.add_argument(
"--block_interval",
type=int,
default=3,
help="the injector at intervals in transformer blocks to reduce training parameters and improve inference speed.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
)
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
parser.add_argument(
"--enable_slicing",
action="store_true",
default=False,
help="Whether or not to use VAE slicing for saving memory.",
)
parser.add_argument(
"--enable_tiling",
action="store_true",
default=False,
help="Whether or not to use VAE tiling for saving memory.",
)
# Optimizer
parser.add_argument(
"--optimizer",
type=lambda s: s.lower(),
default="adam",
choices=["adam", "adamw", "prodigy"],
help=("The optimizer type to use."),
)
parser.add_argument(
"--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
)
parser.add_argument(
"--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
)
parser.add_argument(
"--adam_beta2", type=float, default=0.95, help="The beta2 parameter for the Adam and Prodigy optimizers."
)
parser.add_argument(
"--prodigy_beta3",
type=float,
default=None,
help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.",
)
parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay")
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
parser.add_argument(
"--adam_epsilon",
type=float,
default=1e-08,
help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--prodigy_use_bias_correction", action="store_true", help="Turn on Adam's bias correction.")
parser.add_argument(
"--prodigy_safeguard_warmup",
action="store_true",
help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.",
)
parser.add_argument(
"--finetune_init",
action="store_true",
help="Remove the injector n the first finetune stage w/o loading pretrained ckpt.",
)
# Other information
parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help="Directory where logs are stored.",
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--report_to",
type=str,
default=None,
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
return parser.parse_args()
def parse_matrix(matrix_str):
rows = matrix_str.strip().split('] [')
matrix = []
for row in rows:
row = row.replace('[', '').replace(']', '')
matrix.append(list(map(float, row.split())))
return np.array(matrix)
class VideoDataset(Dataset):
def __init__(
self,
instance_data_root: Optional[str] = None,
dataset_name: Optional[str] = None,
dataset_config_name: Optional[str] = None,
height: int = 480,
width: int = 720,
fps: int = 8,
max_num_frames: int = 49,
skip_frames_start: int = 0,
skip_frames_end: int = 0,
cache_dir: Optional[str] = None,
id_token: Optional[str] = None,
) -> None:
super().__init__()
self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None
self.dataset_name = dataset_name
self.dataset_config_name = dataset_config_name
self.height = height
self.width = width
self.sample_size = (self.height, self.width)
self.fps = fps
self.max_num_frames = max_num_frames
self.skip_frames_start = skip_frames_start
self.skip_frames_end = skip_frames_end
self.cache_dir = cache_dir
self.id_token = id_token or ""
self.pixel_transforms = [
transforms.Resize(self.sample_size),
transforms.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True
),
]
self.video_names = []
# ----------------------------------- Released Dataset -----------------------------------
scenes = ['Desert', 'HDRI']
scene_location_pair = {
'Desert' : 'desert',
'HDRI' :
{
'loc1' : 'snowy street',
'loc2' : 'park',
'loc3' : 'indoor open space',
'loc11' : 'gymnastics room',
'loc13' : 'autumn forest',
}
}
for scene in scenes:
video_path = os.path.join(self.instance_data_root, '480_720', scene)
video_names = os.listdir(video_path)
locations_path = os.path.join(video_path, "location_data.json")
with open(locations_path, 'r') as f: locations = json.load(f)
locations_info = {locations[idx]['name']:locations[idx] for idx in range(len(locations))}
for video_name in video_names:
if video_name.endswith('Hemi12_1') == True:
if scene != 'HDRI':
location = scene_location_pair[scene]
else:
location = scene_location_pair['HDRI'][video_name.split('_')[1]]
self.video_names.append((scene, video_name, location, locations_info))
# ----------------------------------- Internal Dataset -----------------------------------
# scenes = ['AsianTown_480_720', 'Desert_480_720', 'HDRI_480_720', 'Forest_480_720']
# scene_location_pair = {
# 'AsianTown_480_720' : 'asian town',
# 'Desert_480_720' : 'desert',
# 'Forest_480_720' : 'crossland',
# 'MatrixCity' : 'city',
# 'HDRI_480_720' :
# {
# 'loc1' : 'snowy street',
# 'loc2' : 'park',
# 'loc3' : 'indoor open space',
# 'loc11' : 'gymnastics room',
# 'loc13' : 'autumn forest',
# }
# }
# for scene in scenes:
# video_path = os.path.join(self.instance_data_root, scene)
# video_names = os.listdir(video_path)
# locations_path = os.path.join(video_path, "location_data.json")
# with open(locations_path, 'r') as f: locations = json.load(f)
# locations_info = {locations[idx]['name']:locations[idx] for idx in range(len(locations))}
# for video_name in video_names:
# if video_name.endswith('Hemi12_1') == True:
# if scene != 'HDRI_480_720':
# location = scene_location_pair[scene]
# else:
# location = scene_location_pair['HDRI_480_720'][video_name.split('_')[1]]
# self.video_names.append((scene, video_name, location, locations_info))
self.cam_num = 12
self.max_objs_num = 3
self.length = len(self.video_names)
self.captions_path = os.path.join(self.instance_data_root, "CharacterInfo.json")
with open(self.captions_path, 'r') as f: captions = json.load(f)['CharacterInfo']
self.captions_info = {int(captions[idx]['index']):captions[idx]['eng'] for idx in range(len(captions))}
self.cams_path = os.path.join(self.instance_data_root, "Hemi12_transforms.json")
with open(self.cams_path, 'r') as f: self.cams_info = json.load(f)
cam_poses = []
for i, key in enumerate(self.cams_info.keys()):
if "C_" in key:
cam_poses.append(parse_matrix(self.cams_info[key]))
cam_poses = np.stack(cam_poses)
cam_poses = np.transpose(cam_poses, (0,2,1))
cam_poses = cam_poses[:,:,[1,2,0,3]]
cam_poses[:,:3,3] /= 100.
self.cam_poses = cam_poses
self.sample_n_frames = 49
def __len__(self):
return self.length
def save_images2video(self, images, video_name):
fps = 8
format = "mp4"
codec = "libx264"
ffmpeg_params = ["-crf", str(12)]
pixelformat = "yuv420p"
video_stream = BytesIO()
with imageio.get_writer(
video_stream,
fps=fps,
format=format,
codec=codec,
ffmpeg_params=ffmpeg_params,
pixelformat=pixelformat,
) as writer:
for idx in range(len(images)):
writer.append_data(images[idx])
video_data = video_stream.getvalue()
output_path = os.path.join(video_name + ".mp4")
with open(output_path, "wb") as f:
f.write(video_data)
def __getitem__(self, idx):
while True:
try:
(scene, video_name, location, locations_info) = self.video_names[idx]
with open(os.path.join(self.instance_data_root, '480_720', scene, video_name, video_name+'.json'), 'r') as f: objs_file = json.load(f)
objs_num = len(objs_file['0'])
video_index = random.randint(1, self.cam_num-1)
location_name = video_name.split('_')[1]
location_info = locations_info[location_name]
cam_pose = self.cam_poses[video_index-1]
obj_transl = location_info['coordinates']['CameraTarget']['position']
video_caption_concat = ''
video_caption_list = []
obj_poses_list = []
for obj_idx in range(objs_num):
obj_name_index = objs_file['0'][obj_idx]['index']
video_caption = self.captions_info[obj_name_index]
if video_caption.startswith(" "):
video_caption = video_caption[1:]
if video_caption.endswith("."):
video_caption = video_caption[:-1]
video_caption = video_caption.lower()
video_caption_list.append(video_caption)
obj_poses = self.load_sceneposes(objs_file, obj_idx, obj_transl)
obj_poses = np.linalg.inv(cam_pose) @ obj_poses
obj_poses_list.append(obj_poses)
for obj_idx in range(objs_num):
video_caption = video_caption_list[obj_idx]
if obj_idx == objs_num - 1:
if objs_num == 1:
video_caption_concat += video_caption + ' is moving in the ' + location
else:
video_caption_concat += video_caption + ' are moving in the ' + location
else:
video_caption_concat += video_caption + ' and '
obj_poses_all = torch.from_numpy(np.array(obj_poses_list))
total_frames = 99
current_sample_stride = 1.75
cropped_length = int(self.sample_n_frames * current_sample_stride)
start_frame_ind = random.randint(10, max(10, total_frames - cropped_length - 1))
end_frame_ind = min(start_frame_ind + cropped_length, total_frames)
frame_indices = np.linspace(start_frame_ind, end_frame_ind - 1, self.sample_n_frames, dtype=int)
video_frames_path = os.path.join(self.instance_data_root, '480_720', scene, video_name, 'videos', video_name+ f'_C_{video_index:02d}_35mm.mp4')
cap = cv2.VideoCapture(video_frames_path)
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
# get local rank
ctx = decord.cpu(0)
reader = decord.VideoReader(video_frames_path, ctx=ctx, height=height, width=width)
assert len(reader) == total_frames or len(reader) == total_frames+1
frame_indexes = [frame_idx for frame_idx in range(total_frames)]
try:
video_chunk = reader.get_batch(frame_indexes).asnumpy()
except:
video_chunk = reader.get_batch(frame_indexes).numpy()
frame_inverse = torch.rand(1).item() > 0.5
if frame_inverse:
frame_indices = frame_indices[::-1]
video_name += '_inv'
pixel_values = np.array([video_chunk[indice] for indice in frame_indices])
pixel_values = rearrange(torch.from_numpy(pixel_values) / 255.0, "f h w c -> f c h w")
pixel_values = self.pixel_transforms[0](pixel_values)
pixel_values = self.pixel_transforms[1](pixel_values)
# interpolation
trunc_frame_indices = np.zeros_like(frame_indices[::4])
trunc_frame_indices[0] = frame_indices[0]
trunc_frame_indices[1:] = ((frame_indices[1:][::4] + frame_indices[4:][::4]) / 2).astype(np.int64)
obj_poses_all = obj_poses_all[:, trunc_frame_indices]
pose_embeds = rearrange(obj_poses_all[:, :, :3, :], "b f p q -> b f (q p)").contiguous()
break
except Exception as e:
(scene, video_name, location, locations_info) = self.video_names[idx]
with open(f'invalid_scene.txt', 'a+') as f:
f.write(f'{scene} {video_name} {location}')
f.write('\n')
idx = random.randint(0, self.length - 1)
return {
"prompt": video_caption_concat,
"prompts_list": video_caption_list,
"pose_embeds": pose_embeds,
"video": pixel_values,
}
def load_sceneposes(self, objs_file, obj_idx, obj_transl):
ext_poses = []
for i, key in enumerate(objs_file.keys()):
ext_poses.append(parse_matrix(objs_file[key][obj_idx]['matrix']))
ext_poses = np.stack(ext_poses)
ext_poses = np.transpose(ext_poses, (0,2,1))
ext_poses[:,:3,3] -= obj_transl
ext_poses[:,:3,3] /= 100.
ext_poses = ext_poses[:, :, [1,2,0,3]]
return ext_poses
def save_model_card(
repo_id: str,
videos=None,
base_model: str = None,
validation_prompt=None,
repo_folder=None,
fps=8,
):
widget_dict = []
if videos is not None:
for i, video in enumerate(videos):
export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
widget_dict.append(
{"text": validation_prompt if validation_prompt else " ", "output": {"url": f"video_{i}.mp4"}}
)
model_description = f"""
# CogVideoX LoRA - {repo_id}
## Model description
These are {repo_id} LoRA weights for {base_model}.
The weights were trained using the [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
Was LoRA for the text encoder enabled? No.
## Download model
[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
```py
from diffusers import CogVideoXPipeline
import torch
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name=["cogvideox-lora"])
# The LoRA adapter weights are determined by what was used for training.
# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64.
# It can be made lower or higher from what was used in training to decrease or amplify the effect
# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows.
pipe.set_adapters(["cogvideox-lora"], [32 / 64])
video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0]
```
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
## License
Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) and [here](https://huggingface.co/THUDM/CogVideoX-2b/blob/main/LICENSE).
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="other",
base_model=base_model,
prompt=validation_prompt,
model_description=model_description,
widget=widget_dict,
)
tags = [
"text-to-video",
"diffusers-training",
"diffusers",
"lora",
"cogvideox",
"cogvideox-diffusers",
"template:sd-lora",
]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def log_validation(
pipe,
args,
accelerator,
pipeline_args,
epoch,
is_final_validation: bool = False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if "variance_type" in pipe.scheduler.config:
variance_type = pipe.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
pipe = pipe.to(accelerator.device)
# pipe.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
videos = []
for _ in range(args.num_validation_videos):
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
videos.append(video)
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
if tracker.name == "wandb":
video_filenames = []
for i, video in enumerate(videos):
prompt = (
pipeline_args["prompt"][:25]
.replace(" ", "_")
.replace(" ", "_")
.replace("'", "_")
.replace('"', "_")
.replace("/", "_")
)
filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
export_to_video(video, filename, fps=8)
video_filenames.append(filename)
tracker.log(
{
phase_name: [
wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
for i, filename in enumerate(video_filenames)
]
}
)
free_memory()
return videos
def _get_t5_prompt_embeds(
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
prompt: Union[str, List[str]],
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
text_input_ids=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if tokenizer is not None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
prompt_attention_mask = text_inputs.attention_mask
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_attention_mask = prompt_attention_mask.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds, prompt_attention_mask
def encode_prompt(
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
prompt: Union[str, List[str]],
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
text_input_ids=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_embeds, prompt_attention_mask = _get_t5_prompt_embeds(
tokenizer,
text_encoder,
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
text_input_ids=text_input_ids,
)
return prompt_embeds, prompt_attention_mask
def compute_prompt_embeddings(
tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
):
if requires_grad:
prompt_embeds, prompt_attention_mask = encode_prompt(
tokenizer,
text_encoder,
prompt,
num_videos_per_prompt=1,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
else:
with torch.no_grad():
prompt_embeds, prompt_attention_mask = encode_prompt(
tokenizer,
text_encoder,
prompt,
num_videos_per_prompt=1,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, prompt_attention_mask
def prepare_rotary_positional_embeddings(
height: int,
width: int,
num_frames: int,
vae_scale_factor_spatial: int = 8,
patch_size: int = 2,
attention_head_dim: int = 64,
device: Optional[torch.device] = None,
base_height: int = 480,
base_width: int = 720,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * patch_size)
grid_width = width // (vae_scale_factor_spatial * patch_size)
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
# Use DeepSpeed optimzer
if use_deepspeed:
from accelerate.utils import DummyOptim
return DummyOptim(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
# Optimizer creation
supported_optimizers = ["adam", "adamw", "prodigy"]
if args.optimizer not in supported_optimizers:
logger.warning(
f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW"
)
args.optimizer = "adamw"
if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]):
logger.warning(
f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
if args.optimizer.lower() == "adamw":
optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
elif args.optimizer.lower() == "adam":
optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
elif args.optimizer.lower() == "prodigy":
try:
import prodigyopt
except ImportError:
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
decouple=args.prodigy_decouple,
use_bias_correction=args.prodigy_use_bias_correction,
safeguard_warmup=args.prodigy_safeguard_warmup,
)
return optimizer
def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
" Please use `huggingface-cli login` to authenticate with the Hub."
)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
kwargs_handlers=[kwargs],
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name,
exist_ok=True,
).repo_id
# Prepare models and scheduler
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)
text_encoder = T5EncoderModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
# CogVideoX-2b weights are stored in float16
# CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
pipe = CogVideoXPipeline.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=torch.bfloat16,
finetune_init=args.finetune_init,
)
if args.resume_from_checkpoint:
transformer = CogVideoXTransformer3DModel.from_pretrained(
args.resume_from_checkpoint,
torch_dtype=load_dtype,
finetune_init=args.finetune_init,
)
else:
transformer = CogVideoXTransformer3DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
torch_dtype=load_dtype,
revision=args.revision,
variant=args.variant,
finetune_init=args.finetune_init,
)
dim = args.dim
for idx in range(len(transformer.transformer_blocks)):
if idx%args.block_interval == 0:
transformer.transformer_blocks[idx].attn_injector = copy.deepcopy(transformer.transformer_blocks[idx].attn1)
transformer.transformer_blocks[idx].pose_fuse_layer = nn.Linear(12, dim)
transformer.transformer_blocks[idx].attn_null_feature = nn.Parameter(torch.zeros([dim]))
pipe.transformer = transformer
if args.lora_path:
pipe.load_lora_weights(args.lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="default")
pipe.fuse_lora(components=['transformer'] ,lora_scale=args.lora_scale)
transformer = pipe.transformer
vae = AutoencoderKLCogVideoX.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
if args.enable_slicing:
vae.enable_slicing()
if args.enable_tiling:
vae.enable_tiling()
# We only train the additional injector layers
text_encoder.requires_grad_(False)
transformer.requires_grad_(False)
vae.requires_grad_(False)
for idx in range(len(transformer.transformer_blocks)):
if hasattr(transformer.transformer_blocks[idx], "attn_injector"):
for name, param_module in transformer.transformer_blocks[idx].attn_injector.named_modules():
for params in param_module.parameters():
params.requires_grad_(True)
if hasattr(transformer.transformer_blocks[idx], "pose_fuse_layer"):
for name, param_module in transformer.transformer_blocks[idx].pose_fuse_layer.named_modules():
for params in param_module.parameters():
params.requires_grad_(True)
if hasattr(transformer.transformer_blocks[idx], "attn_null_feature"):
transformer.transformer_blocks[idx].attn_null_feature.requires_grad_(True)
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.state.deepspeed_plugin:
# DeepSpeed is handling precision, use what's in the DeepSpeed config
if (
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
):
weight_dtype = torch.float16
if (
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
):
weight_dtype = torch.float16
else:
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
text_encoder.to(accelerator.device, dtype=weight_dtype)
transformer.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing()
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
from diffusers import CogVideoXTransformer3DModel
model_base = CogVideoXTransformer3DModel.from_pretrained(os.path.join(args.pretrained_model_name_or_path, 'transformer'), torch_dtype=torch.bfloat16)
model_save = copy.deepcopy(model)
for idx in range(len(model_save.transformer_blocks)):
if hasattr(model_save.transformer_blocks[idx], 'attn_injector'):
model_base.transformer_blocks[idx].attn_injector = model_save.transformer_blocks[idx].attn_injector
model_base.transformer_blocks[idx].pose_fuse_layer = model_save.transformer_blocks[idx].pose_fuse_layer
model_base.transformer_blocks[idx].attn_null_feature = model_save.transformer_blocks[idx].attn_null_feature
model_base.save_pretrained(os.path.join(output_dir))
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
def load_model_hook(models, input_dir):
transformer_ = None
while len(models) > 0:
model = models.pop()
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
else:
raise ValueError(f"Unexpected save model: {model.__class__}")
lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if args.mixed_precision == "fp16":
# only upcast trainable parameters (LoRA) into fp32
cast_training_params([transformer_])
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32 and torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
# only upcast trainable parameters (LoRA) into fp32
cast_training_params([transformer], dtype=torch.float32)
transformer_trainable_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
# Optimization parameters
transformer_parameters_with_lr = {"params": transformer_trainable_parameters, "lr": args.learning_rate}
params_to_optimize = [transformer_parameters_with_lr]
use_deepspeed_optimizer = (
accelerator.state.deepspeed_plugin is not None
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
)
use_deepspeed_scheduler = (
accelerator.state.deepspeed_plugin is not None
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
)
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
# Dataset and DataLoader
train_dataset = VideoDataset(
instance_data_root=args.instance_data_root,
dataset_name=args.dataset_name,
dataset_config_name=args.dataset_config_name,
height=args.height,
width=args.width,
fps=args.fps,
max_num_frames=args.max_num_frames,
skip_frames_start=args.skip_frames_start,
skip_frames_end=args.skip_frames_end,
cache_dir=args.cache_dir,
id_token=args.id_token,
)
def encode_video(video):
video = video.to(accelerator.device, dtype=vae.dtype)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(video).latent_dist.sample()
latent_dist = latent_dist * vae.config.scaling_factor
return latent_dist
train_dataloader = DataLoader(
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
num_workers=args.dataloader_num_workers,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
if use_deepspeed_scheduler:
from accelerate.utils import DummyScheduler
lr_scheduler = DummyScheduler(
name=args.lr_scheduler,
optimizer=optimizer,
total_num_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
)
else:
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
# Prepare everything with our `accelerator`.
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
tracker_name = args.tracker_name or "cogvideox-injector"
accelerator.init_trackers(tracker_name, config=vars(args))
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
logger.info("***** Running training *****")
logger.info(f" Num trainable parameters = {num_trainable_parameters}")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
logger.info(f" Num epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if not args.resume_from_checkpoint:
initial_global_step = 0
else:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the mos recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
# accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)
# For DeepSpeed training
model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
empty_prompt_embeds, _ = compute_prompt_embeddings(
tokenizer,
text_encoder,
"",
model_config.max_text_seq_length,
accelerator.device,
weight_dtype,
requires_grad=False,
)
for epoch in range(first_epoch, args.num_train_epochs):
transformer.train()
for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]
with accelerator.accumulate(models_to_accumulate):
model_input = encode_video(batch["video"]).permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
batch_size, num_frames, num_channels, height, width = model_input.shape
prompts = batch["prompt"]
prompts_list = batch["prompts_list"]
pose_embeds = batch["pose_embeds"].to(dtype=weight_dtype)
# encode prompts
prompt_embeds, _ = compute_prompt_embeddings(
tokenizer,
text_encoder,
prompts,
model_config.max_text_seq_length,
accelerator.device,
weight_dtype,
requires_grad=False,
)
# entity prompts
entity_num = len(prompts_list)
prompt_entities_embeds = []
prompt_entities_attention_mask = []
for idx in range(entity_num):
prompt_entity_embeds, prompt_entity_attention_mask = compute_prompt_embeddings(
tokenizer,
text_encoder,
prompts_list[idx],
model_config.max_text_seq_length,
accelerator.device,
weight_dtype,
requires_grad=False,
)
prompt_entities_embeds.append(prompt_entity_embeds.to(weight_dtype))
prompt_entities_attention_mask.append(prompt_entity_attention_mask.to(weight_dtype))
prompt_entities_embeds = rearrange(torch.stack(prompt_entities_embeds), "n b t d -> b n t d")
prompt_entities_attention_mask = rearrange(torch.stack(prompt_entities_attention_mask), "n b t -> b n t")
# empty prompt
empty_prompt_embeds = empty_prompt_embeds.repeat(batch_size, 1, 1)
# Sample noise that will be added to the latents
noise = torch.randn_like(model_input)
# Sample a random timestep for each image
timesteps = torch.randint(
0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device
)
timesteps = timesteps.long()
# Prepare rotary embeds
image_rotary_emb = (
prepare_rotary_positional_embeddings(
height=args.height,
width=args.width,
num_frames=num_frames,
vae_scale_factor_spatial=vae_scale_factor_spatial,
patch_size=model_config.patch_size,
attention_head_dim=model_config.attention_head_dim,
device=accelerator.device,
)
if model_config.use_rotary_positional_embeddings
else None
)
# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_model_input = scheduler.add_noise(model_input, noise, timesteps)
# Predict the noise residual
model_output = transformer(
hidden_states=noisy_model_input,
encoder_hidden_states=prompt_embeds,
empty_encoder_hidden_states=empty_prompt_embeds,
prompt_entities_embeds=prompt_entities_embeds,
prompt_entities_attention_mask=prompt_entities_attention_mask,
pose_embeds=pose_embeds,
timestep=timesteps,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps)
alphas_cumprod = scheduler.alphas_cumprod[timesteps]
weights = 1 / (1 - alphas_cumprod)
while len(weights.shape) < len(model_pred.shape):
weights = weights.unsqueeze(-1)
target = model_input
loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1)
loss = loss.mean()
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = transformer.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
if accelerator.state.deepspeed_plugin is None:
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
# Save the transformer layers
accelerator.wait_for_everyone()
accelerator.end_training()
if __name__ == "__main__":
args = get_args()
main(args)
================================================
FILE: CogVideo/finetune/train_cogvideox_lora.py
================================================
"""
Adapted from CogVideoX-5B: https://github.com/THUDM/CogVideo by Xiao Fu (CUHK)
"""
import argparse
import logging
import math
import os
import shutil
from pathlib import Path
from typing import List, Optional, Tuple, Union
import torch
import json
import numpy as np
import random
import cv2
import decord
from einops import rearrange
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
from io import BytesIO
import imageio.v2 as imageio
import diffusers
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.optimization import get_scheduler
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
from diffusers.training_utils import (
cast_training_params,
free_memory,
)
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
def get_args():
parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.")
# Model information
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
# Dataset information
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help=(
"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--instance_data_root",
type=str,
default=None,
help=("A folder containing the training data."),
)
parser.add_argument(
"--id_token", type=str, default=None, help="Identifier token appended to the start of each prompt if provided."
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
# Validation
parser.add_argument(
"--validation_prompt",
type=str,
default=None,
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
)
parser.add_argument(
"--validation_prompt_separator",
type=str,
default=":::",
help="String that separates multiple validation prompts",
)
parser.add_argument(
"--num_validation_videos",
type=int,
default=1,
help="Number of videos that should be generated during validation per `validation_prompt`.",
)
parser.add_argument(
"--validation_epochs",
type=int,
default=50,
help=(
"Run validation every X epochs. Validation consists of running the prompt `args.validation_prompt` multiple times: `args.num_validation_videos`."
),
)
parser.add_argument(
"--guidance_scale",
type=float,
default=6,
help="The guidance scale to use while sampling validation videos.",
)
parser.add_argument(
"--use_dynamic_cfg",
action="store_true",
default=False,
help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.",
)
# Training information
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--rank",
type=int,
default=128,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--lora_alpha",
type=float,
default=128,
help=("The scaling factor to scale LoRA weight update. The actual scaling factor is `lora_alpha / rank`"),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="cogvideox-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--height",
type=int,
default=480,
help="All input videos are resized to this height.",
)
parser.add_argument(
"--width",
type=int,
default=720,
help="All input videos are resized to this width.",
)
parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
parser.add_argument(
"--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
)
parser.add_argument(
"--skip_frames_start",
type=int,
default=0,
help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.",
)
parser.add_argument(
"--skip_frames_end",
type=int,
default=0,
help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.",
)
parser.add_argument(
"--random_flip",
action="store_true",
help="whether to randomly flip videos horizontally",
)
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.",
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
)
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
parser.add_argument(
"--enable_slicing",
action="store_true",
default=False,
help="Whether or not to use VAE slicing for saving memory.",
)
parser.add_argument(
"--enable_tiling",
action="store_true",
default=False,
help="Whether or not to use VAE tiling for saving memory.",
)
# Optimizer
parser.add_argument(
"--optimizer",
type=lambda s: s.lower(),
default="adam",
choices=["adam", "adamw", "prodigy"],
help=("The optimizer type to use."),
)
parser.add_argument(
"--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
)
parser.add_argument(
"--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
)
parser.add_argument(
"--adam_beta2", type=float, default=0.95, help="The beta2 parameter for the Adam and Prodigy optimizers."
)
parser.add_argument(
"--prodigy_beta3",
type=float,
default=None,
help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.",
)
parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay")
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
parser.add_argument(
"--adam_epsilon",
type=float,
default=1e-08,
help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--prodigy_use_bias_correction", action="store_true", help="Turn on Adam's bias correction.")
parser.add_argument(
"--prodigy_safeguard_warmup",
action="store_true",
help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.",
)
# Other information
parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help="Directory where logs are stored.",
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--report_to",
type=str,
default=None,
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
return parser.parse_args()
def parse_matrix(matrix_str):
rows = matrix_str.strip().split('] [')
matrix = []
for row in rows:
row = row.replace('[', '').replace(']', '')
matrix.append(list(map(float, row.split())))
return np.array(matrix)
class VideoDataset(Dataset):
def __init__(
self,
instance_data_root: Optional[str] = None,
dataset_name: Optional[str] = None,
dataset_config_name: Optional[str] = None,
height: int = 480,
width: int = 720,
fps: int = 8,
max_num_frames: int = 49,
skip_frames_start: int = 0,
skip_frames_end: int = 0,
cache_dir: Optional[str] = None,
id_token: Optional[str] = None,
) -> None:
super().__init__()
self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None
self.dataset_name = dataset_name
self.dataset_config_name = dataset_config_name
self.height = height
self.width = width
self.sample_size = (self.height, self.width)
self.fps = fps
self.max_num_frames = max_num_frames
self.skip_frames_start = skip_frames_start
self.skip_frames_end = skip_frames_end
self.cache_dir = cache_dir
self.id_token = id_token or ""
self.pixel_transforms = [
transforms.Resize(self.sample_size),
transforms.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True
),
]
self.video_names = []
# ----------------------------------- Released Dataset -----------------------------------
scenes = ['Desert', 'HDRI']
scene_location_pair = {
'Desert' : 'desert',
'HDRI' :
{
'loc1' : 'snowy street',
'loc2' : 'park',
'loc3' : 'indoor open space',
'loc11' : 'gymnastics room',
'loc13' : 'autumn forest',
}
}
for scene in scenes:
video_path = os.path.join(self.instance_data_root, '480_720', scene)
video_names = os.listdir(video_path)
locations_path = os.path.join(video_path, "location_data.json")
with open(locations_path, 'r') as f: locations = json.load(f)
locations_info = {locations[idx]['name']:locations[idx] for idx in range(len(locations))}
for video_name in video_names:
if video_name.endswith('Hemi12_1') == True:
if scene != 'HDRI':
location = scene_location_pair[scene]
else:
location = scene_location_pair['HDRI'][video_name.split('_')[1]]
self.video_names.append((scene, video_name, location, locations_info))
# ----------------------------------- Internal Dataset -----------------------------------
# scenes = ['AsianTown_480_720', 'Desert_480_720', 'HDRI_480_720', 'Forest_480_720']
# scene_location_pair = {
# 'AsianTown_480_720' : 'asian town',
# 'Desert_480_720' : 'desert',
# 'Forest_480_720' : 'crossland',
# 'MatrixCity' : 'city',
# 'HDRI_480_720' :
# {
# 'loc1' : 'snowy street',
# 'loc2' : 'park',
# 'loc3' : 'indoor open space',
# 'loc11' : 'gymnastics room',
# 'loc13' : 'autumn forest',
# }
# }
# for scene in scenes:
# video_path = os.path.join(self.instance_data_root, scene)
# video_names = os.listdir(video_path)
# locations_path = os.path.join(video_path, "location_data.json")
# with open(locations_path, 'r') as f: locations = json.load(f)
# locations_info = {locations[idx]['name']:locations[idx] for idx in range(len(locations))}
# for video_name in video_names:
# if video_name.endswith('Hemi12_1') == True:
# if scene != 'HDRI_480_720':
# location = scene_location_pair[scene]
# else:
# location = scene_location_pair['HDRI_480_720'][video_name.split('_')[1]]
# self.video_names.append((scene, video_name, location, locations_info))
self.cam_num = 12
self.max_objs_num = 3
self.length = len(self.video_names)
self.captions_path = os.path.join(self.instance_data_root, "CharacterInfo.json")
with open(self.captions_path, 'r') as f: captions = json.load(f)['CharacterInfo']
self.captions_info = {int(captions[idx]['index']):captions[idx]['eng'] for idx in range(len(captions))}
self.cams_path = os.path.join(self.instance_data_root, "Hemi12_transforms.json")
with open(self.cams_path, 'r') as f: self.cams_info = json.load(f)
cam_poses = []
for i, key in enumerate(self.cams_info.keys()):
if "C_" in key:
cam_poses.append(parse_matrix(self.cams_info[key]))
cam_poses = np.stack(cam_poses)
cam_poses = np.transpose(cam_poses, (0,2,1))
cam_poses = cam_poses[:,:,[1,2,0,3]]
cam_poses[:,:3,3] /= 100.
self.cam_poses = cam_poses
self.sample_n_frames = 49
def __len__(self):
return self.length
def save_images2video(self, images, video_name):
fps = 8
format = "mp4"
codec = "libx264"
ffmpeg_params = ["-crf", str(12)]
pixelformat = "yuv420p"
video_stream = BytesIO()
with imageio.get_writer(
video_stream,
fps=fps,
format=format,
codec=codec,
ffmpeg_params=ffmpeg_params,
pixelformat=pixelformat,
) as writer:
for idx in range(len(images)):
writer.append_data(images[idx])
video_data = video_stream.getvalue()
output_path = os.path.join(video_name + ".mp4")
with open(output_path, "wb") as f:
f.write(video_data)
def __getitem__(self, idx):
while True:
try:
(scene, video_name, location, locations_info) = self.video_names[idx]
with open(os.path.join(self.instance_data_root, '480_720', scene, video_name, video_name+'.json'), 'r') as f: objs_file = json.load(f)
objs_num = len(objs_file['0'])
video_index = random.randint(1, self.cam_num-1)
location_name = video_name.split('_')[1]
location_info = locations_info[location_name]
cam_pose = self.cam_poses[video_index-1]
obj_transl = location_info['coordinates']['CameraTarget']['position']
video_caption_concat = ''
video_caption_list = []
obj_poses_list = []
for obj_idx in range(objs_num):
obj_name_index = objs_file['0'][obj_idx]['index']
video_caption = self.captions_info[obj_name_index]
if video_caption.startswith(" "):
video_caption = video_caption[1:]
if video_caption.endswith("."):
video_caption = video_caption[:-1]
video_caption = video_caption.lower()
video_caption_list.append(video_caption)
obj_poses = self.load_sceneposes(objs_file, obj_idx, obj_transl)
obj_poses = np.linalg.inv(cam_pose) @ obj_poses
obj_poses_list.append(obj_poses)
for obj_idx in range(objs_num):
video_caption = video_caption_list[obj_idx]
if obj_idx == objs_num - 1:
if objs_num == 1:
video_caption_concat += video_caption + ' is moving in the ' + location
else:
video_caption_concat += video_caption + ' are moving in the ' + location
else:
video_caption_concat += video_caption + ' and '
obj_poses_all = torch.from_numpy(np.array(obj_poses_list))
total_frames = 99
current_sample_stride = 1.75
cropped_length = int(self.sample_n_frames * current_sample_stride)
start_frame_ind = random.randint(10, max(10, total_frames - cropped_length - 1))
end_frame_ind = min(start_frame_ind + cropped_length, total_frames)
frame_indices = np.linspace(start_frame_ind, end_frame_ind - 1, self.sample_n_frames, dtype=int)
video_frames_path = os.path.join(self.instance_data_root, '480_720', scene, video_name, 'videos', video_name+ f'_C_{video_index:02d}_35mm.mp4')
cap = cv2.VideoCapture(video_frames_path)
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
# get local rank
ctx = decord.cpu(0)
reader = decord.VideoReader(video_frames_path, ctx=ctx, height=height, width=width)
assert len(reader) == total_frames or len(reader) == total_frames+1
frame_indexes = [frame_idx for frame_idx in range(total_frames)]
try:
video_chunk = reader.get_batch(frame_indexes).asnumpy()
except:
video_chunk = reader.get_batch(frame_indexes).numpy()
frame_inverse = torch.rand(1).item() > 0.5
if frame_inverse:
frame_indices = frame_indices[::-1]
video_name += '_inv'
pixel_values = np.array([video_chunk[indice] for indice in frame_indices])
pixel_values = rearrange(torch.from_numpy(pixel_values) / 255.0, "f h w c -> f c h w")
pixel_values = self.pixel_transforms[0](pixel_values)
pixel_values = self.pixel_transforms[1](pixel_values)
# interpolation
trunc_frame_indices = np.zeros_like(frame_indices[::4])
trunc_frame_indices[0] = frame_indices[0]
trunc_frame_indices[1:] = ((frame_indices[1:][::4] + frame_indices[4:][::4]) / 2).astype(np.int64)
obj_poses_all = obj_poses_all[:, trunc_frame_indices]
pose_embeds = rearrange(obj_poses_all[:, :, :3, :], "b f p q -> b f (q p)").contiguous()
break
except Exception as e:
(scene, video_name, location, locations_info) = self.video_names[idx]
with open(f'invalid_scene.txt', 'a+') as f:
f.write(f'{scene} {video_name} {location}')
f.write('\n')
idx = random.randint(0, self.length - 1)
return {
"prompt": video_caption_concat,
"video": pixel_values,
"video_name": video_name,
}
def load_sceneposes(self, objs_file, obj_idx, obj_transl):
ext_poses = []
for i, key in enumerate(objs_file.keys()):
ext_poses.append(parse_matrix(objs_file[key][obj_idx]['matrix']))
ext_poses = np.stack(ext_poses)
ext_poses = np.transpose(ext_poses, (0,2,1))
ext_poses[:,:3,3] -= obj_transl
ext_poses[:,:3,3] /= 100.
ext_poses = ext_poses[:, :, [1,2,0,3]]
return ext_poses
def save_model_card(
repo_id: str,
videos=None,
base_model: str = None,
validation_prompt=None,
repo_folder=None,
fps=8,
):
widget_dict = []
if videos is not None:
for i, video in enumerate(videos):
export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
widget_dict.append(
{"text": validation_prompt if validation_prompt else " ", "output": {"url": f"video_{i}.mp4"}}
)
model_description = f"""
# CogVideoX LoRA - {repo_id}
## Model description
These are {repo_id} LoRA weights for {base_model}.
The weights were trained using the [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
Was LoRA for the text encoder enabled? No.
## Download model
[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
```py
from diffusers import CogVideoXPipeline
import torch
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name=["cogvideox-lora"])
# The LoRA adapter weights are determined by what was used for training.
# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64.
# It can be made lower or higher from what was used in training to decrease or amplify the effect
# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows.
pipe.set_adapters(["cogvideox-lora"], [32 / 64])
video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0]
```
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
## License
Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) and [here](https://huggingface.co/THUDM/CogVideoX-2b/blob/main/LICENSE).
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="other",
base_model=base_model,
prompt=validation_prompt,
model_description=model_description,
widget=widget_dict,
)
tags = [
"text-to-video",
"diffusers-training",
"diffusers",
"lora",
"cogvideox",
"cogvideox-diffusers",
"template:sd-lora",
]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def log_validation(
pipe,
args,
accelerator,
pipeline_args,
epoch,
is_final_validation: bool = False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if "variance_type" in pipe.scheduler.config:
variance_type = pipe.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
pipe = pipe.to(accelerator.device)
# pipe.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
videos = []
for _ in range(args.num_validation_videos):
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
videos.append(video)
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
if tracker.name == "wandb":
video_filenames = []
for i, video in enumerate(videos):
prompt = (
pipeline_args["prompt"][:25]
.replace(" ", "_")
.replace(" ", "_")
.replace("'", "_")
.replace('"', "_")
.replace("/", "_")
)
filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
export_to_video(video, filename, fps=8)
video_filenames.append(filename)
tracker.log(
{
phase_name: [
wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
for i, filename in enumerate(video_filenames)
]
}
)
free_memory()
return videos
def _get_t5_prompt_embeds(
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
prompt: Union[str, List[str]],
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
text_input_ids=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if tokenizer is not None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds
def encode_prompt(
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
prompt: Union[str, List[str]],
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
text_input_ids=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_embeds = _get_t5_prompt_embeds(
tokenizer,
text_encoder,
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
text_input_ids=text_input_ids,
)
return prompt_embeds
def compute_prompt_embeddings(
tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
):
if requires_grad:
prompt_embeds = encode_prompt(
tokenizer,
text_encoder,
prompt,
num_videos_per_prompt=1,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
else:
with torch.no_grad():
prompt_embeds = encode_prompt(
tokenizer,
text_encoder,
prompt,
num_videos_per_prompt=1,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds
def prepare_rotary_positional_embeddings(
height: int,
width: int,
num_frames: int,
vae_scale_factor_spatial: int = 8,
patch_size: int = 2,
attention_head_dim: int = 64,
device: Optional[torch.device] = None,
base_height: int = 480,
base_width: int = 720,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * patch_size)
grid_width = width // (vae_scale_factor_spatial * patch_size)
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
# Use DeepSpeed optimzer
if use_deepspeed:
from accelerate.utils import DummyOptim
return DummyOptim(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
# Optimizer creation
supported_optimizers = ["adam", "adamw", "prodigy"]
if args.optimizer not in supported_optimizers:
logger.warning(
f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW"
)
args.optimizer = "adamw"
if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]):
logger.warning(
f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
if args.optimizer.lower() == "adamw":
optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
elif args.optimizer.lower() == "adam":
optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
elif args.optimizer.lower() == "prodigy":
try:
import prodigyopt
except ImportError:
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
decouple=args.prodigy_decouple,
use_bias_correction=args.prodigy_use_bias_correction,
safeguard_warmup=args.prodigy_safeguard_warmup,
)
return optimizer
def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
" Please use `huggingface-cli login` to authenticate with the Hub."
)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
kwargs_handlers=[kwargs],
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name,
exist_ok=True,
).repo_id
# Prepare models and scheduler
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)
text_encoder = T5EncoderModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
# CogVideoX-2b weights are stored in float16
# CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
transformer = CogVideoXTransformer3DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
torch_dtype=load_dtype,
revision=args.revision,
variant=args.variant,
)
vae = AutoencoderKLCogVideoX.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
if args.enable_slicing:
vae.enable_slicing()
if args.enable_tiling:
vae.enable_tiling()
# We only train the additional adapter LoRA layers
text_encoder.requires_grad_(False)
transformer.requires_grad_(False)
vae.requires_grad_(False)
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.state.deepspeed_plugin:
# DeepSpeed is handling precision, use what's in the DeepSpeed config
if (
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
):
weight_dtype = torch.float16
if (
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
):
weight_dtype = torch.float16
else:
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
text_encoder.to(accelerator.device, dtype=weight_dtype)
transformer.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing()
# now we will add new LoRA weights to the attention layers
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.lora_alpha,
init_lora_weights=True,
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
transformer.add_adapter(transformer_lora_config)
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
CogVideoXPipeline.save_lora_weights(
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
)
def load_model_hook(models, input_dir):
transformer_ = None
while len(models) > 0:
model = models.pop()
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
else:
raise ValueError(f"Unexpected save model: {model.__class__}")
lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if args.mixed_precision == "fp16":
# only upcast trainable parameters (LoRA) into fp32
cast_training_params([transformer_])
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32 and torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
# only upcast trainable parameters (LoRA) into fp32
cast_training_params([transformer], dtype=torch.float32)
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
# Optimization parameters
transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
params_to_optimize = [transformer_parameters_with_lr]
use_deepspeed_optimizer = (
accelerator.state.deepspeed_plugin is not None
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
)
use_deepspeed_scheduler = (
accelerator.state.deepspeed_plugin is not None
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
)
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
# Dataset and DataLoader
train_dataset = VideoDataset(
instance_data_root=args.instance_data_root,
dataset_name=args.dataset_name,
dataset_config_name=args.dataset_config_name,
height=args.height,
width=args.width,
fps=args.fps,
max_num_frames=args.max_num_frames,
skip_frames_start=args.skip_frames_start,
skip_frames_end=args.skip_frames_end,
cache_dir=args.cache_dir,
id_token=args.id_token,
)
def encode_video(video):
video = video.to(accelerator.device, dtype=vae.dtype)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(video).latent_dist.sample()
latent_dist = latent_dist * vae.config.scaling_factor
return latent_dist
train_dataloader = DataLoader(
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
# collate_fn=collate_fn,
num_workers=args.dataloader_num_workers,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
if use_deepspeed_scheduler:
from accelerate.utils import DummyScheduler
lr_scheduler = DummyScheduler(
name=args.lr_scheduler,
optimizer=optimizer,
total_num_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
)
else:
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
# Prepare everything with our `accelerator`.
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
tracker_name = args.tracker_name or "cogvideox-lora"
accelerator.init_trackers(tracker_name, config=vars(args))
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
logger.info("***** Running training *****")
logger.info(f" Num trainable parameters = {num_trainable_parameters}")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
logger.info(f" Num epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if not args.resume_from_checkpoint:
initial_global_step = 0
else:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the mos recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)
# For DeepSpeed training
model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
for epoch in range(first_epoch, args.num_train_epochs):
transformer.train()
for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]
with accelerator.accumulate(models_to_accumulate):
model_input = encode_video(batch["video"]).permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
prompts = batch["prompt"]
# encode prompts
prompt_embeds = compute_prompt_embeddings(
tokenizer,
text_encoder,
prompts,
model_config.max_text_seq_length,
accelerator.device,
weight_dtype,
requires_grad=False,
)
# Sample noise that will be added to the latents
noise = torch.randn_like(model_input)
batch_size, num_frames, num_channels, height, width = model_input.shape
# Sample a random timestep for each image
timesteps = torch.randint(
0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device
)
timesteps = timesteps.long()
# Prepare rotary embeds
image_rotary_emb = (
prepare_rotary_positional_embeddings(
height=args.height,
width=args.width,
num_frames=num_frames,
vae_scale_factor_spatial=vae_scale_factor_spatial,
patch_size=model_config.patch_size,
attention_head_dim=model_config.attention_head_dim,
device=accelerator.device,
)
if model_config.use_rotary_positional_embeddings
else None
)
# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_model_input = scheduler.add_noise(model_input, noise, timesteps)
# Predict the noise residual
model_output = transformer(
hidden_states=noisy_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timesteps,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps)
alphas_cumprod = scheduler.alphas_cumprod[timesteps]
weights = 1 / (1 - alphas_cumprod)
while len(weights.shape) < len(model_pred.shape):
weights = weights.unsqueeze(-1)
target = model_input
loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1)
loss = loss.mean()
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = transformer.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
if accelerator.state.deepspeed_plugin is None:
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if accelerator.is_main_process:
if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
# Create pipeline
pipe = CogVideoXPipeline.from_pretrained(
args.pretrained_model_name_or_path,
transformer=unwrap_model(transformer),
text_encoder=unwrap_model(text_encoder),
vae=unwrap_model(vae),
scheduler=scheduler,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
for validation_prompt in validation_prompts:
pipeline_args = {
"prompt": validation_prompt,
"guidance_scale": args.guidance_scale,
"use_dynamic_cfg": args.use_dynamic_cfg,
"height": args.height,
"width": args.width,
}
validation_outputs = log_validation(
pipe=pipe,
args=args,
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
)
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = unwrap_model(transformer)
dtype = (
torch.float16
if args.mixed_precision == "fp16"
else torch.bfloat16
if args.mixed_precision == "bf16"
else torch.float32
)
transformer = transformer.to(dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
CogVideoXPipeline.save_lora_weights(
save_directory=args.output_dir,
transformer_lora_layers=transformer_lora_layers,
)
# Final test inference
pipe = CogVideoXPipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
if args.enable_slicing:
pipe.vae.enable_slicing()
if args.enable_tiling:
pipe.vae.enable_tiling()
# Load LoRA weights
lora_scaling = args.lora_alpha / args.rank
pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora")
pipe.set_adapters(["cogvideox-lora"], [lora_scaling])
# Run inference
validation_outputs = []
if args.validation_prompt and args.num_validation_videos > 0:
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
for validation_prompt in validation_prompts:
pipeline_args = {
"prompt": validation_prompt,
"guidance_scale": args.guidance_scale,
"use_dynamic_cfg": args.use_dynamic_cfg,
"height": args.height,
"width": args.width,
}
video = log_validation(
pipe=pipe,
args=args,
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
)
validation_outputs.extend(video)
if args.push_to_hub:
save_model_card(
repo_id,
videos=validation_outputs,
base_model=args.pretrained_model_name_or_path,
validation_prompt=args.validation_prompt,
repo_folder=args.output_dir,
fps=args.fps,
)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training()
if __name__ == "__main__":
args = get_args()
main(args)
================================================
FILE: CogVideo/inference/3dtrajmaster_inference.py
================================================
"""
Adapted from CogVideoX-5B: https://github.com/THUDM/CogVideo by Xiao Fu (CUHK)
"""
import argparse
from typing import Literal
import copy
import torch
from diffusers import (
CogVideoXDDIMScheduler,
CogVideoXDPMScheduler,
CogVideoXImageToVideoPipeline,
CogVideoXVideoToVideoPipeline,
)
import os
import json
import datetime
import sys
sys.path.append('../finetune')
from models.pipeline_cogvideox import CogVideoXPipeline
from models.cogvideox_transformer_3d import CogVideoXTransformer3DModel
from diffusers.utils import export_to_video, load_image, load_video
import json
import numpy as np
import random
from einops import rearrange
def parse_matrix(matrix_str):
rows = matrix_str.strip().split('] [')
matrix = []
for row in rows:
row = row.replace('[', '').replace(']', '')
matrix.append(list(map(float, row.split())))
return np.array(matrix)
def load_sceneposes(objs_file, obj_idx, obj_transl):
ext_poses = []
for i, key in enumerate(objs_file.keys()):
ext_poses.append(parse_matrix(objs_file[key][obj_idx]['matrix']))
ext_poses = np.stack(ext_poses)
ext_poses = np.transpose(ext_poses, (0,2,1))
ext_poses[:,:3,3] -= obj_transl
ext_poses[:,:3,3] /= 100.
ext_poses = ext_poses[:, :, [1,2,0,3]]
return ext_poses
def get_pose_embeds(scene, video_name, instance_data_root, locations_info, cam_poses):
with open(os.path.join(instance_data_root, "480_720", scene, video_name, video_name+'.json'), 'r') as f: objs_file = json.load(f)
objs_num = len(objs_file['0'])
video_index = 12
location_name = video_name.split('_')[1]
location_info = locations_info[location_name]
cam_pose = cam_poses[video_index-1]
obj_transl = location_info['coordinates']['CameraTarget']['position']
obj_poses_list = []
for obj_idx in range(objs_num):
obj_poses = load_sceneposes(objs_file, obj_idx, obj_transl)
obj_poses = np.linalg.inv(cam_pose) @ obj_poses
obj_poses_list.append(obj_poses)
obj_poses_all = torch.from_numpy(np.array(obj_poses_list))
total_frames = 99
sample_n_frames = 49
current_sample_stride = 1.75
start_frame_ind = 10
cropped_length = int(sample_n_frames * current_sample_stride)
end_frame_ind = min(start_frame_ind + cropped_length, total_frames)
frame_indices = np.linspace(start_frame_ind, end_frame_ind - 1, sample_n_frames, dtype=int)
# interpolation
trunc_frame_indices = np.zeros_like(frame_indices[::4])
trunc_frame_indices[0] = frame_indices[0]
trunc_frame_indices[1:] = ((frame_indices[1:][::4] + frame_indices[4:][::4])/2).astype(np.int64)
obj_poses_all = obj_poses_all[:, trunc_frame_indices]
pose_embeds = rearrange(obj_poses_all[:, :, :3, :], "n f p q -> n f (q p)").contiguous().to(torch.bfloat16)
return pose_embeds
def init_cam_poses(instance_data_root):
cam_num = 12
cams_path = os.path.join(instance_data_root, "Hemi12_transforms.json")
with open(cams_path, 'r') as f: cams_info = json.load(f)
cam_poses = []
for i, key in enumerate(cams_info.keys()):
if "C_" in key:
cam_poses.append(parse_matrix(cams_info[key]))
cam_poses = np.stack(cam_poses)
cam_poses = np.transpose(cam_poses, (0,2,1))
cam_poses = cam_poses[:,:,[1,2,0,3]]
cam_poses[:,:3,3] /= 100.
return cam_poses
def generate_video(
model_path: str,
ckpt_path: str,
lora_path: str = None,
lora_scale: float = 1.0,
output_path: str = "./output.mp4",
image_or_video_path: str = "",
annealed_sample_step: int = 15,
num_inference_steps: int = 50,
guidance_scale: float = 6.0,
num_videos_per_prompt: int = 1,
dtype: torch.dtype = torch.bfloat16,
generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video
seed: int = 42,
):
"""
Generates a video based on the given prompt and saves it to the specified path.
Parameters:
- prompt (str): The description of the video to be generated.
- model_path (str): The path of the pre-trained model to be used.
- lora_path (str): The path of the LoRA weights to be used.
- lora_scale (float):
- output_path (str): The path where the generated video will be saved.
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
- num_videos_per_prompt (int): Number of videos to generate per prompt.
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
- generate_type (str): The type of video generation (e.g., 't2v', 'i2v', 'v2v').·
- seed (int): The seed for reproducibility.
"""
# 1. Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16).
# add device_map="balanced" in the from_pretrained function and remove the enable_model_cpu_offload()
# function to use Multi GPUs.
transformer = CogVideoXTransformer3DModel.from_pretrained(ckpt_path, torch_dtype=dtype)
pipe = CogVideoXPipeline.from_pretrained(model_path,
transformer=transformer,
torch_dtype=dtype
)
pipe.transformer_ori = copy.deepcopy(pipe.transformer).to("cuda")
# If you're using with lora, add this code
if lora_path:
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="default")
pipe.fuse_lora(components=['transformer'] ,lora_scale=lora_scale)
# 2. Set Scheduler.
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
# 3. Enable CPU offload for the model.
pipe.to("cuda")
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
# 4. Load object poses, scene and object prompts
instance_data_root = "/m2v_intern/fuxiao/360Motion-Dataset"
scene = "Desert"
locations_path = os.path.join(instance_data_root, "480_720", scene, "location_data.json")
with open(locations_path, 'r') as f: locations = json.load(f)
locations_info = {locations[idx]['name']:locations[idx] for idx in range(len(locations))}
cam_poses = init_cam_poses(instance_data_root)
video_names = os.listdir(os.path.join(instance_data_root, "480_720", scene))
video_names.remove('location_data.json')
with open('./test_sets.json', 'r') as f: test_sets = json.load(f)
for idx in range(len(test_sets)):
eval_set = test_sets[str(idx)]
video_caption_list = eval_set['entity_prompts']
objs_num = len(video_caption_list)
location = eval_set['loc_prompt']
video_name = eval_set['video_name']
pose_embeds = get_pose_embeds(scene, video_name, instance_data_root, locations_info, cam_poses)
prompt = ""
for obj_idx in range(objs_num):
video_caption = video_caption_list[obj_idx]
if obj_idx == objs_num - 1:
if objs_num == 1:
prompt += video_caption + ' is moving in the ' + location
else:
prompt += video_caption + ' are moving in the ' + location
else:
prompt += video_caption + ' and '
# 5. Generate the video frames based on the prompt.
video_generate = pipe(
prompt=prompt,
prompts_list=video_caption_list,
pose_embeds=pose_embeds[None],
num_videos_per_prompt=num_videos_per_prompt,
annealed_sample_step=annealed_sample_step,
num_inference_steps=num_inference_steps,
num_frames=49,
use_dynamic_cfg=True,
guidance_scale=guidance_scale,
generator=torch.Generator().manual_seed(seed),
).frames[0]
# 6. Export the generated frames to a video file. fps must be 8 for original video.
save_video_name = ''
save_video_name += str(objs_num) + '_' + video_name + '_' + location + '_'
for obj_idx in range(objs_num):
video_caption = video_caption_list[obj_idx][:30]
video_caption = video_caption.replace(' ', '_')
save_video_name += video_caption + '_'
save_video_name += '.mp4'
save_video_name = save_video_name.replace('_.mp4', '.mp4')
save_video_path = os.path.join(output_path, save_video_name)
export_to_video(video_generate, save_video_path, fps=8)
with open(save_video_path.replace('.mp4', '.txt'), 'a+') as f:
f.write(video_name)
f.write('\n')
for obj_idx in range(objs_num):
f.write(video_caption_list[obj_idx])
f.write('\n')
f.write(location)
f.write('\n')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
parser.add_argument(
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used"
)
parser.add_argument(
"--ckpt_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained transformer to be used"
)
parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
parser.add_argument("--lora_scale", type=float, default=1.0)
parser.add_argument(
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
)
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
parser.add_argument(
"--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
)
parser.add_argument("--annealed_sample_step", type=int, default=15, help="Number of videos to generate per prompt")
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
parser.add_argument(
"--generate_type", type=str, default="t2v", help="The type of video generation (e.g., 't2v', 'i2v', 'v2v')"
)
parser.add_argument(
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
)
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
args = parser.parse_args()
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
os.makedirs(args.output_path, exist_ok=True)
generate_video(
model_path=args.model_path,
ckpt_path=args.ckpt_path,
lora_path=args.lora_path,
lora_scale=args.lora_scale,
output_path=args.output_path,
annealed_sample_step=args.annealed_sample_step,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
num_videos_per_prompt=args.num_videos_per_prompt,
dtype=dtype,
generate_type=args.generate_type,
seed=args.seed,
)
================================================
FILE: CogVideo/inference/entity_zoo.txt
================================================
[
"a fire spirit with long, twisting flames resembling flowing red and orange hair, a bright yellow core",
"a gentle breeze with soft tendrils of pale blue mist resembling flowing fabric, delicate streaks of white vapor",
"an ember haze with flickering orange and red flames giving off a warm glow, smoldering dark red wisps",
"a luminous storm cloud with layers of deep gray and blue churning together, intermittent flashes of lightning",
"a storm entity with dark swirling clouds as a body, streaks of electric blue lightning shooting across it",
"a cloud creature with billowing white and gray plumes forming a soft, rounded body, wisps of darker fog",
"a foggy apparition with pale gray wisps drifting together in a soft, undefined form, tiny white sparkles",
"a man with short spiky brown hair, athletic build, a navy blue jacket, beige cargo pants, and black sneakers",
"a woman with long wavy blonde hair, petite figure, a red floral dress, white sandals, and a yellow shoulder bag",
"a man with a shaved head, broad shoulders, a gray graphic t-shirt, dark jeans, and brown leather boots",
"a woman with shoulder-length straight auburn hair, a slender figure, a green button-up blouse, black leggings, white sneakers",
"a man with messy black hair, tall frame, a plaid red and black shirt, faded blue jeans, and tan hiking boots",
"a man with medium-length straight brown hair, tall and slender, a gray crew-neck t-shirt, beige trousers, dark green sneakers",
"a woman with short curly black hair, slender build, a pink hoodie, light gray joggers, and blue sneakers",
"a man with short black wavy hair, lean figure, a green and yellow plaid shirt, dark brown pants, and black suede shoes",
"a man with curly black hair, muscular build, a dark green hoodie, gray joggers, and white running shoes",
"a woman with short blonde hair, slim athletic build, a red leather jacket, dark blue jeans, and white sneakers",
"a man with medium-length wavy brown hair, lean build, a black bomber jacket, olive green cargo pants, and brown hiking boots",
"a man with buzz-cut blonde hair, stocky build, a gray zip-up sweater, black shorts, and red basketball shoes",
"a woman with long straight black hair, toned build, a blue denim jacket, light gray leggings, and black slip-on shoes",
"a man with short curly red hair, average build, a black leather jacket, dark blue cargo pants, and white sneakers",
"a woman with shoulder-length wavy brown hair, slim build, a green parka, black leggings, and gray hiking boots",
"a man with short straight black hair, tall and lean build, a navy blue sweater, khaki shorts, and brown sandals",
"a woman with pixie-cut blonde hair, athletic build, a red windbreaker, blue ripped jeans, and black combat boots",
"a man with medium-length wavy gray hair, muscular build, a maroon t-shirt, beige chinos, and brown loafers",
"a woman with long curly black hair, average build, a purple hoodie, black athletic shorts, and white running shoes",
"a man with short spiky blonde hair, slim build, a black trench coat, blue jeans, and brown hiking shoes",
"a dog with a fluffy coat, wagging tail, and warm golden-brown fur, exuding a gentle and friendly charm",
"a tiger with vibrant orange and black stripes, piercing yellow eyes, and a powerful stance, exuding strength and grace",
"a giraffe with golden-yellow fur, long legs, a tall slender neck, and patches of brown spots, exuding elegance and calm",
"an alpaca with soft white wool, short legs, a thick neck, and a fluffy head of fur, radiating gentle charm",
"a zebra with black and white stripes, sturdy legs, a short neck, and a sleek mane running down its back",
"a deer with sleek tan fur, long slender legs, a graceful neck, and tiny antlers atop its head",
"a gazelle with light golden fur, long slender legs, a thin neck, and short, sharp horns, embodying elegance and agility",
"a horse with chestnut brown fur, muscular legs, a slim neck, and a flowing mane, exuding strength and grace",
"a sleek black panther with a smooth, glossy coat, emerald green eyes, and a powerful stance",
"a cheetah with golden fur covered in black spots, intense amber eyes, and a slender, agile body",
"a regal lion with a thick, flowing golden mane, sharp brown eyes, and a powerful muscular frame",
"a snow leopard with pale gray fur adorned with dark rosettes, icy blue eyes, and a stealthy, poised posture",
"a jaguar with a golden-yellow coat dotted with intricate black rosettes, deep green eyes, and a muscular build",
"a wolf with thick silver-gray fur, alert golden eyes, and a lean yet strong body, exuding confidence and boldness",
"a tiger with a pristine white coat marked by bold black stripes, bright blue eyes, and a graceful, poised form",
"a lynx with tufted ears, soft reddish-brown fur with faint spots, and intense yellow-green eyes",
"a bear with dark brown fur, small but fierce black eyes, and a broad and muscular build, radiating power",
"a swift fox with reddish-orange fur, a bushy tail tipped with white, and sharp, intelligent amber eyes",
"a falcon with blue-gray feathers, sharp talons, and keen yellow eyes fixed on its prey below",
"a fox with sleek russet fur, a bushy tail tipped with black, and bright green and cunning eyes",
"a kangaroo with brown fur, powerful hind legs, and a muscular tail, showcasing its strength and agility",
"a polar bear with thick white fur, strong paws, and a black nose, embodying the essence of the Arctic",
"a cheetah with a slender build, spotted golden fur, and sharp eyes, epitomizing speed and agility",
"a dolphin with sleek grey skin, a curved dorsal fin, and intelligent, playful eyes, reflecting its nature",
"a wolf with a body covered in thick silver fur, sharp ears, and piercing yellow eyes, showcasing its alertness",
"a leopard with a body covered in golden fur, dark rosettes, and a long muscular tail, emphasizing its strength",
"a penguin with a body covered in smooth black-and-white feathers, short wings, and webbed feet",
"a gazelle with a body covered in sleek tan fur, long legs, and elegant curved horns, showcasing its grace",
"a rabbit with a body covered in soft fur, quick hops, and a playful demeanor, showcasing its energy",
"a koala with a body covered in soft grey fur, large round ears, and a black nose, radiating cuteness",
"a rhinoceros with a body covered in thick grey skin, a massive horn on its snout, and sturdy legs",
"a flamingo with a body covered in pink feathers, long slender legs, and a gracefully curved neck",
"a parrot with bright red, blue, and yellow feathers, a curved beak, and sharp intelligent eyes",
"a hippopotamus with a body covered in thick grey-brown skin, massive jaws, and a large body",
"a crocodile with a body covered in scaly green skin, a powerful tail, and sharp teeth",
"a moose with a body covered in thick brown fur, massive antlers, and a bulky frame",
"a chameleon with a body covered in vibrant green scales, bulging eyes, and a curled tail, showcasing its unique charm",
"a lemur with a body covered in soft grey fur, a ringed tail, and wide yellow eyes, and curious expression",
"a squirrel with a body covered in bushy red fur, large eyes, and a fluffy tail",
"a panda with a body covered in fluffy black-and-white fur, a round face, and gentle eyes, radiating warmth",
"a porcupine with a body covered in spiky brown quills, a small nose, and curious eyes",
"a sedan with a sleek metallic silver body, long wheelbase, a low-profile hood, and a small rear spoiler",
"a private jet with a shiny silver body, elongated wings, a slim nose, and a compact rear stabilizer",
"an SUV with a matte black exterior, elevated suspension, a tall roofline, and a compact rear roof rack",
"a pickup truck with rugged dark green paint, extended cab, raised suspension, and a modest cargo bed cover",
"a vintage convertible with a body covered in shiny red paint, chrome bumpers, and a stylish design",
"a futuristic electric car with a minimalist silver design, slim LED lights, and smooth curves",
"a family minivan with a spacious interior, sliding doors, and a metallic blue exterior",
"a compact electric vehicle with a silver finish, aerodynamic profile, and efficient battery",
"a sporty roadster with a convertible top, silver trim, and a powerful engine",
"a retro coupe with a body covered in teal paint, round headlights, and a shiny chrome grille",
"a firefighting robot with a water cannon arm, heat sensors, and durable red-and-silver exterior",
"a companion robot with a friendly digital face, a smooth white exterior, and social interaction algorithms",
"an industrial welding robot with articulated arms, a laser precision welder, and heat-resistant shields",
"a surveillance drone robot with extendable camera arms, thermal vision, and a stealth black body",
"a disaster rescue robot with reinforced limbs, advanced AI, and a rugged body designed to navigate",
"an exploration rover robot with solar panels, durable wheels, and advanced sensors for planetary exploration",
"a fluttering butterfly with intricate wing patterns, vivid colors, and graceful flight",
]
================================================
FILE: CogVideo/inference/location_zoo.txt
================================================
[
'fjord',
'sunset beach',
'cave',
'snowy tundra',
'prairie',
'asian town',
'rainforest',
'canyon',
'savanna',
'urban rooftop garden',
'swamp',
'riverbank',
'coral reef',
'volcanic landscape',
'wind farm',
'town street',
'night city square',
'mall lobby',
'glacier',
'seaside street',
'gymnastics room',
'abandoned factory',
'autumn forest',
'mountain village',
'coastal harbor',
'ancient ruins',
'modern metropolis',
'desert',
'forest',
'city',
'snowy street',
'park',
]
================================================
FILE: CogVideo/pyproject.toml
================================================
[tool.ruff]
line-length = 119
[tool.ruff.lint]
# Never enforce `E501` (line length violations).
ignore = ["C901", "E501", "E741", "F402", "F823"]
select = ["C", "E", "F", "I", "W"]
# Ignore import violations in all `__init__.py` files.
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["E402", "F401", "F403", "F811"]
[tool.ruff.lint.isort]
lines-after-imports = 2
[tool.ruff.format]
# Like Black, use double quotes for strings.
quote-style = "double"
# Like Black, indent with spaces, rather than tabs.
indent-style = "space"
# Like Black, respect magic trailing commas.
skip-magic-trailing-comma = false
# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"
================================================
FILE: CogVideo/requirements.txt
================================================
diffusers==0.31.0
accelerate==1.1.1
transformers==4.46.2
numpy==1.26.0
# torch==2.5.0
# torchvision==0.20.0
sentencepiece==0.2.0
SwissArmyTransformer==0.4.12
gradio==5.5.0
imageio==2.35.1
imageio-ffmpeg==0.5.1
openai==1.54.0
moviepy==1.0.3
scikit-video==1.1.11
opencv-python
peft==0.12.0
decord
wandb
================================================
FILE: CogVideo/tools/caption/README.md
================================================
# Video Caption
Typically, most video data does not come with corresponding descriptive text, so it is necessary to convert the video
data into textual descriptions to provide the essential training data for text-to-video models.
## Update and News
- 🔥🔥 **News**: ```2024/9/19```: The caption model used in the CogVideoX training process to convert video data into text
descriptions, [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption), is now open-source. Feel
free to download and use it.
## Video Caption via CogVLM2-Caption
🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-llama3-caption) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-llama3-caption/)
CogVLM2-Caption is a video captioning model used to generate training data for the CogVideoX model.
### Install
```shell
pip install -r requirements.txt
```
### Usage
```shell
python video_caption.py
```
Example:
## Video Caption via CogVLM2-Video
[Code](https://github.com/THUDM/CogVLM2/tree/main/video_demo) | 🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-video-llama3-chat) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-video-llama3-chat) | 📑 [Blog](https://cogvlm2-video.github.io/) | [💬 Online Demo](http://cogvlm2-online.cogviewai.cn:7868/)
CogVLM2-Video is a versatile video understanding model equipped with timestamp-based question answering capabilities.
Users can input prompts such as `Please describe this video in detail.` to the model to obtain a detailed video caption:
Users can use the provided [code](https://github.com/THUDM/CogVLM2/tree/main/video_demo) to load the model or configure a RESTful API to generate video captions.
## Citation
🌟 If you find our work helpful, please leave us a star and cite our paper.
CogVLM2-Caption:
```
@article{yang2024cogvideox,
title={CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer},
author={Yang, Zhuoyi and Teng, Jiayan and Zheng, Wendi and Ding, Ming and Huang, Shiyu and Xu, Jiazheng and Yang, Yuanming and Hong, Wenyi and Zhang, Xiaohan and Feng, Guanyu and others},
journal={arXiv preprint arXiv:2408.06072},
year={2024}
}
```
CogVLM2-Video:
```
@article{hong2024cogvlm2,
title={CogVLM2: Visual Language Models for Image and Video Understanding},
author={Hong, Wenyi and Wang, Weihan and Ding, Ming and Yu, Wenmeng and Lv, Qingsong and Wang, Yan and Cheng, Yean and Huang, Shiyu and Ji, Junhui and Xue, Zhao and others},
journal={arXiv preprint arXiv:2408.16500},
year={2024}
}
```
================================================
FILE: CogVideo/tools/caption/README_ja.md
================================================
# ビデオキャプション
通常、ほとんどのビデオデータには対応する説明文が付いていないため、ビデオデータをテキストの説明に変換して、テキストからビデオへのモデルに必要なトレーニングデータを提供する必要があります。
## 更新とニュース
- 🔥🔥 **ニュース**: ```2024/9/19```:CogVideoX
のトレーニングプロセスで、ビデオデータをテキストに変換するためのキャプションモデル [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption)
がオープンソース化されました。ぜひダウンロードしてご利用ください。
## CogVLM2-Captionによるビデオキャプション
🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-llama3-caption) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-llama3-caption/)
CogVLM2-Captionは、CogVideoXモデルのトレーニングデータを生成するために使用されるビデオキャプションモデルです。
### インストール
```shell
pip install -r requirements.txt
```
### 使用方法
```shell
python video_caption.py
```
例:
ユーザーは提供された[コード](https://github.com/THUDM/CogVLM2/tree/main/video_demo)を使用してモデルをロードするか、RESTful API を構成してビデオキャプションを生成できます。
## Citation
🌟 If you find our work helpful, please leave us a star and cite our paper.
CogVLM2-Caption:
```
@article{yang2024cogvideox,
title={CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer},
author={Yang, Zhuoyi and Teng, Jiayan and Zheng, Wendi and Ding, Ming and Huang, Shiyu and Xu, Jiazheng and Yang, Yuanming and Hong, Wenyi and Zhang, Xiaohan and Feng, Guanyu and others},
journal={arXiv preprint arXiv:2408.06072},
year={2024}
}
```
CogVLM2-Video:
```
@article{hong2024cogvlm2,
title={CogVLM2: Visual Language Models for Image and Video Understanding},
author={Hong, Wenyi and Wang, Weihan and Ding, Ming and Yu, Wenmeng and Lv, Qingsong and Wang, Yan and Cheng, Yean and Huang, Shiyu and Ji, Junhui and Xue, Zhao and others},
journal={arXiv preprint arXiv:2408.16500},
year={2024}
}
```
================================================
FILE: CogVideo/tools/caption/README_zh.md
================================================
# 视频Caption
通常,大多数视频数据不带有相应的描述性文本,因此需要将视频数据转换为文本描述,以提供必要的训练数据用于文本到视频模型。
## 项目更新
- 🔥🔥 **News**: ```2024/9/19```: CogVideoX 训练过程中用于将视频数据转换为文本描述的 Caption
模型 [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption)
已经开源。欢迎前往下载并使用。
## 通过 CogVLM2-Caption 模型生成视频Caption
🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-llama3-caption) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-llama3-caption/)
CogVLM2-Caption是用于生成CogVideoX模型训练数据的视频caption模型。
### 安装依赖
```shell
pip install -r requirements.txt
```
### 运行caption模型
```shell
python video_caption.py
```
示例:
## 通过 CogVLM2-Video 模型生成视频Caption
[Code](https://github.com/THUDM/CogVLM2/tree/main/video_demo) | 🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-video-llama3-chat) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-video-llama3-chat) | 📑 [Blog](https://cogvlm2-video.github.io/) | [💬 Online Demo](http://cogvlm2-online.cogviewai.cn:7868/)
CogVLM2-Video 是一个多功能的视频理解模型,具备基于时间戳的问题回答能力。用户可以输入诸如 `Describe this video in detail.` 的提示语给模型,以获得详细的视频Caption:
用户可以使用提供的[代码](https://github.com/THUDM/CogVLM2/tree/main/video_demo)加载模型或配置 RESTful API 来生成视频Caption。
## Citation
🌟 If you find our work helpful, please leave us a star and cite our paper.
CogVLM2-Caption:
```
@article{yang2024cogvideox,
title={CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer},
author={Yang, Zhuoyi and Teng, Jiayan and Zheng, Wendi and Ding, Ming and Huang, Shiyu and Xu, Jiazheng and Yang, Yuanming and Hong, Wenyi and Zhang, Xiaohan and Feng, Guanyu and others},
journal={arXiv preprint arXiv:2408.06072},
year={2024}
}
```
CogVLM2-Video:
```
@article{hong2024cogvlm2,
title={CogVLM2: Visual Language Models for Image and Video Understanding},
author={Hong, Wenyi and Wang, Weihan and Ding, Ming and Yu, Wenmeng and Lv, Qingsong and Wang, Yan and Cheng, Yean and Huang, Shiyu and Ji, Junhui and Xue, Zhao and others},
journal={arXiv preprint arXiv:2408.16500},
year={2024}
}
```
================================================
FILE: CogVideo/tools/caption/requirements.txt
================================================
decord>=0.6.0
#根据https://download.pytorch.org/whl/torch/,python版本为[3.8,3.11]
torch==2.1.0
torchvision== 0.16.0
pytorchvideo==0.1.5
xformers
transformers==4.42.4
#git+https://github.com/huggingface/transformers.git
huggingface-hub>=0.23.0
pillow
chainlit>=1.0
pydantic>=2.7.1
timm>=0.9.16
openai>=1.30.1
loguru>=0.7.2
pydantic>=2.7.1
einops
sse-starlette>=2.1.0
flask
gunicorn
gevent
requests
gradio
================================================
FILE: CogVideo/tools/caption/video_caption.py
================================================
import io
import argparse
import numpy as np
import torch
from decord import cpu, VideoReader, bridge
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_PATH = "THUDM/cogvlm2-llama3-caption"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[
0] >= 8 else torch.float16
parser = argparse.ArgumentParser(description="CogVLM2-Video CLI Demo")
parser.add_argument('--quant', type=int, choices=[4, 8], help='Enable 4-bit or 8-bit precision loading', default=0)
args = parser.parse_args([])
def load_video(video_data, strategy='chat'):
bridge.set_bridge('torch')
mp4_stream = video_data
num_frames = 24
decord_vr = VideoReader(io.BytesIO(mp4_stream), ctx=cpu(0))
frame_id_list = None
total_frames = len(decord_vr)
if strategy == 'base':
clip_end_sec = 60
clip_start_sec = 0
start_frame = int(clip_start_sec * decord_vr.get_avg_fps())
end_frame = min(total_frames,
int(clip_end_sec * decord_vr.get_avg_fps())) if clip_end_sec is not None else total_frames
frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int)
elif strategy == 'chat':
timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames))
timestamps = [i[0] for i in timestamps]
max_second = round(max(timestamps)) + 1
frame_id_list = []
for second in range(max_second):
closest_num = min(timestamps, key=lambda x: abs(x - second))
index = timestamps.index(closest_num)
frame_id_list.append(index)
if len(frame_id_list) >= num_frames:
break
video_data = decord_vr.get_batch(frame_id_list)
video_data = video_data.permute(3, 0, 1, 2)
return video_data
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=TORCH_TYPE,
trust_remote_code=True
).eval().to(DEVICE)
def predict(prompt, video_data, temperature):
strategy = 'chat'
video = load_video(video_data, strategy=strategy)
history = []
query = prompt
inputs = model.build_conversation_input_ids(
tokenizer=tokenizer,
query=query,
images=[video],
history=history,
template_version=strategy
)
inputs = {
'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
'images': [[inputs['images'][0].to('cuda').to(TORCH_TYPE)]],
}
gen_kwargs = {
"max_new_tokens": 2048,
"pad_token_id": 128002,
"top_k": 1,
"do_sample": False,
"top_p": 0.1,
"temperature": temperature,
}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
def test():
prompt = "Please describe this video in detail."
temperature = 0.1
video_data = open('test.mp4', 'rb').read()
response = predict(prompt, video_data, temperature)
print(response)
if __name__ == '__main__':
test()
================================================
FILE: CogVideo/tools/convert_weight_sat2hf.py
================================================
"""
This script demonstrates how to convert and generate video from a text prompt
using CogVideoX with 🤗Huggingface Diffusers Pipeline.
This script requires the `diffusers>=0.30.2` library to be installed.
Functions:
- reassign_query_key_value_inplace: Reassigns the query, key, and value weights in-place.
- reassign_query_key_layernorm_inplace: Reassigns layer normalization for query and key in-place.
- reassign_adaln_norm_inplace: Reassigns adaptive layer normalization in-place.
- remove_keys_inplace: Removes specified keys from the state_dict in-place.
- replace_up_keys_inplace: Replaces keys in the "up" block in-place.
- get_state_dict: Extracts the state_dict from a saved checkpoint.
- update_state_dict_inplace: Updates the state_dict with new key assignments in-place.
- convert_transformer: Converts a transformer checkpoint to the CogVideoX format.
- convert_vae: Converts a VAE checkpoint to the CogVideoX format.
- get_args: Parses command-line arguments for the script.
- generate_video: Generates a video from a text prompt using the CogVideoX pipeline.
"""
import argparse
from typing import Any, Dict
import torch
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import (
AutoencoderKLCogVideoX,
CogVideoXDDIMScheduler,
CogVideoXImageToVideoPipeline,
CogVideoXPipeline,
CogVideoXTransformer3DModel,
)
def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
to_q_key = key.replace("query_key_value", "to_q")
to_k_key = key.replace("query_key_value", "to_k")
to_v_key = key.replace("query_key_value", "to_v")
to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0)
state_dict[to_q_key] = to_q
state_dict[to_k_key] = to_k
state_dict[to_v_key] = to_v
state_dict.pop(key)
def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
layer_id, weight_or_bias = key.split(".")[-2:]
if "query" in key:
new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}"
elif "key" in key:
new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}"
state_dict[new_key] = state_dict.pop(key)
def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
layer_id, _, weight_or_bias = key.split(".")[-3:]
weights_or_biases = state_dict[key].chunk(12, dim=0)
norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9])
norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12])
norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}"
state_dict[norm1_key] = norm1_weights_or_biases
norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}"
state_dict[norm2_key] = norm2_weights_or_biases
state_dict.pop(key)
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
state_dict.pop(key)
def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
key_split = key.split(".")
layer_index = int(key_split[2])
replace_layer_index = 4 - 1 - layer_index
key_split[1] = "up_blocks"
key_split[2] = str(replace_layer_index)
new_key = ".".join(key_split)
state_dict[new_key] = state_dict.pop(key)
TRANSFORMER_KEYS_RENAME_DICT = {
"transformer.final_layernorm": "norm_final",
"transformer": "transformer_blocks",
"attention": "attn1",
"mlp": "ff.net",
"dense_h_to_4h": "0.proj",
"dense_4h_to_h": "2",
".layers": "",
"dense": "to_out.0",
"input_layernorm": "norm1.norm",
"post_attn1_layernorm": "norm2.norm",
"time_embed.0": "time_embedding.linear_1",
"time_embed.2": "time_embedding.linear_2",
"mixins.patch_embed": "patch_embed",
"mixins.final_layer.norm_final": "norm_out.norm",
"mixins.final_layer.linear": "proj_out",
"mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
"mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding", # Specific to CogVideoX-5b-I2V
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"query_key_value": reassign_query_key_value_inplace,
"query_layernorm_list": reassign_query_key_layernorm_inplace,
"key_layernorm_list": reassign_query_key_layernorm_inplace,
"adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
"embed_tokens": remove_keys_inplace,
"freqs_sin": remove_keys_inplace,
"freqs_cos": remove_keys_inplace,
"position_embedding": remove_keys_inplace,
}
VAE_KEYS_RENAME_DICT = {
"block.": "resnets.",
"down.": "down_blocks.",
"downsample": "downsamplers.0",
"upsample": "upsamplers.0",
"nin_shortcut": "conv_shortcut",
"encoder.mid.block_1": "encoder.mid_block.resnets.0",
"encoder.mid.block_2": "encoder.mid_block.resnets.1",
"decoder.mid.block_1": "decoder.mid_block.resnets.0",
"decoder.mid.block_2": "decoder.mid_block.resnets.1",
}
VAE_SPECIAL_KEYS_REMAP = {
"loss": remove_keys_inplace,
"up.": replace_up_keys_inplace,
}
TOKENIZER_MAX_LENGTH = 226
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)
def convert_transformer(
ckpt_path: str,
num_layers: int,
num_attention_heads: int,
use_rotary_positional_embeddings: bool,
i2v: bool,
dtype: torch.dtype,
):
PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
transformer = CogVideoXTransformer3DModel(
in_channels=32 if i2v else 16,
num_layers=num_layers,
num_attention_heads=num_attention_heads,
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
use_learned_positional_embeddings=i2v,
).to(dtype=dtype)
for key in list(original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
transformer.load_state_dict(original_state_dict, strict=True)
return transformer
def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True)
return vae
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint")
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
parser.add_argument(
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
)
parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
)
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
# For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
parser.add_argument(
"--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
)
# For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
transformer = None
vae = None
if args.fp16 and args.bf16:
raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
if args.transformer_ckpt_path is not None:
transformer = convert_transformer(
args.transformer_ckpt_path,
args.num_layers,
args.num_attention_heads,
args.use_rotary_positional_embeddings,
args.i2v,
dtype,
)
if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
# Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()
scheduler = CogVideoXDDIMScheduler.from_config(
{
"snr_shift_scale": args.snr_shift_scale,
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": False,
"num_train_timesteps": 1000,
"prediction_type": "v_prediction",
"rescale_betas_zero_snr": True,
"set_alpha_to_one": True,
"timestep_spacing": "trailing",
}
)
if args.i2v:
pipeline_cls = CogVideoXImageToVideoPipeline
else:
pipeline_cls = CogVideoXPipeline
pipe = pipeline_cls(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
if args.fp16:
pipe = pipe.to(dtype=torch.float16)
if args.bf16:
pipe = pipe.to(dtype=torch.bfloat16)
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
# is either fp16/bf16 here).
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)
================================================
FILE: CogVideo/tools/export_sat_lora_weight.py
================================================
from typing import Any, Dict
import torch
import argparse
from diffusers.loaders.lora_base import LoraBaseMixin
from diffusers.models.modeling_utils import load_state_dict
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict
LORA_KEYS_RENAME = {
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight',
'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight',
'attention.query_key_value.matrix_B.0': 'attn1.to_q.lora_B.weight',
'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight',
'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight',
'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight',
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
}
PREFIX_KEY = "model.diffusion_model."
SAT_UNIT_KEY = "layers"
LORA_PREFIX_KEY = "transformer_blocks"
def export_lora_weight(ckpt_path,lora_save_directory):
merge_original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
lora_state_dict = {}
for key in list(merge_original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :]
for special_key, lora_keys in LORA_KEYS_RENAME.items():
if new_key.endswith(special_key):
new_key = new_key.replace(special_key, lora_keys)
new_key = new_key.replace(SAT_UNIT_KEY, LORA_PREFIX_KEY)
lora_state_dict[new_key] = merge_original_state_dict[key]
# final length should be 240
if len(lora_state_dict) != 240:
raise ValueError("lora_state_dict length is not 240")
lora_state_dict.keys()
LoraBaseMixin.write_lora_layers(
state_dict=lora_state_dict,
save_directory=lora_save_directory,
is_main_process=True,
weight_name=None,
save_function=None,
safe_serialization=True
)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--sat_pt_path", type=str, required=True, help="Path to original sat transformer checkpoint"
)
parser.add_argument("--lora_save_directory", type=str, required=True, help="Path where converted lora should be saved")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
export_lora_weight(args.sat_pt_path, args.lora_save_directory)
================================================
FILE: CogVideo/tools/llm_flux_cogvideox/generate.sh
================================================
#!/bin/bash
NUM_VIDEOS=10
INFERENCE_STEPS=50
GUIDANCE_SCALE=7.0
OUTPUT_DIR_PREFIX="outputs/gpu_"
LOG_DIR_PREFIX="logs/gpu_"
VIDEO_MODEL_PATH="/share/official_pretrains/hf_home/CogVideoX-5b-I2V"
LLM_MODEL_PATH="/share/home/zyx/Models/Meta-Llama-3.1-8B-Instruct"
IMAGE_MODEL_PATH = "share/home/zyx/Models/FLUX.1-dev"
#VIDEO_MODEL_PATH="THUDM/CogVideoX-5B-I2V"
#LLM_MODEL_PATH="THUDM/glm-4-9b-chat"
#IMAGE_MODEL_PATH = "black-forest-labs/FLUX.1-dev"
CUDA_DEVICES=${CUDA_VISIBLE_DEVICES:-"0"}
IFS=',' read -r -a GPU_ARRAY <<< "$CUDA_DEVICES"
for i in "${!GPU_ARRAY[@]}"
do
GPU=${GPU_ARRAY[$i]}
echo "Starting task on GPU $GPU..."
CUDA_VISIBLE_DEVICES=$GPU nohup python3 llm_flux_cogvideox.py \
--caption_generator_model_id $LLM_MODEL_PATH \
--image_generator_model_id $IMAGE_MODEL_PATH \
--model_path $VIDEO_MODEL_PATH \
--num_videos $NUM_VIDEOS \
--image_generator_num_inference_steps $INFERENCE_STEPS \
--guidance_scale $GUIDANCE_SCALE \
--use_dynamic_cfg \
--output_dir ${OUTPUT_DIR_PREFIX}${GPU} \
> ${LOG_DIR_PREFIX}${GPU}.log 2>&1 &
done
================================================
FILE: CogVideo/tools/llm_flux_cogvideox/gradio_page.py
================================================
import os
import gradio as gr
import gc
import random
import torch
import numpy as np
from PIL import Image
import transformers
from diffusers import CogVideoXImageToVideoPipeline, CogVideoXDPMScheduler, DiffusionPipeline
from diffusers.utils import export_to_video
from transformers import AutoTokenizer
from datetime import datetime, timedelta
import threading
import time
import moviepy.editor as mp
torch.set_float32_matmul_precision("high")
# Set default values
caption_generator_model_id = "/share/home/zyx/Models/Meta-Llama-3.1-8B-Instruct"
image_generator_model_id = "/share/home/zyx/Models/FLUX.1-dev"
video_generator_model_id = "/share/official_pretrains/hf_home/CogVideoX-5b-I2V"
seed = 1337
os.makedirs("./output", exist_ok=True)
os.makedirs("./gradio_tmp", exist_ok=True)
tokenizer = AutoTokenizer.from_pretrained(caption_generator_model_id, trust_remote_code=True)
caption_generator = transformers.pipeline(
"text-generation",
model=caption_generator_model_id,
device_map="balanced",
model_kwargs={
"local_files_only": True,
"torch_dtype": torch.bfloat16,
},
trust_remote_code=True,
tokenizer=tokenizer
)
image_generator = DiffusionPipeline.from_pretrained(
image_generator_model_id,
torch_dtype=torch.bfloat16,
device_map="balanced"
)
# image_generator.to("cuda")
video_generator = CogVideoXImageToVideoPipeline.from_pretrained(
video_generator_model_id,
torch_dtype=torch.bfloat16,
device_map="balanced"
)
video_generator.vae.enable_slicing()
video_generator.vae.enable_tiling()
video_generator.scheduler = CogVideoXDPMScheduler.from_config(
video_generator.scheduler.config, timestep_spacing="trailing"
)
# Define prompts
SYSTEM_PROMPT = """
You are part of a team of people that create videos using generative models. You use a video-generation model that can generate a video about anything you describe.
For example, if you respond with "A beautiful morning in the woods with the sun peaking through the trees", the video generation model will create a video of exactly as described. Your task is to summarize the descriptions of videos provided by users and create detailed prompts to feed into the generative model.
There are a few rules to follow:
- You will only ever output a single video description per request.
- If the user mentions to summarize the prompt in [X] words, make sure not to exceed the limit.
Your responses should just be the video generation prompt. Here are examples:
- "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting."
- "A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart of the city, holding a can of spray paint, spray-painting a colorful bird on a mottled wall."
""".strip()
USER_PROMPT = """
Could you generate a prompt for a video generation model? Please limit the prompt to [{0}] words.
""".strip()
def generate_caption(prompt):
num_words = random.choice([25, 50, 75, 100])
user_prompt = USER_PROMPT.format(num_words)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt + "\n" + user_prompt},
]
response = caption_generator(
messages,
max_new_tokens=226,
return_full_text=False
)
caption = response[0]["generated_text"]
if caption.startswith("\"") and caption.endswith("\""):
caption = caption[1:-1]
return caption
def generate_image(caption, progress=gr.Progress(track_tqdm=True)):
image = image_generator(
prompt=caption,
height=480,
width=720,
num_inference_steps=30,
guidance_scale=3.5,
).images[0]
return image, image # One for output One for State
def generate_video(
caption,
image,
progress=gr.Progress(track_tqdm=True)
):
generator = torch.Generator().manual_seed(seed)
video_frames = video_generator(
image=image,
prompt=caption,
height=480,
width=720,
num_frames=49,
num_inference_steps=50,
guidance_scale=6,
use_dynamic_cfg=True,
generator=generator,
).frames[0]
video_path = save_video(video_frames)
gif_path = convert_to_gif(video_path)
return video_path, gif_path
def save_video(tensor):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
video_path = f"./output/{timestamp}.mp4"
os.makedirs(os.path.dirname(video_path), exist_ok=True)
export_to_video(tensor, video_path, fps=8)
return video_path
def convert_to_gif(video_path):
clip = mp.VideoFileClip(video_path)
clip = clip.set_fps(8)
clip = clip.resize(height=240)
gif_path = video_path.replace(".mp4", ".gif")
clip.write_gif(gif_path, fps=8)
return gif_path
def delete_old_files():
while True:
now = datetime.now()
cutoff = now - timedelta(minutes=10)
directories = ["./output", "./gradio_tmp"]
for directory in directories:
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
if os.path.isfile(file_path):
file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
if file_mtime < cutoff:
os.remove(file_path)
time.sleep(600)
threading.Thread(target=delete_old_files, daemon=True).start()
with gr.Blocks() as demo:
gr.Markdown("""
LLM + FLUX + CogVideoX-I2V Space 🤗
""")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=5)
generate_caption_button = gr.Button("Generate Caption")
caption = gr.Textbox(label="Caption", placeholder="Caption will appear here", lines=5)
generate_image_button = gr.Button("Generate Image")
image_output = gr.Image(label="Generated Image")
state_image = gr.State()
generate_caption_button.click(fn=generate_caption, inputs=prompt, outputs=caption)
generate_image_button.click(fn=generate_image, inputs=caption, outputs=[image_output, state_image])
with gr.Column():
video_output = gr.Video(label="Generated Video", width=720, height=480)
download_video_button = gr.File(label="📥 Download Video", visible=False)
download_gif_button = gr.File(label="📥 Download GIF", visible=False)
generate_video_button = gr.Button("Generate Video from Image")
generate_video_button.click(fn=generate_video, inputs=[caption, state_image],
outputs=[video_output, download_gif_button])
if __name__ == "__main__":
demo.launch()
================================================
FILE: CogVideo/tools/llm_flux_cogvideox/llm_flux_cogvideox.py
================================================
"""
The original experimental code for this project can be found at:
https://gist.github.com/a-r-r-o-w/d070cce059ab4ceab3a9f289ff83c69c
By using this code, description prompts will be generated through a local large language model, and images will be
generated using the black-forest-labs/FLUX.1-dev model, followed by video generation via CogVideoX.
The entire process utilizes open-source solutions, without the need for any API keys.
You can use the generate.sh file in the same folder to automate running this code
for batch generation of videos and images.
bash generate.sh
"""
import argparse
import gc
import json
import os
import pathlib
import random
from typing import Any, Dict
from transformers import AutoTokenizer
os.environ["TORCH_LOGS"] = "+dynamo,recompiles,graph_breaks"
os.environ["TORCHDYNAMO_VERBOSE"] = "1"
import numpy as np
import torch
import transformers
from diffusers import CogVideoXImageToVideoPipeline, CogVideoXDPMScheduler, DiffusionPipeline
from diffusers.utils.logging import get_logger
from diffusers.utils import export_to_video
torch.set_float32_matmul_precision("high")
logger = get_logger(__name__)
SYSTEM_PROMPT = """
You are part of a team of people that create videos using generative models. You use a video-generation model that can generate a video about anything you describe.
For example, if you respond with "A beautiful morning in the woods with the sun peaking through the trees", the video generation model will create a video of exactly as described. You task is to summarize the descriptions of videos provided to by users, and create details prompts to feed into the generative model.
There are a few rules to follow:
- You will only ever output a single video description per request.
- If the user mentions to summarize the prompt in [X] words, make sure to not exceed the limit.
You responses should just be the video generation prompt. Here are examples:
- “A lone figure stands on a city rooftop at night, gazing up at the full moon. The moon glows brightly, casting a gentle light over the quiet cityscape. Below, the windows of countless homes shine with warm lights, creating a contrast between the bustling life below and the peaceful solitude above. The scene captures the essence of the Mid-Autumn Festival, where despite the distance, the figure feels connected to loved ones through the shared beauty of the moonlit sky.”
- "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting."
- "A street artist, clad in a worn-out denim jacket and a colorful banana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall"
""".strip()
USER_PROMPT = """
Could you generate a prompt for a video generation model?
Please limit the prompt to [{0}] words.
""".strip()
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num_videos",
type=int,
default=5,
help="Number of unique videos you would like to generate."
)
parser.add_argument(
"--model_path",
type=str,
default="THUDM/CogVideoX-5B",
help="The path of Image2Video CogVideoX-5B",
)
parser.add_argument(
"--caption_generator_model_id",
type=str,
default="THUDM/glm-4-9b-chat",
help="Caption generation model. default GLM-4-9B",
)
parser.add_argument(
"--caption_generator_cache_dir",
type=str,
default=None,
help="Cache directory for caption generation model."
)
parser.add_argument(
"--image_generator_model_id",
type=str,
default="black-forest-labs/FLUX.1-dev",
help="Image generation model."
)
parser.add_argument(
"--image_generator_cache_dir",
type=str,
default=None,
help="Cache directory for image generation model."
)
parser.add_argument(
"--image_generator_num_inference_steps",
type=int,
default=50,
help="Caption generation model."
)
parser.add_argument(
"--guidance_scale",
type=float,
default=7,
help="Guidance scale to be use for generation."
)
parser.add_argument(
"--use_dynamic_cfg",
action="store_true",
help="Whether or not to use cosine dynamic guidance for generation [Recommended].",
)
parser.add_argument(
"--output_dir",
type=str,
default="outputs/",
help="Location where generated images and videos should be stored.",
)
parser.add_argument(
"--compile",
action="store_true",
help="Whether or not to compile the transformer of image and video generators."
)
parser.add_argument(
"--enable_vae_tiling",
action="store_true",
help="Whether or not to use VAE tiling when encoding/decoding."
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Seed for reproducibility."
)
return parser.parse_args()
def reset_memory():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
@torch.no_grad()
def main(args: Dict[str, Any]) -> None:
output_dir = pathlib.Path(args.output_dir)
os.makedirs(output_dir.as_posix(), exist_ok=True)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
reset_memory()
tokenizer = AutoTokenizer.from_pretrained(args.caption_generator_model_id, trust_remote_code=True)
caption_generator = transformers.pipeline(
"text-generation",
model=args.caption_generator_model_id,
device_map="auto",
model_kwargs={
"local_files_only": True,
"cache_dir": args.caption_generator_cache_dir,
"torch_dtype": torch.bfloat16,
},
trust_remote_code=True,
tokenizer=tokenizer
)
captions = []
for i in range(args.num_videos):
num_words = random.choice([50, 75, 100])
user_prompt = USER_PROMPT.format(num_words)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
]
outputs = caption_generator(messages, max_new_tokens=226)
caption = outputs[0]["generated_text"][-1]["content"]
if caption.startswith("\"") and caption.endswith("\""):
caption = caption[1:-1]
captions.append(caption)
logger.info(f"Generated caption: {caption}")
with open(output_dir / "captions.json", "w") as file:
json.dump(captions, file)
del caption_generator
reset_memory()
image_generator = DiffusionPipeline.from_pretrained(
args.image_generator_model_id,
cache_dir=args.image_generator_cache_dir,
torch_dtype=torch.bfloat16
)
image_generator.to("cuda")
if args.compile:
image_generator.transformer = torch.compile(image_generator.transformer, mode="max-autotune", fullgraph=True)
if args.enable_vae_tiling:
image_generator.vae.enable_tiling()
images = []
for index, caption in enumerate(captions):
image = image_generator(
prompt=caption,
height=480,
width=720,
num_inference_steps=args.image_generator_num_inference_steps,
guidance_scale=3.5,
).images[0]
filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_")
image.save(output_dir / f"{index}_{filename}.png")
images.append(image)
del image_generator
reset_memory()
video_generator = CogVideoXImageToVideoPipeline.from_pretrained(
args.model_path, torch_dtype=torch.bfloat16).to("cuda")
video_generator.scheduler = CogVideoXDPMScheduler.from_config(
video_generator.scheduler.config,
timestep_spacing="trailing")
if args.compile:
video_generator.transformer = torch.compile(video_generator.transformer, mode="max-autotune", fullgraph=True)
if args.enable_vae_tiling:
video_generator.vae.enable_tiling()
generator = torch.Generator().manual_seed(args.seed)
for index, (caption, image) in enumerate(zip(captions, images)):
video = video_generator(
image=image,
prompt=caption,
height=480,
width=720,
num_frames=49,
num_inference_steps=50,
guidance_scale=args.guidance_scale,
use_dynamic_cfg=args.use_dynamic_cfg,
generator=generator,
).frames[0]
filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_")
export_to_video(video, output_dir / f"{index}_{filename}.mp4", fps=8)
if __name__ == "__main__":
args = get_args()
main(args)
================================================
FILE: CogVideo/tools/load_cogvideox_lora.py
================================================
# Copyright 2024 The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import random
import time
from diffusers.utils import export_to_video
from diffusers.image_processor import VaeImageProcessor
from datetime import datetime, timedelta
from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler, CogVideoXDPMScheduler
import os
import torch
import argparse
device = "cuda" if torch.cuda.is_available() else "cpu"
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--lora_weights_path",
type=str,
default=None,
required=True,
help="Path to lora weights.",
)
parser.add_argument(
"--lora_r",
type=int,
default=128,
help="""LoRA weights have a rank parameter, with the default for 2B trans set at 128 and 5B trans set at 256.
This part is used to calculate the value for lora_scale, which is by default divided by the alpha value,
used for stable learning and to prevent underflow. In the SAT training framework,
alpha is set to 1 by default. The higher the rank, the better the expressive capability,
but it requires more memory and training time. Increasing this number blindly isn't always better.
The formula for lora_scale is: lora_r / alpha.
""",
)
parser.add_argument(
"--lora_alpha",
type=int,
default=1,
help="""LoRA weights have a rank parameter, with the default for 2B trans set at 128 and 5B trans set at 256.
This part is used to calculate the value for lora_scale, which is by default divided by the alpha value,
used for stable learning and to prevent underflow. In the SAT training framework,
alpha is set to 1 by default. The higher the rank, the better the expressive capability,
but it requires more memory and training time. Increasing this number blindly isn't always better.
The formula for lora_scale is: lora_r / alpha.
""",
)
parser.add_argument(
"--prompt",
type=str,
help="prompt",
)
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="The output directory where the model predictions and checkpoints will be written.",
)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
pipe = CogVideoXPipeline.from_pretrained(args.pretrained_model_name_or_path, torch_dtype=torch.bfloat16).to(device)
pipe.load_lora_weights(args.lora_weights_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="cogvideox-lora")
# pipe.fuse_lora(lora_scale=args.lora_alpha/args.lora_r, ['transformer'])
lora_scaling=args.lora_alpha/args.lora_r
pipe.set_adapters(["cogvideox-lora"], [lora_scaling])
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
os.makedirs(args.output_dir, exist_ok=True)
latents = pipe(
prompt=args.prompt,
num_videos_per_prompt=1,
num_inference_steps=50,
num_frames=49,
use_dynamic_cfg=True,
output_type="pt",
guidance_scale=3.0,
generator=torch.Generator(device="cpu").manual_seed(42),
).frames
batch_size = latents.shape[0]
batch_video_frames = []
for batch_idx in range(batch_size):
pt_image = latents[batch_idx]
pt_image = torch.stack([pt_image[i] for i in range(pt_image.shape[0])])
image_np = VaeImageProcessor.pt_to_numpy(pt_image)
image_pil = VaeImageProcessor.numpy_to_pil(image_np)
batch_video_frames.append(image_pil)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
video_path = f"{args.output_dir}/{timestamp}.mp4"
os.makedirs(os.path.dirname(video_path), exist_ok=True)
tensor = batch_video_frames[0]
fps=math.ceil((len(batch_video_frames[0]) - 1) / 6)
export_to_video(tensor, video_path, fps=fps)
================================================
FILE: CogVideo/tools/parallel_inference/parallel_inference_xdit.py
================================================
"""
This is a parallel inference script for CogVideo. The original script
can be found from the xDiT project at
https://github.com/xdit-project/xDiT/blob/main/examples/cogvideox_example.py
By using this code, the inference process is parallelized on multiple GPUs,
and thus speeded up.
Usage:
1. pip install xfuser
2. mkdir results
3. run the following command to generate video
torchrun --nproc_per_node=4 parallel_inference_xdit.py \
--model --ulysses_degree 1 --ring_degree 2 \
--use_cfg_parallel --height 480 --width 720 --num_frames 9 \
--prompt 'A small dog.'
You can also use the run.sh file in the same folder to automate running this
code for batch generation of videos, by running:
sh ./run.sh
"""
import time
import torch
import torch.distributed
from diffusers import AutoencoderKLTemporalDecoder
from xfuser import xFuserCogVideoXPipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
get_data_parallel_rank,
get_data_parallel_world_size,
get_runtime_state,
is_dp_last_group,
)
from diffusers.utils import export_to_video
def main():
parser = FlexibleArgumentParser(description="xFuser Arguments")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)
# Check if ulysses_degree is valid
num_heads = 30
if engine_args.ulysses_degree > 0 and num_heads % engine_args.ulysses_degree != 0:
raise ValueError(
f"ulysses_degree ({engine_args.ulysses_degree}) must be a divisor of the number of heads ({num_heads})"
)
engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank
pipe = xFuserCogVideoXPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.bfloat16,
)
if args.enable_sequential_cpu_offload:
pipe.enable_model_cpu_offload(gpu_id=local_rank)
pipe.vae.enable_tiling()
else:
device = torch.device(f"cuda:{local_rank}")
pipe = pipe.to(device)
torch.cuda.reset_peak_memory_stats()
start_time = time.time()
output = pipe(
height=input_config.height,
width=input_config.width,
num_frames=input_config.num_frames,
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
guidance_scale=6,
).frames[0]
end_time = time.time()
elapsed_time = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")
parallel_info = (
f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
f"tp{engine_args.tensor_parallel_degree}_"
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if is_dp_last_group():
world_size = get_data_parallel_world_size()
resolution = f"{input_config.width}x{input_config.height}"
output_filename = f"results/cogvideox_{parallel_info}_{resolution}.mp4"
export_to_video(output, output_filename, fps=8)
print(f"output saved to {output_filename}")
if get_world_group().rank == get_world_group().world_size - 1:
print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB")
get_runtime_state().destory_distributed_env()
if __name__ == "__main__":
main()
================================================
FILE: CogVideo/tools/parallel_inference/run.sh
================================================
set -x
export PYTHONPATH=$PWD:$PYTHONPATH
# Select the model type
# The model is downloaded to a specified location on disk,
# or you can simply use the model's ID on Hugging Face,
# which will then be downloaded to the default cache path on Hugging Face.
export MODEL_TYPE="CogVideoX"
# Configuration for different model types
# script, model_id, inference_step
declare -A MODEL_CONFIGS=(
["CogVideoX"]="parallel_inference_xdit.py /cfs/dit/CogVideoX-2b 20"
)
if [[ -v MODEL_CONFIGS[$MODEL_TYPE] ]]; then
IFS=' ' read -r SCRIPT MODEL_ID INFERENCE_STEP <<< "${MODEL_CONFIGS[$MODEL_TYPE]}"
export SCRIPT MODEL_ID INFERENCE_STEP
else
echo "Invalid MODEL_TYPE: $MODEL_TYPE"
exit 1
fi
mkdir -p ./results
# task args
if [ "$MODEL_TYPE" = "CogVideoX" ]; then
TASK_ARGS="--height 480 --width 720 --num_frames 9"
fi
# CogVideoX asserts sp_degree == ulysses_degree*ring_degree <= 2. Also, do not set the pipefusion degree.
if [ "$MODEL_TYPE" = "CogVideoX" ]; then
N_GPUS=4
PARALLEL_ARGS="--ulysses_degree 2 --ring_degree 1"
CFG_ARGS="--use_cfg_parallel"
fi
torchrun --nproc_per_node=$N_GPUS ./$SCRIPT \
--model $MODEL_ID \
$PARALLEL_ARGS \
$TASK_ARGS \
$PIPEFUSION_ARGS \
$OUTPUT_ARGS \
--num_inference_steps $INFERENCE_STEP \
--warmup_steps 0 \
--prompt "A small dog." \
$CFG_ARGS \
$PARALLLEL_VAE \
$COMPILE_FLAG
================================================
FILE: CogVideo/tools/replicate/cog.yaml
================================================
# Configuration for Cog ⚙️
# Reference: https://cog.run/yaml
build:
# set to true if your model requires a GPU
gpu: true
# a list of ubuntu apt packages to install
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
# python version in the form '3.11' or '3.11.4'
python_version: "3.11"
# a list of packages in the format ==
python_packages:
- diffusers>=0.30.3
- accelerate>=0.34.2
- transformers>=4.44.2
- numpy==1.26.0
- torch>=2.4.0
- torchvision>=0.19.0
- sentencepiece>=0.2.0
- SwissArmyTransformer>=0.4.12
- imageio>=2.35.1
- imageio-ffmpeg>=0.5.1
- openai>=1.45.0
- moviepy>=1.0.3
- pillow==9.5.0
- pydantic==1.10.7
run:
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
# predict.py defines how predictions are run on your model
predict: "predict_t2v.py:Predictor"
# predict: "predict_i2v.py:Predictor"
================================================
FILE: CogVideo/tools/replicate/predict_i2v.py
================================================
# Prediction interface for Cog ⚙️
# https://cog.run/python
import os
import subprocess
import time
import torch
from diffusers import CogVideoXImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
from cog import BasePredictor, Input, Path
MODEL_CACHE = "model_cache_i2v"
MODEL_URL = (
f"https://weights.replicate.delivery/default/THUDM/CogVideo/{MODEL_CACHE}.tar"
)
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_HOME"] = MODEL_CACHE
os.environ["TORCH_HOME"] = MODEL_CACHE
os.environ["HF_DATASETS_CACHE"] = MODEL_CACHE
os.environ["TRANSFORMERS_CACHE"] = MODEL_CACHE
os.environ["HUGGINGFACE_HUB_CACHE"] = MODEL_CACHE
def download_weights(url, dest):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
print("downloading took: ", time.time() - start)
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
if not os.path.exists(MODEL_CACHE):
download_weights(MODEL_URL, MODEL_CACHE)
# model_id: THUDM/CogVideoX-5b-I2V
self.pipe = CogVideoXImageToVideoPipeline.from_pretrained(
MODEL_CACHE, torch_dtype=torch.bfloat16
).to("cuda")
self.pipe.enable_model_cpu_offload()
self.pipe.vae.enable_tiling()
def predict(
self,
prompt: str = Input(
description="Input prompt", default="Starry sky slowly rotating."
),
image: Path = Input(description="Input image"),
num_inference_steps: int = Input(
description="Number of denoising steps", ge=1, le=500, default=50
),
guidance_scale: float = Input(
description="Scale for classifier-free guidance", ge=1, le=20, default=6
),
num_frames: int = Input(
description="Number of frames for the output video", default=49
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
) -> Path:
"""Run a single prediction on the model"""
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
img = load_image(image=str(image))
video = self.pipe(
prompt=prompt,
image=img,
num_videos_per_prompt=1,
num_inference_steps=num_inference_steps,
num_frames=num_frames,
guidance_scale=guidance_scale,
generator=torch.Generator(device="cuda").manual_seed(seed),
).frames[0]
out_path = "/tmp/out.mp4"
export_to_video(video, out_path, fps=8)
return Path(out_path)
================================================
FILE: CogVideo/tools/replicate/predict_t2v.py
================================================
# Prediction interface for Cog ⚙️
# https://cog.run/python
import os
import subprocess
import time
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
from cog import BasePredictor, Input, Path
MODEL_CACHE = "model_cache"
MODEL_URL = (
f"https://weights.replicate.delivery/default/THUDM/CogVideo/{MODEL_CACHE}.tar"
)
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_HOME"] = MODEL_CACHE
os.environ["TORCH_HOME"] = MODEL_CACHE
os.environ["HF_DATASETS_CACHE"] = MODEL_CACHE
os.environ["TRANSFORMERS_CACHE"] = MODEL_CACHE
os.environ["HUGGINGFACE_HUB_CACHE"] = MODEL_CACHE
def download_weights(url, dest):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
print("downloading took: ", time.time() - start)
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
if not os.path.exists(MODEL_CACHE):
download_weights(MODEL_URL, MODEL_CACHE)
# model_id: THUDM/CogVideoX-5b
self.pipe = CogVideoXPipeline.from_pretrained(
MODEL_CACHE,
torch_dtype=torch.bfloat16,
).to("cuda")
self.pipe.enable_model_cpu_offload()
self.pipe.vae.enable_tiling()
def predict(
self,
prompt: str = Input(
description="Input prompt",
default="A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance.",
),
num_inference_steps: int = Input(
description="Number of denoising steps", ge=1, le=500, default=50
),
guidance_scale: float = Input(
description="Scale for classifier-free guidance", ge=1, le=20, default=6
),
num_frames: int = Input(
description="Number of frames for the output video", default=49
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
) -> Path:
"""Run a single prediction on the model"""
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
video = self.pipe(
prompt=prompt,
num_videos_per_prompt=1,
num_inference_steps=num_inference_steps,
num_frames=num_frames,
guidance_scale=guidance_scale,
generator=torch.Generator(device="cuda").manual_seed(seed),
).frames[0]
out_path = "/tmp/out.mp4"
export_to_video(video, out_path, fps=8)
return Path(out_path)
================================================
FILE: CogVideo/tools/venhancer/README.md
================================================
# Enhance CogVideoX Generated Videos with VEnhancer
This tutorial will guide you through using the VEnhancer tool to enhance videos generated by CogVideoX, including
achieving higher frame rates and higher resolutions.
## Model Introduction
VEnhancer implements spatial super-resolution, temporal super-resolution (frame interpolation), and video refinement in
a unified framework. It can flexibly adapt to different upsampling factors (e.g., 1x~8x) for spatial or temporal
super-resolution. Additionally, it provides flexible control to modify the refinement strength, enabling it to handle
diverse video artifacts.
VEnhancer follows the design of ControlNet, copying the architecture and weights of the multi-frame encoder and middle
block from a pre-trained video diffusion model to build a trainable conditional network. This video ControlNet accepts
low-resolution keyframes and noisy full-frame latents as inputs. In addition to the time step t and prompt, our proposed
video-aware conditioning also includes noise augmentation level σ and downscaling factor s as additional network
conditioning inputs.
## Hardware Requirements
+ Operating System: Linux (requires xformers dependency)
+ Hardware: NVIDIA GPU with at least 60GB of VRAM per card. Machines such as H100, A100 are recommended.
## Quick Start
1. Clone the repository and install dependencies as per the official instructions:
```shell
git clone https://github.com/Vchitect/VEnhancer.git
cd VEnhancer
## Torch and other dependencies can use those from CogVideoX. If you need to create a new environment, use the following commands:
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2
## Install required dependencies
pip install -r requirements.txt
```
Where:
- `input_path` is the path to the input video
- `prompt` is the description of the video content. The prompt used by this tool should be shorter, not exceeding 77
words. You may need to simplify the prompt used for generating the CogVideoX video.
- `target_fps` is the target frame rate for the video. Typically, 16 fps is already smooth, with 24 fps as the default
value.
- `up_scale` is recommend to be set to 2,3,4. The target resolution is limited to be around 2k and below.
- `noise_aug` value depends on the input video quality. Lower quality needs higher noise levels, which corresponds to
stronger refinement. 250~300 is for very low-quality videos. good videos: <= 200.
- `steps` if you want fewer steps, please change solver_mode to "normal" first, then decline the number of steps. "
fast" solver_mode has fixed steps (15).
The code will automatically download the required models from Hugging Face during execution.
Typical runtime logs are as follows:
```shell
/share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
@torch.library.impl_abstract("xformers_flash::flash_fwd")
/share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
@torch.library.impl_abstract("xformers_flash::flash_bwd")
2024-08-20 13:25:17,553 - video_to_video - INFO - checkpoint_path: ./ckpts/venhancer_paper.pt
/share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
2024-08-20 13:25:37,486 - video_to_video - INFO - Build encoder with FrozenOpenCLIPEmbedder
/share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:35: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
load_dict = torch.load(cfg.model_path, map_location='cpu')
2024-08-20 13:25:55,391 - video_to_video - INFO - Load model path ./ckpts/venhancer_paper.pt, with local status
2024-08-20 13:25:55,392 - video_to_video - INFO - Build diffusion with GaussianDiffusion
2024-08-20 13:26:16,092 - video_to_video - INFO - input video path: inputs/000000.mp4
2024-08-20 13:26:16,093 - video_to_video - INFO - text: Wide-angle aerial shot at dawn,soft morning light casting long shadows,an elderly man walking his dog through a quiet,foggy park,trees and benches in the background,peaceful and serene atmosphere
2024-08-20 13:26:16,156 - video_to_video - INFO - input frames length: 49
2024-08-20 13:26:16,156 - video_to_video - INFO - input fps: 8.0
2024-08-20 13:26:16,156 - video_to_video - INFO - target_fps: 24.0
2024-08-20 13:26:16,311 - video_to_video - INFO - input resolution: (480, 720)
2024-08-20 13:26:16,312 - video_to_video - INFO - target resolution: (1320, 1982)
2024-08-20 13:26:16,312 - video_to_video - INFO - noise augmentation: 250
2024-08-20 13:26:16,312 - video_to_video - INFO - scale s is set to: 8
2024-08-20 13:26:16,399 - video_to_video - INFO - video_data shape: torch.Size([145, 3, 1320, 1982])
/share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:108: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with amp.autocast(enabled=True):
2024-08-20 13:27:19,605 - video_to_video - INFO - step: 0
2024-08-20 13:30:12,020 - video_to_video - INFO - step: 1
2024-08-20 13:33:04,956 - video_to_video - INFO - step: 2
2024-08-20 13:35:58,691 - video_to_video - INFO - step: 3
2024-08-20 13:38:51,254 - video_to_video - INFO - step: 4
2024-08-20 13:41:44,150 - video_to_video - INFO - step: 5
2024-08-20 13:44:37,017 - video_to_video - INFO - step: 6
2024-08-20 13:47:30,037 - video_to_video - INFO - step: 7
2024-08-20 13:50:22,838 - video_to_video - INFO - step: 8
2024-08-20 13:53:15,844 - video_to_video - INFO - step: 9
2024-08-20 13:56:08,657 - video_to_video - INFO - step: 10
2024-08-20 13:59:01,648 - video_to_video - INFO - step: 11
2024-08-20 14:01:54,541 - video_to_video - INFO - step: 12
2024-08-20 14:04:47,488 - video_to_video - INFO - step: 13
2024-08-20 14:10:13,637 - video_to_video - INFO - sampling, finished.
```
Running on a single A100 GPU, enhancing each 6-second CogVideoX generated video with default settings will consume 60GB
of VRAM and take 40-50 minutes.
================================================
FILE: CogVideo/tools/venhancer/README_ja.md
================================================
# VEnhancer で CogVideoX によって生成されたビデオを強化する
このチュートリアルでは、VEnhancer ツールを使用して、CogVideoX で生成されたビデオを強化し、より高いフレームレートと高い解像度を実現する方法を説明します。
## モデルの紹介
VEnhancer は、空間超解像、時間超解像(フレーム補間)、およびビデオのリファインメントを統一されたフレームワークで実現します。空間または時間の超解像のために、さまざまなアップサンプリング係数(例:1x〜8x)に柔軟に対応できます。さらに、多様なビデオアーティファクトを処理するために、リファインメント強度を変更する柔軟な制御を提供します。
VEnhancer は ControlNet の設計に従い、事前訓練されたビデオ拡散モデルのマルチフレームエンコーダーとミドルブロックのアーキテクチャとウェイトをコピーして、トレーニング可能な条件ネットワークを構築します。このビデオ ControlNet は、低解像度のキーフレームとノイズを含む完全なフレームを入力として受け取ります。さらに、タイムステップ t とプロンプトに加えて、提案されたビデオ対応条件により、ノイズ増幅レベル σ およびダウンスケーリングファクター s が追加のネットワーク条件として使用されます。
## ハードウェア要件
+ オペレーティングシステム: Linux (xformers 依存関係が必要)
+ ハードウェア: 単一カードあたり少なくとも 60GB の VRAM を持つ NVIDIA GPU。H100、A100 などのマシンを推奨します。
## クイックスタート
1. 公式の指示に従ってリポジトリをクローンし、依存関係をインストールします。
```shell
git clone https://github.com/Vchitect/VEnhancer.git
cd VEnhancer
## Torch などの依存関係は CogVideoX の依存関係を使用できます。新しい環境を作成する必要がある場合は、以下のコマンドを使用してください。
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2
## 必須の依存関係をインストールします。
pip install -r requirements.txt
```
2. コードを実行します。
```shell
python enhance_a_video.py --up_scale 4 --target_fps 24 --noise_aug 250 --solver_mode 'fast' --steps 15 --input_path inputs/000000.mp4 --prompt 'Wide-angle aerial shot at dawn, soft morning light casting long shadows, an elderly man walking his dog through a quiet, foggy park, trees and benches in the background, peaceful and serene atmosphere' --save_dir 'results/'
```
次の設定を行います:
- `input_path` 是输入视频的路径
- `prompt` 是视频内容的描述。此工具使用的提示词应更短,不超过77个字。您可能需要简化用于生成CogVideoX视频的提示词。
- `target_fps` 是视频的目标帧率。通常,16 fps已经很流畅,默认值为24 fps。
- `up_scale` 推荐设置为2、3或4。目标分辨率限制在2k左右及以下。
- `noise_aug` 的值取决于输入视频的质量。质量较低的视频需要更高的噪声级别,这对应于更强的优化。250~300适用于非常低质量的视频。对于高质量视频,设置为≤200。
- `steps` 如果想减少步数,请先将solver_mode改为“normal”,然后减少步数。“fast”模式的步数是固定的(15步)。
代码在执行过程中会自动从Hugging Face下载所需的模型。
コードの実行中に、必要なモデルは Hugging Face から自動的にダウンロードされます。
```shell
/share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
@torch.library.impl_abstract("xformers_flash::flash_fwd")
/share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
@torch.library.impl_abstract("xformers_flash::flash_bwd")
2024-08-20 13:25:17,553 - video_to_video - INFO - checkpoint_path: ./ckpts/venhancer_paper.pt
/share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
2024-08-20 13:25:37,486 - video_to_video - INFO - Build encoder with FrozenOpenCLIPEmbedder
/share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:35: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
load_dict = torch.load(cfg.model_path, map_location='cpu')
2024-08-20 13:25:55,391 - video_to_video - INFO - Load model path ./ckpts/venhancer_paper.pt, with local status
2024-08-20 13:25:55,392 - video_to_video - INFO - Build diffusion with GaussianDiffusion
2024-08-20 13:26:16,092 - video_to_video - INFO - input video path: inputs/000000.mp4
2024-08-20 13:26:16,093 - video_to_video - INFO - text: Wide-angle aerial shot at dawn,soft morning light casting long shadows,an elderly man walking his dog through a quiet,foggy park,trees and benches in the background,peaceful and serene atmosphere
2024-08-20 13:26:16,156 - video_to_video - INFO - input frames length: 49
2024-08-20 13:26:16,156 - video_to_video - INFO - input fps: 8.0
2024-08-20 13:26:16,156 - video_to_video - INFO - target_fps: 24.0
2024-08-20 13:26:16,311 - video_to_video - INFO - input resolution: (480, 720)
2024-08-20 13:26:16,312 - video_to_video - INFO - target resolution: (1320, 1982)
2024-08-20 13:26:16,312 - video_to_video - INFO - noise augmentation: 250
2024-08-20 13:26:16,312 - video_to_video - INFO - scale s is set to: 8
2024-08-20 13:26:16,399 - video_to_video - INFO - video_data shape: torch.Size([145, 3, 1320, 1982])
/share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:108: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with amp.autocast(enabled=True):
2024-08-20 13:27:19,605 - video_to_video - INFO - step: 0
2024-08-20 13:30:12,020 - video_to_video - INFO - step: 1
2024-08-20 13:33:04,956 - video_to_video - INFO - step: 2
2024-08-20 13:35:58,691 - video_to_video - INFO - step: 3
2024-08-20 13:38:51,254 - video_to_video - INFO - step: 4
2024-08-20 13:41:44,150 - video_to_video - INFO - step: 5
2024-08-20 13:44:37,017 - video_to_video - INFO - step: 6
2024-08-20 13:47:30,037 - video_to_video - INFO - step: 7
2024-08-20 13:50:22,838 - video_to_video - INFO - step: 8
2024-08-20 13:53:15,844 - video_to_video - INFO - step: 9
2024-08-20 13:56:08,657 - video_to_video - INFO - step: 10
2024-08-20 13:59:01,648 - video_to_video - INFO - step: 11
2024-08-20 14:01:54,541 - video_to_video - INFO - step: 12
2024-08-20 14:04:47,488 - video_to_video - INFO - step: 13
2024-08-20 14:10:13,637 - video_to_video - INFO - sampling, finished.
```
A100 GPU を単一で使用している場合、CogVideoX によって生成された 6 秒間のビデオを強化するには、デフォルト設定で 60GB の VRAM を消費し、40〜50 分かかります。
================================================
FILE: CogVideo/tools/venhancer/README_zh.md
================================================
# 使用 VEnhancer 对 CogVdieoX 生成视频进行增强
本教程将要使用 VEnhancer 工具 对 CogVdieoX 生成视频进行增强, 包括更高的帧率和更高的分辨率
## 模型介绍
VEnhancer 在一个统一的框架中实现了空间超分辨率、时间超分辨率(帧插值)和视频优化。它可以灵活地适应不同的上采样因子(例如,1x~
8x)用于空间或时间超分辨率。此外,它提供了灵活的控制,以修改优化强度,从而处理多样化的视频伪影。
VEnhancer 遵循 ControlNet 的设计,复制了预训练的视频扩散模型的多帧编码器和中间块的架构和权重,构建了一个可训练的条件网络。这个视频
ControlNet 接受低分辨率关键帧和包含噪声的完整帧作为输入。此外,除了时间步 t 和提示词外,我们提出的视频感知条件还将噪声增强的噪声级别
σ 和降尺度因子 s 作为附加的网络条件输入。
## 硬件需求
+ 操作系统: Linux (需要依赖xformers)
+ 硬件: NVIDIA GPU 并至少保证单卡显存超过60G,推荐使用 H100,A100等机器。
## 快速上手
1. 按照官方指引克隆仓库并安装依赖
```shell
git clone https://github.com/Vchitect/VEnhancer.git
cd VEnhancer
## torch等依赖可以使用CogVideoX的依赖,如果你需要创建一个新的环境,可以使用以下命令
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2
## 安装必须的依赖
pip install -r requirements.txt
```
2. 运行代码
```shell
python enhance_a_video.py \
--up_scale 4 --target_fps 24 --noise_aug 250 \
--solver_mode 'fast' --steps 15 \
--input_path inputs/000000.mp4 \
--prompt 'Wide-angle aerial shot at dawn,soft morning light casting long shadows,an elderly man walking his dog through a quiet,foggy park,trees and benches in the background,peaceful and serene atmosphere' \
--save_dir 'results/'
```
其中:
- `input_path` 是输入视频的路径
- `prompt` 是视频内容的描述。此工具使用的提示词应更短,不超过77个字。您可能需要简化用于生成CogVideoX视频的提示词。
- `target_fps` 是视频的目标帧率。通常,16 fps已经很流畅,默认值为24 fps。
- `up_scale` 推荐设置为2、3或4。目标分辨率限制在2k左右及以下。
- `noise_aug` 的值取决于输入视频的质量。质量较低的视频需要更高的噪声级别,这对应于更强的优化。250~300适用于非常低质量的视频。对于高质量视频,设置为≤200。
- `steps` 如果想减少步数,请先将solver_mode改为“normal”,然后减少步数。“fast”模式的步数是固定的(15步)。
代码在执行过程中会自动从Hugging Face下载所需的模型。
代码运行过程中,会自动从Huggingface拉取需要的模型
运行日志通常如下:
```shell
/share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
@torch.library.impl_abstract("xformers_flash::flash_fwd")
/share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
@torch.library.impl_abstract("xformers_flash::flash_bwd")
2024-08-20 13:25:17,553 - video_to_video - INFO - checkpoint_path: ./ckpts/venhancer_paper.pt
/share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
2024-08-20 13:25:37,486 - video_to_video - INFO - Build encoder with FrozenOpenCLIPEmbedder
/share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:35: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
load_dict = torch.load(cfg.model_path, map_location='cpu')
2024-08-20 13:25:55,391 - video_to_video - INFO - Load model path ./ckpts/venhancer_paper.pt, with local status
2024-08-20 13:25:55,392 - video_to_video - INFO - Build diffusion with GaussianDiffusion
2024-08-20 13:26:16,092 - video_to_video - INFO - input video path: inputs/000000.mp4
2024-08-20 13:26:16,093 - video_to_video - INFO - text: Wide-angle aerial shot at dawn,soft morning light casting long shadows,an elderly man walking his dog through a quiet,foggy park,trees and benches in the background,peaceful and serene atmosphere
2024-08-20 13:26:16,156 - video_to_video - INFO - input frames length: 49
2024-08-20 13:26:16,156 - video_to_video - INFO - input fps: 8.0
2024-08-20 13:26:16,156 - video_to_video - INFO - target_fps: 24.0
2024-08-20 13:26:16,311 - video_to_video - INFO - input resolution: (480, 720)
2024-08-20 13:26:16,312 - video_to_video - INFO - target resolution: (1320, 1982)
2024-08-20 13:26:16,312 - video_to_video - INFO - noise augmentation: 250
2024-08-20 13:26:16,312 - video_to_video - INFO - scale s is set to: 8
2024-08-20 13:26:16,399 - video_to_video - INFO - video_data shape: torch.Size([145, 3, 1320, 1982])
/share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:108: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with amp.autocast(enabled=True):
2024-08-20 13:27:19,605 - video_to_video - INFO - step: 0
2024-08-20 13:30:12,020 - video_to_video - INFO - step: 1
2024-08-20 13:33:04,956 - video_to_video - INFO - step: 2
2024-08-20 13:35:58,691 - video_to_video - INFO - step: 3
2024-08-20 13:38:51,254 - video_to_video - INFO - step: 4
2024-08-20 13:41:44,150 - video_to_video - INFO - step: 5
2024-08-20 13:44:37,017 - video_to_video - INFO - step: 6
2024-08-20 13:47:30,037 - video_to_video - INFO - step: 7
2024-08-20 13:50:22,838 - video_to_video - INFO - step: 8
2024-08-20 13:53:15,844 - video_to_video - INFO - step: 9
2024-08-20 13:56:08,657 - video_to_video - INFO - step: 10
2024-08-20 13:59:01,648 - video_to_video - INFO - step: 11
2024-08-20 14:01:54,541 - video_to_video - INFO - step: 12
2024-08-20 14:04:47,488 - video_to_video - INFO - step: 13
2024-08-20 14:10:13,637 - video_to_video - INFO - sampling, finished.
```
使用A100单卡运行,对于每个CogVideoX生产的6秒视频,按照默认配置,会消耗60G显存,并用时40-50分钟。
================================================
FILE: CogVideo/weights/put weights here.txt
================================================
================================================
FILE: README.md
================================================
## ___***3DTrajMaster: Mastering 3D Trajectory for Multi-Entity Motion in Video Generation***___

**[Xiao Fu1](https://fuxiao0719.github.io/),
[Xian Liu1](https://alvinliu0.github.io/),
[Xintao Wang2 ✉](https://xinntao.github.io/),
[Sida Peng3](https://pengsida.net/),
[Menghan Xia2](https://menghanxia.github.io/),
[Xiaoyu Shi2](https://xiaoyushi97.github.io/),
[Ziyang Yuan2](https://scholar.google.ru/citations?user=fWxWEzsAAAAJ&hl=en),
[Pengfei Wan2](https://scholar.google.com/citations?user=P6MraaYAAAAJ&hl=en)
[Di Zhang2](https://openreview.net/profile?id=~Di_ZHANG3),
[Dahua Lin1✉](http://dahua.site/)**
1The Chinese University of Hong Kong
2Kuaishou Technology
3Zhejiang University
✉: Corresponding Authors
**ICLR 2025**
## 🌟 Introduction
🔥 3DTrajMaster controls **one or multiple entity motions in 3D space with entity-specific 3D trajectories** for text-to-video (T2V) generation. It has the following features:
- **6 Domain of Freedom (DoF)**: control 3D entity location and orientation.
- **Diverse Entities**: human, animal, robot, car, even abstract fire, breeze, etc.
- **Diverse Background**: city, forest, desert, gym, sunset beach, glacier, hall, night city, etc.
- **Complex 3D trajectories**: 3D occlusion, rotating in place, 180°/continuous 90° turnings, etc.
- **Fine-grained Entity Prompt**: change human hair, clothing, gender, figure size, accessory, etc.
https://github.com/user-attachments/assets/efe1870f-4168-4aff-98b8-dbd9e3802928
🔥 **Release News**
- `[2025/01/23]` 3DTrajMaster is accepted to ICLR 2025.
- `[2025/01/22]` Release inference and training codes based on CogVideoX-5B.
- `[2024/12/10]` Release [paper](https://arxiv.org/pdf/2412.07759), [project page](http://fuxiao0719.github.io/projects/3dtrajmaster), [dataset](https://huggingface.co/datasets/KwaiVGI/360Motion-Dataset), and [eval code](https://github.com/KwaiVGI/3DTrajMaster).
## ⚙️ Quick Start
> **(1) Access to Our Internal Video Model**
As per company policy, we may not release the proprietary trained model at this time. However, if you wish to access our internal model, please submit your request via (1) [a shared document](https://docs.google.com/spreadsheets/d/1HL96IS33fyzrDeXTt3hJ80ZsnfRBzDoKh8wparoBAGI/edit?pli=1&gid=0#gid=0) or (2) directly via email (`lemonaddie0909@gmail.com`, recommended); we will respond to requests with the generated video as quickly as possible.
Please ensure your request includes the following:
1. Entity prompts (1–3, with a maximum of 42 tokens, approximately 20 words per entity)
2. Location prompt
3. Trajectory template (you can refer to the trajectory template in our released 360°-Motion Dataset, or simply describe new ones via text)
> **(2) Access to Publicly Available Codebase**
We open-source a model based on CogVideoX-5B. Below is a comparison between CogVideoX and our internal video model as of 2025.01.15.
https://github.com/user-attachments/assets/a49e46d3-92d0-42ec-a89f-a9d43919f620
#### Inference
1. **[Environment Set Up]** Our environment setup is identical to [CogVideoX](https://github.com/THUDM/CogVideo). You can refer to their configuration to complete the environment setup.
```bash
conda create -n 3dtrajmaster python=3.10
conda activate 3dtrajmaster
pip install -r requirements.txt
```
2. **[Download Weights and Dataset]** Download the pretrained checkpoints (CogVideo-5B, LoRA, and injector) from [here](https://huggingface.co/KwaiVGI/3DTrajMaster) and place them in the `CogVideo/weights` directory. Then, download the dataset from [here](https://huggingface.co/datasets/KwaiVGI/360Motion-Dataset). Please note that in both training stages, we use only 11 camera poses and exclude the last camera pose as the novel pose setting.
3. **[Inference on Generalizable Prompts]** Change root path to `CogVideo/inference`. Note a higher LoRA scale and more annealed steps can improve accuracy in prompt generation but may result in lower visual quality. You can modify `test_sets.json` to add novel entity&location prompts. For entity input, you can use GPT to enhance the description to an appropriate length, such as "Generate a detailed description of approximately 20 words".
```bash
python 3dtrajmaster_inference.py \
--model_path ../weights/cogvideox-5b \
--ckpt_path ../weights/injector \
--lora_path ../weights/lora \
--lora_scale 0.6 \
--annealed_sample_step 20 \
--seed 24 \
--output_path output_example
```
| Argument | Description |
|-------------------------|-------------|
| `--lora_scale` | LoRA alpha weight. Options: 0-1, float. Default: 0.6. |
| `--annealed_sample_step` | annealed sampling steps during inference. Options: 0-50, int. Default: 20. |
| Generalizable Robustness | prompt entity number: 1>2>3 |
| Entity Length | 15-24 words, ~24-40 tokens after T5 embeddings |
The following code snapshot showcases the core components of 3DTrajMaster, namely the plug-and-play 3D-motion grounded object injector.
```python
# 1. norm & modulate
norm_hidden_states, norm_empty_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(hidden_states, empty_encoder_hidden_states, temb)
bz, N_visual, dim = norm_hidden_states.shape
max_entity_num = 3
_, entity_num, num_frames, _ = pose_embeds.shape
# 2. pair-wise fusion of trajectory and entity
attn_input = self.attn_null_feature.repeat(bz, max_entity_num, 50, num_frames, 1)
pose_embeds = self.pose_fuse_layer(pose_embeds)
attn_input[:,:entity_num,:,:,:] = pose_embeds.unsqueeze(-3) + prompt_entities_embeds.unsqueeze(-2)
attn_input = torch.cat((
rearrange(norm_hidden_states, "b (n t) d -> b n t d",n=num_frames),
rearrange(attn_input, "b n t f d -> b f (n t) d")),
dim=2
).flatten(1,2)
# 3. gated self-attention
attn_hidden_states, attn_encoder_hidden_states = self.attn1_injector(
hidden_states=attn_input,
encoder_hidden_states=norm_empty_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
attn_hidden_states = attn_hidden_states[:,:N_visual,:]
hidden_states = hidden_states + gate_msa * attn_hidden_states
```
#### Training
1. Change root path to `CogVideo/finetune`. First, train lora module to fit the synthetic data domain.
```bash
bash finetune_single_rank_lora.sh
```
2. Then, train injector module to learn the entity motion controller. Here we set `--block_interval` to 2 to insert the injector every 2 transformer blocks. You can increase this value for a lighter model, but note that it will require a longer training time. For the initial fine-tuning stage, use `--finetune_init`. If resuming from a pre-trained checkpoint, omit `--finetune_init` and specify `--resume_from_checkpoint $TRANSFORMER_PATH` instead. Note that in both training stages, we use only 11 camera poses and exclude the last camera pose as the novel pose setting.
```bash
bash finetune_single_rank_injector.sh
```
## 📦 360°-Motion Dataset ([Download 🤗](https://huggingface.co/datasets/KwaiVGI/360Motion-Dataset))
```
├── 360Motion-Dataset Video Number Cam-Obj Distance (m)
├── 480_720/384_672
├── Desert (desert) 18,000 [3.06, 13.39]
├── location_data.json
├── HDRI
├── loc1 (snowy street) 3,600 [3.43, 13.02]
├── loc2 (park) 3,600 [4.16, 12.22]
├── loc3 (indoor open space) 3,600 [3.62, 12.79]
├── loc11 (gymnastics room) 3,600 [4.06, 12.32]
├── loc13 (autumn forest) 3,600 [4.49, 11.92]
├── location_data.json
├── RefPic
├── CharacterInfo.json
├── Hemi12_transforms.json
```
> **(1) Released Dataset Information (V1.0.0)**
| Argument | Description |Argument | Description |
|-------------------------|-------------|-------------------------|-------------|
| **Video Resolution** | (1) 480×720 (2) 384×672 | **Frames/Duration/FPS** | 99/3.3s/30 |
| **UE Scenes** | 6 (1 desert+5 HDRIs) | **Video Samples** | (1) 36,000 (2) 36,000 |
| **Camera Intrinsics (fx,fy)** | (1) 1060.606 (2) 989.899 | **Sensor Width/Height (mm)** | (1) 23.76/15.84 (2) 23.76/13.365 |
| **Hemi12_transforms.json** | 12 surrounding cameras | **CharacterInfo.json** | entity prompts |
| **RefPic** | 50 animals | **1/2/3 Trajectory Templates** | 36/60/35 (121 in total) |
| **{D/N}_{locX}** | {Day/Night}_{LocationX} | **{C}_ {XX}_{35mm}** | {Close-Up Shot}_{Cam. Index(1-12)} _{Focal Length}|
**Note that** the resolution of 384×672 refers to our internal video diffusion resolution. In fact, we render the video at a resolution of 378×672 (aspect ratio 9:16), with a 3-pixel black border added to both the top and bottom.
> **(2) Difference with the Dataset to Train on Our Internal Video Diffusion Model**
The release of the full dataset regarding more entities and UE scenes is still under our internal license check.
| Argument | Released Dataset | Our Internal Dataset|
|-------------------------|-------------|-------------------------|
| **Video Resolution** | (1) 480×720 (2) 384×672 | 384×672 |
| **Entities** | 50 (all animals) | 70 (20 humans+50 animals) |
| **Video Samples** | (1) 36,000 (2) 36,000 | 54,000 |
| **Scenes** | 6 | 9 (+city, forest, asian town) |
| **Trajectory Templates** | 121 | 96 |
> **(3) Load Dataset Sample**
1. Change root path to `dataset`. We provide a script to load our dataset (video & entity & pose sequence) as follows. It will generate the sampled video for visualization in the same folder path.
```bash
python load_dataset.py
```
2. Visualize the 6DoF pose sequence via Open3D as follows:
```bash
python vis_trajecotry.py
```
After running the visualization script, you will get an interactive window like this. Note that we have converted the right-handed coordinate system (Open3D) to the left-handed coordinate system in order to better align with the motion trajectory of the video:
## 🚀 Benchmark Evaluation (Reproduce Paper Results)
```
├── eval
├── GVHMR
├── common_metrics_on_video_quality
```
> **(1) Evaluation on 3D Trajectory**
1. Change root path to `eval/GVHMR`. Then follow [GVHMR](https://github.com/zju3dv/GVHMR/blob/main/docs/INSTALL.md) installation to prepare the setups and (recommend using a different Conda environment to avoid package conflicts). Our evaluation input is available at [here](https://drive.google.com/file/d/1DLWioJtvv9u4snybu5DrteVWma12JXq3/view?usp=drive_link). Please note that the 3D trajectories have been downsampled from 77 frames to 20 frames to match the RGB latent space of the 3D VAE.
2. Download the [inference videos](https://drive.google.com/file/d/1jMH2-ZC0ZBgtqej5Sp-E5ebBIX7mk3Xz/view?usp=drive_link) generated by our internal video diffusion model and corresponding [evalution GT poses](https://drive.google.com/file/d/1iFcPSlcKb_rDNJ85UPoThdl22BqR2Xgh/view?usp=drive_link) by using this command (you can check the 3D evaluated trajectory via our provided visualization script):
```bash
bash download_eval_pose.sh
```
3. Estimation of human poses on evaluation sets:
```bash
python tools/demo/demo_folder.py -f eval_sets -d outputs/eval_sets_gvhmr -s
```
4. Evaluation of all human samples (note to convert the left and right hand coordinate systems) :
```bash
python tools/eval_pose.py -f outputs/eval_sets_gvhmr
```
> **(2) Evaluation on Visual Quality**
1. Change root path to `eval/common_metrics_on_video_quality`. Then download [fvd](https://drive.google.com/file/d/1U2hd6qvwKLfp7c8yGgcTqdqrP_lKJElB/view?usp=drive_link), [inference videos](https://drive.google.com/file/d/1jMH2-ZC0ZBgtqej5Sp-E5ebBIX7mk3Xz/view?usp=drive_link) and [base T2V inference videos](https://drive.google.com/file/d/1kfdCDA5koYh9g3IkCCHb4XPch2CJAwek/view?usp=drive_link) using the download script:
```bash
bash download_eval_visual.sh
```
2. Evaluation of FVD, FID, and CLIP-SIM metrics.
```bash
pip install pytorch-fid clip
bash eval_visual.sh
```
## 📚 Related Work
- [MotionCtrl](https://github.com/TencentARC/MotionCtrl): the first to control 3D camera motion and 2D object motion in video generation
- [TC4D](https://sherwinbahmani.github.io/tc4d/): compositional text-to-4D scene generation with 3D trajectory conditions
- [Tora](https://ali-videoai.github.io/tora_video/): control 2D motions in trajectory-oriented diffusion transformer for video generation
- [SynCamMaster](https://jianhongbai.github.io/SynCamMaster/): multi-camera synchronized video generation from diverse viewpoints
- [StyleMaster](https://zixuan-ye.github.io/stylemaster): enable artistic video generation and translation with reference style image
####
## 🔗 Citation
If you find this work helpful, please consider citing:
```BibTeXw
@inproceedings{fu20243dtrajmaster,
title={3DTrajMaster: Mastering 3D Trajectory for Multi-Entity Motion in Video Generation},
author={Fu, Xiao and Liu, Xian and Wang, Xintao and Peng, Sida and Xia, Menghan and Shi, Xiaoyu and Yuan, Ziyang and Wan, Pengfei and Zhang, Di and Lin, Dahua},
booktitle={ICLR},
year={2025}
}
```
================================================
FILE: dataset/load_dataset.py
================================================
# Copyright 2024 Xiao Fu, CUHK, Kuaishou Tech. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# If you find this code useful, we kindly ask you to cite our paper in your work.
# More information about the method can be found at http://fuxiao0719.github.io/projects/3dtrajmaster
# --------------------------------------------------------------------------
import os
import numpy as np
import json
import torch
import random
import cv2
import decord
from einops import rearrange
from utils import *
# --------------------------------------------------------------------------
# 1. Load scenes infomation
# --------------------------------------------------------------------------
dataset_root = 'root_path/360Motion-Dataset'
video_res = '480_720'
video_names = []
scenes = ['Desert', 'HDRI']
scene_location_pair = {
'Desert' : 'desert',
'HDRI' :
{
'loc1' : 'snowy street',
'loc2' : 'park',
'loc3' : 'indoor open space',
'loc11' : 'gymnastics room',
'loc13' : 'autumn forest',
}
}
for scene in scenes:
video_path = os.path.join(dataset_root, video_res, scene)
locations_path = os.path.join(video_path, "location_data.json")
with open(locations_path, 'r') as f: locations = json.load(f)
locations_info = {locations[idx]['name']:locations[idx] for idx in range(len(locations))}
for video_name in os.listdir(video_path):
if video_name.endswith('Hemi12_1') == True:
if scene != 'HDRI':
location = scene_location_pair[scene]
else:
location = scene_location_pair['HDRI'][video_name.split('_')[1]]
video_names.append((video_res, scene, video_name, location, locations_info))
# --------------------------------------------------------------------------
# 2. Load 12 surrounding cameras
# --------------------------------------------------------------------------
cam_num = 12
max_objs_num = 3
length = len(video_names)
captions_path = os.path.join(dataset_root, "CharacterInfo.json")
with open(captions_path, 'r') as f: captions = json.load(f)['CharacterInfo']
captions_info = {int(captions[idx]['index']):captions[idx]['eng'] for idx in range(len(captions))}
cams_path = os.path.join(dataset_root, "Hemi12_transforms.json")
with open(cams_path, 'r') as f: cams_info = json.load(f)
cam_poses = []
for i, key in enumerate(cams_info.keys()):
if "C_" in key:
cam_poses.append(parse_matrix(cams_info[key]))
cam_poses = np.stack(cam_poses)
cam_poses = np.transpose(cam_poses, (0,2,1))
cam_poses = cam_poses[:,:,[1,2,0,3]]
cam_poses[:,:3,3] /= 100.
cam_poses = cam_poses
sample_n_frames = 49
# --------------------------------------------------------------------------
# 3. Load a sample of video & object poses
# --------------------------------------------------------------------------
(video_res, scene, video_name, location, locations_info) = video_names[20]
with open(os.path.join(dataset_root, video_res, scene, video_name, video_name+'.json'), 'r') as f: objs_file = json.load(f)
objs_num = len(objs_file['0'])
video_index = random.randint(1, cam_num-1)
location_name = video_name.split('_')[1]
location_info = locations_info[location_name]
cam_pose = cam_poses[video_index-1]
obj_transl = location_info['coordinates']['CameraTarget']['position']
prompt = ''
video_caption_list = []
obj_poses_list = []
for obj_idx in range(objs_num):
obj_name_index = objs_file['0'][obj_idx]['index']
video_caption = captions_info[obj_name_index]
if video_caption.startswith(" "):
video_caption = video_caption[1:]
if video_caption.endswith("."):
video_caption = video_caption[:-1]
video_caption = video_caption.lower()
video_caption_list.append(video_caption)
obj_poses = load_sceneposes(objs_file, obj_idx, obj_transl)
obj_poses = np.linalg.inv(cam_pose) @ obj_poses
obj_poses_list.append(obj_poses)
for obj_idx in range(objs_num):
video_caption = video_caption_list[obj_idx]
if obj_idx == objs_num - 1:
if objs_num == 1:
prompt += video_caption + ' is moving in the ' + location
else:
prompt += video_caption + ' are moving in the ' + location
else:
prompt += video_caption + ' and '
obj_poses_all = torch.from_numpy(np.array(obj_poses_list))
total_frames = 99
current_sample_stride = 1.75
cropped_length = int(sample_n_frames * current_sample_stride)
start_frame_ind = random.randint(10, max(10, total_frames - cropped_length - 1))
end_frame_ind = min(start_frame_ind + cropped_length, total_frames)
frame_indices = np.linspace(start_frame_ind, end_frame_ind - 1, sample_n_frames, dtype=int)
video_frames_path = os.path.join(dataset_root, video_res, scene, video_name, 'videos', video_name+ f'_C_{video_index:02d}_35mm.mp4')
cap = cv2.VideoCapture(video_frames_path)
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
# get local rank
ctx = decord.cpu(0)
reader = decord.VideoReader(video_frames_path, ctx=ctx, height=height, width=width)
assert len(reader) == total_frames or len(reader) == total_frames+1
frame_indexes = [frame_idx for frame_idx in range(total_frames)]
try:
video_chunk = reader.get_batch(frame_indexes).asnumpy()
except:
video_chunk = reader.get_batch(frame_indexes).numpy()
pixel_values = np.array([video_chunk[indice] for indice in frame_indices])
pixel_values = rearrange(torch.from_numpy(pixel_values) / 255.0, "f h w c -> f c h w")
save_video = True
if save_video:
video_data = (pixel_values.cpu().to(torch.float32).numpy() * 255).astype(np.uint8)
video_data = rearrange(video_data, "f c h w -> f h w c")
save_images2video(video_data, video_name, 12)
================================================
FILE: dataset/traj_vis/D_loc1_61_t3n13_003d_Hemi12_1.json
================================================
{"0":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.8 -22935.3 147.103 1] "},{"index":141,"matrix":"[-0.613263 -0.789878 0 0] [0.789878 -0.613263 -0 0] [0 0 1 0] [18849.5 -22481.6 147.103 1] "},{"index":2,"matrix":"[0.951994 0.306116 0 0] [-0.306116 0.951994 0 0] [0 -0 1 0] [19030.7 -22801.6 147.103 1] "}],"1":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.8 -22931.3 147.103 1] "},{"index":141,"matrix":"[-0.341759 -0.939788 0 0] [0.939788 -0.341759 -0 0] [0 0 1 0] [18843.7 -22489 147.103 1] "},{"index":2,"matrix":"[0.777737 0.628589 0 0] [-0.628589 0.777737 0 0] [0 -0 1 0] [19037.2 -22799.5 147.103 1] "}],"2":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.7 -22927.4 147.103 1] "},{"index":141,"matrix":"[-0.317189 -0.948362 0 0] [0.948362 -0.317189 -0 0] [0 0 1 0] [18840.5 -22497.9 147.103 1] "},{"index":2,"matrix":"[0.734696 0.678396 0 0] [-0.678396 0.734696 0 0] [0 -0 1 0] [19042.5 -22795.2 147.103 1] "}],"3":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.7 -22923.5 147.103 1] "},{"index":141,"matrix":"[-0.304602 -0.95248 0 0] [0.95248 -0.304602 -0 0] [0 0 1 0] [18837.5 -22506.8 147.103 1] "},{"index":2,"matrix":"[0.717886 0.696161 0 0] [-0.696161 0.717886 0 0] [0 -0 1 0] [19047.5 -22790.6 147.103 1] "}],"4":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.7 -22919.6 147.103 1] "},{"index":141,"matrix":"[-0.291852 -0.956463 0 0] [0.956463 -0.291852 -0 0] [0 0 1 0] [18834.6 -22515.8 147.103 1] "},{"index":2,"matrix":"[0.703315 0.710879 0 0] [-0.710879 0.703315 0 0] [0 -0 1 0] [19052.4 -22785.8 147.103 1] "}],"5":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.7 -22915.7 147.103 1] "},{"index":141,"matrix":"[-0.278641 -0.960395 0 0] [0.960395 -0.278641 -0 0] [0 0 1 0] [18831.9 -22524.8 147.103 1] "},{"index":2,"matrix":"[0.688207 0.725515 0 0] [-0.725515 0.688207 0 0] [0 -0 1 0] [19057.2 -22781 147.103 1] "}],"6":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.7 -22911.7 147.103 1] "},{"index":141,"matrix":"[-0.264988 -0.964252 0 0] [0.964252 -0.264988 -0 0] [0 0 1 0] [18829.3 -22533.8 147.103 1] "},{"index":2,"matrix":"[0.672214 0.740357 0 0] [-0.740357 0.672214 0 0] [0 -0 1 0] [19061.9 -22776.1 147.103 1] "}],"7":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.6 -22907.8 147.103 1] "},{"index":141,"matrix":"[-0.250908 -0.968011 0 0] [0.968011 -0.250908 -0 0] [0 0 1 0] [18826.8 -22542.9 147.103 1] "},{"index":2,"matrix":"[0.655244 0.755417 0 0] [-0.755417 0.655244 0 0] [0 -0 1 0] [19066.4 -22771 147.103 1] "}],"8":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.6 -22903.9 147.103 1] "},{"index":141,"matrix":"[-0.23641 -0.971653 0 0] [0.971653 -0.23641 -0 0] [0 0 1 0] [18824.4 -22552 147.103 1] "},{"index":2,"matrix":"[0.637232 0.770672 0 0] [-0.770672 0.637232 0 0] [0 -0 1 0] [19070.9 -22765.9 147.103 1] "}],"9":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.6 -22900 147.103 1] "},{"index":141,"matrix":"[-0.221488 -0.975163 0 0] [0.975163 -0.221488 -0 0] [0 0 1 0] [18822.2 -22561.1 147.103 1] "},{"index":2,"matrix":"[0.618071 0.786122 0 0] [-0.786122 0.618071 0 0] [0 -0 1 0] [19075.2 -22760.7 147.103 1] "}],"10":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.6 -22896.1 147.103 1] "},{"index":141,"matrix":"[-0.206134 -0.978524 0 0] [0.978524 -0.206134 -0 0] [0 0 1 0] [18820.1 -22570.3 147.103 1] "},{"index":2,"matrix":"[0.597642 0.801763 0 0] [-0.801763 0.597642 0 0] [0 -0 1 0] [19079.4 -22755.3 147.103 1] "}],"11":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.5 -22892.1 147.103 1] "},{"index":141,"matrix":"[-0.190335 -0.981719 0 0] [0.981719 -0.190335 -0 0] [0 0 1 0] [18818.2 -22579.5 147.103 1] "},{"index":2,"matrix":"[0.575785 0.817601 0 0] [-0.817601 0.575785 0 0] [0 -0 1 0] [19083.5 -22749.9 147.103 1] "}],"12":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.5 -22888.2 147.103 1] "},{"index":141,"matrix":"[-0.174056 -0.984736 0 0] [0.984736 -0.174056 -0 0] [0 0 1 0] [18816.4 -22588.7 147.103 1] "},{"index":2,"matrix":"[0.552297 0.833647 0 0] [-0.833647 0.552297 0 0] [0 -0 1 0] [19087.4 -22744.3 147.103 1] "}],"13":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.5 -22884.3 147.103 1] "},{"index":141,"matrix":"[-0.157284 -0.987553 0 0] [0.987553 -0.157284 -0 0] [0 0 1 0] [18814.7 -22598 147.103 1] "},{"index":2,"matrix":"[0.526901 0.849927 0 0] [-0.849927 0.526901 0 0] [0 -0 1 0] [19091.2 -22738.6 147.103 1] "}],"14":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.5 -22880.4 147.103 1] "},{"index":141,"matrix":"[-0.139988 -0.990153 0 0] [0.990153 -0.139988 -0 0] [0 0 1 0] [18813.3 -22607.3 147.103 1] "},{"index":2,"matrix":"[0.499271 0.866446 0 0] [-0.866446 0.499271 0 0] [0 -0 1 0] [19094.8 -22732.8 147.103 1] "}],"15":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.5 -22876.5 147.103 1] "},{"index":141,"matrix":"[-0.122116 -0.992516 0 0] [0.992516 -0.122116 -0 0] [0 0 1 0] [18811.9 -22616.6 147.103 1] "},{"index":2,"matrix":"[0.468991 0.883203 0 0] [-0.883203 0.468991 0 0] [0 -0 1 0] [19098.2 -22726.9 147.103 1] "}],"16":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.4 -22872.5 147.103 1] "},{"index":141,"matrix":"[-0.103621 -0.994617 0 0] [0.994617 -0.103621 -0 0] [0 0 1 0] [18810.8 -22625.9 147.103 1] "},{"index":2,"matrix":"[0.435515 0.900182 0 0] [-0.900182 0.435515 0 0] [0 -0 1 0] [19101.4 -22720.9 147.103 1] "}],"17":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.4 -22868.6 147.103 1] "},{"index":141,"matrix":"[-0.0844501 -0.996428 0 0] [0.996428 -0.0844501 -0 0] [0 0 1 0] [18809.8 -22635.3 147.103 1] "},{"index":2,"matrix":"[0.398234 0.917284 0 0] [-0.917284 0.398234 0 0] [0 -0 1 0] [19104.3 -22714.8 147.103 1] "}],"18":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.4 -22864.7 147.103 1] "},{"index":141,"matrix":"[-0.0645412 -0.997915 0 0] [0.997915 -0.0645412 -0 0] [0 0 1 0] [18809 -22644.7 147.103 1] "},{"index":2,"matrix":"[0.356564 0.934271 0 0] [-0.934271 0.356564 0 0] [0 -0 1 0] [19107 -22708.6 147.103 1] "}],"19":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.4 -22860.8 147.103 1] "},{"index":141,"matrix":"[-0.0438255 -0.999039 0 0] [0.999039 -0.0438255 -0 0] [0 0 1 0] [18808.4 -22654.1 147.103 1] "},{"index":2,"matrix":"[0.310051 0.95072 0 0] [-0.95072 0.310051 0 0] [0 -0 1 0] [19109.5 -22702.2 147.103 1] "}],"20":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.4 -22856.9 147.103 1] "},{"index":141,"matrix":"[-0.0222269 -0.999753 0 0] [0.999753 -0.0222269 -0 0] [0 0 1 0] [18808 -22663.4 147.103 1] "},{"index":2,"matrix":"[0.258735 0.965948 0 0] [-0.965948 0.258735 0 0] [0 -0 1 0] [19111.6 -22695.7 147.103 1] "}],"21":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.3 -22852.9 147.103 1] "},{"index":141,"matrix":"[0.00032029 -1 0 0] [1 0.00032029 -0 0] [0 0 1 0] [18807.8 -22672.9 147.103 1] "},{"index":2,"matrix":"[0.200412 0.979712 0 0] [-0.979712 0.200412 0 0] [0 -0 1 0] [19113.3 -22689.2 147.103 1] "}],"22":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.3 -22849 147.103 1] "},{"index":141,"matrix":"[0.0239348 -0.999714 0 0] [0.999714 0.0239348 -0 0] [0 0 1 0] [18807.8 -22682.3 147.103 1] "},{"index":2,"matrix":"[0.122083 0.99252 0 0] [-0.99252 0.122083 0 0] [0 -0 1 0] [19114.7 -22682.5 147.103 1] "}],"23":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.3 -22845.1 147.103 1] "},{"index":141,"matrix":"[0.0487004 -0.998813 0 0] [0.998813 0.0487004 -0 0] [0 0 1 0] [18808 -22691.7 147.103 1] "},{"index":2,"matrix":"[0.0327411 0.999464 0 0] [-0.999464 0.0327411 0 0] [0 -0 1 0] [19115.5 -22675.7 147.103 1] "}],"24":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.3 -22841.2 147.103 1] "},{"index":141,"matrix":"[0.0747293 -0.997204 0 0] [0.997204 0.0747293 -0 0] [0 0 1 0] [18808.5 -22701.1 147.103 1] "},{"index":2,"matrix":"[-0.0557011 0.998447 0 0] [-0.998447 -0.0557011 0 0] [0 -0 1 0] [19115.7 -22668.9 147.103 1] "}],"25":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.3 -22837.3 147.103 1] "},{"index":141,"matrix":"[0.102095 -0.994775 0 0] [0.994775 0.102095 -0 0] [0 0 1 0] [18809.2 -22710.4 147.103 1] "},{"index":2,"matrix":"[-0.13492 0.990856 0 0] [-0.990856 -0.13492 0 0] [0 -0 1 0] [19115.4 -22662.2 147.103 1] "}],"26":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.2 -22833.3 147.103 1] "},{"index":141,"matrix":"[0.13093 -0.991392 0 0] [0.991392 0.13093 -0 0] [0 0 1 0] [18810.2 -22719.8 147.103 1] "},{"index":2,"matrix":"[-0.202816 0.979217 0 0] [-0.979217 -0.202816 0 0] [0 -0 1 0] [19114.4 -22655.4 147.103 1] "}],"27":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.2 -22829.4 147.103 1] "},{"index":141,"matrix":"[0.161334 -0.9869 0 0] [0.9869 0.161334 -0 0] [0 0 1 0] [18811.4 -22729.1 147.103 1] "},{"index":2,"matrix":"[-0.261224 0.965278 0 0] [-0.965278 -0.261224 0 0] [0 -0 1 0] [19113.1 -22648.7 147.103 1] "}],"28":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.2 -22825.5 147.103 1] "},{"index":141,"matrix":"[0.193408 -0.981118 0 0] [0.981118 0.193408 -0 0] [0 0 1 0] [18812.9 -22738.4 147.103 1] "},{"index":2,"matrix":"[-0.312923 0.949779 0 0] [-0.949779 -0.312923 0 0] [0 -0 1 0] [19111.3 -22642.2 147.103 1] "}],"29":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.2 -22821.6 147.103 1] "},{"index":141,"matrix":"[0.227223 -0.973843 0 0] [0.973843 0.227223 -0 0] [0 0 1 0] [18814.7 -22747.6 147.103 1] "},{"index":2,"matrix":"[-0.360193 0.932878 0 0] [-0.932878 -0.360193 0 0] [0 -0 1 0] [19109.2 -22635.7 147.103 1] "}],"30":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.2 -22817.7 147.103 1] "},{"index":141,"matrix":"[0.262798 -0.964851 0 0] [0.964851 0.262798 -0 0] [0 0 1 0] [18816.9 -22756.8 147.103 1] "},{"index":2,"matrix":"[-0.404567 0.914508 0 0] [-0.914508 -0.404567 0 0] [0 -0 1 0] [19106.7 -22629.4 147.103 1] "}],"31":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.1 -22813.7 147.103 1] "},{"index":141,"matrix":"[0.300213 -0.953872 0 0] [0.953872 0.300213 -0 0] [0 0 1 0] [18819.3 -22765.9 147.103 1] "},{"index":2,"matrix":"[-0.447013 0.894527 0 0] [-0.894527 -0.447013 0 0] [0 -0 1 0] [19104 -22623.2 147.103 1] "}],"32":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.1 -22809.8 147.103 1] "},{"index":141,"matrix":"[0.339367 -0.940654 0 0] [0.940654 0.339367 -0 0] [0 0 1 0] [18822.2 -22774.9 147.103 1] "},{"index":2,"matrix":"[-0.488136 0.872768 0 0] [-0.872768 -0.488136 0 0] [0 -0 1 0] [19100.9 -22617.1 147.103 1] "}],"33":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.1 -22805.9 147.103 1] "},{"index":141,"matrix":"[0.380199 -0.924905 0 0] [0.924905 0.380199 -0 0] [0 0 1 0] [18825.3 -22783.7 147.103 1] "},{"index":2,"matrix":"[-0.52829 0.849064 0 0] [-0.849064 -0.52829 0 0] [0 -0 1 0] [19097.6 -22611.1 147.103 1] "}],"34":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999.1 -22802 147.103 1] "},{"index":141,"matrix":"[0.422527 -0.90635 0 0] [0.90635 0.422527 -0 0] [0 0 1 0] [18828.9 -22792.4 147.103 1] "},{"index":2,"matrix":"[-0.56771 0.823229 0 0] [-0.823229 -0.56771 0 0] [0 -0 1 0] [19094 -22605.3 147.103 1] "}],"35":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999 -22798.1 147.103 1] "},{"index":141,"matrix":"[0.466092 -0.884736 0 0] [0.884736 0.466092 -0 0] [0 0 1 0] [18832.9 -22800.9 147.103 1] "},{"index":2,"matrix":"[-0.606509 0.795076 0 0] [-0.795076 -0.606509 0 0] [0 -0 1 0] [19090.1 -22599.7 147.103 1] "}],"36":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999 -22794.1 147.103 1] "},{"index":141,"matrix":"[0.510543 -0.859852 0 0] [0.859852 0.510543 -0 0] [0 0 1 0] [18837.3 -22809.3 147.103 1] "},{"index":2,"matrix":"[-0.64471 0.764428 0 0] [-0.764428 -0.64471 0 0] [0 -0 1 0] [19086 -22594.3 147.103 1] "}],"37":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999 -22790.2 147.103 1] "},{"index":141,"matrix":"[0.55544 -0.831557 0 0] [0.831557 0.55544 -0 0] [0 0 1 0] [18842.1 -22817.3 147.103 1] "},{"index":2,"matrix":"[-0.682253 0.731116 0 0] [-0.731116 -0.682253 0 0] [0 -0 1 0] [19081.6 -22589.1 147.103 1] "}],"38":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999 -22786.3 147.103 1] "},{"index":141,"matrix":"[0.60026 -0.799805 0 0] [0.799805 0.60026 -0 0] [0 0 1 0] [18847.3 -22825.2 147.103 1] "},{"index":2,"matrix":"[-0.719002 0.695008 0 0] [-0.695008 -0.719002 0 0] [0 -0 1 0] [19077 -22584.2 147.103 1] "}],"39":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18999 -22782.4 147.103 1] "},{"index":141,"matrix":"[0.644427 -0.764665 0 0] [0.764665 0.644427 -0 0] [0 0 1 0] [18853 -22832.7 147.103 1] "},{"index":2,"matrix":"[-0.754745 0.656018 0 0] [-0.656018 -0.754745 0 0] [0 -0 1 0] [19072.1 -22579.4 147.103 1] "}],"40":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.9 -22778.5 147.103 1] "},{"index":141,"matrix":"[0.68732 -0.726355 0 0] [0.726355 0.68732 -0 0] [0 0 1 0] [18859 -22839.9 147.103 1] "},{"index":2,"matrix":"[-0.789206 0.614128 0 0] [-0.614128 -0.789206 0 0] [0 -0 1 0] [19066.9 -22575 147.103 1] "}],"41":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.9 -22774.5 147.103 1] "},{"index":141,"matrix":"[0.728345 -0.685211 0 0] [0.685211 0.728345 -0 0] [0 0 1 0] [18865.5 -22846.7 147.103 1] "},{"index":2,"matrix":"[-0.822049 0.569417 0 0] [-0.569417 -0.822049 0 0] [0 -0 1 0] [19061.6 -22570.8 147.103 1] "}],"42":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.9 -22770.6 147.103 1] "},{"index":141,"matrix":"[0.766948 -0.641709 0 0] [0.641709 0.766948 -0 0] [0 0 1 0] [18872.3 -22853.2 147.103 1] "},{"index":2,"matrix":"[-0.852907 0.522062 0 0] [-0.522062 -0.852907 0 0] [0 -0 1 0] [19056 -22566.9 147.103 1] "}],"43":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.9 -22766.7 147.103 1] "},{"index":141,"matrix":"[0.802643 -0.596459 0 0] [0.596459 0.802643 -0 0] [0 0 1 0] [18879.6 -22859.2 147.103 1] "},{"index":2,"matrix":"[-0.8814 0.472371 0 0] [-0.472371 -0.8814 0 0] [0 -0 1 0] [19050.2 -22563.4 147.103 1] "}],"44":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.9 -22762.8 147.103 1] "},{"index":141,"matrix":"[0.835163 -0.550002 0 0] [0.550002 0.835163 -0 0] [0 0 1 0] [18887.1 -22864.8 147.103 1] "},{"index":2,"matrix":"[-0.907165 0.420775 0 0] [-0.420775 -0.907165 0 0] [0 -0 1 0] [19044.2 -22560.2 147.103 1] "}],"45":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.8 -22758.9 147.103 1] "},{"index":141,"matrix":"[0.864292 -0.502991 0 0] [0.502991 0.864292 -0 0] [0 0 1 0] [18895 -22870 147.103 1] "},{"index":2,"matrix":"[-0.929899 0.367815 0 0] [-0.367815 -0.929899 0 0] [0 -0 1 0] [19038 -22557.3 147.103 1] "}],"46":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.8 -22754.9 147.103 1] "},{"index":141,"matrix":"[0.889922 -0.456113 0 0] [0.456113 0.889922 -0 0] [0 0 1 0] [18903.1 -22874.7 147.103 1] "},{"index":2,"matrix":"[-0.949391 0.314096 0 0] [-0.314096 -0.949391 0 0] [0 -0 1 0] [19031.7 -22554.8 147.103 1] "}],"47":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.8 -22751 147.103 1] "},{"index":141,"matrix":"[0.912239 -0.409659 0 0] [0.409659 0.912239 -0 0] [0 0 1 0] [18911.5 -22879 147.103 1] "},{"index":2,"matrix":"[-0.965524 0.260313 0 0] [-0.260313 -0.965524 0 0] [0 -0 1 0] [19025.2 -22552.6 147.103 1] "}],"48":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.8 -22747.1 147.103 1] "},{"index":141,"matrix":"[0.931546 -0.363624 0 0] [0.363624 0.931546 -0 0] [0 0 1 0] [18920 -22882.9 147.103 1] "},{"index":2,"matrix":"[-0.978323 0.207085 0 0] [-0.207085 -0.978323 0 0] [0 -0 1 0] [19018.7 -22550.9 147.103 1] "}],"49":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.8 -22743.2 147.103 1] "},{"index":141,"matrix":"[0.94818 -0.317734 0 0] [0.317734 0.94818 -0 0] [0 0 1 0] [18928.8 -22886.3 147.103 1] "},{"index":2,"matrix":"[-0.987885 0.155185 0 0] [-0.155185 -0.987885 0 0] [0 -0 1 0] [19012 -22549.5 147.103 1] "}],"50":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.7 -22739.3 147.103 1] "},{"index":141,"matrix":"[0.962364 -0.271763 0 0] [0.271763 0.962364 -0 0] [0 0 1 0] [18937.7 -22889.3 147.103 1] "},{"index":2,"matrix":"[-0.994528 0.104469 0 0] [-0.104469 -0.994528 0 0] [0 -0 1 0] [19005.3 -22548.4 147.103 1] "}],"51":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.7 -22735.3 147.103 1] "},{"index":141,"matrix":"[0.974232 -0.225548 0 0] [0.225548 0.974232 -0 0] [0 0 1 0] [18946.8 -22891.8 147.103 1] "},{"index":2,"matrix":"[-0.998523 0.0543235 0 0] [-0.0543235 -0.998523 0 0] [0 -0 1 0] [18998.5 -22547.7 147.103 1] "}],"52":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.7 -22731.4 147.103 1] "},{"index":141,"matrix":"[0.983864 -0.17892 0 0] [0.17892 0.983864 -0 0] [0 0 1 0] [18955.9 -22893.9 147.103 1] "},{"index":2,"matrix":"[-0.999991 0.00428697 0 0] [-0.00428697 -0.999991 0 0] [0 -0 1 0] [18991.7 -22547.3 147.103 1] "}],"53":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.7 -22727.5 147.103 1] "},{"index":141,"matrix":"[0.99128 -0.131775 0 0] [0.131775 0.99128 -0 0] [0 0 1 0] [18965.2 -22895.6 147.103 1] "},{"index":2,"matrix":"[-0.998943 -0.0459677 0 0] [0.0459677 -0.998943 -0 0] [0 0 1 0] [18984.9 -22547.3 147.103 1] "}],"54":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.7 -22723.6 147.103 1] "},{"index":141,"matrix":"[0.996469 -0.0839633 0 0] [0.0839633 0.996469 -0 0] [0 0 1 0] [18974.5 -22896.9 147.103 1] "},{"index":2,"matrix":"[-0.995314 -0.0966938 0 0] [0.0966938 -0.995314 -0 0] [0 0 1 0] [18978.1 -22547.6 147.103 1] "}],"55":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.6 -22719.7 147.103 1] "},{"index":141,"matrix":"[0.999369 -0.0355101 0 0] [0.0355101 0.999369 -0 0] [0 0 1 0] [18983.9 -22897.7 147.103 1] "},{"index":2,"matrix":"[-0.988976 -0.148073 0 0] [0.148073 -0.988976 -0 0] [0 0 1 0] [18971.3 -22548.3 147.103 1] "}],"56":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.6 -22715.7 147.103 1] "},{"index":141,"matrix":"[0.999908 0.0135796 0 0] [-0.0135796 0.999908 0 0] [0 -0 1 0] [18993.3 -22898 147.103 1] "},{"index":2,"matrix":"[-0.979751 -0.20022 0 0] [0.20022 -0.979751 -0 0] [0 0 1 0] [18964.6 -22549.3 147.103 1] "}],"57":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.6 -22711.8 147.103 1] "},{"index":141,"matrix":"[0.997998 0.063239 0 0] [-0.063239 0.997998 0 0] [0 -0 1 0] [19002.7 -22897.9 147.103 1] "},{"index":2,"matrix":"[-0.967418 -0.253185 0 0] [0.253185 -0.967418 -0 0] [0 0 1 0] [18958 -22550.6 147.103 1] "}],"58":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.6 -22707.9 147.103 1] "},{"index":141,"matrix":"[0.993556 0.113345 0 0] [-0.113345 0.993556 0 0] [0 -0 1 0] [19012.1 -22897.3 147.103 1] "},{"index":2,"matrix":"[-0.95174 -0.306905 0 0] [0.306905 -0.95174 -0 0] [0 0 1 0] [18951.4 -22552.4 147.103 1] "}],"59":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.5 -22704 147.103 1] "},{"index":141,"matrix":"[0.986506 0.163722 0 0] [-0.163722 0.986506 0 0] [0 -0 1 0] [19021.4 -22896.2 147.103 1] "},{"index":2,"matrix":"[-0.932457 -0.36128 0 0] [0.36128 -0.932457 -0 0] [0 0 1 0] [18944.9 -22554.5 147.103 1] "}],"60":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.5 -22700.1 147.103 1] "},{"index":141,"matrix":"[0.976803 0.214142 0 0] [-0.214142 0.976803 0 0] [0 -0 1 0] [19030.7 -22894.7 147.103 1] "},{"index":2,"matrix":"[-0.909323 -0.416092 0 0] [0.416092 -0.909323 -0 0] [0 0 1 0] [18938.6 -22556.9 147.103 1] "}],"61":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.5 -22696.1 147.103 1] "},{"index":141,"matrix":"[0.964431 0.264335 0 0] [-0.264335 0.964431 0 0] [0 -0 1 0] [19039.9 -22892.6 147.103 1] "},{"index":2,"matrix":"[-0.882132 -0.471002 0 0] [0.471002 -0.882132 -0 0] [0 0 1 0] [18932.4 -22559.7 147.103 1] "}],"62":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.5 -22692.2 147.103 1] "},{"index":141,"matrix":"[0.949422 0.314004 0 0] [-0.314004 0.949422 0 0] [0 -0 1 0] [19049 -22890.2 147.103 1] "},{"index":2,"matrix":"[-0.850733 -0.525599 0 0] [0.525599 -0.850733 -0 0] [0 0 1 0] [18926.4 -22562.9 147.103 1] "}],"63":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.5 -22688.3 147.103 1] "},{"index":141,"matrix":"[0.931857 0.362827 0 0] [-0.362827 0.931857 0 0] [0 -0 1 0] [19057.9 -22887.2 147.103 1] "},{"index":2,"matrix":"[-0.815099 -0.579321 0 0] [0.579321 -0.815099 -0 0] [0 0 1 0] [18920.6 -22566.5 147.103 1] "}],"64":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.4 -22684.4 147.103 1] "},{"index":141,"matrix":"[0.911861 0.4105 0 0] [-0.4105 0.911861 0 0] [0 -0 1 0] [19066.7 -22883.8 147.103 1] "},{"index":2,"matrix":"[-0.775337 -0.631548 0 0] [0.631548 -0.775337 -0 0] [0 0 1 0] [18915 -22570.5 147.103 1] "}],"65":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.4 -22680.5 147.103 1] "},{"index":141,"matrix":"[0.889615 0.456711 0 0] [-0.456711 0.889615 0 0] [0 -0 1 0] [19075.3 -22879.9 147.103 1] "},{"index":2,"matrix":"[-0.731725 -0.6816 0 0] [0.6816 -0.731725 -0 0] [0 0 1 0] [18909.8 -22574.8 147.103 1] "}],"66":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.4 -22676.5 147.103 1] "},{"index":141,"matrix":"[0.865334 0.501196 0 0] [-0.501196 0.865334 0 0] [0 -0 1 0] [19083.6 -22875.6 147.103 1] "},{"index":2,"matrix":"[-0.684729 -0.728797 0 0] [0.728797 -0.684729 -0 0] [0 0 1 0] [18904.8 -22579.4 147.103 1] "}],"67":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.4 -22672.6 147.103 1] "},{"index":141,"matrix":"[0.839263 0.543725 0 0] [-0.543725 0.839263 0 0] [0 -0 1 0] [19091.8 -22870.9 147.103 1] "},{"index":2,"matrix":"[-0.634996 -0.772516 0 0] [0.772516 -0.634996 -0 0] [0 0 1 0] [18900.1 -22584.4 147.103 1] "}],"68":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.4 -22668.7 147.103 1] "},{"index":141,"matrix":"[0.811667 0.584121 0 0] [-0.584121 0.811667 0 0] [0 -0 1 0] [19099.7 -22865.8 147.103 1] "},{"index":2,"matrix":"[-0.583283 -0.812269 0 0] [0.812269 -0.583283 -0 0] [0 0 1 0] [18895.8 -22589.6 147.103 1] "}],"69":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.3 -22664.8 147.103 1] "},{"index":141,"matrix":"[0.782812 0.622258 0 0] [-0.622258 0.782812 0 0] [0 -0 1 0] [19107.3 -22860.3 147.103 1] "},{"index":2,"matrix":"[-0.530414 -0.847739 0 0] [0.847739 -0.530414 -0 0] [0 0 1 0] [18891.8 -22595.1 147.103 1] "}],"70":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.3 -22660.9 147.103 1] "},{"index":141,"matrix":"[0.752969 0.658056 0 0] [-0.658056 0.752969 0 0] [0 -0 1 0] [19114.7 -22854.5 147.103 1] "},{"index":2,"matrix":"[-0.477171 -0.878811 0 0] [0.878811 -0.477171 -0 0] [0 0 1 0] [18888.2 -22600.9 147.103 1] "}],"71":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.3 -22656.9 147.103 1] "},{"index":141,"matrix":"[0.722357 0.691521 0 0] [-0.691521 0.722357 0 0] [0 -0 1 0] [19121.7 -22848.3 147.103 1] "},{"index":2,"matrix":"[-0.424227 -0.905556 0 0] [0.905556 -0.424227 -0 0] [0 0 1 0] [18885 -22606.9 147.103 1] "}],"72":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.3 -22653 147.103 1] "},{"index":141,"matrix":"[0.691244 0.722621 0 0] [-0.722621 0.691244 0 0] [0 -0 1 0] [19128.5 -22841.8 147.103 1] "},{"index":2,"matrix":"[-0.372103 -0.928191 0 0] [0.928191 -0.372103 -0 0] [0 0 1 0] [18882.1 -22613.1 147.103 1] "}],"73":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.3 -22649.1 147.103 1] "},{"index":141,"matrix":"[0.659808 0.751434 0 0] [-0.751434 0.659808 0 0] [0 -0 1 0] [19135 -22835 147.103 1] "},{"index":2,"matrix":"[-0.321151 -0.947028 0 0] [0.947028 -0.321151 -0 0] [0 0 1 0] [18879.6 -22619.4 147.103 1] "}],"74":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.2 -22645.2 147.103 1] "},{"index":141,"matrix":"[0.62823 0.778027 0 0] [-0.778027 0.62823 0 0] [0 -0 1 0] [19141.3 -22827.9 147.103 1] "},{"index":2,"matrix":"[-0.271566 -0.96242 0 0] [0.96242 -0.271566 -0 0] [0 0 1 0] [18877.4 -22625.8 147.103 1] "}],"75":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.2 -22641.3 147.103 1] "},{"index":141,"matrix":"[0.596664 0.802491 0 0] [-0.802491 0.596664 0 0] [0 -0 1 0] [19147.2 -22820.6 147.103 1] "},{"index":2,"matrix":"[-0.22342 -0.974722 0 0] [0.974722 -0.22342 -0 0] [0 0 1 0] [18875.5 -22632.4 147.103 1] "}],"76":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.2 -22637.3 147.103 1] "},{"index":141,"matrix":"[0.565227 0.824936 0 0] [-0.824936 0.565227 0 0] [0 -0 1 0] [19152.8 -22813 147.103 1] "},{"index":2,"matrix":"[-0.176687 -0.984267 0 0] [0.984267 -0.176687 -0 0] [0 0 1 0] [18874 -22639 147.103 1] "}],"77":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.2 -22633.4 147.103 1] "},{"index":141,"matrix":"[0.534027 0.845468 0 0] [-0.845468 0.534027 0 0] [0 -0 1 0] [19158.1 -22805.3 147.103 1] "},{"index":2,"matrix":"[-0.131276 -0.991346 0 0] [0.991346 -0.131276 -0 0] [0 0 1 0] [18872.8 -22645.7 147.103 1] "}],"78":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.2 -22629.5 147.103 1] "},{"index":141,"matrix":"[0.503141 0.864204 0 0] [-0.864204 0.503141 0 0] [0 -0 1 0] [19163.1 -22797.3 147.103 1] "},{"index":2,"matrix":"[-0.0870326 -0.996205 0 0] [0.996205 -0.0870326 -0 0] [0 0 1 0] [18871.9 -22652.4 147.103 1] "}],"79":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.1 -22625.6 147.103 1] "},{"index":141,"matrix":"[0.472633 0.88126 0 0] [-0.88126 0.472633 0 0] [0 -0 1 0] [19167.8 -22789.2 147.103 1] "},{"index":2,"matrix":"[-0.0437215 -0.999044 0 0] [0.999044 -0.0437215 -0 0] [0 0 1 0] [18871.3 -22659.2 147.103 1] "}],"80":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.1 -22621.7 147.103 1] "},{"index":141,"matrix":"[0.442541 0.896748 0 0] [-0.896748 0.442541 0 0] [0 -0 1 0] [19172.3 -22780.9 147.103 1] "},{"index":2,"matrix":"[-0.0214401 -0.99977 0 0] [0.99977 -0.0214401 -0 0] [0 0 1 0] [18871 -22666 147.103 1] "}],"81":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.1 -22617.7 147.103 1] "},{"index":141,"matrix":"[0.412894 0.910779 0 0] [-0.910779 0.412894 0 0] [0 -0 1 0] [19176.5 -22772.5 147.103 1] "},{"index":2,"matrix":"[-0.0214401 -0.99977 0 0] [0.99977 -0.0214401 -0 0] [0 0 1 0] [18870.9 -22672.8 147.103 1] "}],"82":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998.1 -22613.8 147.103 1] "},{"index":141,"matrix":"[0.383706 0.923455 0 0] [-0.923455 0.383706 0 0] [0 -0 1 0] [19180.3 -22763.9 147.103 1] "},{"index":2,"matrix":"[-0.0214401 -0.99977 0 0] [0.99977 -0.0214401 -0 0] [0 0 1 0] [18870.7 -22679.6 147.103 1] "}],"83":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998 -22609.9 147.103 1] "},{"index":141,"matrix":"[0.354978 0.934875 0 0] [-0.934875 0.354978 0 0] [0 -0 1 0] [19184 -22755.2 147.103 1] "},{"index":2,"matrix":"[-0.0214401 -0.99977 0 0] [0.99977 -0.0214401 -0 0] [0 0 1 0] [18870.6 -22686.4 147.103 1] "}],"84":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998 -22606 147.103 1] "},{"index":141,"matrix":"[0.326699 0.945128 0 0] [-0.945128 0.326699 0 0] [0 -0 1 0] [19187.3 -22746.4 147.103 1] "},{"index":2,"matrix":"[-0.0214401 -0.99977 0 0] [0.99977 -0.0214401 -0 0] [0 0 1 0] [18870.4 -22693.2 147.103 1] "}],"85":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998 -22602.1 147.103 1] "},{"index":141,"matrix":"[0.298848 0.954301 0 0] [-0.954301 0.298848 0 0] [0 -0 1 0] [19190.4 -22737.5 147.103 1] "},{"index":2,"matrix":"[-0.0214401 -0.99977 0 0] [0.99977 -0.0214401 -0 0] [0 0 1 0] [18870.3 -22700 147.103 1] "}],"86":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998 -22598.1 147.103 1] "},{"index":141,"matrix":"[0.271393 0.962469 0 0] [-0.962469 0.271393 0 0] [0 -0 1 0] [19193.2 -22728.5 147.103 1] "},{"index":2,"matrix":"[-0.0214401 -0.99977 0 0] [0.99977 -0.0214401 -0 0] [0 0 1 0] [18870.1 -22706.8 147.103 1] "}],"87":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18998 -22594.2 147.103 1] "},{"index":141,"matrix":"[0.244292 0.969702 0 0] [-0.969702 0.244292 0 0] [0 -0 1 0] [19195.7 -22719.5 147.103 1] "},{"index":2,"matrix":"[-0.0214401 -0.99977 0 0] [0.99977 -0.0214401 -0 0] [0 0 1 0] [18870 -22713.6 147.103 1] "}],"88":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18997.9 -22590.3 147.103 1] "},{"index":141,"matrix":"[0.21749 0.976063 0 0] [-0.976063 0.21749 0 0] [0 -0 1 0] [19198 -22710.4 147.103 1] "},{"index":2,"matrix":"[-0.0214401 -0.99977 0 0] [0.99977 -0.0214401 -0 0] [0 0 1 0] [18869.8 -22720.4 147.103 1] "}],"89":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18997.9 -22586.4 147.103 1] "},{"index":141,"matrix":"[0.190923 0.981605 0 0] [-0.981605 0.190923 0 0] [0 -0 1 0] [19200.1 -22701.2 147.103 1] "},{"index":2,"matrix":"[-0.0214401 -0.99977 0 0] [0.99977 -0.0214401 -0 0] [0 0 1 0] [18869.7 -22727.2 147.103 1] "}],"90":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18997.9 -22582.5 147.103 1] "},{"index":141,"matrix":"[0.164513 0.986375 0 0] [-0.986375 0.164513 0 0] [0 -0 1 0] [19201.9 -22691.9 147.103 1] "},{"index":2,"matrix":"[-0.0214401 -0.99977 0 0] [0.99977 -0.0214401 -0 0] [0 0 1 0] [18869.6 -22734 147.103 1] "}],"91":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18997.9 -22578.5 147.103 1] "},{"index":141,"matrix":"[0.138169 0.990409 0 0] [-0.990409 0.138169 0 0] [0 -0 1 0] [19203.4 -22682.7 147.103 1] "},{"index":2,"matrix":"[-0.0214401 -0.99977 0 0] [0.99977 -0.0214401 -0 0] [0 0 1 0] [18869.4 -22740.8 147.103 1] "}],"92":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18997.9 -22574.6 147.103 1] "},{"index":141,"matrix":"[0.111776 0.993733 0 0] [-0.993733 0.111776 0 0] [0 -0 1 0] [19204.7 -22673.4 147.103 1] "},{"index":2,"matrix":"[-0.0214401 -0.99977 0 0] [0.99977 -0.0214401 -0 0] [0 0 1 0] [18869.3 -22747.6 147.103 1] "}],"93":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18997.8 -22570.7 147.103 1] "},{"index":141,"matrix":"[0.0851806 0.996366 0 0] [-0.996366 0.0851806 0 0] [0 -0 1 0] [19205.8 -22664 147.103 1] "},{"index":2,"matrix":"[-0.0214401 -0.99977 0 0] [0.99977 -0.0214401 -0 0] [0 0 1 0] [18869.1 -22754.4 147.103 1] "}],"94":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18997.8 -22566.8 147.103 1] "},{"index":141,"matrix":"[0.0582103 0.998304 0 0] [-0.998304 0.0582103 0 0] [0 -0 1 0] [19206.6 -22654.6 147.103 1] "},{"index":2,"matrix":"[-0.0214401 -0.99977 0 0] [0.99977 -0.0214401 -0 0] [0 0 1 0] [18869 -22761.2 147.103 1] "}],"95":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18997.8 -22562.9 147.103 1] "},{"index":141,"matrix":"[0.0306518 0.99953 0 0] [-0.99953 0.0306518 0 0] [0 -0 1 0] [19207.1 -22645.2 147.103 1] "},{"index":2,"matrix":"[-0.0214401 -0.99977 0 0] [0.99977 -0.0214401 -0 0] [0 0 1 0] [18868.8 -22768 147.103 1] "}],"96":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18997.8 -22558.9 147.103 1] "},{"index":141,"matrix":"[0.00217772 0.999998 0 0] [-0.999998 0.00217772 0 0] [0 -0 1 0] [19207.4 -22635.8 147.103 1] "},{"index":2,"matrix":"[-0.0214401 -0.99977 0 0] [0.99977 -0.0214401 -0 0] [0 0 1 0] [18868.8 -22768 147.103 1] "}],"97":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18997.8 -22555 147.103 1] "},{"index":141,"matrix":"[-0.0275859 0.999619 0 0] [-0.999619 -0.0275859 0 0] [0 -0 1 0] [19207.4 -22626.4 147.103 1] "},{"index":2,"matrix":"[-0.0214401 -0.99977 0 0] [0.99977 -0.0214401 -0 0] [0 0 1 0] [18868.8 -22768 147.103 1] "}],"98":[{"index":113,"matrix":"[-0.00531441 0.999986 0 0] [-0.999986 -0.00531441 0 0] [0 -0 1 0] [18997.7 -22551.1 147.103 1] "},{"index":141,"matrix":"[-0.0591986 0.998246 0 0] [-0.998246 -0.0591986 0 0] [0 -0 1 0] [19207.2 -22617 147.103 1] "},{"index":2,"matrix":"[-0.0214401 -0.99977 0 0] [0.99977 -0.0214401 -0 0] [0 0 1 0] [18868.8 -22768 147.103 1] "}]}
================================================
FILE: dataset/traj_vis/Hemi12_transforms.json
================================================
{
"C_01_35mm": "[-0.8622445326446021 -0.497817113029644 -0.09334070869305826 0] [0.49999999999999994 -0.8660254037844387 0.0 0] [-0.08083542493543144 -0.04667035434652912 0.9956342260592881 0] [692.820323027551 399.99999999999994 0.0 1]",
"C_02_35mm": "[-0.49781711302964426 -0.862244532644602 -0.09334070869305827 0] [0.8660254037844386 -0.5000000000000002 0.0 0] [-0.04667035434652916 -0.08083542493543144 0.9956342260592881 0] [400.0000000000001 692.8203230275509 0.0 1]",
"C_03_35mm": "[-1.6011019497192098e-16 -0.9956342260592881 -0.09334070869305827 0] [1.0 -1.6081226496766366e-16 0.0 0] [-1.5010330778617594e-17 -0.09334070869305827 0.9956342260592881 0] [4.898587196589413e-14 800.0 0.0 1]",
"C_04_35mm": "[0.49781711302964377 -0.8622445326446022 -0.09334070869305827 0] [0.8660254037844388 0.4999999999999997 0.0 0] [0.04667035434652911 -0.08083542493543147 0.9956342260592881 0] [-399.99999999999983 692.820323027551 0.0 1]",
"C_05_35mm": "[0.8622445326446021 -0.4978171130296439 -0.09334070869305826 0] [0.49999999999999983 0.8660254037844387 0.0 0] [0.08083542493543144 -0.046670354346529115 0.9956342260592881 0] [-692.820323027551 399.99999999999994 0.0 1]",
"C_06_35mm": "[0.9956342260592881 -1.2193002680650596e-16 -0.09334070869305827 0] [1.2246467991473532e-16 1.0 0.0 0] [0.09334070869305827 -1.1430940013109933e-17 0.9956342260592881 0] [-800.0 9.797174393178826e-14 0.0 1]",
"C_07_35mm": "[0.862244532644602 0.49781711302964415 -0.09334070869305827 0] [-0.5000000000000001 0.8660254037844386 0.0 0] [0.08083542493543144 0.04667035434652914 0.9956342260592881 0] [-692.8203230275509 -400.0000000000001 0.0 1]",
"C_08_35mm": "[0.4978171130296444 0.8622445326446019 -0.09334070869305827 0] [-0.8660254037844385 0.5000000000000003 0.0 0] [0.046670354346529164 0.08083542493543144 0.9956342260592881 0] [-400.00000000000034 -692.8203230275508 0.0 1]",
"C_09_35mm": "[2.820402217784269e-16 0.9956342260592881 -0.09334070869305827 0] [-1.0 2.83276944882399e-16 0.0 0] [2.6441270791727528e-17 0.09334070869305827 0.9956342260592881 0] [-1.4695761589768238e-13 -800.0 0.0 1]",
"C_10_35mm": "[-0.49781711302964426 0.862244532644602 -0.09334070869305827 0] [-0.8660254037844386 -0.5000000000000002 0.0 0] [-0.04667035434652916 0.08083542493543144 0.9956342260592881 0] [400.0000000000001 -692.8203230275509 0.0 1]",
"C_11_35mm": "[-0.8622445326446019 0.4978171130296444 -0.09334070869305827 0] [-0.5000000000000003 -0.8660254037844385 0.0 0] [-0.08083542493543144 0.046670354346529164 0.9956342260592881 0] [692.8203230275507 -400.00000000000034 0.0 1]",
"C_12_35mm": "[-0.9956342260592881 1.2193002680650596e-16 -0.09334070869305827 0] [-1.2246467991473532e-16 -1.0 0.0 0] [-0.09334070869305827 1.1430940013109933e-17 0.9956342260592881 0] [800.0 -1.9594348786357651e-13 0.0 1]"
}
================================================
FILE: dataset/traj_vis/location_data_desert.json
================================================
[
{
"name": "loc1",
"coordinates": {
"CameraRig_Rail": {
"position": [
0,
0,
0
],
"rotation": [
0,
0,
0
],
"scale": [
1,
1,
1
]
},
"CameraTarget": {
"position": [
19000.0,
-22700.0,
0.0
],
"rotation": [
0.0,
0.0,
0.0
],
"scale": [
1.0,
1.0,
1.0
]
},
"CameraComponent": {
"position": [
0.0,
0.0,
0.0
],
"rotation": [
0.0,
0.0,
0.0
],
"scale": [
1.0,
1.0,
1.0
]
}
},
"Height": "H10"
},
{
"name": "loc2",
"coordinates": {
"CameraRig_Rail": {
"position": [
0,
0,
0
],
"rotation": [
0,
0,
0
],
"scale": [
1,
1,
1
]
},
"CameraTarget": {
"position": [
-200.0,
-11500.0,
130.0
],
"rotation": [
0.0,
0.0,
0.0
],
"scale": [
1.0,
1.0,
1.0
]
},
"CameraComponent": {
"position": [
0.0,
0.0,
0.0
],
"rotation": [
0.0,
0.0,
0.0
],
"scale": [
1.0,
1.0,
1.0
]
}
},
"Height": "H10"
},
{
"name": "loc3",
"coordinates": {
"CameraRig_Rail": {
"position": [
0,
0,
0
],
"rotation": [
0,
0,
0
],
"scale": [
1,
1,
1
]
},
"CameraTarget": {
"position": [
-22500.0,
-12900.0,
20.0
],
"rotation": [
0.0,
0.0,
0.0
],
"scale": [
1.0,
1.0,
1.0
]
},
"CameraComponent": {
"position": [
0.0,
0.0,
0.0
],
"rotation": [
0.0,
0.0,
0.0
],
"scale": [
1.0,
1.0,
1.0
]
}
},
"Height": "H10"
},
{
"name": "loc4",
"coordinates": {
"CameraRig_Rail": {
"position": [
0,
0,
0
],
"rotation": [
0,
0,
0
],
"scale": [
1,
1,
1
]
},
"CameraTarget": {
"position": [
-22000.0,
6600.0,
40.0
],
"rotation": [
0.0,
0.0,
0.0
],
"scale": [
1.0,
1.0,
1.0
]
},
"CameraComponent": {
"position": [
0.0,
0.0,
0.0
],
"rotation": [
0.0,
0.0,
0.0
],
"scale": [
1.0,
1.0,
1.0
]
}
},
"Height": "H10"
},
{
"name": "loc5",
"coordinates": {
"CameraRig_Rail": {
"position": [
0,
0,
0
],
"rotation": [
0,
0,
0
],
"scale": [
1,
1,
1
]
},
"CameraTarget": {
"position": [
1300.0,
28000.0,
20.0
],
"rotation": [
0.0,
0.0,
0.0
],
"scale": [
1.0,
1.0,
1.0
]
},
"CameraComponent": {
"position": [
0.0,
0.0,
0.0
],
"rotation": [
0.0,
0.0,
0.0
],
"scale": [
1.0,
1.0,
1.0
]
}
},
"Height": "H10"
}
]
================================================
FILE: dataset/utils.py
================================================
# Copyright 2024 Xiao Fu, CUHK, Kuaishou Tech. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# If you find this code useful, we kindly ask you to cite our paper in your work.
# More information about the method can be found at http://fuxiao0719.github.io/projects/3dtrajmaster
# --------------------------------------------------------------------------
import os
import numpy as np
from io import BytesIO
import imageio.v2 as imageio
import open3d as o3d
import math
import trimesh
import json
def get_camera_frustum(img_size, K, W2C, frustum_length=0.5, color=[0., 1., 0.]):
W, H = img_size
hfov = np.rad2deg(np.arctan(W / 2. / K[0, 0]) * 2.)
vfov = np.rad2deg(np.arctan(H / 2. / K[1, 1]) * 2.)
half_w = frustum_length * np.tan(np.deg2rad(hfov / 2.))
half_h = frustum_length * np.tan(np.deg2rad(vfov / 2.))
# build view frustum for camera (I, 0)
frustum_points = np.array([[0., 0., 0.], # frustum origin
[-half_w, -half_h, frustum_length], # top-left image corner
[half_w, -half_h, frustum_length], # top-right image corner
[half_w, half_h, frustum_length], # bottom-right image corner
[-half_w, half_h, frustum_length]]) # bottom-left image corner
frustum_lines = np.array([[0, i] for i in range(1, 5)] + [[i, (i+1)] for i in range(1, 4)] + [[4, 1]])
frustum_colors = np.tile(np.array(color).reshape((1, 3)), (frustum_lines.shape[0], 1))
# frustum_colors = np.vstack((np.tile(np.array([[1., 0., 0.]]), (4, 1)),
# np.tile(np.array([[0., 1., 0.]]), (4, 1))))
# transform view frustum from (I, 0) to (R, t)
C2W = np.linalg.inv(W2C)
frustum_points = np.dot(np.hstack((frustum_points, np.ones_like(frustum_points[:, 0:1]))), C2W.T)
frustum_points = frustum_points[:, :3] / frustum_points[:, 3:4]
return frustum_points, frustum_lines, frustum_colors
def frustums2lineset(frustums):
N = len(frustums)
merged_points = np.zeros((N*5, 3)) # 5 vertices per frustum
merged_lines = np.zeros((N*8, 2)) # 8 lines per frustum
merged_colors = np.zeros((N*8, 3)) # each line gets a color
for i, (frustum_points, frustum_lines, frustum_colors) in enumerate(frustums):
merged_points[i*5:(i+1)*5, :] = frustum_points
merged_lines[i*8:(i+1)*8, :] = frustum_lines + i*5
merged_colors[i*8:(i+1)*8, :] = frustum_colors
lineset = o3d.geometry.LineSet()
lineset.points = o3d.utility.Vector3dVector(merged_points)
lineset.lines = o3d.utility.Vector2iVector(merged_lines)
lineset.colors = o3d.utility.Vector3dVector(merged_colors)
return lineset
def visualize_cameras(colored_camera_dicts, sphere_radius, camera_size=0.1, geometry_file=None, geometry_type='mesh'):
sphere = o3d.geometry.TriangleMesh.create_sphere(radius=sphere_radius, resolution=10)
sphere = o3d.geometry.LineSet.create_from_triangle_mesh(sphere)
sphere.paint_uniform_color((1, 0, 0))
coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0., 0., 0.])
things_to_draw = [sphere, coord_frame]
idx = 0
for color, camera_dict in colored_camera_dicts:
idx += 1
cnt = 0
frustums = []
for img_name in sorted(camera_dict.keys()):
K = np.array(camera_dict[img_name]['K']).reshape((4, 4))
W2C = np.array(camera_dict[img_name]['W2C']).reshape((4, 4))
C2W = np.linalg.inv(W2C)
img_size = camera_dict[img_name]['img_size']
frustums.append(get_camera_frustum(img_size, K, W2C, frustum_length=camera_size, color=color))
cnt += 1
cameras = frustums2lineset(frustums)
things_to_draw.append(cameras)
if geometry_file is not None:
if geometry_type == 'mesh':
geometry = o3d.io.read_triangle_mesh(geometry_file)
geometry.compute_vertex_normals()
elif geometry_type == 'pointcloud':
geometry = o3d.io.read_point_cloud(geometry_file)
else:
raise Exception('Unknown geometry_type: ', geometry_type)
things_to_draw.append(geometry)
o3d.visualization.draw_geometries(things_to_draw)
def parse_matrix(matrix_str):
rows = matrix_str.strip().split('] [')
matrix = []
for row in rows:
row = row.replace('[', '').replace(']', '')
matrix.append(list(map(float, row.split())))
return np.array(matrix)
def load_sceneposes(objs_file, obj_idx, obj_transl):
ext_poses = []
for i, key in enumerate(objs_file.keys()):
ext_poses.append(parse_matrix(objs_file[key][obj_idx]['matrix']))
ext_poses = np.stack(ext_poses)
ext_poses = np.transpose(ext_poses, (0,2,1))
ext_poses[:,:3,3] -= obj_transl
ext_poses[:,:3,3] /= 100.
ext_poses = ext_poses[:, :, [1,2,0,3]]
return ext_poses
def save_images2video(images, video_name, fps):
fps = fps
format = "mp4"
codec = "libx264"
ffmpeg_params = ["-crf", str(12)]
pixelformat = "yuv420p"
video_stream = BytesIO()
with imageio.get_writer(
video_stream,
fps=fps,
format=format,
codec=codec,
ffmpeg_params=ffmpeg_params,
pixelformat=pixelformat,
) as writer:
for idx in range(len(images)):
writer.append_data(images[idx])
video_data = video_stream.getvalue()
output_path = os.path.join(video_name + ".mp4")
with open(output_path, "wb") as f:
f.write(video_data)
def normalize(x):
return x / np.linalg.norm(x)
def viewmatrix(z, up, pos):
vec2 = normalize(z)
vec1_avg = up
vec0 = normalize(np.cross(vec1_avg, vec2))
vec1 = normalize(np.cross(vec2, vec0))
m = np.stack([vec0, vec1, vec2, pos], 1)
return m
def matrix_to_euler_angles(matrix):
sy = math.sqrt(matrix[0][0] * matrix[0][0] + matrix[1][0] * matrix[1][0])
singular = sy < 1e-6
if not singular:
x = math.atan2(matrix[2][1], matrix[2][2])
y = math.atan2(-matrix[2][0], sy)
z = math.atan2(matrix[1][0], matrix[0][0])
else:
x = math.atan2(-matrix[1][2], matrix[1][1])
y = math.atan2(-matrix[2][0], sy)
z = 0
return math.degrees(x), math.degrees(y), math.degrees(z)
def eul2rot(theta) :
R = np.array([[np.cos(theta[1])*np.cos(theta[2]), np.sin(theta[0])*np.sin(theta[1])*np.cos(theta[2]) - np.sin(theta[2])*np.cos(theta[0]), np.sin(theta[1])*np.cos(theta[0])*np.cos(theta[2]) + np.sin(theta[0])*np.sin(theta[2])],
[np.sin(theta[2])*np.cos(theta[1]), np.sin(theta[0])*np.sin(theta[1])*np.sin(theta[2]) + np.cos(theta[0])*np.cos(theta[2]), np.sin(theta[1])*np.sin(theta[2])*np.cos(theta[0]) - np.sin(theta[0])*np.cos(theta[2])],
[-np.sin(theta[1]), np.sin(theta[0])*np.cos(theta[1]), np.cos(theta[0])*np.cos(theta[1])]])
return R.T
def extract_location_rotation(data):
results = {}
for key, value in data.items():
matrix = parse_matrix(value)
location = np.array([matrix[3][0], matrix[3][1], matrix[3][2]])
rotation = eul2rot(matrix_to_euler_angles(matrix))
transofmed_matrix = np.identity(4)
transofmed_matrix[:3,3] = location
transofmed_matrix[:3,:3] = rotation
results[key] = transofmed_matrix
return results
def get_cam_points_vis(W, H, intrinsics, ext_pose, color,frustum_length):
cam = get_camera_frustum((W, H), intrinsics, np.linalg.inv(ext_pose), frustum_length=frustum_length, color=[0., 0., 1.])
cam_points = cam[0]
for item in cam[1]:
cam_points = np.concatenate((cam_points, np.linspace(cam[0][item[0]], cam[0][item[1]], num=1000, endpoint=True, retstep=False, dtype=None)))
cam_points[:,0]*=-1
cam_points = trimesh.points.PointCloud(vertices = cam_points, colors=[0, 255, 0, 255])
cam_points_vis = o3d.geometry.PointCloud()
cam_points_vis.points = o3d.utility.Vector3dVector(cam_points)
cam_points_vis.paint_uniform_color(color)
return cam_points_vis
def batch_axis_angle_to_rotation_matrix(r_batch):
batch_size = r_batch.shape[0]
rotation_matrices = []
for i in range(batch_size):
r = r_batch[i]
theta = np.linalg.norm(r)
if theta == 0:
rotation_matrices.append(np.eye(3))
else:
k = r / theta
kx, ky, kz = k
K = np.array([
[0, -kz, ky],
[kz, 0, -kx],
[-ky, kx, 0]
])
R = np.eye(3) + np.sin(theta) * K + (1 - np.cos(theta)) * np.dot(K, K)
rotation_matrices.append(R)
return np.array(rotation_matrices)
================================================
FILE: dataset/vis_trajectory.py
================================================
import trimesh
import numpy as np
import imageio
import copy
import cv2
import os
from glob import glob
import open3d
from multiprocessing import Pool
import json
from utils import *
if __name__ == '__main__' :
H = 480
W = 720
intrinsics = np.array([[1060.606,0.],
[0., 1060.606]])
cam_path = "traj_vis/Hemi12_transforms.json"
location_path = "traj_vis/location_data_desert.json"
video_name = "D_loc1_61_t3n13_003d_Hemi12_1.json"
with open(location_path, 'r') as f: locations = json.load(f)
locations_info = {locations[idx]['name']:locations[idx] for idx in range(len(locations))}
location_name = video_name.split('_')[1]
location_info = locations_info[location_name]
translation = location_info['coordinates']['CameraTarget']['position']
vis_all = []
# vis cam
with open(cam_path, 'r') as file:
data = json.load(file)
cam_poses = []
for i, key in enumerate(data.keys()):
if "C_" in key:
cam_poses.append(parse_matrix(data[key]))
cam_poses = np.stack(cam_poses)
cam_poses = np.transpose(cam_poses, (0,2,1))
cam_poses[:,:3,3] /= 100.
cam_poses = cam_poses[:,:,[1,2,0,3]]
relative_pose = np.linalg.inv(cam_poses[0])
cam_poses = relative_pose @ cam_poses
# convert right-hand coord to left-hand coord
cam_poses[:,:3,3][:,1] *= -1.
cam_poses[:,:,:2] *= -1.
cam_num = len(cam_poses)
for idx in range(cam_num):
cam_pose = cam_poses[idx]
cam_points_vis = get_cam_points_vis(W, H, intrinsics, cam_pose, [0.4, 0.4, 0.4], frustum_length=1.)
vis_all.append(cam_points_vis)
# vis gt obj poses
start_frame_ind = 10
sample_n_frames = 77
frame_indices = np.linspace(start_frame_ind, start_frame_ind + sample_n_frames - 1, sample_n_frames, dtype=int)
with open('traj_vis/'+video_name, 'r') as file:
data = json.load(file)
obj_poses = []
for i, key in enumerate(data.keys()):
obj_poses.append(parse_matrix(data[key][0]['matrix']))
obj_poses = np.stack(obj_poses)
obj_poses = np.transpose(obj_poses, (0,2,1))
obj_poses[:,:3,3] -= translation
obj_poses[:,:3,3] /= 100.
obj_poses = obj_poses[:, :, [1,2,0,3]]
obj_poses = relative_pose @ obj_poses
obj_poses = obj_poses[frame_indices]
# convert right-hand coord to left-hand coord
obj_poses[:,:3,3][:,1] *= -1.
obj_poses[:,:,:2] *= -1.
obj_num = len(obj_poses)
for idx in range(obj_num):
obj_pose = obj_poses[idx]
if idx % 5 == 0:
cam_points_vis = get_cam_points_vis(W, H, intrinsics, obj_pose, [0.8, 0., 0.], frustum_length=0.5)
vis_all.append(cam_points_vis)
if len(data[key])>=2:
with open('traj_vis/'+video_name, 'r') as file:
data = json.load(file)
obj_poses = []
for i, key in enumerate(data.keys()):
obj_poses.append(parse_matrix(data[key][1]['matrix']))
obj_poses = np.stack(obj_poses)
obj_poses = np.transpose(obj_poses, (0,2,1))
obj_poses[:,:3,3] -= translation
obj_poses[:,:3,3] /= 100.
obj_poses = obj_poses[:, :, [1,2,0,3]]
obj_poses = relative_pose @ obj_poses
obj_poses = obj_poses[frame_indices]
# convert right-hand coord to left-hand coord
obj_poses[:,:3,3][:,1] *= -1.
obj_poses[:,:,:2] *= -1.
obj_num = len(obj_poses)
for idx in range(obj_num):
obj_pose = obj_poses[idx]
if (idx % 5 == 0) :
cam_points_vis = get_cam_points_vis(W, H, intrinsics, obj_pose, [0., 0.8,0.], frustum_length=0.5)
vis_all.append(cam_points_vis)
if len(data[key])>=3:
with open('traj_vis/'+video_name, 'r') as file:
data = json.load(file)
obj_poses = []
for i, key in enumerate(data.keys()):
obj_poses.append(parse_matrix(data[key][2]['matrix']))
obj_poses = np.stack(obj_poses)
obj_poses = np.transpose(obj_poses, (0,2,1))
obj_poses[:,:3,3] -= translation
obj_poses[:,:3,3] /= 100.
obj_poses = obj_poses[:, :, [1,2,0,3]]
obj_poses = relative_pose @ obj_poses
obj_poses = obj_poses[frame_indices]
# convert right-hand coord to left-hand coord
obj_poses[:,:3,3][:,1] *= -1.
obj_poses[:,:,:2] *= -1.
obj_num = len(obj_poses)
for idx in range(obj_num):
obj_pose = obj_poses[idx]
if (idx % 5 == 0):
cam_points_vis = get_cam_points_vis(W, H, intrinsics, obj_pose, [0., 0., 0.8], frustum_length=0.5)
vis_all.append(cam_points_vis)
# vis coordinates
axis = open3d.geometry.TriangleMesh.create_coordinate_frame(size=2, origin=[0,0,0])
open3d.visualization.draw_geometries(vis_all)
================================================
FILE: eval/GVHMR/.gitignore
================================================
.vscode
.hydra
inputs
outputs
# All file or folders start with tmp will be ignored
tmp*
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
#
.DS_Store/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# torchsparse
torchsparse
# tensorboard
tensorboard
# glove
glove
================================================
FILE: eval/GVHMR/.gitmodules
================================================
[submodule "third-party/DPVO"]
path = third-party/DPVO
url = https://github.com/princeton-vl/DPVO.git
================================================
FILE: eval/GVHMR/LICENSE
================================================
Copyright 2022-2023 3D Vision Group at the State Key Lab of CAD&CG,
Zhejiang University. All Rights Reserved.
For more information see
If you use this software, please cite the corresponding publications
listed on the above website.
Permission to use, copy, modify and distribute this software and its
documentation for educational, research and non-profit purposes only.
Any modification based on this work must be open-source and prohibited
for commercial use.
You must retain, in the source form of any derivative works that you
distribute, all copyright, patent, trademark, and attribution notices
from the source form of this work.
For commercial uses of this software, please send email to xwzhou@zju.edu.cn
================================================
FILE: eval/GVHMR/README.md
================================================
# GVHMR: World-Grounded Human Motion Recovery via Gravity-View Coordinates
### [Project Page](https://zju3dv.github.io/gvhmr) | [Paper](https://arxiv.org/abs/2409.06662)
> World-Grounded Human Motion Recovery via Gravity-View Coordinates
> [Zehong Shen](https://zehongs.github.io/)\*,
[Huaijin Pi](https://phj128.github.io/)\*,
[Yan Xia](https://isshikihugh.github.io/scholar),
[Zhi Cen](https://scholar.google.com/citations?user=Xyy-uFMAAAAJ),
[Sida Peng](https://pengsida.net/)†,
[Zechen Hu](https://zju3dv.github.io/gvhmr),
[Hujun Bao](http://www.cad.zju.edu.cn/home/bao/),
[Ruizhen Hu](https://csse.szu.edu.cn/staff/ruizhenhu/),
[Xiaowei Zhou](https://xzhou.me/)
> SIGGRAPH Asia 2024
## Setup
Please see [installation](docs/INSTALL.md) for details.
## Quick Start
### [ Google Colab demo for GVHMR](https://colab.research.google.com/drive/1N9WSchizHv2bfQqkE9Wuiegw_OT7mtGj?usp=sharing)
### [ HuggingFace demo for GVHMR](https://huggingface.co/spaces/LittleFrog/GVHMR)
### Demo
Demo entries are provided in `tools/demo`. Use `-s` to skip visual odometry if you know the camera is static, otherwise the camera will be estimated by DPVO.
We also provide a script `demo_folder.py` to inference a entire folder.
```shell
python tools/demo/demo.py --video=docs/example_video/tennis.mp4 -s
python tools/demo/demo_folder.py -f inputs/demo/folder_in -d outputs/demo/folder_out -s
```
### Reproduce
1. **Test**:
To reproduce the 3DPW, RICH, and EMDB results in a single run, use the following command:
```shell
python tools/train.py global/task=gvhmr/test_3dpw_emdb_rich exp=gvhmr/mixed/mixed ckpt_path=inputs/checkpoints/gvhmr/gvhmr_siga24_release.ckpt
```
To test individual datasets, change `global/task` to `gvhmr/test_3dpw`, `gvhmr/test_rich`, or `gvhmr/test_emdb`.
2. **Train**:
To train the model, use the following command:
```shell
# The gvhmr_siga24_release.ckpt is trained with 2x4090 for 420 epochs, note that different GPU settings may lead to different results.
python tools/train.py exp=gvhmr/mixed/mixed
```
During training, note that we do not employ post-processing as in the test script, so the global metrics results will differ (but should still be good for comparison with baseline methods).
# Citation
If you find this code useful for your research, please use the following BibTeX entry.
```
@inproceedings{shen2024gvhmr,
title={World-Grounded Human Motion Recovery via Gravity-View Coordinates},
author={Shen, Zehong and Pi, Huaijin and Xia, Yan and Cen, Zhi and Peng, Sida and Hu, Zechen and Bao, Hujun and Hu, Ruizhen and Zhou, Xiaowei},
booktitle={SIGGRAPH Asia Conference Proceedings},
year={2024}
}
```
# Acknowledgement
We thank the authors of
[WHAM](https://github.com/yohanshin/WHAM),
[4D-Humans](https://github.com/shubham-goel/4D-Humans),
and [ViTPose-Pytorch](https://github.com/gpastal24/ViTPose-Pytorch) for their great works, without which our project/code would not be possible.
================================================
FILE: eval/GVHMR/docs/INSTALL.md
================================================
# Install
## Environment
```bash
git clone https://github.com/zju3dv/GVHMR --recursive
cd GVHMR
conda create -y -n gvhmr python=3.10
conda activate gvhmr
pip install -r requirements.txt
pip install -e .
# to install gvhmr in other repo as editable, try adding "python.analysis.extraPaths": ["path/to/your/package"] to settings.json
# DPVO
cd third-party/DPVO
wget https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.zip
unzip eigen-3.4.0.zip -d thirdparty && rm -rf eigen-3.4.0.zip
pip install torch-scatter -f "https://data.pyg.org/whl/torch-2.3.0+cu121.html"
pip install numba pypose
export CUDA_HOME=/usr/local/cuda-12.1/
export PATH=$PATH:/usr/local/cuda-12.1/bin/
pip install -e .
```
## Inputs & Outputs
```bash
mkdir inputs
mkdir outputs
```
**Weights**
```bash
mkdir -p inputs/checkpoints
# 1. You need to sign up for downloading [SMPL](https://smpl.is.tue.mpg.de/) and [SMPLX](https://smpl-x.is.tue.mpg.de/). And the checkpoints should be placed in the following structure:
inputs/checkpoints/
├── body_models/smplx/
│ └── SMPLX_{GENDER}.npz # SMPLX (We predict SMPLX params + evaluation)
└── body_models/smpl/
└── SMPL_{GENDER}.pkl # SMPL (rendering and evaluation)
# 2. Download other pretrained models from Google-Drive (By downloading, you agree to the corresponding licences): https://drive.google.com/drive/folders/1eebJ13FUEXrKBawHpJroW0sNSxLjh9xD?usp=drive_link
inputs/checkpoints/
├── dpvo/
│ └── dpvo.pth
├── gvhmr/
│ └── gvhmr_siga24_release.ckpt
├── hmr2/
│ └── epoch=10-step=25000.ckpt
├── vitpose/
│ └── vitpose-h-multi-coco.pth
└── yolo/
└── yolov8x.pt
```
**Data**
We provide preprocessed data for training and evaluation.
Note that we do not intend to distribute the original datasets, and you need to download them (annotation, videos, etc.) from the original websites.
*We're unable to provide the original data due to the license restrictions.*
By downloading the preprocessed data, you agree to the original dataset's terms of use and use the data for research purposes only.
You can download them from [Google-Drive](https://drive.google.com/drive/folders/10sEef1V_tULzddFxzCmDUpsIqfv7eP-P?usp=drive_link). Please place them in the "inputs" folder and execute the following commands:
```bash
cd inputs
# Train
tar -xzvf AMASS_hmr4d_support.tar.gz
tar -xzvf BEDLAM_hmr4d_support.tar.gz
tar -xzvf H36M_hmr4d_support.tar.gz
# Test
tar -xzvf 3DPW_hmr4d_support.tar.gz
tar -xzvf EMDB_hmr4d_support.tar.gz
tar -xzvf RICH_hmr4d_support.tar.gz
# The folder structure should be like this:
inputs/
├── AMASS/hmr4d_support/
├── BEDLAM/hmr4d_support/
├── H36M/hmr4d_support/
├── 3DPW/hmr4d_support/
├── EMDB/hmr4d_support/
└── RICH/hmr4d_support/
```
================================================
FILE: eval/GVHMR/download_eval_pose.sh
================================================
gdown https://drive.google.com/uc\?id\=1jMH2-ZC0ZBgtqej5Sp-E5ebBIX7mk3Xz
gdown https://drive.google.com/uc\?id\=1iFcPSlcKb_rDNJ85UPoThdl22BqR2Xgh
unzip eval_sets.zip
rm -rf eval_sets.zip
================================================
FILE: eval/GVHMR/eval.sh
================================================
python tools/demo/demo_folder.py -f eval_sets -d outputs/eval_sets_gvhmr -s
python tools/eval_pose.py -f outputs/eval_sets_gvhmr_v2
================================================
FILE: eval/GVHMR/hmr4d/__init__.py
================================================
import os
from pathlib import Path
PROJ_ROOT = Path(__file__).resolve().parents[1]
def os_chdir_to_proj_root():
"""useful for running notebooks in different directories."""
os.chdir(PROJ_ROOT)
================================================
FILE: eval/GVHMR/hmr4d/build_gvhmr.py
================================================
from omegaconf import OmegaConf
from hmr4d import PROJ_ROOT
from hydra.utils import instantiate
from hmr4d.model.gvhmr.gvhmr_pl_demo import DemoPL
def build_gvhmr_demo():
cfg = OmegaConf.load(PROJ_ROOT / "hmr4d/configs/demo_gvhmr_model/siga24_release.yaml")
gvhmr_demo_pl: DemoPL = instantiate(cfg.model, _recursive_=False)
gvhmr_demo_pl.load_pretrained_model(PROJ_ROOT / "inputs/checkpoints/gvhmr/gvhmr_siga24_release.ckpt")
return gvhmr_demo_pl.eval()
================================================
FILE: eval/GVHMR/hmr4d/configs/__init__.py
================================================
from dataclasses import dataclass
from hydra.core.config_store import ConfigStore
from hydra_zen import builds
import argparse
from hydra import compose, initialize_config_module
import os
os.environ["HYDRA_FULL_ERROR"] = "1"
MainStore = ConfigStore.instance()
def register_store_gvhmr():
"""Register group options to MainStore"""
from . import store_gvhmr
def parse_args_to_cfg():
"""
Use minimal Hydra API to parse args and return cfg.
This function don't do _run_hydra which create log file hierarchy.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--config-name", "-cn", default="train")
parser.add_argument(
"overrides",
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
args = parser.parse_args()
# Cfg
with initialize_config_module(version_base="1.3", config_module=f"hmr4d.configs"):
cfg = compose(config_name=args.config_name, overrides=args.overrides)
return cfg
================================================
FILE: eval/GVHMR/hmr4d/configs/data/mocap/testY.yaml
================================================
# definition of lightning datamodule (dataset + dataloader)
_target_: hmr4d.datamodule.mocap_trainX_testY.DataModule
dataset_opts:
test: ${test_datasets}
loader_opts:
test:
batch_size: 1
num_workers: 0
================================================
FILE: eval/GVHMR/hmr4d/configs/data/mocap/trainX_testY.yaml
================================================
# definition of lightning datamodule (dataset + dataloader)
_target_: hmr4d.datamodule.mocap_trainX_testY.DataModule
dataset_opts:
train: ${train_datasets}
val: ${test_datasets}
loader_opts:
train:
batch_size: 32
num_workers: 8
val:
batch_size: 1
num_workers: 1
limit_each_trainset: null
================================================
FILE: eval/GVHMR/hmr4d/configs/demo.yaml
================================================
defaults:
- _self_
- model: gvhmr/gvhmr_pl_demo
- network: gvhmr/relative_transformer
- endecoder: gvhmr/v1_amass_local_bedlam_cam
pipeline:
_target_: hmr4d.model.gvhmr.pipeline.gvhmr_pipeline.Pipeline
args_denoiser3d: ${network}
args:
endecoder_opt: ${endecoder}
normalize_cam_angvel: True
weights: null
static_conf: null
ckpt_path: inputs/checkpoints/gvhmr/gvhmr_siga24_release.ckpt
# ================================ #
# global setting #
# ================================ #
video_name: ???
output_root: outputs/demo
output_dir: "${output_root}/${video_name}"
preprocess_dir: ${output_dir}/preprocess
video_path: "${output_dir}/0_input_video.mp4"
# Options
static_cam: False
verbose: False
paths:
bbx: ${preprocess_dir}/bbx.pt
bbx_xyxy_video_overlay: ${preprocess_dir}/bbx_xyxy_video_overlay.mp4
vit_features: ${preprocess_dir}/vit_features.pt
vitpose: ${preprocess_dir}/vitpose.pt
vitpose_video_overlay: ${preprocess_dir}/vitpose_video_overlay.mp4
hmr4d_results: ${output_dir}/hmr4d_results.pt
incam_video: ${output_dir}/1_incam.mp4
global_video: ${output_dir}/2_global.mp4
incam_global_horiz_video: ${output_dir}/${video_name}_3_incam_global_horiz.mp4
slam: ${preprocess_dir}/slam_results.pt
================================================
FILE: eval/GVHMR/hmr4d/configs/exp/gvhmr/mixed/mixed.yaml
================================================
# @package _global_
defaults:
- override /data: mocap/trainX_testY
- override /model: gvhmr/gvhmr_pl
- override /endecoder: gvhmr/v1_amass_local_bedlam_cam
- override /optimizer: adamw_2e-4
- override /scheduler_cfg: epoch_half_200_350
- override /train_datasets:
- pure_motion_amass/v11
- imgfeat_bedlam/v2
- imgfeat_h36m/v1
- imgfeat_3dpw/v1
- override /test_datasets:
- emdb1/v1_fliptest
- emdb2/v1_fliptest
- rich/all
- 3dpw/fliptest
- override /callbacks:
- simple_ckpt_saver/every10e_top100
- prog_bar/prog_reporter_every0.1
- train_speed_timer/base
- lr_monitor/pl
- metric_emdb1
- metric_emdb2
- metric_rich
- metric_3dpw
- override /network: gvhmr/relative_transformer
exp_name_base: mixed
exp_name_var: ""
exp_name: ${exp_name_base}${exp_name_var}
data_name: mocap_mixed_v1
pipeline:
_target_: hmr4d.model.gvhmr.pipeline.gvhmr_pipeline.Pipeline
args_denoiser3d: ${network}
args:
endecoder_opt: ${endecoder}
normalize_cam_angvel: True
weights:
cr_j3d: 500.
transl_c: 1.
cr_verts: 500.
j2d: 1000.
verts2d: 1000.
transl_w: 1.
static_conf_bce: 1.
static_conf:
vel_thr: 0.15
data:
loader_opts:
train:
batch_size: 128
num_workers: 12
pl_trainer:
precision: 16-mixed
log_every_n_steps: 50
gradient_clip_val: 0.5
max_epochs: 500
check_val_every_n_epoch: 10
devices: 2
logger:
_target_: pytorch_lightning.loggers.TensorBoardLogger
save_dir: ${output_dir} # /save_dir/name/version/sub_dir
name: ""
version: "tb" # merge name and version
================================================
FILE: eval/GVHMR/hmr4d/configs/global/debug/debug_train.yaml
================================================
# @package _global_
data_name: debug
exp_name: debug
# data:
# limit_each_trainset: 40
# loader_opts:
# train:
# batch_size: 4
# num_workers: 0
# val:
# batch_size: 1
# num_workers: 0
pl_trainer:
limit_train_batches: 32
limit_val_batches: 2
check_val_every_n_epoch: 3
enable_checkpointing: False
devices: 1
callbacks:
model_checkpoint: null
================================================
FILE: eval/GVHMR/hmr4d/configs/global/debug/debug_train_limit_data.yaml
================================================
# @package _global_
data_name: debug
exp_name: debug
data:
limit_each_trainset: 40
loader_opts:
train:
batch_size: 4
num_workers: 0
val:
batch_size: 1
num_workers: 0
pl_trainer:
limit_val_batches: 2
check_val_every_n_epoch: 3
enable_checkpointing: False
devices: 1
callbacks:
model_checkpoint: null
================================================
FILE: eval/GVHMR/hmr4d/configs/global/task/gvhmr/test_3dpw.yaml
================================================
# @package _global_
defaults:
- override /data: mocap/testY
- override /test_datasets:
- 3dpw/fliptest
- override /callbacks:
- metric_3dpw
- _self_
task: test
data_name: test_mocap
ckpt_path: ??? # will not override previous setting if already set
# lightning utilities
pl_trainer:
devices: 1
logger: null
================================================
FILE: eval/GVHMR/hmr4d/configs/global/task/gvhmr/test_3dpw_emdb_rich.yaml
================================================
# @package _global_
defaults:
- override /data: mocap/testY
- override /test_datasets:
- rich/all
- emdb1/v1_fliptest
- emdb2/v1_fliptest
- 3dpw/fliptest
- override /callbacks:
- metric_rich
- metric_emdb1
- metric_emdb2
- metric_3dpw
- _self_
task: test
data_name: test_mocap
ckpt_path: ??? # will not override previous setting if already set
# lightning utilities
pl_trainer:
devices: 1
logger: null
================================================
FILE: eval/GVHMR/hmr4d/configs/global/task/gvhmr/test_emdb.yaml
================================================
# @package _global_
defaults:
- override /data: mocap/testY
- override /test_datasets:
- emdb1/v1_fliptest
- emdb2/v1_fliptest
- override /callbacks:
- metric_emdb1
- metric_emdb2
- _self_
task: test
data_name: test_mocap
ckpt_path: ??? # will not override previous setting if already set
# lightning utilities
pl_trainer:
devices: 1
logger: null
================================================
FILE: eval/GVHMR/hmr4d/configs/global/task/gvhmr/test_rich.yaml
================================================
# @package _global_
defaults:
- override /data: mocap/testY
- override /test_datasets:
- rich/all
- override /callbacks:
- metric_rich
- _self_
task: test
data_name: test_mocap
ckpt_path: ??? # will not override previous setting if already set
# lightning utilities
pl_trainer:
devices: 1
logger: null
================================================
FILE: eval/GVHMR/hmr4d/configs/hydra/default.yaml
================================================
# enable color logging
defaults:
- override hydra_logging: colorlog
- override job_logging: colorlog
job_logging:
formatters:
simple:
datefmt: '%m/%d %H:%M:%S'
format: '[%(asctime)s][%(levelname)s] %(message)s'
colorlog:
datefmt: '%m/%d %H:%M:%S'
format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] %(message)s'
handlers:
file:
filename: ${output_dir}/${hydra.job.name}.log
run:
dir: ${output_dir}
================================================
FILE: eval/GVHMR/hmr4d/configs/siga24_release.yaml
================================================
pipeline:
_target_: hmr4d.model.gvhmr.pipeline.gvhmr_pipeline.Pipeline
args_denoiser3d: ${network}
args:
endecoder_opt: ${endecoder}
normalize_cam_angvel: true
weights: null
static_conf: null
model:
_target_: hmr4d.model.gvhmr.gvhmr_pl_demo.DemoPL
pipeline: ${pipeline}
network:
_target_: hmr4d.network.gvhmr.relative_transformer.NetworkEncoderRoPEV2
output_dim: 151
max_len: 120
kp2d_mapping: linear_v2
cliffcam_dim: 3
cam_angvel_dim: 6
imgseq_dim: 1024
f_imgseq_filter: null
cond_ver: v1
latent_dim: 512
num_layers: 12
num_heads: 8
mlp_ratio: 4.0
pred_cam_ver: v2
pred_cam_dim: 3
static_conf_dim: 6
pred_coco17_dim: 0
dropout: 0.1
avgbeta: true
endecoder:
_target_: hmr4d.model.gvhmr.utils.endecoder.EnDecoder
stats_name: MM_V1_AMASS_LOCAL_BEDLAM_CAM
noise_pose_k: 10
================================================
FILE: eval/GVHMR/hmr4d/configs/store_gvhmr.py
================================================
# Dataset
import hmr4d.dataset.pure_motion.amass
import hmr4d.dataset.emdb.emdb_motion_test
import hmr4d.dataset.rich.rich_motion_test
import hmr4d.dataset.threedpw.threedpw_motion_test
import hmr4d.dataset.threedpw.threedpw_motion_train
import hmr4d.dataset.bedlam.bedlam
import hmr4d.dataset.h36m.h36m
# Trainer: Model Optimizer Loss
import hmr4d.model.gvhmr.gvhmr_pl
import hmr4d.model.gvhmr.utils.endecoder
import hmr4d.model.common_utils.optimizer
import hmr4d.model.common_utils.scheduler_cfg
# Metric
import hmr4d.model.gvhmr.callbacks.metric_emdb
import hmr4d.model.gvhmr.callbacks.metric_rich
import hmr4d.model.gvhmr.callbacks.metric_3dpw
# PL Callbacks
import hmr4d.utils.callbacks.simple_ckpt_saver
import hmr4d.utils.callbacks.train_speed_timer
import hmr4d.utils.callbacks.prog_bar
import hmr4d.utils.callbacks.lr_monitor
# Networks
import hmr4d.network.gvhmr.relative_transformer
================================================
FILE: eval/GVHMR/hmr4d/configs/train.yaml
================================================
# ================================ #
# override #
# ================================ #
# specify default configuration; the order determines the override order
defaults:
- _self_
# pytorch-lightning
- data: ???
- model: ???
- callbacks: null
# system
- hydra: default
# utility groups that changes a lot
- pipeline: null
- network: null
- optimizer: null
- scheduler_cfg: default
- train_datasets: null
- test_datasets: null
- endecoder: null # normalize/unnormalize data
- refiner: null
# global-override
- exp: ??? # set "data, model and callbacks" in yaml
- global/task: null # dump/test
- global/hsearch: null # hyper-param search
- global/debug: null # debug mode
# ================================ #
# global setting #
# ================================ #
# expirement information
task: fit # [fit, predict]
exp_name: ???
data_name: ???
# utilities in the entry file
output_dir: "outputs/${data_name}/${exp_name}"
ckpt_path: null
resume_mode: null
seed: 42
# lightning default settings
pl_trainer:
devices: 1
num_sanity_val_steps: 0 # disable sanity check
precision: 32
inference_mode: False
logger: null
================================================
FILE: eval/GVHMR/hmr4d/datamodule/mocap_trainX_testY.py
================================================
import pytorch_lightning as pl
from pytorch_lightning.utilities.combined_loader import CombinedLoader
from hydra.utils import instantiate
from torch.utils.data import DataLoader, ConcatDataset, Subset
from omegaconf import ListConfig, DictConfig
from hmr4d.utils.pylogger import Log
from numpy.random import choice
from torch.utils.data import default_collate
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
def collate_fn(batch):
"""Handle meta and Add batch size to the return dict
Args:
batch: list of dict, each dict is a data point
"""
# Assume all keys in the batch are the same
return_dict = {}
for k in batch[0].keys():
if k.startswith("meta"): # data information, do not batch
return_dict[k] = [d[k] for d in batch]
else:
return_dict[k] = default_collate([d[k] for d in batch])
return_dict["B"] = len(batch)
return return_dict
class DataModule(pl.LightningDataModule):
def __init__(self, dataset_opts: DictConfig, loader_opts: DictConfig, limit_each_trainset=None):
"""This is a general datamodule that can be used for any dataset.
Train uses ConcatDataset
Val and Test use CombinedLoader, sequential, completely consumes ecah iterable sequentially, and returns a triplet (data, idx, iterable_idx)
Args:
dataset_opts: the target of the dataset. e.g. dataset_opts.train = {_target_: ..., limit_size: None}
loader_opts: the options for the dataset
limit_each_trainset: limit the size of each dataset, None means no limit, useful for debugging
"""
super().__init__()
self.loader_opts = loader_opts
self.limit_each_trainset = limit_each_trainset
# Train uses concat dataset
if "train" in dataset_opts:
assert "train" in self.loader_opts, "train not in loader_opts"
split_opts = dataset_opts.get("train")
assert isinstance(split_opts, DictConfig), "split_opts should be a dict for each dataset"
dataset = []
dataset_num = len(split_opts)
for idx, (k, v) in enumerate(split_opts.items()):
dataset_i = instantiate(v)
if self.limit_each_trainset:
dataset_i = Subset(dataset_i, choice(len(dataset_i), self.limit_each_trainset))
dataset.append(dataset_i)
Log.info(f"[Train Dataset][{idx+1}/{dataset_num}]: name={k}, size={len(dataset[-1])}, {v._target_}")
dataset = ConcatDataset(dataset)
self.trainset = dataset
Log.info(f"[Train Dataset][All]: ConcatDataset size={len(dataset)}")
Log.info(f"")
# Val and Test use sequential dataset
for split in ("val", "test"):
if split not in dataset_opts:
continue
assert split in self.loader_opts, f"split={split} not in loader_opts"
split_opts = dataset_opts.get(split)
assert isinstance(split_opts, DictConfig), "split_opts should be a dict for each dataset"
dataset = []
dataset_num = len(split_opts)
for idx, (k, v) in enumerate(split_opts.items()):
dataset.append(instantiate(v))
dataset_type = "Val Dataset" if split == "val" else "Test Dataset"
Log.info(f"[{dataset_type}][{idx+1}/{dataset_num}]: name={k}, size={len(dataset[-1])}, {v._target_}")
setattr(self, f"{split}sets", dataset)
Log.info(f"")
def train_dataloader(self):
if hasattr(self, "trainset"):
return DataLoader(
self.trainset,
shuffle=True,
num_workers=self.loader_opts.train.num_workers,
persistent_workers=True and self.loader_opts.train.num_workers > 0,
batch_size=self.loader_opts.train.batch_size,
drop_last=True,
collate_fn=collate_fn,
)
else:
return super().train_dataloader()
def val_dataloader(self):
if hasattr(self, "valsets"):
loaders = []
for valset in self.valsets:
loaders.append(
DataLoader(
valset,
shuffle=False,
num_workers=self.loader_opts.val.num_workers,
persistent_workers=True and self.loader_opts.val.num_workers > 0,
batch_size=self.loader_opts.val.batch_size,
collate_fn=collate_fn,
)
)
return CombinedLoader(loaders, mode="sequential")
else:
return None
def test_dataloader(self):
if hasattr(self, "testsets"):
loaders = []
for testset in self.testsets:
loaders.append(
DataLoader(
testset,
shuffle=False,
num_workers=self.loader_opts.test.num_workers,
persistent_workers=False,
batch_size=self.loader_opts.test.batch_size,
collate_fn=collate_fn,
)
)
return CombinedLoader(loaders, mode="sequential")
else:
return super().test_dataloader()
================================================
FILE: eval/GVHMR/hmr4d/dataset/bedlam/bedlam.py
================================================
from pathlib import Path
import numpy as np
import torch
from hmr4d.utils.pylogger import Log
from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle
from time import time
from hmr4d.configs import MainStore, builds
from hmr4d.utils.smplx_utils import make_smplx
from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines
from hmr4d.utils.vis.renderer_utils import simple_render_mesh_background
from hmr4d.utils.video_io_utils import read_video_np, save_video
import hmr4d.utils.matrix as matrix
from hmr4d.utils.net_utils import get_valid_mask, repeat_to_max_len, repeat_to_max_len_dict
from hmr4d.dataset.imgfeat_motion.base_dataset import ImgfeatMotionDatasetBase
from hmr4d.dataset.bedlam.utils import mid2featname, mid2vname
from hmr4d.utils.geo_transform import compute_cam_angvel, apply_T_on_points
from hmr4d.utils.geo.hmr_global import get_T_w2c_from_wcparams, get_c_rootparam, get_R_c2gv
class BedlamDatasetV2(ImgfeatMotionDatasetBase):
"""mid_to_valid_range and features are newly generated."""
MIDINDEX_TO_LOAD = {
"all60": ("mid_to_valid_range_all60.pt", "imgfeats/bedlam_all60"),
"maxspan60": ("mid_to_valid_range_maxspan60.pt", "imgfeats/bedlam_maxspan60"),
}
def __init__(
self,
mid_indices=["all60", "maxspan60"],
lazy_load=True, # Load from disk when needed
random1024=False, # Faster loading for debugging
):
self.root = Path("inputs/BEDLAM/hmr4d_support")
self.min_motion_frames = 60
self.max_motion_frames = 120
self.lazy_load = lazy_load
self.random1024 = random1024
# speficify mid_index to handle
if not isinstance(mid_indices, list):
mid_indices = [mid_indices]
self.mid_indices = mid_indices
assert all([m in self.MIDINDEX_TO_LOAD for m in mid_indices])
super().__init__()
def _load_dataset(self):
Log.info(f"[BEDLAM] Loading from {self.root}")
tic = time()
# Load mid to valid range
self.mid_to_valid_range = {}
self.mid_to_imgfeat_dir = {}
for m in self.mid_indices:
fn, feat_dir = self.MIDINDEX_TO_LOAD[m]
mid_to_valid_range_ = torch.load(self.root / fn)
self.mid_to_valid_range.update(mid_to_valid_range_)
self.mid_to_imgfeat_dir.update({mid: self.root / feat_dir for mid in mid_to_valid_range_})
# Load motionfiles
Log.info(f"[BEDLAM] Start loading motion files")
if self.random1024: # Debug, faster loading
try:
Log.info(f"[BEDLAM] Loading 1024 samples for debugging ...")
self.motion_files = torch.load(self.root / "smplpose_v2_random1024.pth")
except:
Log.info(f"[BEDLAM] Not found, saving 1024 samples to disk ...")
self.motion_files = torch.load(self.root / "smplpose_v2.pth")
keys = list(self.motion_files.keys())
keys = np.random.choice(keys, 1024, replace=False)
self.motion_files = {k: self.motion_files[k] for k in keys}
torch.save(self.motion_files, self.root / "smplpose_v2_random1024.pth")
self.mid_to_valid_range = {k: v for k, v in self.mid_to_valid_range.items() if k in self.motion_files}
else:
self.motion_files = torch.load(self.root / "smplpose_v2.pth")
Log.info(f"[BEDLAM] Motion files loaded. Elapsed: {time() - tic:.2f}s")
def _get_idx2meta(self):
# sum_frame = sum([e-s for s, e in self.mid_to_valid_range.values()])
self.idx2meta = list(self.mid_to_valid_range.keys())
Log.info(f"[BEDLAM] {len(self.idx2meta)} sequences. ")
def _load_data(self, idx):
mid = self.idx2meta[idx]
# neutral smplx : "pose": (F, 63), "trans": (F, 3), "beta": (10),
# and : "skeleton": (J, 3)
data = self.motion_files[mid].copy()
# Random select a subset
range1, range2 = self.mid_to_valid_range[mid] # [range1, range2)
mlength = range2 - range1
min_motion_len = self.min_motion_frames
max_motion_len = self.max_motion_frames
if mlength < min_motion_len: # the minimal mlength is 30 when generating data
start = range1
length = mlength
else:
effect_max_motion_len = min(max_motion_len, mlength)
length = np.random.randint(min_motion_len, effect_max_motion_len + 1) # [low, high)
start = np.random.randint(range1, range2 - length + 1)
end = start + length
data["start_end"] = (start, end)
data["length"] = length
# Update data to a subset
for k, v in data.items():
if isinstance(v, torch.Tensor) and len(v.shape) > 1 and k != "skeleton":
data[k] = v[start:end]
# Load img(as feature) : {mid -> 'features', 'bbx_xys', 'img_wh', 'start_end'}
imgfeat_dir = self.mid_to_imgfeat_dir[mid]
f_img_dict = torch.load(imgfeat_dir / mid2featname(mid))
# remap (start, end)
start_mapped = start - f_img_dict["start_end"][0]
end_mapped = end - f_img_dict["start_end"][0]
data["f_imgseq"] = f_img_dict["features"][start_mapped:end_mapped].float() # (L, 1024)
data["bbx_xys"] = f_img_dict["bbx_xys"][start_mapped:end_mapped].float() # (L, 4)
data["img_wh"] = f_img_dict["img_wh"] # (2)
data["kp2d"] = torch.zeros((end - start), 17, 3) # (L, 17, 3) # do not provide kp2d
return data
def _process_data(self, data, idx):
length = data["length"]
# SMPL params in cam
body_pose = data["pose"][:, 3:] # (F, 63)
betas = data["beta"].repeat(length, 1) # (F, 10)
global_orient = data["global_orient_incam"] # (F, 3)
transl = data["trans_incam"] + data["cam_ext"][:, :3, 3] # (F, 3), bedlam convention
smpl_params_c = {"body_pose": body_pose, "betas": betas, "transl": transl, "global_orient": global_orient}
# SMPL params in world
global_orient_w = data["pose"][:, :3] # (F, 3)
transl_w = data["trans"] # (F, 3)
smpl_params_w = {"body_pose": body_pose, "betas": betas, "transl": transl_w, "global_orient": global_orient_w}
gravity_vec = torch.tensor([0, -1, 0], dtype=torch.float32) # (3), BEDLAM is ay
T_w2c = get_T_w2c_from_wcparams(
global_orient_w=global_orient_w,
transl_w=transl_w,
global_orient_c=global_orient,
transl_c=transl,
offset=data["skeleton"][0],
) # (F, 4, 4)
R_c2gv = get_R_c2gv(T_w2c[:, :3, :3], gravity_vec) # (F, 3, 3)
# cam_angvel (slightly different from WHAM)
cam_angvel = compute_cam_angvel(T_w2c[:, :3, :3]) # (F, 6)
# Returns: do not forget to make it batchable! (last lines)
max_len = self.max_motion_frames
return_data = {
"meta": {"data_name": "bedlam", "idx": idx},
"length": length,
"smpl_params_c": smpl_params_c,
"smpl_params_w": smpl_params_w,
"R_c2gv": R_c2gv, # (F, 3, 3)
"gravity_vec": gravity_vec, # (3)
"bbx_xys": data["bbx_xys"], # (F, 3)
"K_fullimg": data["cam_int"], # (F, 3, 3)
"f_imgseq": data["f_imgseq"], # (F, D)
"kp2d": data["kp2d"], # (F, 17, 3)
"cam_angvel": cam_angvel, # (F, 6)
"mask": {
"valid": get_valid_mask(max_len, length),
"vitpose": False,
"bbx_xys": True,
"f_imgseq": True,
"spv_incam_only": False,
},
}
if False: # check transformation, wis3d: sampled motion (global, incam)
wis3d = make_wis3d(name="debug-data-bedlam")
smplx = make_smplx("supermotion")
# global
smplx_out = smplx(**smpl_params_w)
w_gt_joints = smplx_out.joints
add_motion_as_lines(w_gt_joints, wis3d, name="w-gt_joints")
# incam
smplx_out = smplx(**smpl_params_c)
c_gt_joints = smplx_out.joints
add_motion_as_lines(c_gt_joints, wis3d, name="c-gt_joints")
# Check transformation works correctly
print("T_w2c", (apply_T_on_points(w_gt_joints, T_w2c) - c_gt_joints).abs().max())
R_c, t_c = get_c_rootparam(
smpl_params_w["global_orient"], smpl_params_w["transl"], T_w2c, data["skeleton"][0]
)
print("transl_c", (t_c - smpl_params_c["transl"]).abs().max())
R_diff = matrix_to_axis_angle(
(axis_angle_to_matrix(R_c) @ axis_angle_to_matrix(smpl_params_c["global_orient"]).transpose(-1, -2))
).norm(dim=-1)
print("global_orient_c", R_diff.abs().max()) # < 1e-6
skeleton_beta = smplx.get_skeleton(smpl_params_c["betas"])
print("Skeleton", (skeleton_beta[0] - data["skeleton"]).abs().max()) # (1.2e-7)
if False: # cam-overlay
smplx = make_smplx("supermotion")
# *. original bedlam param
# mid = self.idx2meta[idx]
# video_path = "-".join(mid.replace("bedlam_data/", "inputs/bedlam/").split("-")[:-1])
# npz_file = "inputs/bedlam/processed_labels/20221024_3-10_100_batch01handhair_static_highSchoolGym.npz"
# params = np.load(npz_file, allow_pickle=True)
# mid2index = {}
# for j in tqdm(range(len(params["video_name"]))):
# k = params["video_name"][j] + "-" + params["sub"][j]
# mid2index[k] = j
# betas = params['shape'][mid2index[mid]][:length]
# global_orient_incam = torch.from_numpy(params['pose_cam'][121][:, :3])
# body_pose = torch.from_numpy(params['pose_cam'][121][:, 3:66])
# transl_incam = torch.from_numpy(params["trans_cam"][121])
smplx_out = smplx(**smpl_params_c)
# ----- Render Overlay ----- #
mid = self.idx2meta[idx]
images = read_video_np(self.root / "videos" / mid2vname(mid), data["start_end"][0], data["start_end"][1])
render_dict = {
"K": data["cam_int"][:1], # only support batch-size 1
"faces": smplx.faces,
"verts": smplx_out.vertices,
"background": images,
}
img_overlay = simple_render_mesh_background(render_dict)
save_video(img_overlay, "tmp.mp4", crf=23)
# Batchable
return_data["smpl_params_c"] = repeat_to_max_len_dict(return_data["smpl_params_c"], max_len)
return_data["smpl_params_w"] = repeat_to_max_len_dict(return_data["smpl_params_w"], max_len)
return_data["R_c2gv"] = repeat_to_max_len(return_data["R_c2gv"], max_len)
return_data["bbx_xys"] = repeat_to_max_len(return_data["bbx_xys"], max_len)
return_data["K_fullimg"] = repeat_to_max_len(return_data["K_fullimg"], max_len)
return_data["f_imgseq"] = repeat_to_max_len(return_data["f_imgseq"], max_len)
return_data["kp2d"] = repeat_to_max_len(return_data["kp2d"], max_len)
return_data["cam_angvel"] = repeat_to_max_len(return_data["cam_angvel"], max_len)
return return_data
group_name = "train_datasets/imgfeat_bedlam"
MainStore.store(name="v2", node=builds(BedlamDatasetV2), group=group_name)
MainStore.store(name="v2_random1024", node=builds(BedlamDatasetV2, random1024=True), group=group_name)
================================================
FILE: eval/GVHMR/hmr4d/dataset/bedlam/utils.py
================================================
import torch
import numpy as np
from pathlib import Path
resource_dir = Path(__file__).parent / "resource"
def mid2vname(mid):
"""vname = {scene}/{seq}, Note that it ends with .mp4"""
# mid example: "inputs/bedlam/bedlam_download/20221011_1_250_batch01hand_closeup_suburb_a/mp4/seq_000001.mp4-rp_emma_posed_008"
# -> vname: 20221011_1_250_batch01hand_closeup_suburb_a/seq_000001.mp4
scene = mid.split("/")[-3]
seq = mid.split("/")[-1].split("-")[0]
vname = f"{scene}/{seq}"
return vname
def mid2featname(mid):
"""featname = {scene}/{seqsubj}, Note that it ends with .pt (extra)"""
# mid example: "inputs/bedlam/bedlam_download/20221011_1_250_batch01hand_closeup_suburb_a/mp4/seq_000001.mp4-rp_emma_posed_008"
# -> featname: 20221011_1_250_batch01hand_closeup_suburb_a/seq_000001.mp4-rp_emma_posed_008.pt
scene = mid.split("/")[-3]
seqsubj = mid.split("/")[-1]
featname = f"{scene}/{seqsubj}.pt"
return featname
def featname2mid(featname):
"""reverse func of mid2featname, Note that it removes .pt (extra)"""
# featname example: 20221011_1_250_batch01hand_closeup_suburb_a/seq_000001.mp4-rp_emma_posed_008.pt
# -> mid: inputs/bedlam/bedlam_download/20221011_1_250_batch01hand_closeup_suburb_a/mp4/seq_000001.mp4-rp_emma_posed_008
scene = featname.split("/")[0]
seqsubj = featname.split("/")[1].strip(".pt")
mid = f"inputs/bedlam/bedlam_download/{scene}/mp4/{seqsubj}"
return mid
def load_vname2lwh():
return torch.load(resource_dir / "vname2lwh.pt")
================================================
FILE: eval/GVHMR/hmr4d/dataset/emdb/emdb_motion_test.py
================================================
from pathlib import Path
import numpy as np
import torch
from torch.utils import data
from hmr4d.utils.pylogger import Log
from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines
from hmr4d.utils.geo_transform import compute_cam_angvel
from pytorch3d.transforms import quaternion_to_matrix
from hmr4d.utils.geo.hmr_cam import estimate_K, resize_K
from hmr4d.utils.geo.flip_utils import flip_kp2d_coco17
from .utils import EMDB1_NAMES, EMDB2_NAMES
VID_PRESETS = {1: EMDB1_NAMES, 2: EMDB2_NAMES}
from hmr4d.configs import MainStore, builds
class EmdbSmplFullSeqDataset(data.Dataset):
def __init__(self, split=1, flip_test=False):
"""
split: 1 for EMDB-1, 2 for EMDB-2
flip_test: if True, extra flip data will be returned
"""
super().__init__()
self.dataset_name = "EMDB"
self.split = split
self.dataset_id = f"EMDB_{split}"
Log.info(f"[{self.dataset_name}] Full sequence, split={split}")
# Load evaluation protocol from WHAM labels
tic = Log.time()
self.emdb_dir = Path("inputs/EMDB/hmr4d_support")
# 'name', 'gender', 'smpl_params', 'mask', 'K_fullimg', 'T_w2c', 'bbx_xys', 'kp2d', 'features'
self.labels = torch.load(self.emdb_dir / "emdb_vit_v4.pt")
self.cam_traj = torch.load(self.emdb_dir / "emdb_dpvo_traj.pt") # estimated with DPVO
# Setup dataset index
self.idx2meta = []
for vid in VID_PRESETS[split]:
seq_length = len(self.labels[vid]["mask"])
self.idx2meta.append((vid, 0, seq_length)) # start=0, end=seq_length
Log.info(f"[{self.dataset_name}] {len(self.idx2meta)} sequences. Elapsed: {Log.time() - tic:.2f}s")
# If flip_test is enabled, we will return extra data for flipped test
self.flip_test = flip_test
if self.flip_test:
Log.info(f"[{self.dataset_name}] Flip test enabled")
def __len__(self):
return len(self.idx2meta)
def _load_data(self, idx):
data = {}
# [vid, start, end]
vid, start, end = self.idx2meta[idx]
length = end - start
meta = {"dataset_id": self.dataset_id, "vid": vid, "vid-start-end": (start, end)}
data.update({"meta": meta, "length": length})
label = self.labels[vid]
# smpl_params in world
gender = label["gender"]
smpl_params = label["smpl_params"]
mask = label["mask"]
data.update({"smpl_params": smpl_params, "gender": gender, "mask": mask})
# camera
# K_fullimg = label["K_fullimg"] # We use estimated K
width_height = (1440, 1920) if vid != "P0_09_outdoor_walk" else (720, 960)
K_fullimg = estimate_K(*width_height)
T_w2c = label["T_w2c"]
data.update({"K_fullimg": K_fullimg, "T_w2c": T_w2c})
# R_w2c -> cam_angvel
use_DPVO = False
if use_DPVO:
traj = self.cam_traj[data["meta"]["vid"]] # (L, 7)
R_w2c = quaternion_to_matrix(traj[:, [6, 3, 4, 5]]).mT # (L, 3, 3)
else: # GT
R_w2c = data["T_w2c"][:, :3, :3] # (L, 3, 3)
data["cam_angvel"] = compute_cam_angvel(R_w2c) # (L, 6)
# image bbx, features
bbx_xys = label["bbx_xys"]
f_imgseq = label["features"]
kp2d = label["kp2d"]
data.update({"bbx_xys": bbx_xys, "f_imgseq": f_imgseq, "kp2d": kp2d})
# to render a video
video_path = self.emdb_dir / f"videos/{vid}.mp4"
frame_id = torch.where(mask)[0].long()
resize_factor = 0.5
width_height_render = torch.tensor(width_height) * resize_factor
K_render = resize_K(K_fullimg, resize_factor)
bbx_xys_render = bbx_xys * resize_factor
data["meta_render"] = {
"split": self.split,
"name": vid,
"video_path": str(video_path),
"resize_factor": resize_factor,
"frame_id": frame_id,
"width_height": width_height_render.int(),
"K": K_render,
"bbx_xys": bbx_xys_render,
"R_cam_type": "DPVO" if use_DPVO else "GtGyro",
}
# if enable flip_test
if self.flip_test:
imgfeat_dir = self.emdb_dir / "imgfeats/emdb_flip"
f_img_dict = torch.load(imgfeat_dir / f"{vid}.pt")
flipped_bbx_xys = f_img_dict["bbx_xys"].float() # (L, 3)
flipped_features = f_img_dict["features"].float() # (L, 1024)
width = width_height[0]
flipped_kp2d = flip_kp2d_coco17(kp2d, width) # (L, 17, 3)
R_flip_x = torch.tensor([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]).float()
flipped_R_w2c = R_flip_x @ R_w2c.clone()
data_flip = {
"bbx_xys": flipped_bbx_xys,
"f_imgseq": flipped_features,
"kp2d": flipped_kp2d,
"cam_angvel": compute_cam_angvel(flipped_R_w2c),
}
data["flip_test"] = data_flip
return data
def _process_data(self, data):
length = data["length"]
data["K_fullimg"] = data["K_fullimg"][None].repeat(length, 1, 1)
return data
def __getitem__(self, idx):
data = self._load_data(idx)
data = self._process_data(data)
return data
# EMDB-1 and EMDB-2
MainStore.store(
name="v1",
node=builds(EmdbSmplFullSeqDataset, populate_full_signature=True),
group="test_datasets/emdb1",
)
MainStore.store(
name="v1_fliptest",
node=builds(EmdbSmplFullSeqDataset, flip_test=True, populate_full_signature=True),
group="test_datasets/emdb1",
)
MainStore.store(
name="v1",
node=builds(EmdbSmplFullSeqDataset, split=2, populate_full_signature=True),
group="test_datasets/emdb2",
)
MainStore.store(
name="v1_fliptest",
node=builds(EmdbSmplFullSeqDataset, split=2, flip_test=True, populate_full_signature=True),
group="test_datasets/emdb2",
)
================================================
FILE: eval/GVHMR/hmr4d/dataset/emdb/utils.py
================================================
import torch
import pickle
import numpy as np
from pathlib import Path
from tqdm import tqdm
from hmr4d.utils.geo_transform import convert_lurb_to_bbx_xys
from hmr4d.utils.video_io_utils import get_video_lwh
def name_to_subfolder(name):
return f"{name[:2]}/{name[3:]}"
def name_to_local_pkl_path(name):
return f"{name_to_subfolder(name)}/{name}_data.pkl"
def load_raw_pkl(fp):
annot = pickle.load(open(fp, "rb"))
annot["subfolder"] = name_to_subfolder(annot["name"])
return annot
def load_pkl(fp):
annot = pickle.load(open(fp, "rb"))
# ['gender', 'name', 'emdb1', 'emdb2', 'n_frames', 'good_frames_mask', 'camera', 'smpl', 'kp2d', 'bboxes', 'subfolder']
data = {}
F = annot["n_frames"]
smpl_params = {
"body_pose": annot["smpl"]["poses_body"], # (F, 69)
"betas": annot["smpl"]["betas"][None].repeat(F, axis=0), # (F, 10)
"global_orient": annot["smpl"]["poses_root"], # (F, 3)
"transl": annot["smpl"]["trans"], # (F, 3)
}
smpl_params = {k: torch.from_numpy(v).float() for k, v in smpl_params.items()}
data["name"] = annot["name"]
data["gender"] = annot["gender"]
data["smpl_params"] = smpl_params
data["mask"] = torch.from_numpy(annot["good_frames_mask"]).bool() # (L,)
data["K_fullimg"] = torch.from_numpy(annot["camera"]["intrinsics"]).float() # (3, 3)
data["T_w2c"] = torch.from_numpy(annot["camera"]["extrinsics"]).float() # (L, 4, 4)
bbx_lurb = torch.from_numpy(annot["bboxes"]["bboxes"]).float()
data["bbx_xys"] = convert_lurb_to_bbx_xys(bbx_lurb) # (L, 3)
return data
EMDB1_LIST = [
"P1/14_outdoor_climb/P1_14_outdoor_climb_data.pkl",
"P2/23_outdoor_hug_tree/P2_23_outdoor_hug_tree_data.pkl",
"P3/31_outdoor_workout/P3_31_outdoor_workout_data.pkl",
"P3/32_outdoor_soccer_warmup_a/P3_32_outdoor_soccer_warmup_a_data.pkl",
"P3/33_outdoor_soccer_warmup_b/P3_33_outdoor_soccer_warmup_b_data.pkl",
"P5/42_indoor_dancing/P5_42_indoor_dancing_data.pkl",
"P5/44_indoor_rom/P5_44_indoor_rom_data.pkl",
"P6/49_outdoor_big_stairs_down/P6_49_outdoor_big_stairs_down_data.pkl", # DUPLICATE
"P6/50_outdoor_workout/P6_50_outdoor_workout_data.pkl",
"P6/51_outdoor_dancing/P6_51_outdoor_dancing_data.pkl",
"P7/57_outdoor_rock_chair/P7_57_outdoor_rock_chair_data.pkl", # DUPLICATE
"P7/59_outdoor_rom/P7_59_outdoor_rom_data.pkl",
"P7/60_outdoor_workout/P7_60_outdoor_workout_data.pkl",
"P8/64_outdoor_skateboard/P8_64_outdoor_skateboard_data.pkl", # DUPLICATE
"P8/68_outdoor_handstand/P8_68_outdoor_handstand_data.pkl",
"P8/69_outdoor_cartwheel/P8_69_outdoor_cartwheel_data.pkl",
"P9/76_outdoor_sitting/P9_76_outdoor_sitting_data.pkl",
]
EMDB1_NAMES = ["_".join(p.split("/")[:2]) for p in EMDB1_LIST]
EMDB2_LIST = [
"P0/09_outdoor_walk/P0_09_outdoor_walk_data.pkl",
"P2/19_indoor_walk_off_mvs/P2_19_indoor_walk_off_mvs_data.pkl",
"P2/20_outdoor_walk/P2_20_outdoor_walk_data.pkl",
"P2/24_outdoor_long_walk/P2_24_outdoor_long_walk_data.pkl",
"P3/27_indoor_walk_off_mvs/P3_27_indoor_walk_off_mvs_data.pkl",
"P3/28_outdoor_walk_lunges/P3_28_outdoor_walk_lunges_data.pkl",
"P3/29_outdoor_stairs_up/P3_29_outdoor_stairs_up_data.pkl",
"P3/30_outdoor_stairs_down/P3_30_outdoor_stairs_down_data.pkl",
"P4/35_indoor_walk/P4_35_indoor_walk_data.pkl",
"P4/36_outdoor_long_walk/P4_36_outdoor_long_walk_data.pkl",
"P4/37_outdoor_run_circle/P4_37_outdoor_run_circle_data.pkl",
"P5/40_indoor_walk_big_circle/P5_40_indoor_walk_big_circle_data.pkl",
"P6/48_outdoor_walk_downhill/P6_48_outdoor_walk_downhill_data.pkl",
"P6/49_outdoor_big_stairs_down/P6_49_outdoor_big_stairs_down_data.pkl", # DUPLICATE
"P7/55_outdoor_walk/P7_55_outdoor_walk_data.pkl",
"P7/56_outdoor_stairs_up_down/P7_56_outdoor_stairs_up_down_data.pkl",
"P7/57_outdoor_rock_chair/P7_57_outdoor_rock_chair_data.pkl", # DUPLICATE
"P7/58_outdoor_parcours/P7_58_outdoor_parcours_data.pkl",
"P7/61_outdoor_sit_lie_walk/P7_61_outdoor_sit_lie_walk_data.pkl",
"P8/64_outdoor_skateboard/P8_64_outdoor_skateboard_data.pkl", # DUPLICATE
"P8/65_outdoor_walk_straight/P8_65_outdoor_walk_straight_data.pkl",
"P9/77_outdoor_stairs_up/P9_77_outdoor_stairs_up_data.pkl",
"P9/78_outdoor_stairs_up_down/P9_78_outdoor_stairs_up_down_data.pkl",
"P9/79_outdoor_walk_rectangle/P9_79_outdoor_walk_rectangle_data.pkl",
"P9/80_outdoor_walk_big_circle/P9_80_outdoor_walk_big_circle_data.pkl",
]
EMDB2_NAMES = ["_".join(p.split("/")[:2]) for p in EMDB2_LIST]
EMDB_NAMES = list(sorted(set(EMDB1_NAMES + EMDB2_NAMES)))
def _check_annot(emdb_raw_dir=Path("inputs/EMDB/EMDB")):
for pkl_local_path in set(EMDB1_LIST + EMDB2_LIST):
annot = load_raw_pkl(emdb_raw_dir / pkl_local_path)
if any((annot["bboxes"]["invalid_idxs"] != np.where(~annot["good_frames_mask"])[0])):
print(annot["name"])
def _check_length(emdb_raw_dir=Path("inputs/EMDB/EMDB"), emdb_hmr4d_support_dir=Path("inputs/EMDB/hmr4d_support")):
lengths = []
for local_pkl_path in tqdm(set(EMDB1_LIST + EMDB2_LIST)):
data = load_pkl(emdb_raw_dir / local_pkl_path)
video_path = emdb_hmr4d_support_dir / "videos" / f"{data['name']}.mp4"
length, width, height = get_video_lwh(video_path)
lengths.append(length)
print(sorted(lengths))
video_ram = length[-1] * (width / 4) * (height / 4) * 3 / 1e6
print(f"Video RAM for {lengths[-1]} x {width} x {height}: {video_ram:.2f} MB")
================================================
FILE: eval/GVHMR/hmr4d/dataset/h36m/camera-parameters.json
================================================
{
"intrinsics": {
"54138969": {
"calibration_matrix": [
[
1145.04940458804,
0.0,
512.541504956548
],
[
0.0,
1143.78109572365,
515.4514869776
],
[
0.0,
0.0,
1.0
]
],
"distortion": [
-0.207098910824901,
0.247775183068982,
-0.00142447157470321,
-0.000975698859470499,
-0.00307515035078854
]
},
"55011271": {
"calibration_matrix": [
[
1149.67569986785,
0.0,
508.848621645943
],
[
0.0,
1147.59161666764,
508.064917088557
],
[
0.0,
0.0,
1.0
]
],
"distortion": [
-0.194213629607385,
0.240408539138292,
-0.0027408943961907,
-0.001619026613787,
0.00681997559022603
]
},
"58860488": {
"calibration_matrix": [
[
1149.14071676148,
0.0,
519.815837182153
],
[
0.0,
1148.7989685676,
501.402658888552
],
[
0.0,
0.0,
1.0
]
],
"distortion": [
-0.208338188251856,
0.255488007488945,
-0.000759999321030303,
0.00148438698385668,
-0.00246049749891915
]
},
"60457274": {
"calibration_matrix": [
[
1145.51133842318,
0.0,
514.968197319863
],
[
0.0,
1144.77392807652,
501.882018537695
],
[
0.0,
0.0,
1.0
]
],
"distortion": [
-0.198384093827848,
0.218323676298049,
-0.00181336200488089,
-0.000587205583421232,
-0.00894780704152122
]
}
},
"extrinsics": {
"S1": {
"54138969": {
"R": [
[
-0.9153617321513369,
0.40180836633680234,
0.02574754463350265
],
[
0.051548117060134555,
0.1803735689384521,
-0.9822464900705729
],
[
-0.399319034032262,
-0.8977836111057917,
-0.185819527201491
]
],
"t": [
[
-346.05078140028075
],
[
546.9807793144001
],
[
5474.481087434061
]
]
},
"55011271": {
"R": [
[
0.9281683400814921,
0.3721538354721445,
0.002248380248018696
],
[
0.08166409428175585,
-0.1977722953267526,
-0.976840363061605
],
[
-0.3630902204349604,
0.9068559102440475,
-0.21395758897485287
]
],
"t": [
[
251.42516271750836
],
[
420.9422103702068
],
[
5588.195881837821
]
]
},
"58860488": {
"R": [
[
-0.9141549520542256,
-0.4027780222811878,
-0.045722952682337906
],
[
-0.04562341383935875,
0.21430849526487267,
-0.9756999400261069
],
[
0.40278930937200774,
-0.889854894701693,
-0.214287280609606
]
],
"t": [
[
480.482559565337
],
[
253.83237471361554
],
[
5704.2076793704555
]
]
},
"60457274": {
"R": [
[
0.9141562410494211,
-0.40060705854636447,
0.061905989962380774
],
[
-0.05641000739510571,
-0.2769531972942539,
-0.9592261660183036
],
[
0.40141783470104664,
0.8733904688919611,
-0.2757767409202658
]
],
"t": [
[
51.88347637559197
],
[
378.4208425426766
],
[
4406.149140878431
]
]
}
},
"S2": {
"54138969": {
"R": [
[
-0.9072826056858586,
0.4200536513985309,
0.019829356183203237
],
[
0.06404223092375372,
0.18462275321422528,
-0.9807206695353717
],
[
-0.4156162485733534,
-0.8885208882982778,
-0.1944061855483302
]
],
"t": [
[
-253.9473271477662
],
[
543.369692173605
],
[
5522.981999493327
]
]
},
"55011271": {
"R": [
[
0.9195695689704942,
0.3926824530407384,
0.013867187794489123
],
[
0.09616327770610274,
-0.190692439252443,
-0.9769283584955307
],
[
-0.38097825639298405,
0.8996871037676718,
-0.21311659595137136
]
],
"t": [
[
123.3506735789221
],
[
401.02404156275884
],
[
5743.522551411228
]
]
},
"58860488": {
"R": [
[
-0.9231022562305128,
-0.3793547679556717,
-0.06302526930870815
],
[
-0.023520852900409527,
0.21928184512961552,
-0.9753779994829639
],
[
0.3838345920067314,
-0.898891223911909,
-0.2113423136836923
]
],
"t": [
[
498.7689000990772
],
[
278.0695777621727
],
[
5618.721192968872
]
]
},
"60457274": {
"R": [
[
0.9239917699501332,
-0.37272063182115767,
0.08554846392108466
],
[
-0.01857104155727153,
-0.2671779087245581,
-0.9634682566151569
],
[
0.38196115703026423,
0.8886480156419687,
-0.2537919991167828
]
],
"t": [
[
-55.1478742462578
],
[
424.8747833741909
],
[
4452.137526291175
]
]
}
},
"S3": {
"54138969": {
"R": [
[
-0.909926063968229,
0.4142842734534348,
0.020077322541766036
],
[
0.06112258570603725,
0.18181129378483157,
-0.9814319553432596
],
[
-0.41024210855042,
-0.891803338310328,
-0.19075696094942407
]
],
"t": [
[
-144.30406670344493
],
[
546.2767112872957
],
[
5569.530692348755
]
]
},
"55011271": {
"R": [
[
0.9248703521034336,
0.3800681977315835,
0.012767022876799783
],
[
0.093795468138089,
-0.1954524371286302,
-0.9762175756342618
],
[
-0.3685339088290622,
0.9040721817938792,
-0.21641683887726407
]
],
"t": [
[
-38.93379836342622
],
[
375.57502666735104
],
[
5759.402838804998
]
]
},
"58860488": {
"R": [
[
-0.9218827889823751,
-0.38260686272952316,
-0.061189149122614306
],
[
-0.02577019492115,
0.21811470471458455,
-0.9755828374059251
],
[
0.3866109419452632,
-0.897796170731164,
-0.2109360457310579
]
],
"t": [
[
596.8162203909545
],
[
282.123966506171
],
[
5575.726600786697
]
]
},
"60457274": {
"R": [
[
0.9244960445794738,
-0.37161308683612865,
0.08491629554468147
],
[
-0.018795005038688972,
-0.26693214570791374,
-0.963532032354589
],
[
0.3807280017840865,
0.88918555053481,
-0.2537621827176058
]
],
"t": [
[
-158.57266932864025
],
[
433.1881250816
],
[
4413.555688648984
]
]
}
},
"S4": {
"54138969": {
"R": [
[
-0.906169211683753,
0.422346184383899,
0.021933087625945674
],
[
0.06180306305120707,
0.18355044391174938,
-0.9810655512947585
],
[
-0.4183751201899252,
-0.8876558652294037,
-0.19243004892662768
]
],
"t": [
[
-201.25197932223173
],
[
537.4605027947064
],
[
5553.966756732112
]
]
},
"55011271": {
"R": [
[
0.9205073288493492,
0.39058428754662783,
0.010496278213208041
],
[
0.0923916650578188,
-0.1914846595009468,
-0.9771373523735801
],
[
-0.3796446203523497,
0.900431862773358,
-0.21234976510469855
]
],
"t": [
[
63.12322044876507
],
[
396.6138950755392
],
[
5760.7235858284985
]
]
},
"58860488": {
"R": [
[
-0.9244800422436603,
-0.37641653359695837,
-0.060392422769829215
],
[
-0.02551533125481826,
0.2191523935220463,
-0.9753569583924211
],
[
0.3803756292983513,
-0.9001571094250204,
-0.21220640656557746
]
],
"t": [
[
559.4298619884164
],
[
278.041710381495
],
[
5601.2846874450925
]
]
},
"60457274": {
"R": [
[
0.9241606780958346,
-0.3729066880542538,
0.0828712439019392
],
[
-0.021270464031387784,
-0.2668345784720987,
-0.9635075895349796
],
[
0.38141133756265905,
0.8886731174824772,
-0.2545299232755129
]
],
"t": [
[
-98.61477305435534
],
[
432.68486951797627
],
[
4419.390974448715
]
]
}
},
"S5": {
"54138969": {
"R": [
[
-0.9042074184788829,
0.42657831374650107,
0.020973473936051274
],
[
0.06390493744399675,
0.18368565260974637,
-0.9809055713959477
],
[
-0.4222855708380685,
-0.8856017859436166,
-0.1933503902128034
]
],
"t": [
[
-219.3059666108619
],
[
544.4787497640639
],
[
5518.740477016156
]
]
},
"55011271": {
"R": [
[
0.9222116004775194,
0.38649075753002626,
0.012274293810989732
],
[
0.09333184463870337,
-0.19167233853095322,
-0.9770111982052265
],
[
-0.3752531555110883,
0.902156643264318,
-0.21283434941998647
]
],
"t": [
[
103.90282067751986
],
[
395.67169468951965
],
[
5767.97265758172
]
]
},
"58860488": {
"R": [
[
-0.9258288614330635,
-0.3728674116124112,
-0.06173178026768599
],
[
-0.023578112500148365,
0.220000562347259,
-0.9752147584905696
],
[
0.3772068291381898,
-0.9014264506460582,
-0.21247437993123308
]
],
"t": [
[
520.3272318446208
],
[
283.3690958234795
],
[
5591.123958858676
]
]
},
"60457274": {
"R": [
[
0.9222815489764817,
-0.3772688722588351,
0.0840532119677073
],
[
-0.021177649402562934,
-0.26645871124348197,
-0.9636136478735888
],
[
0.3859381447632816,
0.88694303832152,
-0.25373962085111357
]
],
"t": [
[
-79.116431351199
],
[
425.59047114848386
],
[
4454.481629705836
]
]
}
},
"S6": {
"54138969": {
"R": [
[
-0.9149503344107554,
0.4034864343564006,
0.008036345687245266
],
[
0.07174776353922047,
0.1822275975157708,
-0.9806351824867137
],
[
-0.3971374371533952,
-0.896655898321083,
-0.19567845056940925
]
],
"t": [
[
-239.5182864132218
],
[
545.8141831785044
],
[
5523.931578633363
]
]
},
"55011271": {
"R": [
[
0.9197364689900042,
0.39209901596964664,
0.018525368698999664
],
[
0.101478073351267,
-0.19191459963948,
-0.9761511087296542
],
[
-0.37919260045353465,
0.899681692667386,
-0.21630030892357308
]
],
"t": [
[
169.02510061389722
],
[
409.6671223380997
],
[
5714.338002825065
]
]
},
"58860488": {
"R": [
[
-0.916577698818659,
-0.39393483656788014,
-0.06856140726771254
],
[
-0.01984531630322392,
0.21607069980297702,
-0.9761760169700323
],
[
0.3993638509543854,
-0.8933805444629346,
-0.20586334624209834
]
],
"t": [
[
521.9864793089763
],
[
286.28272817103516
],
[
5643.2724406159
]
]
},
"60457274": {
"R": [
[
0.9182950552949388,
-0.3850769011116475,
0.09192372735651859
],
[
-0.015534985886560007,
-0.26706146429979655,
-0.9635542737695438
],
[
0.3955917790277871,
0.8833990913037544,
-0.25122338635033875
]
],
"t": [
[
-56.29675276801464
],
[
420.29579722027506
],
[
4499.322693551688
]
]
}
},
"S7": {
"54138969": {
"R": [
[
-0.9055764231419416,
0.42392653746206904,
0.014752378956221508
],
[
0.06862812683752326,
0.18074371881263407,
-0.9811329615890764
],
[
-0.41859469903024304,
-0.8874784498483331,
-0.19277053457045695
]
],
"t": [
[
-323.9118424584857
],
[
541.7715234126381
],
[
5506.569132699328
]
]
},
"55011271": {
"R": [
[
0.9212640765077017,
0.3886011826562522,
0.01617473877914905
],
[
0.09922277503271489,
-0.1946115441987536,
-0.9758489574618522
],
[
-0.3760682680727248,
0.9006194910741931,
-0.21784671226815075
]
],
"t": [
[
178.6238708832376
],
[
403.59193467821774
],
[
5694.8801003668095
]
]
},
"58860488": {
"R": [
[
-0.9245069728829368,
-0.37555597339631824,
-0.06515034871105972
],
[
-0.018955014220249332,
0.21601110989507338,
-0.9762068980691586
],
[
0.38069353097569036,
-0.9012751584550871,
-0.20682244613440448
]
],
"t": [
[
441.1064712697594
],
[
271.91614362573955
],
[
5660.120611352617
]
]
},
"60457274": {
"R": [
[
0.9228353966173104,
-0.3744001545228767,
0.09055029013436408
],
[
-0.014982084363704698,
-0.269786590656035,
-0.9628035794752281
],
[
0.3849030629889691,
0.8871525910436372,
-0.25457791009093983
]
],
"t": [
[
25.768533743836343
],
[
431.05581759025813
],
[
4461.872981411145
]
]
}
},
"S8": {
"54138969": {
"R": [
[
-0.9115694669712032,
0.4106494283805017,
0.020202818036194434
],
[
0.060907749548984036,
0.1834736632003901,
-0.9811359034082424
],
[
-0.40660958293025334,
-0.8931430243150293,
-0.19226072190306673
]
],
"t": [
[
-82.70216069652597
],
[
552.1896311377282
],
[
5557.353609418419
]
]
},
"55011271": {
"R": [
[
0.931016282525616,
0.3647626932499711,
0.01252434769597448
],
[
0.08939715221301257,
-0.19463753190599434,
-0.9767929055586687
],
[
-0.35385990285476776,
0.9105297407479727,
-0.2138194574051759
]
],
"t": [
[
-209.06289992510443
],
[
375.0691429434037
],
[
5818.276676972416
]
]
},
"58860488": {
"R": [
[
-0.9209075762929309,
-0.3847355178017309,
-0.0625125368875214
],
[
-0.02568138180824641,
0.21992027027623712,
-0.9751797482259595
],
[
0.38893405939143305,
-0.8964450100611084,
-0.21240678280563546
]
],
"t": [
[
623.0985110132146
],
[
290.9053651845054
],
[
5534.379001592981
]
]
},
"60457274": {
"R": [
[
0.927667052235436,
-0.3636062759574404,
0.08499597802942535
],
[
-0.01666268768012713,
-0.26770413351564454,
-0.9633570738505596
],
[
0.37303645269074087,
0.8922583555131325,
-0.2543989622245125
]
],
"t": [
[
-178.36705625795474
],
[
423.4669232560848
],
[
4421.6448791590965
]
]
}
},
"S9": {
"54138969": {
"R": [
[
-0.9033486204435297,
0.4269119782787646,
0.04132109321984796
],
[
0.04153061098352977,
0.182951140059007,
-0.9822444139329296
],
[
-0.4268916470184284,
-0.8855930460167476,
-0.18299857527497945
]
],
"t": [
[
-321.2078335720134
],
[
467.13452033013084
],
[
5514.330338522134
]
]
},
"55011271": {
"R": [
[
0.9315720471487059,
0.36348288012373176,
-0.007329176497134756
],
[
0.06810069482701912,
-0.19426747906725159,
-0.9785818524481906
],
[
-0.35712157080642226,
0.911120377575769,
-0.20572758986325015
]
],
"t": [
[
19.193095609487138
],
[
404.22842728571936
],
[
5702.169280033924
]
]
},
"58860488": {
"R": [
[
-0.9269344193869241,
-0.3732303525241731,
-0.03862235247246717
],
[
-0.04725991098820678,
0.218240494552814,
-0.9747500127472326
],
[
0.37223525218497616,
-0.901704048173249,
-0.21993345934341726
]
],
"t": [
[
455.40107288876885
],
[
273.3589338272866
],
[
5657.814488280711
]
]
},
"60457274": {
"R": [
[
0.915460708083783,
-0.39734606500700814,
0.06362229623477154
],
[
-0.04940628468469528,
-0.26789167566119776,
-0.9621814117644814
],
[
0.39936288133525055,
0.8776959352388969,
-0.26487569589663096
]
],
"t": [
[
-69.271255294384
],
[
422.1843366088847
],
[
4457.893374979773
]
]
}
},
"S10": {
"54138969": {
"R": [
[
-0.9199955359932982,
0.39133749168985454,
0.021521648410310328
],
[
0.0555185840851712,
0.18448351869097226,
-0.9812662829999691
],
[
-0.3879766752957989,
-0.9015657485337887,
-0.19145051709809383
]
],
"t": [
[
-181.4625993368258
],
[
543.5199110634021
],
[
5582.194377534298
]
]
},
"55011271": {
"R": [
[
0.9152587269115653,
0.40266346194010966,
0.01279059148853104
],
[
0.09843295287698457,
-0.1927270179143742,
-0.9763028476624197
],
[
-0.3906563919867918,
0.8948287171209015,
-0.21603043863220686
]
],
"t": [
[
-22.5707386911355
],
[
383.7773845053516
],
[
5727.149101385447
]
]
},
"58860488": {
"R": [
[
-0.9117691356172892,
-0.4060874893594546,
-0.06140027948988781
],
[
-0.03165845257462336,
0.21854812554171174,
-0.9753124931029975
],
[
0.4094811176553588,
-0.887315990956964,
-0.21212153703897946
]
],
"t": [
[
579.9870562891809
],
[
276.09388439709664
],
[
5616.656671116378
]
]
},
"60457274": {
"R": [
[
0.9374925472123639,
-0.3377263929586908,
0.08395598501825681
],
[
-0.009787543064644189,
-0.2667415901136707,
-0.9637183863060765
],
[
0.3478676873784507,
0.9026570819545717,
-0.25337376437829157
]
],
"t": [
[
-72.59483557976097
],
[
445.63607020105314
],
[
4402.73689876101
]
]
}
},
"S11": {
"54138969": {
"R": [
[
-0.9059013006181885,
0.4217144115102914,
0.038727105014486805
],
[
0.044493184429779696,
0.1857199061874203,
-0.9815948619389944
],
[
-0.4211450938543295,
-0.8875049698848251,
-0.1870073216538954
]
],
"t": [
[
-234.7208032216618
],
[
464.34018262882194
],
[
5536.652631113797
]
]
},
"55011271": {
"R": [
[
0.9216646531492915,
0.3879848687925067,
-0.0014172943441045224
],
[
0.07721054863099915,
-0.18699239961454955,
-0.979322405373477
],
[
-0.3802272982247548,
0.9024974149959955,
-0.20230080971229314
]
],
"t": [
[
-11.934348472090557
],
[
449.4165893644565
],
[
5541.113551868937
]
]
},
"58860488": {
"R": [
[
-0.9063540572469627,
-0.42053101768163204,
-0.04093880896680188
],
[
-0.0603212197838846,
0.22468715090881142,
-0.9725620980997899
],
[
0.4181909532208387,
-0.8790161246439863,
-0.2290130547809762
]
],
"t": [
[
781.127357651581
],
[
235.3131620173424
],
[
5576.37044019807
]
]
},
"60457274": {
"R": [
[
0.91754082476548,
-0.39226322025776267,
0.06517975852741943
],
[
-0.04531905395586976,
-0.26600517028098103,
-0.9629057236990188
],
[
0.395050652748768,
0.8805514269006645,
-0.2618476013752581
]
],
"t": [
[
-155.13650339749012
],
[
422.16256306729633
],
[
4435.416222660868
]
]
}
}
}
}
================================================
FILE: eval/GVHMR/hmr4d/dataset/h36m/h36m.py
================================================
import torch
import numpy as np
from pathlib import Path
from hmr4d.configs import MainStore, builds
from hmr4d.utils.pylogger import Log
from hmr4d.dataset.imgfeat_motion.base_dataset import ImgfeatMotionDatasetBase
from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle
from hmr4d.utils import matrix
from hmr4d.utils.smplx_utils import make_smplx
from tqdm import tqdm
from hmr4d.utils.geo_transform import compute_cam_angvel, apply_T_on_points
from hmr4d.utils.geo.hmr_global import get_tgtcoord_rootparam, get_T_w2c_from_wcparams, get_c_rootparam, get_R_c2gv
from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines
from hmr4d.utils.vis.renderer import Renderer
import imageio
from hmr4d.utils.video_io_utils import read_video_np
from hmr4d.utils.net_utils import get_valid_mask, repeat_to_max_len, repeat_to_max_len_dict
class H36mSmplDataset(ImgfeatMotionDatasetBase):
def __init__(
self,
root="inputs/H36M/hmr4d_support",
original_coord="az",
motion_frames=120, # H36M's videos are 25fps and very long
lazy_load=False,
):
# Path
self.root = Path(root)
# Coord
self.original_coord = original_coord
# Setting
self.motion_frames = motion_frames
self.lazy_load = lazy_load
super().__init__()
def _load_dataset(self):
# smplpose
tic = Log.time()
fn = self.root / "smplxpose_v1.pt"
self.smpl_model = make_smplx("supermotion")
Log.info(f"[H36M] Loading from {fn} ...")
self.motion_files = torch.load(fn)
# Dict of {
# "smpl_params_glob": {'body_pose', 'global_orient', 'transl', 'betas'}, FxC
# "cam_Rt": tensor(F, 3),
# "cam_K": tensor(1, 10),
# }
self.seqs = list(self.motion_files.keys())
Log.info(f"[H36M] {len(self.seqs)} sequences. Elapsed: {Log.time() - tic:.2f}s")
# img(as feature)
# vid -> (features, vid, meta {bbx_xys, K_fullimg})
if not self.lazy_load:
tic = Log.time()
fn = self.root / "vitfeat_h36m.pt"
Log.info(f"[H36M] Fully Loading to RAM ViT-Feat: {fn}")
self.f_img_dicts = torch.load(fn)
Log.info(f"[H36M] Finished. Elapsed: {Log.time() - tic:.2f}s")
else:
raise NotImplementedError # "Check BEDLAM-SMPL for lazy_load"
def _get_idx2meta(self):
# We expect to see the entire sequence during one epoch,
# so each sequence will be sampled max(SeqLength // MotionFrames, 1) times
seq_lengths = []
self.idx2meta = []
for vid in self.f_img_dicts:
seq_length = self.f_img_dicts[vid]["bbx_xys"].shape[0]
num_samples = max(seq_length // self.motion_frames, 1)
seq_lengths.append(seq_length)
self.idx2meta.extend([vid] * num_samples)
hours = sum(seq_lengths) / 25 / 3600
Log.info(f"[H36M] has {hours:.1f} hours motion -> Resampled to {len(self.idx2meta)} samples.")
def _load_data(self, idx):
sampled_motion = {}
vid = self.idx2meta[idx]
motion = self.motion_files[vid]
seq_length = self.f_img_dicts[vid]["bbx_xys"].shape[0] # this is a better choice
sampled_motion["vid"] = vid
# Random select a subset
target_length = self.motion_frames
if target_length > seq_length: # this should not happen
start = 0
length = seq_length
Log.info(f"[H36M] ({idx}) target length < sequence length: {target_length} <= {seq_length}")
else:
start = np.random.randint(0, seq_length - target_length)
length = target_length
end = start + length
sampled_motion["length"] = length
sampled_motion["start_end"] = (start, end)
# Select motion subset
# body_pose, global_orient, transl, betas
sampled_motion["smpl_params_global"] = {k: v[start:end] for k, v in motion["smpl_params_glob"].items()}
# Image as feature
f_img_dict = self.f_img_dicts[vid]
sampled_motion["f_imgseq"] = f_img_dict["features"][start:end].float() # (L, 1024)
sampled_motion["bbx_xys"] = f_img_dict["bbx_xys"][start:end]
sampled_motion["K_fullimg"] = f_img_dict["K_fullimg"]
# sampled_motion["kp2d"] = self.vitpose[vid][start:end].float() # (L, 17, 3)
sampled_motion["kp2d"] = torch.zeros((end - start), 17, 3) # (L, 17, 3)
# Camera
sampled_motion["T_w2c"] = motion["cam_Rt"] # (4, 4)
return sampled_motion
def _process_data(self, data, idx):
length = data["length"]
# SMPL params in world
smpl_params_w = data["smpl_params_global"].copy() # in az
# SMPL params in cam
T_w2c = data["T_w2c"] # (4, 4)
offset = self.smpl_model.get_skeleton(smpl_params_w["betas"][0])[0] # (3)
global_orient_c, transl_c = get_c_rootparam(
smpl_params_w["global_orient"],
smpl_params_w["transl"],
T_w2c,
offset,
)
smpl_params_c = {
"body_pose": smpl_params_w["body_pose"].clone(), # (F, 63)
"betas": smpl_params_w["betas"].clone(), # (F, 10)
"global_orient": global_orient_c, # (F, 3)
"transl": transl_c, # (F, 3)
}
# World params
gravity_vec = torch.tensor([0, 0, -1]).float() # (3), H36M is az
T_w2c = T_w2c.repeat(length, 1, 1) # (F, 4, 4)
R_c2gv = get_R_c2gv(T_w2c[..., :3, :3], axis_gravity_in_w=gravity_vec) # (F, 3, 3)
# Image
bbx_xys = data["bbx_xys"] # (F, 3)
K_fullimg = data["K_fullimg"].repeat(length, 1, 1) # (F, 3, 3)
f_imgseq = data["f_imgseq"] # (F, 1024)
cam_angvel = compute_cam_angvel(T_w2c[:, :3, :3]) # (F, 6) slightly different from WHAM
# Returns: do not forget to make it batchable! (last lines)
max_len = self.motion_frames
return_data = {
"meta": {"data_name": "h36m", "idx": idx, "vid": data["vid"]},
"length": length,
"smpl_params_c": smpl_params_c,
"smpl_params_w": smpl_params_w,
"R_c2gv": R_c2gv, # (F, 3, 3)
"gravity_vec": gravity_vec, # (3)
"bbx_xys": bbx_xys, # (F, 3)
"K_fullimg": K_fullimg, # (F, 3, 3)
"f_imgseq": f_imgseq, # (F, D)
"kp2d": data["kp2d"], # (F, 17, 3)
"cam_angvel": cam_angvel, # (F, 6)
"mask": {
"valid": get_valid_mask(max_len, length),
"vitpose": False,
"bbx_xys": True,
"f_imgseq": True,
"spv_incam_only": False,
},
}
if False: # Render to image to check
smplx_out = self.smplx(**smpl_params_c)
# ----- Overlay ----- #
mid = return_data["meta"]["mid"]
video_path = self.root / f"videos/{mid}.mp4"
images = read_video_np(video_path, data["start_end"][0], data["start_end"][1])
render_dict = {
"K": K_fullimg[:1], # only support batch size 1
"faces": self.smplx.faces,
"verts": smplx_out.vertices,
"background": images,
}
img_overlay = simple_render_mesh_background(render_dict)
save_video(img_overlay, f"tmp.mp4")
# Batchable
return_data["smpl_params_c"] = repeat_to_max_len_dict(return_data["smpl_params_c"], max_len)
return_data["smpl_params_w"] = repeat_to_max_len_dict(return_data["smpl_params_w"], max_len)
return_data["R_c2gv"] = repeat_to_max_len(return_data["R_c2gv"], max_len)
return_data["bbx_xys"] = repeat_to_max_len(return_data["bbx_xys"], max_len)
return_data["K_fullimg"] = repeat_to_max_len(return_data["K_fullimg"], max_len)
return_data["f_imgseq"] = repeat_to_max_len(return_data["f_imgseq"], max_len)
return_data["kp2d"] = repeat_to_max_len(return_data["kp2d"], max_len)
return_data["cam_angvel"] = repeat_to_max_len(return_data["cam_angvel"], max_len)
return return_data
group_name = "train_datasets/imgfeat_h36m"
node_v1 = builds(H36mSmplDataset)
MainStore.store(name="v1", node=node_v1, group=group_name)
================================================
FILE: eval/GVHMR/hmr4d/dataset/h36m/utils.py
================================================
import json
import numpy as np
from pathlib import Path
from collections import defaultdict
import pickle
import torch
RESOURCE_FOLDER = Path(__file__).resolve().parent / "resource"
camera_idx_to_name = {0: "54138969", 1: "55011271", 2: "58860488", 3: "60457274"}
def get_vid(pkl_path, cam_id):
""".../S6/Posing 1.pkl, 54138969 -> S6@Posing_1@54138969"""
sub_id, fn = pkl_path.split("/")[-2:]
vid = f"{sub_id}@{fn.split('.')[0].replace(' ', '_')}@{cam_id}"
return vid
def get_raw_pkl_paths(h36m_raw_root):
smpl_param_dir = h36m_raw_root / "neutrSMPL_H3.6"
pkl_paths = []
for train_sub in ["S1", "S5", "S6", "S7", "S8"]:
for pth in (smpl_param_dir / train_sub).glob("*.pkl"):
if "aligned" not in str(pth): # Use world sequence only
pkl_paths.append(str(pth))
return pkl_paths
def get_cam_KRts():
"""
Returns:
Ks (torch.Tensor): {cam_id: 3x3}
Rts (torch.Tensor): {subj_id: {cam_id: 4x4}}
"""
# this file is copied from https://github.com/karfly/human36m-camera-parameters
cameras_path = RESOURCE_FOLDER / "camera-parameters.json"
with open(cameras_path, "r") as f:
cameras = json.load(f)
# 4 camera ids: '54138969', '55011271', '58860488', '60457274'
Ks = {}
for cam in cameras["intrinsics"]:
Ks[cam] = torch.tensor(cameras["intrinsics"][cam]["calibration_matrix"]).float()
# extrinsics
extrinsics = cameras["extrinsics"]
Rts = defaultdict(dict)
for subj in extrinsics:
for cam in extrinsics[subj]:
Rt = torch.eye(4)
Rt[:3, :3] = torch.tensor(extrinsics[subj][cam]["R"])
Rt[:3, [3]] = torch.tensor(extrinsics[subj][cam]["t"]) / 1000
Rts[subj][cam] = Rt.float()
return Ks, Rts
def parse_raw_pkl(pkl_path, to_50hz=True):
"""
raw_pkl @ 200Hz, where video @ 50Hz.
the frames should be divided by 4, and mannually align with the video.
"""
with open(str(pkl_path), "rb") as f:
data = pickle.load(f, encoding="bytes")
poses = torch.from_numpy(data[b"poses"]).float()
betas = torch.from_numpy(data[b"betas"]).float()
trans = torch.from_numpy(data[b"trans"]).float()
assert poses.shape[0] == trans.shape[0]
if to_50hz:
poses = poses[::4]
trans = trans[::4]
seq_length = poses.shape[0] # 50FPS
smpl_params = {
"body_pose": poses[:, 3:],
"betas": betas[None].expand(seq_length, -1),
"global_orient": poses[:, :3],
"transl": trans,
}
return smpl_params
================================================
FILE: eval/GVHMR/hmr4d/dataset/imgfeat_motion/base_dataset.py
================================================
import torch
from torch.utils import data
import numpy as np
from pathlib import Path
from hmr4d.utils.pylogger import Log
class ImgfeatMotionDatasetBase(data.Dataset):
def __init__(self):
super().__init__()
self._load_dataset()
self._get_idx2meta() # -> Set self.idx2meta
def __len__(self):
return len(self.idx2meta)
def _load_dataset(self):
raise NotImplemented
def _get_idx2meta(self):
raise NotImplemented
def _load_data(self, idx):
raise NotImplemented
def _process_data(self, data, idx):
raise NotImplemented
def __getitem__(self, idx):
data = self._load_data(idx)
data = self._process_data(data, idx)
return data
================================================
FILE: eval/GVHMR/hmr4d/dataset/pure_motion/amass.py
================================================
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from pathlib import Path
from hmr4d.utils.pylogger import Log
from hmr4d.configs import MainStore, builds
from .base_dataset import BaseDataset
from .utils import *
from hmr4d.utils.geo.hmr_global import get_tgtcoord_rootparam
from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines, convert_motion_as_line_mesh
class AmassDataset(BaseDataset):
def __init__(
self,
motion_frames=120,
l_factor=1.5, # speed augmentation
skip_moyo=True, # not contained in the ICCV19 released version
cam_augmentation="v11",
random1024=False, # DEBUG
limit_size=None,
):
self.root = Path("inputs/AMASS/hmr4d_support")
self.motion_frames = motion_frames
self.l_factor = l_factor
self.random1024 = random1024
self.skip_moyo = skip_moyo
self.dataset_name = "AMASS"
super().__init__(cam_augmentation, limit_size)
def _load_dataset(self):
filename = self.root / "smplxpose_v2.pth"
Log.info(f"[{self.dataset_name}] Loading from {filename} ...")
tic = Log.time()
if self.random1024: # Debug, faster loading
try:
Log.info(f"[{self.dataset_name}] Loading 1024 samples for debugging ...")
self.motion_files = torch.load(self.root / "smplxpose_v2_random1024.pth")
except:
Log.info(f"[{self.dataset_name}] Not found! Saving 1024 samples for debugging ...")
self.motion_files = torch.load(filename)
keys = list(self.motion_files.keys())
keys = np.random.choice(keys, 1024, replace=False)
self.motion_files = {k: self.motion_files[k] for k in keys}
torch.save(self.motion_files, self.root / "smplxpose_v2_random1024.pth")
else:
self.motion_files = torch.load(filename)
self.seqs = list(self.motion_files.keys())
Log.info(f"[{self.dataset_name}] {len(self.seqs)} sequences. Elapsed: {Log.time() - tic:.2f}s")
def _get_idx2meta(self):
# We expect to see the entire sequence during one epoch,
# so each sequence will be sampled max(SeqLength // MotionFrames, 1) times
seq_lengths = []
self.idx2meta = []
# Skip too-long idle-prefix
motion_start_id = {}
for vid in self.motion_files:
if self.skip_moyo and "moyo_smplxn" in vid:
continue
seq_length = self.motion_files[vid]["pose"].shape[0]
start_id = motion_start_id[vid] if vid in motion_start_id else 0
seq_length = seq_length - start_id
if seq_length < 25: # Skip clips that are too short
continue
num_samples = max(seq_length // self.motion_frames, 1)
seq_lengths.append(seq_length)
self.idx2meta.extend([(vid, start_id)] * num_samples)
hours = sum(seq_lengths) / 30 / 3600
Log.info(f"[{self.dataset_name}] has {hours:.1f} hours motion -> Resampled to {len(self.idx2meta)} samples.")
def _load_data(self, idx):
"""
- Load original data
- Augmentation: speed-augmentation to L frames
"""
# Load original data
mid, start_id = self.idx2meta[idx]
raw_data = self.motion_files[mid]
raw_len = raw_data["pose"].shape[0] - start_id
data = {
"body_pose": raw_data["pose"][start_id:, 3:], # (F, 63)
"betas": raw_data["beta"].repeat(raw_len, 1), # (10)
"global_orient": raw_data["pose"][start_id:, :3], # (F, 3)
"transl": raw_data["trans"][start_id:], # (F, 3)
}
# Get {tgt_len} frames from data
# Random select a subset with speed augmentation [start, end)
tgt_len = self.motion_frames
raw_subset_len = np.random.randint(int(tgt_len / self.l_factor), int(tgt_len * self.l_factor))
if raw_subset_len <= raw_len:
start = np.random.randint(0, raw_len - raw_subset_len + 1)
end = start + raw_subset_len
else: # interpolation will use all possible frames (results in a slow motion)
start = 0
end = raw_len
data = {k: v[start:end] for k, v in data.items()}
# Interpolation (vec + r6d)
data_interpolated = interpolate_smpl_params(data, tgt_len)
# AZ -> AY
data_interpolated["global_orient"], data_interpolated["transl"], _ = get_tgtcoord_rootparam(
data_interpolated["global_orient"],
data_interpolated["transl"],
tsf="az->ay",
)
data_interpolated["data_name"] = "amass"
return data_interpolated
group_name = "train_datasets/pure_motion_amass"
MainStore.store(name="v11", node=builds(AmassDataset, cam_augmentation="v11"), group=group_name)
================================================
FILE: eval/GVHMR/hmr4d/dataset/pure_motion/base_dataset.py
================================================
import torch
from torch.utils.data import Dataset
from pathlib import Path
from .utils import *
from .cam_traj_utils import CameraAugmentorV11
from hmr4d.utils.geo.hmr_cam import create_camera_sensor
from hmr4d.utils.geo.hmr_global import get_c_rootparam, get_R_c2gv
from hmr4d.utils.net_utils import get_valid_mask, repeat_to_max_len, repeat_to_max_len_dict
from hmr4d.utils.geo_transform import compute_cam_angvel, apply_T_on_points, project_p2d, cvt_p2d_from_i_to_c
from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines, convert_motion_as_line_mesh
from hmr4d.utils.smplx_utils import make_smplx
class BaseDataset(Dataset):
def __init__(self, cam_augmentation, limit_size=None):
super().__init__()
self.cam_augmentation = cam_augmentation
self.limit_size = limit_size
self.smplx = make_smplx("supermotion")
self.smplx_lite = make_smplx("supermotion_smpl24")
self._load_dataset()
self._get_idx2meta()
def _load_dataset(self):
NotImplementedError("_load_dataset is not implemented")
def _get_idx2meta(self):
self.idx2meta = None
NotImplementedError("_get_idx2meta is not implemented")
def __len__(self):
if self.limit_size is not None:
return min(self.limit_size, len(self.idx2meta))
return len(self.idx2meta)
def _load_data(self, idx):
NotImplementedError("_load_data is not implemented")
def _process_data(self, data, idx):
"""
Args:
data: dict {
"body_pose": (F, 63),
"betas": (F, 10),
"global_orient": (F, 3), in the AY coordinates
"transl": (F, 3), in the AY coordinates
}
"""
data_name = data["data_name"]
length = data["body_pose"].shape[0]
# Augmentation: betas, SMPL (gravity-axis)
body_pose = data["body_pose"]
betas = augment_betas(data["betas"], std=0.1)
global_orient_w, transl_w = rotate_around_axis(data["global_orient"], data["transl"], axis="y")
del data
# SMPL_params in world
smpl_params_w = {
"body_pose": body_pose, # (F, 63)
"betas": betas, # (F, 10)
"global_orient": global_orient_w, # (F, 3)
"transl": transl_w, # (F, 3)
}
# Camera trajectory augmentation
if self.cam_augmentation == "v11":
# interleave repeat to original length (faster)
N = 10
w_j3d = self.smplx_lite(
smpl_params_w["body_pose"][::N],
smpl_params_w["betas"][::N],
smpl_params_w["global_orient"][::N],
None,
)
w_j3d = w_j3d.repeat_interleave(N, dim=0) + smpl_params_w["transl"][:, None] # (F, 24, 3)
if False:
wis3d = make_wis3d(name="debug_amass")
add_motion_as_lines(w_j3d, wis3d, "w_j3d")
width, height, K_fullimg = create_camera_sensor(1000, 1000, 43.3) # WHAM
focal_length = K_fullimg[0, 0]
wham_cam_augmentor = CameraAugmentorV11()
T_w2c = wham_cam_augmentor(w_j3d, length) # (F, 4, 4)
else:
raise NotImplementedError
if False: # render
for idx_render in range(10):
T_w2c = wham_cam_augmentor(smpl_params_w["transl"])
# targets
w_j3d = self.smplx(**smpl_params_w).joints[:, :22]
c_j3d = apply_T_on_points(w_j3d, T_w2c)
verts, faces, vertex_colors = convert_motion_as_line_mesh(c_j3d)
vertex_colors = vertex_colors[None] / 255.0
bg = np.ones((height, width, 3), dtype=np.uint8) * 255
# render
renderer = Renderer(width, height, device="cuda", faces=faces, K=K_fullimg)
vname = f"{idx_render:02d}"
out_fn = Path(f"outputs/dump_render_wham_cam/{vname}.mp4")
out_fn.parent.mkdir(exist_ok=True, parents=True)
writer = imageio.get_writer(out_fn, fps=30, mode="I", format="FFMPEG", macro_block_size=1)
for i in tqdm(range(len(verts)), desc=f"Rendering {vname}"):
# incam
# img_overlay_pred = renderer.render_mesh(verts[i].cuda(), bg, [0.8, 0.8, 0.8], VI=1)
img_overlay_pred = renderer.render_mesh(verts[i].cuda(), bg, vertex_colors, VI=1)
# if batch["meta_render"][0].get("bbx_xys", None) is not None: # draw bbox lines
# bbx_xys = batch["meta_render"][0]["bbx_xys"][i].cpu().numpy()
# lu_point = (bbx_xys[:2] - bbx_xys[2:] / 2).astype(int)
# rd_point = (bbx_xys[:2] + bbx_xys[2:] / 2).astype(int)
# img_overlay_pred = cv2.rectangle(img_overlay_pred, lu_point, rd_point, (255, 178, 102), 2)
# write
writer.append_data(img_overlay_pred)
writer.close()
pass
# SMPL params in cam
offset = self.smplx.get_skeleton(smpl_params_w["betas"][0])[0] # (3)
global_orient_c, transl_c = get_c_rootparam(
smpl_params_w["global_orient"],
smpl_params_w["transl"],
T_w2c,
offset,
)
smpl_params_c = {
"body_pose": smpl_params_w["body_pose"].clone(), # (F, 63)
"betas": smpl_params_w["betas"].clone(), # (F, 10)
"global_orient": global_orient_c, # (F, 3)
"transl": transl_c, # (F, 3)
}
# World params
gravity_vec = torch.tensor([0, -1, 0], dtype=torch.float32) # (3), BEDLAM is ay
R_c2gv = get_R_c2gv(T_w2c[:, :3, :3], gravity_vec) # (F, 3, 3)
# Image
K_fullimg = K_fullimg.repeat(length, 1, 1) # (F, 3, 3)
cam_angvel = compute_cam_angvel(T_w2c[:, :3, :3]) # (F, 6)
# Returns: do not forget to make it batchable! (last lines)
# NOTE: bbx_xys and f_imgseq will be added later
max_len = length
return_data = {
"meta": {"data_name": data_name, "idx": idx, "T_w2c": T_w2c},
"length": length,
"smpl_params_c": smpl_params_c,
"smpl_params_w": smpl_params_w,
"R_c2gv": R_c2gv, # (F, 3, 3)
"gravity_vec": gravity_vec, # (3)
"bbx_xys": torch.zeros((length, 3)), # (F, 3) # NOTE: a placeholder
"K_fullimg": K_fullimg, # (F, 3, 3)
"f_imgseq": torch.zeros((length, 1024)), # (F, D) # NOTE: a placeholder
"kp2d": torch.zeros(length, 17, 3), # (F, 17, 3)
"cam_angvel": cam_angvel, # (F, 6)
"mask": {
"valid": get_valid_mask(length, length),
"vitpose": False,
"bbx_xys": False,
"f_imgseq": False,
"spv_incam_only": False,
},
}
# Batchable
return_data["smpl_params_c"] = repeat_to_max_len_dict(return_data["smpl_params_c"], max_len)
return_data["smpl_params_w"] = repeat_to_max_len_dict(return_data["smpl_params_w"], max_len)
return_data["R_c2gv"] = repeat_to_max_len(return_data["R_c2gv"], max_len)
return_data["K_fullimg"] = repeat_to_max_len(return_data["K_fullimg"], max_len)
return_data["cam_angvel"] = repeat_to_max_len(return_data["cam_angvel"], max_len)
return return_data
def __getitem__(self, idx):
data = self._load_data(idx)
data = self._process_data(data, idx)
return data
================================================
FILE: eval/GVHMR/hmr4d/dataset/pure_motion/cam_traj_utils.py
================================================
import torch
import torch.nn.functional as F
import numpy as np
from numpy.random import rand, randn
from pytorch3d.transforms import (
axis_angle_to_matrix,
matrix_to_axis_angle,
matrix_to_rotation_6d,
rotation_6d_to_matrix,
)
from einops import rearrange
from hmr4d.utils.geo.hmr_cam import create_camera_sensor
from hmr4d.utils.geo_transform import transform_mat, apply_T_on_points
from hmr4d.utils.geo.transforms import axis_rotate_to_matrix
import hmr4d.utils.matrix as matrix
halfpi = np.pi / 2
R_y_upsidedown = torch.tensor([[-1, 0, 0], [0, -1, 0], [0, 0, 1]]).float()
def noisy_interpolation(x, length, step_noise_perc=0.2):
"""Non-linear interpolation with noise, although with noise, the jittery is very small
Args:
x: (2, C)
length: scalar
step_noise_perc: [x0, x1 +-(step_noise_perc * step), x2], where step = x1-x0
"""
assert x.shape[0] == 2 and len(x.shape) == 2
dim = x.shape[-1]
output = np.zeros((length, dim))
# Use linsapce(0, 1) +- noise as reference
linspace = np.repeat(np.linspace(0, 1, length)[None], dim, axis=0) # (D, L)
noise = (linspace[0, 1] - linspace[0, 0]) * step_noise_perc
space_noise = np.random.uniform(-noise, noise, (dim, length - 2)) # (D, L-2)
linspace[:, 1:-1] = linspace[:, 1:-1] + space_noise
# Do 1d interp
for i in range(dim):
output[:, i] = np.interp(linspace[i], np.array([0.0, 1.0]), x[:, i])
return output
def noisy_impluse_interpolation(data1, data2, step_noise_perc=0.2):
"""Non-linear interpolation of impluse with noise"""
dim = data1.shape[-1]
L = data1.shape[0]
linspace1 = np.stack([np.linspace(0, 1, L // 2) for _ in range(dim)])
linspace2 = np.stack([np.linspace(0, 1, L // 2)[::-1] for _ in range(dim)])
linspace = np.concatenate([linspace1, linspace2], axis=-1)
noise = (linspace[0, 1] - linspace[0, 0]) * step_noise_perc
space_noise = np.stack([np.random.uniform(-noise, noise, L - 2) for _ in range(dim)])
linspace[:, 1:-1] = linspace[:, 1:-1] + space_noise
linspace = linspace.T
output = data1 * (1 - linspace) + data2 * linspace
return output
def create_camera(w_root, cfg):
"""Create static camera pose
Args:
w_root: (3,), y-up coordinates
Returns:
R_w2c: (3, 3)
t_w2c: (3)
"""
# Parse
pitch_std = cfg["pitch_std"]
pitch_mean = cfg["pitch_mean"]
roll_std = cfg["roll_std"]
tz_range1_prob = cfg["tz_range1_prob"]
tz_range1 = cfg["tz_range1"]
tz_range2 = cfg["tz_range2"]
f = cfg["f"]
w = cfg["w"]
# algo
yaw = rand() * 2 * np.pi # Look at any direction in xz-plane
pitch = np.clip(randn() * pitch_std + pitch_mean, -halfpi, halfpi)
roll = np.clip(randn() * roll_std, -halfpi, halfpi) # Normal-dist
# Note we use OpenCV's camera system by first applying R_y_upsidedown
yaw_rm = axis_rotate_to_matrix(yaw, axis="y")
pitch_rm = axis_rotate_to_matrix(pitch, axis="x")
roll_rm = axis_rotate_to_matrix(roll, axis="z")
R_w2c = (roll_rm @ pitch_rm @ yaw_rm @ R_y_upsidedown).squeeze(0) # (3, 3)
# Place people in the scene
if rand() < tz_range1_prob:
tz = rand() * (tz_range1[1] - tz_range1[0]) + tz_range1[0]
max_dist_in_fov = (w / 2) / f * tz
tx = (rand() * 2 - 1) * 0.7 * max_dist_in_fov
ty = (rand() * 2 - 1) * 0.5 * max_dist_in_fov
else:
tz = rand() * (tz_range2[1] - tz_range2[0]) + tz_range2[0]
max_dist_in_fov = (w / 2) / f * tz
max_dist_in_fov *= 0.9 # add a threshold
tx = torch.randn(1) * 1.6
tx = torch.clamp(tx, -max_dist_in_fov, max_dist_in_fov)
ty = torch.randn(1) * 0.8
ty = torch.clamp(ty, -max_dist_in_fov, max_dist_in_fov)
dist = torch.tensor([tx, ty, tz], dtype=torch.float)
t_w2c = dist - torch.matmul(R_w2c, w_root)
return R_w2c, t_w2c
def create_rotation_move(R, length, r_xyz_w_std=[np.pi / 8, np.pi / 4, np.pi / 8]):
"""Create rotational move for the camera
Args:
R: (3, 3)
Return:
R_move: (L, 3, 3)
"""
# Create final camera pose
assert len(R.size()) == 2
r_xyz = (2 * rand(3) - 1) * r_xyz_w_std
Rf = R @ axis_angle_to_matrix(torch.from_numpy(r_xyz).float())
# Inbetweening two poses
Rs = torch.stack((R, Rf)) # (2, 3, 3)
rs = matrix_to_rotation_6d(Rs).numpy() # (2, 6)
rs_move = noisy_interpolation(rs, length) # (L, 6)
R_move = rotation_6d_to_matrix(torch.from_numpy(rs_move).float())
return R_move
def create_translation_move(R_w2c, t_w2c, length, t_xyz_w_std=[1.0, 0.25, 1.0]):
"""Create translational move for the camera
Args:
R_w2c: (3, 3),
t_w2c: (3,),
"""
# Create subject final displacement
subj_start_final = np.array([[0, 0, 0], randn(3) * t_xyz_w_std])
subj_move = noisy_interpolation(subj_start_final, length)
subj_move = torch.from_numpy(subj_move).float() # (L, 3)
# Equal to camera move
t_move = t_w2c + torch.einsum("ij,lj->li", R_w2c, subj_move)
return t_move
class CameraAugmentorV11:
cfg_create_camera = {
"pitch_mean": np.pi / 36,
"pitch_std": np.pi / 8,
"roll_std": np.pi / 24,
"tz_range1_prob": 0.4,
"tz_range1": [1.0, 6.0], # uniform sample
"tz_range2": [4.0, 12.0],
"tx_scale": 0.7,
"ty_scale": 0.3,
}
# r_xyz_w_std = [np.pi / 8, np.pi / 4, np.pi / 8] # in world coords
r_xyz_w_std = [np.pi / 6, np.pi / 3, np.pi / 6] # in world coords
t_xyz_w_std = [1.0, 0.25, 1.0] # in world coords
r_xyz_w_std_half = [x / 2 for x in r_xyz_w_std]
t_xyz_w_std_half = [x / 2 for x in t_xyz_w_std]
t_factor = 1.0
tz_bias_factor = 1.0
rotx_impluse_noise = np.pi / 36
roty_impluse_noise = np.pi / 36
rotz_impluse_noise = np.pi / 36
rot_impluse_n = 1
tx_step_noise = 0.0025
ty_step_noise = 0.0025
tz_step_noise = 0.0025
tx_impluse_noise = 0.15
ty_impluse_noise = 0.15
tz_impluse_noise = 0.15
t_impluse_n = 1
# === Postprocess === #
height_max = 4.0
height_min = -2.0 # -1.5 -> -2.0 allow look upside
tz_post_min = 0.5
def __init__(self):
self.w = 1000
self.f = create_camera_sensor(1000, 1000, 24)[2][0, 0] # use 24mm camera
self.half_fov_tol = (self.w / 2) / self.f
def create_rotation_track(self, cam_mat, root, rx_factor=1.0, ry_factor=1.0, rz_factor=1.0):
"""Create rotational move for the camera with rotating human"""
human_mat = matrix.get_TRS(matrix.identity_mat()[None, :3, :3], root)
cam2human_mat = matrix.get_mat_BtoA(human_mat, cam_mat)
R = matrix.get_rotation(cam2human_mat)
# Create final camera pose
yaw = np.random.normal(scale=ry_factor)
pitch = np.random.normal(scale=rx_factor)
roll = np.random.normal(scale=rz_factor)
yaw_rm = axis_angle_to_matrix(torch.tensor([0, yaw, 0]).float())
pitch_rm = axis_angle_to_matrix(torch.tensor([pitch, 0, 0]).float())
roll_rm = axis_angle_to_matrix(torch.tensor([0, 0, roll]).float())
Rf = roll_rm @ pitch_rm @ yaw_rm @ R[0]
# Inbetweening two poses
Rs = torch.stack((R[0], Rf))
rs = matrix_to_rotation_6d(Rs).numpy()
rs_move = noisy_interpolation(rs, self.l)
R_move = rotation_6d_to_matrix(torch.from_numpy(rs_move).float())
R_move = torch.inverse(R_move)
return R_move
def create_translation_track(self, cam_mat, root, t_factor=1.0, tz_bias_factor=0.0):
"""Create translational move for the camera with tracking human"""
delta_T0 = matrix.get_position(cam_mat)[0] - root[0]
T_new = matrix.get_position(cam_mat)
tz_bias = delta_T0.norm(dim=-1) * tz_bias_factor * np.clip(1 + np.random.normal(scale=0.1), 0.67, 1.5)
T_new[1:] = root[1:] + delta_T0
cam_mat = matrix.get_TRS(matrix.get_rotation(cam_mat), T_new)
w2c = torch.inverse(cam_mat)
T_new = matrix.get_position(w2c)
# Create final camera position
tx = np.random.normal(scale=t_factor)
ty = np.random.normal(scale=t_factor)
tz = np.random.normal(scale=t_factor) + tz_bias
Ts = np.array([[0, 0, 0], [tx, ty, tz]])
T_move = noisy_interpolation(Ts, self.l)
T_move = torch.from_numpy(T_move).float()
return T_move + T_new
def add_stepnoise(self, R, T):
w2c = matrix.get_TRS(R, T)
cam_mat = torch.inverse(w2c)
R_new = matrix.get_rotation(cam_mat)
T_new = matrix.get_position(cam_mat)
L = R_new.shape[0]
window = 10
def add_impulse_rot(R_new):
N = np.random.randint(1, self.rot_impluse_n + 1)
rx = np.random.normal(scale=self.rotx_impluse_noise, size=N)
ry = np.random.normal(scale=self.roty_impluse_noise, size=N)
rz = np.random.normal(scale=self.rotz_impluse_noise, size=N)
R_impluse_noise = axis_angle_to_matrix(torch.from_numpy(np.array([rx, ry, rz])).float().transpose(0, 1))
R_noise = R_new.clone()
last_i = 0
for i in range(N):
n_i = np.random.randint(last_i + window, L - (N - i) * window * 2)
# make impluse smooth
window_R = R_noise[n_i - window : n_i + window].clone()
window_r = matrix_to_rotation_6d(window_R).numpy()
impluse_R = R_impluse_noise[i] @ window_R[window]
window_impluse_R = window_R.clone()
window_impluse_R[:] = impluse_R[None]
window_impluse_r = matrix_to_rotation_6d(window_impluse_R).numpy()
window_new_r = noisy_impluse_interpolation(window_r, window_impluse_r)
window_new_R = rotation_6d_to_matrix(torch.from_numpy(window_new_r).float())
R_noise[n_i - window : n_i + window] = window_new_R
last_i = n_i
R_new = R_noise
return R_new
def add_impulse_t(T_new):
N = np.random.randint(1, self.t_impluse_n + 1)
tx = np.random.normal(scale=self.tx_impluse_noise, size=N)
ty = np.random.normal(scale=self.ty_impluse_noise, size=N)
tz = np.random.normal(scale=self.tz_impluse_noise, size=N)
T_impluse_noise = torch.from_numpy(np.array([tx, ty, tz])).float().transpose(0, 1)
T_noise = T_new.clone()
last_i = 0
for i in range(N):
n_i = np.random.randint(last_i + window, L - N * window * 2)
# make impluse smooth
window_T = T_noise[n_i - window : n_i + window].clone()
window_impluse_T = window_T.clone()
window_impluse_T += T_impluse_noise[i : i + 1]
window_impluse_T = window_impluse_T.numpy()
window_T = window_T.numpy()
window_new_T = noisy_impluse_interpolation(window_T, window_impluse_T)
window_new_T = torch.from_numpy(window_new_T).float()
T_noise[n_i - window : n_i + window] = window_new_T
last_i = n_i
T_new = T_noise
return T_new
impulse_type_prob = {
"t": 0.2,
"r": 0.2,
"both": 0.1,
"pass": 0.5,
}
impulse_type = np.random.choice(list(impulse_type_prob.keys()), p=list(impulse_type_prob.values()))
if impulse_type == "t":
# impluse translation only
T_new = add_impulse_t(T_new)
elif impulse_type == "r":
# impluse rotation only
R_new = add_impulse_rot(R_new)
elif impulse_type == "both":
# impluse rotation and translation
R_new = add_impulse_rot(R_new)
T_new = add_impulse_t(T_new)
else:
assert impulse_type == "pass"
cam_mat_new = matrix.get_TRS(R_new, T_new)
w2c_new = torch.inverse(cam_mat_new)
R_new = matrix.get_rotation(w2c_new)
T_new = matrix.get_position(w2c_new)
tx = np.random.normal(scale=self.tx_step_noise, size=L)
ty = np.random.normal(scale=self.ty_step_noise, size=L)
tz = np.random.normal(scale=self.tz_step_noise, size=L)
T_new = T_new + torch.from_numpy(np.array([tx, ty, tz])).float().transpose(0, 1)
return R_new, T_new
def __call__(self, w_j3d, length=120):
"""
Args:
w_j3d: (L, J, 3)
length: scalar
"""
# Check
self.l = length
assert w_j3d.size(0) == self.l, "currently, only support fixed length"
# Setup
w_j3d = w_j3d.clone()
w_root = w_j3d[:, 0] # (L, 3)
# Simulate a static camera pose
cfg_camera0 = {**self.cfg_create_camera, "w": self.w, "f": self.f}
R0_w2c, t0_w2c = create_camera(w_root[0], cfg_camera0) # (3, 3) and (3,)
# Move camera
camera_type_prob = {
"random": 0.25,
"track": 0.15,
"trackrotate": 0.10,
"trackpush": 0.05,
"trackpull": 0.05,
"static": 0.4,
}
camera_type = np.random.choice(list(camera_type_prob.keys()), p=list(camera_type_prob.values()))
if camera_type == "random": # random move + add noise on cam
R_w2c = create_rotation_move(R0_w2c, length, self.r_xyz_w_std)
t_w2c = create_translation_move(R0_w2c, t0_w2c, length, self.t_xyz_w_std)
R_w2c, t_w2c = self.add_stepnoise(R_w2c, t_w2c)
elif camera_type == "track": # track human
R_w2c = create_rotation_move(R0_w2c, length, self.r_xyz_w_std_half)
cam_mat = torch.inverse(transform_mat(R0_w2c, t0_w2c)).repeat(length, 1, 1) # (F, 4, 4)
t_w2c = self.create_translation_track(cam_mat, w_root, 0.5)
R_w2c, t_w2c = self.add_stepnoise(R_w2c, t_w2c)
elif camera_type == "trackrotate": # track human and rotate
cam_mat = torch.inverse(transform_mat(R0_w2c, t0_w2c)).repeat(length, 1, 1) # (F, 4, 4)
t_w2c = self.create_translation_track(cam_mat, w_root, 0.5)
cam_mat = matrix.get_TRS(matrix.get_rotation(cam_mat), t_w2c)
R_w2c = self.create_rotation_track(cam_mat, w_root, np.pi / 16, np.pi, np.pi / 16)
R_w2c, t_w2c = self.add_stepnoise(R_w2c, t_w2c)
elif camera_type == "trackpush": # track human and push close to human
R_w2c = create_rotation_move(R0_w2c, length, self.r_xyz_w_std_half)
# [1/tz_bias_factor, 1] * dist
cam_mat = torch.inverse(transform_mat(R0_w2c, t0_w2c)).repeat(length, 1, 1) # (F, 4, 4)
t_w2c = self.create_translation_track(cam_mat, w_root, 0.5, (1.0 / (1 + self.tz_bias_factor) - 1))
R_w2c, t_w2c = self.add_stepnoise(R_w2c, t_w2c)
elif camera_type == "trackpull": # track human and pull far from human
R_w2c = create_rotation_move(R0_w2c, length, self.r_xyz_w_std_half)
# [1, (tz_bias_factor + 1)] * dist
cam_mat = torch.inverse(transform_mat(R0_w2c, t0_w2c)).repeat(length, 1, 1) # (F, 4, 4)
t_w2c = self.create_translation_track(cam_mat, w_root, 0.5, self.tz_bias_factor)
R_w2c, t_w2c = self.add_stepnoise(R_w2c, t_w2c)
else:
assert camera_type == "static"
R_w2c = R0_w2c.repeat(length, 1, 1) # (F, 3, 3)
t_w2c = t0_w2c.repeat(length, 1) # (F, 3)
# Recompute t_w2c for better camera height
# cam_w = torch.einsum("lji,lj->li", R_w2c, -t_w2c) # (L, 3), camera center in world: cam_w = - R_w2c^t_w2c @ t
# height = cam_w[..., 1] - w_root[:, 1]
# height = torch.clamp(height, self.height_min, self.height_max)
# new_pos = cam_w.clone()
# new_pos[:, 1] = w_root[:, 1] + height
# t_w2c = torch.einsum("lij,lj->li", R_w2c, -new_pos) # (L, 3), new t = -R_w2c @ cam_w
# Recompute t_w2c for better depth and FoV
c_j3d = torch.einsum("lij,lkj->lki", R_w2c, w_j3d) + t_w2c[:, None] # (L, J, 3)
delta = torch.zeros_like(t_w2c) # (L, 3) this will be later added to t_w2c
# - If the person is too close to the camera, push away the person in the z direction
c_j3d_min = c_j3d[..., 2].min() # scalar
if c_j3d_min < self.tz_post_min:
push_away = self.tz_post_min - c_j3d_min
delta[..., 2] += push_away
c_j3d[..., 2] += push_away
# - If the person is not in the FoV, push away the person in the z direction
c_root = c_j3d[:, 0] # (L, 3)
half_fov = torch.div(c_root[:, :2], c_root[:, 2:]).abs() # (L, 2), [x/z, y/z]
if half_fov.max() > self.half_fov_tol:
max_idx1, max_idx2 = torch.where(torch.max(half_fov) == half_fov)
max_idx1, max_idx2 = max_idx1[0], max_idx2[0]
z_trg = c_root[max_idx1, max_idx2].abs() / self.half_fov_tol # extreme fitted z in the fov
push_away = z_trg - c_root[max_idx1, 2]
delta[..., 2] += push_away
t_w2c += delta
T_w2c = transform_mat(R_w2c, t_w2c) # (F, 4, 4)
return T_w2c
================================================
FILE: eval/GVHMR/hmr4d/dataset/pure_motion/utils.py
================================================
import torch
import torch.nn.functional as F
from pytorch3d.transforms import (
axis_angle_to_matrix,
matrix_to_axis_angle,
matrix_to_rotation_6d,
rotation_6d_to_matrix,
)
from einops import rearrange
def aa_to_r6d(x):
return matrix_to_rotation_6d(axis_angle_to_matrix(x))
def r6d_to_aa(x):
return matrix_to_axis_angle(rotation_6d_to_matrix(x))
def interpolate_smpl_params(smpl_params, tgt_len):
"""
smpl_params['body_pose'] (L, 63)
tgt_len: L->L'
"""
betas = smpl_params["betas"]
body_pose = smpl_params["body_pose"]
global_orient = smpl_params["global_orient"] # (L, 3)
transl = smpl_params["transl"] # (L, 3)
# Interpolate
body_pose = rearrange(aa_to_r6d(body_pose.reshape(-1, 21, 3)), "l j c -> c j l")
body_pose = F.interpolate(body_pose, tgt_len, mode="linear", align_corners=True)
body_pose = r6d_to_aa(rearrange(body_pose, "c j l -> l j c")).reshape(-1, 63)
# although this should be the same as above, we do it for consistency
betas = rearrange(betas, "l c -> c 1 l")
betas = F.interpolate(betas, tgt_len, mode="linear", align_corners=True)
betas = rearrange(betas, "c 1 l -> l c")
global_orient = rearrange(aa_to_r6d(global_orient.reshape(-1, 1, 3)), "l j c -> c j l")
global_orient = F.interpolate(global_orient, tgt_len, mode="linear", align_corners=True)
global_orient = r6d_to_aa(rearrange(global_orient, "c j l -> l j c")).reshape(-1, 3)
transl = rearrange(transl, "l c -> c 1 l")
transl = F.interpolate(transl, tgt_len, mode="linear", align_corners=True)
transl = rearrange(transl, "c 1 l -> l c")
return {"body_pose": body_pose, "betas": betas, "global_orient": global_orient, "transl": transl}
def rotate_around_axis(global_orient, transl, axis="y"):
"""Global coordinate augmentation. Random rotation around y-axis"""
angle = torch.rand(1) * 2 * torch.pi
if axis == "y":
aa = torch.tensor([0.0, angle, 0.0]).float().unsqueeze(0)
rmat = axis_angle_to_matrix(aa)
global_orient = matrix_to_axis_angle(rmat @ axis_angle_to_matrix(global_orient))
transl = (rmat.squeeze(0) @ transl.T).T
return global_orient, transl
def augment_betas(betas, std=0.1):
noise = torch.normal(mean=torch.zeros(10), std=torch.ones(10) * std)
betas_aug = betas + noise[None]
return betas_aug
================================================
FILE: eval/GVHMR/hmr4d/dataset/rich/resource/seqname2imgrange.json
================================================
{"ParkingLot1_002_burpee3": [1, 351], "ParkingLot1_002_overfence1": [1, 268], "ParkingLot1_002_overfence2": [1, 270], "ParkingLot1_002_stretching1": [1, 327], "ParkingLot1_002_pushup1": [1, 220], "ParkingLot1_004_pushup2": [1, 347], "ParkingLot1_004_burpeejump1": [1, 296], "ParkingLot1_004_eating1": [1, 522], "ParkingLot1_004_takingphotos1": [1, 593], "ParkingLot1_004_phonetalk1": [1, 724], "ParkingLot1_005_burpeejump2": [1, 270], "ParkingLot1_005_overfence1": [1, 301], "ParkingLot1_005_pushup2": [1, 262], "ParkingLot1_005_pushup3": [1, 243], "ParkingLot1_004_005_greetingchattingeating1": [275, 849], "ParkingLot1_007_overfence2": [1, 263], "ParkingLot1_007_eating1": [1, 426], "ParkingLot1_007_eating2": [1, 498], "ParkingLot2_008_phonetalk1": [171, 1215], "ParkingLot2_008_burpeejump1": [78, 505], "ParkingLot2_008_overfence1": [161, 459], "ParkingLot2_008_pushup1": [165, 459], "ParkingLot2_008_pushup2": [107, 719], "ParkingLot2_008_overfence2": [138, 632], "ParkingLot2_008_overfence3": [100, 661], "ParkingLot2_008_eating1": [180, 1332], "ParkingLot2_014_pushup2": [80, 420], "ParkingLot2_014_burpeejump1": [50, 348], "ParkingLot2_014_burpeejump2": [50, 248], "ParkingLot2_014_phonetalk2": [121, 1141], "ParkingLot2_014_takingphotos2": [91, 906], "ParkingLot2_014_overfence3": [40, 502], "ParkingLot2_015_overfence1": [170, 692], "ParkingLot2_015_burpeejump2": [344, 678], "ParkingLot2_015_pushup1": [190, 817], "ParkingLot2_015_eating2": [31, 835], "ParkingLot2_016_burpeejump2": [100, 793], "ParkingLot2_016_overfence2": [100, 720], "ParkingLot2_016_pushup1": [61, 680], "ParkingLot2_016_pushup2": [100, 570], "ParkingLot2_016_stretching1": [100, 691], "Pavallion_000_yoga2": [1, 1643], "Pavallion_000_plankjack": [1, 900], "Pavallion_000_phonesiteat": [1, 1157], "Pavallion_000_sidebalancerun": [1, 1091], "Pavallion_002_plankjack": [110, 699], "Pavallion_002_phonesiteat": [1, 1030], "Pavallion_003_plankjack": [1, 764], "Pavallion_003_phonesiteat": [75, 838], "Pavallion_003_sidebalancerun": [1, 942], "Pavallion_006_phonesiteat": [130, 841], "Pavallion_006_sidebalancerun": [1, 798], "Pavallion_006_plankjack": [1, 615], "Pavallion_013_phonesiteat": [1, 1254], "Pavallion_013_plankjack": [1, 641], "Pavallion_013_yoga2": [1, 884], "Pavallion_003_018_tossball": [230, 949], "LectureHall_018_wipingchairs1": [1, 1166], "LectureHall_018_wipingspray1": [1, 904], "LectureHall_020_wipingtable1": [1, 897], "BBQ_001_juggle": [0, 297], "BBQ_001_guitar": [0, 381], "ParkingLot1_002_stretching2": [240, 240], "ParkingLot1_002_burpee1": [1, 286], "ParkingLot1_002_burpee2": [1, 203], "ParkingLot1_004_pushup1": [1, 354], "ParkingLot1_004_eating2": [1, 516], "ParkingLot1_004_phonetalk2": [1, 960], "ParkingLot1_004_takingphotos2": [1, 571], "ParkingLot1_004_stretching2": [1, 399], "ParkingLot1_005_overfence2": [1, 298], "ParkingLot1_005_pushup1": [1, 476], "ParkingLot1_005_burpeejump1": [1, 252], "ParkingLot1_007_burpee2": [1, 349], "ParkingLot2_008_eating2": [160, 1100], "ParkingLot2_008_burpeejump2": [129, 492], "ParkingLot2_014_overfence1": [95, 547], "ParkingLot2_014_eating2": [101, 986], "ParkingLot2_016_phonetalk5": [170, 1259], "Pavallion_002_sidebalancerun": [1, 655], "Pavallion_013_sidebalancerun": [1, 810], "Pavallion_018_sidebalancerun": [1, 873], "LectureHall_018_wipingtable1": [1, 1280], "LectureHall_020_wipingchairs1": [1, 1163], "LectureHall_003_wipingchairs1": [1, 724], "Pavallion_000_yoga1": [1, 1757], "Pavallion_002_yoga1": [1, 613], "Pavallion_003_yoga1": [1, 792], "Pavallion_006_yoga1": [1, 930], "Pavallion_018_yoga1": [1, 880], "ParkingLot2_017_burpeejump2": [118, 612], "ParkingLot2_017_burpeejump1": [40, 817], "ParkingLot2_017_overfence1": [110, 661], "ParkingLot2_017_overfence2": [90, 944], "ParkingLot2_017_eating1": [97, 895], "ParkingLot2_017_pushup1": [191, 719], "ParkingLot2_017_pushup2": [74, 811], "ParkingLot2_009_burpeejump1": [200, 1085], "ParkingLot2_009_burpeejump2": [150, 399], "ParkingLot2_009_overfence1": [140, 601], "ParkingLot2_009_overfence2": [150, 559], "LectureHall_009_sidebalancerun1": [1, 673], "LectureHall_010_plankjack1": [1, 532], "LectureHall_010_sidebalancerun1": [1, 919], "LectureHall_021_plankjack1": [1, 507], "LectureHall_021_sidebalancerun1": [1, 855], "LectureHall_019_wipingchairs1": [1, 978], "LectureHall_009_021_reparingprojector1": [1, 499], "ParkingLot2_009_spray1": [145, 1242], "ParkingLot2_009_impro1": [100, 990], "ParkingLot2_009_impro2": [100, 1140], "ParkingLot2_009_impro5": [100, 649], "Gym_010_pushup1": [1, 475], "Gym_010_pushup2": [1, 407], "Gym_011_pushup1": [1, 346], "Gym_011_pushup2": [1, 540], "Gym_011_burpee2": [1, 479], "Gym_012_pushup2": [1, 291], "Gym_010_mountainclimber1": [0, 0], "Gym_010_mountainclimber2": [1, 471], "Gym_013_dips1": [1, 503], "Gym_013_dips2": [1, 333], "Gym_013_dips3": [1, 502], "Gym_013_lunge1": [1, 690], "Gym_013_lunge2": [1, 834], "Gym_013_pushup1": [1, 861], "Gym_013_pushup2": [1, 477], "Gym_013_burpee4": [1, 320], "Gym_010_lunge1": [1, 337], "Gym_010_lunge2": [1, 312], "Gym_010_dips1": [1, 572], "Gym_010_dips2": [1, 603], "Gym_010_cooking1": [1, 779], "Gym_011_cooking1": [1, 1141], "Gym_011_cooking2": [1, 1145], "Gym_011_dips1": [1, 494], "Gym_011_dips4": [1, 495], "Gym_011_dips3": [1, 320], "Gym_011_dips2": [1, 382], "Gym_012_lunge1": [1, 225], "Gym_012_lunge2": [1, 318], "Gym_012_cooking2": [1, 993]}
================================================
FILE: eval/GVHMR/hmr4d/dataset/rich/resource/test.txt
================================================
sequence_name capture_name scan_name id moving_cam gender scene action/scene-interaction subjects view_id
ParkingLot2_017_burpeejump2 ParkingLot2 scan_camcoord 017 V female V V X 0,2,3
ParkingLot2_017_burpeejump1 ParkingLot2 scan_camcoord 017 V female V V X 0,1,5
ParkingLot2_017_overfence1 ParkingLot2 scan_camcoord 017 V female V V X 0,3,4
ParkingLot2_017_overfence2 ParkingLot2 scan_camcoord 017 V female V V X 0,1,4
ParkingLot2_017_eating1 ParkingLot2 scan_camcoord 017 V female V V X 0,2,4
ParkingLot2_017_pushup1 ParkingLot2 scan_camcoord 017 X female V V X 0,1,4,5
ParkingLot2_017_pushup2 ParkingLot2 scan_camcoord 017 V female V V X 0,4,5
ParkingLot2_009_burpeejump1 ParkingLot2 scan_camcoord 009 X female V V X 0,1,2,3
ParkingLot2_009_burpeejump2 ParkingLot2 scan_camcoord 009 X female V V X 0,2,3,4
ParkingLot2_009_overfence1 ParkingLot2 scan_camcoord 009 X female V V X 0,3,4,5
ParkingLot2_009_overfence2 ParkingLot2 scan_camcoord 009 X female V V X 0,1,4,5
LectureHall_009_sidebalancerun1 LectureHall scan_yoga_scene_camcoord 009 X female V V X 0,1,4,5
LectureHall_010_plankjack1 LectureHall scan_yoga_scene_camcoord 010 X female V V X 0,2,4,6
LectureHall_010_sidebalancerun1 LectureHall scan_yoga_scene_camcoord 010 X female V V X 0,1,2,4
LectureHall_021_plankjack1 LectureHall scan_yoga_scene_camcoord 021 X female V V X 0,3,5,6
LectureHall_021_sidebalancerun1 LectureHall scan_yoga_scene_camcoord 021 X female V V X 0,4,5,6
LectureHall_019_wipingchairs1 LectureHall scan_chair_scene_camcoord 019 X female V V X 0,1,2,3
LectureHall_009_021_reparingprojector1 LectureHall scan_yoga_scene_camcoord 009 X female V X X 0,3,4,5
LectureHall_009_021_reparingprojector1 LectureHall scan_yoga_scene_camcoord 021 X female V X X 0,3,4,5
ParkingLot2_009_spray1 ParkingLot2 scan_camcoord 009 X female V X X 0,1,2,3
ParkingLot2_009_impro1 ParkingLot2 scan_camcoord 009 X female V X X 0,2,3,4
ParkingLot2_009_impro2 ParkingLot2 scan_camcoord 009 X female V X X 0,3,4,5
ParkingLot2_009_impro5 ParkingLot2 scan_camcoord 009 X female V X X 0,2,4,5
Gym_010_pushup1 Gym scan_camcoord 010 X female X V X 3,4,5,6
Gym_010_pushup2 Gym scan_camcoord 010 X female X V X 2,3,4,5
Gym_011_pushup1 Gym scan_camcoord 011 X male X V X 2,3,4,5
Gym_011_pushup2 Gym scan_camcoord 011 X male X V X 2,3,4,5
Gym_011_burpee2 Gym scan_camcoord 011 X male X V X 2,3,4,5
Gym_012_pushup2 Gym scan_camcoord 012 X female X V X 3,4,5,6
Gym_010_mountainclimber1 Gym scan_camcoord 010 X female X V X 3,4,5,6
Gym_010_mountainclimber2 Gym scan_camcoord 010 X female X V X 3,4,5,6
Gym_013_dips1 Gym scan_camcoord 013 X female X X V 0,3,4,5
Gym_013_dips2 Gym scan_camcoord 013 X female X X V 1,2,4,5
Gym_013_dips3 Gym scan_camcoord 013 X female X X V 1,2,4,5
Gym_013_lunge1 Gym scan_camcoord 013 X female X X V 1,4,5,6
Gym_013_lunge2 Gym scan_camcoord 013 X female X X V 0,4,5,6
Gym_013_pushup1 Gym scan_camcoord 013 X female X V V 0,3,4,5
Gym_013_pushup2 Gym scan_camcoord 013 X female X V V 1,2,4,5
Gym_013_burpee4 Gym scan_camcoord 013 X female X V V 0,4,5,6
Gym_010_lunge1 Gym scan_camcoord 010 X female X X X 1,4,5,6
Gym_010_lunge2 Gym scan_camcoord 010 X female X X X 0,2,4,5
Gym_010_dips1 Gym scan_camcoord 010 X female X X X 0,4,5,6
Gym_010_dips2 Gym scan_camcoord 010 X female X X X 1,2,4,5
Gym_010_cooking1 Gym scan_table_camcoord 010 X female X X X 1,3,4,5
Gym_011_cooking1 Gym scan_table_camcoord 011 V male X X X 4,5,6
Gym_011_cooking2 Gym scan_table_camcoord 011 V male X X X 2,4,5
Gym_011_dips1 Gym scan_camcoord 011 X male X X X 1,3,4,5
Gym_011_dips4 Gym scan_camcoord 011 X male X X X 0,2,4,5
Gym_011_dips3 Gym scan_camcoord 011 X male X X X 0,3,4,5
Gym_011_dips2 Gym scan_camcoord 011 X male X X X 1,3,4,5
Gym_012_lunge1 Gym scan_camcoord 012 X female X X X 0,3,4,5
Gym_012_lunge2 Gym scan_camcoord 012 X female X X X 0,4,5,6
Gym_012_cooking2 Gym scan_table_camcoord 012 V female X X X 3,4,5
================================================
FILE: eval/GVHMR/hmr4d/dataset/rich/resource/train.txt
================================================
sequence_name capture_name scan_name id moving_cam gender view_id
ParkingLot1_002_burpee3 ParkingLot1 scan_camcoord 002 X male 0,1,2,3,4,5,6,7
ParkingLot1_002_overfence1 ParkingLot1 scan_camcoord 002 X male 0,1,2,3,4,5,6,7
ParkingLot1_002_overfence2 ParkingLot1 scan_camcoord 002 X male 0,1,2,3,4,5,6,7
ParkingLot1_002_stretching1 ParkingLot1 scan_camcoord 002 X male 0,1,2,3,4,5,6,7
ParkingLot1_002_pushup1 ParkingLot1 scan_camcoord 002 X male 0,1,2,3,4,5,6,7
ParkingLot1_004_pushup2 ParkingLot1 scan_camcoord 004 X male 0,1,2,3,4,5,6,7
ParkingLot1_004_burpeejump1 ParkingLot1 scan_camcoord 004 X male 0,1,2,3,4,5,6,7
ParkingLot1_004_eating1 ParkingLot1 scan_camcoord 004 X male 0,1,2,3,4,5,6,7
ParkingLot1_004_takingphotos1 ParkingLot1 scan_camcoord 004 X male 0,1,2,3,4,5,6,7
ParkingLot1_004_phonetalk1 ParkingLot1 scan_camcoord 004 X male 0,1,2,3,4,5,6,7
ParkingLot1_005_burpeejump2 ParkingLot1 scan_camcoord 005 X male 0,1,2,3,4,5,6,7
ParkingLot1_005_overfence1 ParkingLot1 scan_camcoord 005 X male 0,1,2,3,4,5,6,7
ParkingLot1_005_pushup2 ParkingLot1 scan_camcoord 005 X male 0,1,2,3,4,5,6,7
ParkingLot1_005_pushup3 ParkingLot1 scan_camcoord 005 X male 0,1,2,3,4,5,6,7
ParkingLot1_004_005_greetingchattingeating1 ParkingLot1 scan_camcoord 004 X male 0,1,2,3,4,5,6,7
ParkingLot1_004_005_greetingchattingeating1 ParkingLot1 scan_camcoord 005 X male 0,1,2,3,4,5,6,7
ParkingLot1_007_overfence2 ParkingLot1 scan_camcoord 007 X male 0,1,2,3,4,5,6,7
ParkingLot1_007_eating1 ParkingLot1 scan_camcoord 007 X male 0,1,2,3,4,5,6,7
ParkingLot1_007_eating2 ParkingLot1 scan_camcoord 007 X male 0,1,2,3,4,5,6,7
ParkingLot2_008_phonetalk1 ParkingLot2 scan_camcoord 008 V male 0,1,2,3,4,5
ParkingLot2_008_burpeejump1 ParkingLot2 scan_camcoord 008 V male 0,1,2,3,4,5
ParkingLot2_008_overfence1 ParkingLot2 scan_camcoord 008 V male 0,1,2,3,4,5
ParkingLot2_008_pushup1 ParkingLot2 scan_camcoord 008 V male 0,1,2,3,4,5
ParkingLot2_008_pushup2 ParkingLot2 scan_camcoord 008 V male 0,1,2,3,4,5
ParkingLot2_008_overfence2 ParkingLot2 scan_camcoord 008 V male 0,1,2,3,4,5
ParkingLot2_008_overfence3 ParkingLot2 scan_camcoord 008 V male 0,1,2,3,4,5
ParkingLot2_008_eating1 ParkingLot2 scan_camcoord 008 V male 0,1,2,3,4,5
ParkingLot2_014_pushup2 ParkingLot2 scan_camcoord 014 X male 0,1,2,3,4,5
ParkingLot2_014_burpeejump1 ParkingLot2 scan_camcoord 014 X male 0,1,2,3,4,5
ParkingLot2_014_burpeejump2 ParkingLot2 scan_camcoord 014 X male 0,1,2,3,4,5
ParkingLot2_014_phonetalk2 ParkingLot2 scan_camcoord 014 X male 0,1,2,3,4,5
ParkingLot2_014_takingphotos2 ParkingLot2 scan_camcoord 014 X male 0,1,2,3,4,5
ParkingLot2_014_overfence3 ParkingLot2 scan_camcoord 014 X male 0,1,2,3,4,5
ParkingLot2_015_overfence1 ParkingLot2 scan_camcoord 015 X male 0,1,2,3,4,5
ParkingLot2_015_burpeejump2 ParkingLot2 scan_camcoord 015 X male 0,1,2,3,4,5
ParkingLot2_015_pushup1 ParkingLot2 scan_camcoord 015 X male 0,1,2,3,4,5
ParkingLot2_015_eating2 ParkingLot2 scan_camcoord 015 X male 0,1,2,3,4,5
ParkingLot2_016_burpeejump2 ParkingLot2 scan_camcoord 016 V female 0,1,2,3,4,5
ParkingLot2_016_overfence2 ParkingLot2 scan_camcoord 016 V female 0,1,2,3,4,5
ParkingLot2_016_pushup1 ParkingLot2 scan_camcoord 016 V female 0,1,2,3,4,5
ParkingLot2_016_pushup2 ParkingLot2 scan_camcoord 016 V female 0,1,2,3,4,5
ParkingLot2_016_stretching1 ParkingLot2 scan_camcoord 016 V female 0,1,2,3,4,5
Pavallion_000_yoga2 Pavallion scan_camcoord 000 X male 0,1,2,3,4,5,6
Pavallion_000_plankjack Pavallion scan_camcoord 000 X male 0,1,2,3,4,5,6
Pavallion_000_phonesiteat Pavallion scan_camcoord 000 X male 0,1,3,4,6
Pavallion_000_sidebalancerun Pavallion scan_camcoord 000 X male 0,1,2,3,4,5,6
Pavallion_002_plankjack Pavallion scan_camcoord 002 V male 0,1,2,3,4,5,6
Pavallion_002_phonesiteat Pavallion scan_camcoord 002 V male 0,1,3,4,6
Pavallion_003_plankjack Pavallion scan_camcoord 003 V male 0,1,2,3,4,5,6
Pavallion_003_phonesiteat Pavallion scan_camcoord 003 V male 0,1,3,4,6
Pavallion_003_sidebalancerun Pavallion scan_camcoord 003 V male 0,1,2,3,4,5,6
Pavallion_006_phonesiteat Pavallion scan_camcoord 006 V male 0,1,3,4,6
Pavallion_006_sidebalancerun Pavallion scan_camcoord 006 V male 0,1,2,3,4,5,6
Pavallion_006_plankjack Pavallion scan_camcoord 006 V male 0,1,2,3,4,5,6
Pavallion_013_phonesiteat Pavallion scan_camcoord 013 X female 0,1,3,4,6
Pavallion_013_plankjack Pavallion scan_camcoord 013 X female 0,1,2,3,4,5,6
Pavallion_013_yoga2 Pavallion scan_camcoord 013 V female 0,1,2,3,4,5,6
Pavallion_003_018_tossball Pavallion scan_camcoord 003 X male 0,1,2,3,4,5,6
Pavallion_003_018_tossball Pavallion scan_camcoord 018 X female 0,1,2,3,4,5,6
LectureHall_018_wipingchairs1 LectureHall scan_chair_scene_camcoord 018 X female 0,1,2,3,4,5,6
LectureHall_018_wipingspray1 LectureHall scan_chair_scene_camcoord 018 X female 2,3,4
LectureHall_020_wipingtable1 LectureHall scan_chair_scene_camcoord 020 X male 0,2,4,5,6
BBQ_001_juggle BBQ scan_camcoord 001 X male 0,1,2,3,4,5,6,7
BBQ_001_guitar BBQ scan_camcoord 001 X male 0,1,2,3,4,5,6,7
================================================
FILE: eval/GVHMR/hmr4d/dataset/rich/resource/val.txt
================================================
sequence_name capture_name scan_name id moving_cam gender scene action/scene-interaction subjects view_id
ParkingLot1_002_stretching2 ParkingLot1 scan_camcoord 002 X male V V V 0,1,2,3,4,5,6,7
ParkingLot1_002_burpee1 ParkingLot1 scan_camcoord 002 X male V V V 0,1,2,3,4,5,6,7
ParkingLot1_002_burpee2 ParkingLot1 scan_camcoord 002 X male V V V 0,1,2,3,4,5,6,7
ParkingLot1_004_pushup1 ParkingLot1 scan_camcoord 004 X male V V V 0,1,2,3,4,5,6,7
ParkingLot1_004_eating2 ParkingLot1 scan_camcoord 004 X male V V V 0,1,2,3,4,5,6,7
ParkingLot1_004_phonetalk2 ParkingLot1 scan_camcoord 004 X male V V V 0,1,2,3,4,5,6,7
ParkingLot1_004_takingphotos2 ParkingLot1 scan_camcoord 004 X male V V V 0,1,2,3,4,5,6,7
ParkingLot1_004_stretching2 ParkingLot1 scan_camcoord 004 X male V V V 0,1,2,3,4,5,6,7
ParkingLot1_005_overfence2 ParkingLot1 scan_camcoord 005 X male V V V 0,1,2,3,4,5,6,7
ParkingLot1_005_pushup1 ParkingLot1 scan_camcoord 005 X male V V V 0,1,2,3,4,5,6,7
ParkingLot1_005_burpeejump1 ParkingLot1 scan_camcoord 005 X male V V V 0,1,2,3,4,5,6,7
ParkingLot1_007_burpee2 ParkingLot1 scan_camcoord 007 X male V V V 0,1,2,3,4,5,6,7
ParkingLot2_008_eating2 ParkingLot2 scan_camcoord 008 V male V V V 0,1,2,3,4,5
ParkingLot2_008_burpeejump2 ParkingLot2 scan_camcoord 008 V male V V V 0,1,2,3,4,5
ParkingLot2_014_overfence1 ParkingLot2 scan_camcoord 014 X male V V V 0,1,2,3,4,5
ParkingLot2_014_eating2 ParkingLot2 scan_camcoord 014 X male V V V 0,1,2,3,4,5
ParkingLot2_016_phonetalk5 ParkingLot2 scan_camcoord 016 V female V V V 0,1,2,3,4,5
Pavallion_002_sidebalancerun Pavallion scan_camcoord 002 V male V V V 0,1,2,3,4,5,6
Pavallion_013_sidebalancerun Pavallion scan_camcoord 013 X female V V V 0,1,2,3,4,5,6
Pavallion_018_sidebalancerun Pavallion scan_camcoord 018 V female V V V 0,1,2,3,4,5,6
LectureHall_018_wipingtable1 LectureHall scan_chair_scene_camcoord 018 X female V V V 0,2,4,5,6
LectureHall_020_wipingchairs1 LectureHall scan_chair_scene_camcoord 020 X male V V V 0,1,2,3,4,5,6
LectureHall_003_wipingchairs1 LectureHall scan_chair_scene_camcoord 003 X male V V V 0,1,2,3,4,5,6
Pavallion_000_yoga1 Pavallion scan_camcoord 000 X male V X V 0,1,2,3,4,5,6
Pavallion_002_yoga1 Pavallion scan_camcoord 002 V male V X V 0,1,2,3,4,5,6
Pavallion_003_yoga1 Pavallion scan_camcoord 003 V male V X V 0,1,2,3,4,5,6
Pavallion_006_yoga1 Pavallion scan_camcoord 006 V male V X V 0,1,2,3,4,5,6
Pavallion_018_yoga1 Pavallion scan_camcoord 018 V female V X V 0,1,2,3,4,5,6
================================================
FILE: eval/GVHMR/hmr4d/dataset/rich/resource/w2az_sahmr.json
================================================
{"BBQ_scan_camcoord": [[0.9989829107564298, 0.03367618890797693, -0.029984301180211045, 0.0008183751635392625], [0.03414262169451401, -0.1305975871406019, 0.9908473906797644, -0.005059823133706893], [0.02945208652127451, -0.9908633531086326, -0.13161455111748036, 1.4054905296083466], [0.0, 0.0, 0.0, 1.0]], "Gym_scan_camcoord": [[0.9932599733260449, -0.07628732032461205, 0.0872632233306122, -0.047601130084306706], [-0.10233962102690007, -0.22374853741942266, 0.9692590953768503, -0.04091804681182174], [-0.05441716049582774, -0.9716567484252654, -0.23004768176013274, 1.537911791136788], [0.0, 0.0, 0.0, 1.0]], "Gym_scan_table_camcoord": [[0.9974451989415423, -0.06250743213795668, 0.03458172980064169, 0.02231858470834599], [-0.04804912583358893, -0.22882402250236075, 0.972281259838159, 0.039081886755815726], [-0.05286167435026744, -0.9714588965331274, -0.2312428501197992, 1.5421821446346522], [0.0, 0.0, 0.0, 1.0]], "LectureHall_scan_chair_scene_camcoord": [[0.9992930513998263, 0.030087515976743376, -0.0225419343977731, 0.001998908749589632], [0.030705594681969043, -0.30721111058653017, 0.9511458878570781, -0.025811963513866963], [0.021692484396004613, -0.9511656401040444, -0.307917783192506, 2.060346184503773], [0.0, 0.0, 0.0, 1.0]], "LectureHall_scan_yoga_scene_camcoord": [[0.9993358324246812, 0.03030060260429296, -0.020242715082476024, -0.003510046042036605], [0.028600729415016745, -0.3079667078507395, 0.9509671419836329, -0.01748548118379142], [0.022580795137075255, -0.9509144968594153, -0.3086287856852993, 2.0424701474796567], [0.0, 0.0, 0.0, 1.0]], "ParkingLot1_scan_camcoord": [[0.9989627324729327, -0.03724260727951709, 0.02620013994738054, 0.0070941466745699025], [-0.03091587075252664, -0.13228243926883107, 0.9907298144280939, -0.0274920377236923], [-0.03343154297742938, -0.9905121627037764, -0.13329661462331338, 1.3859200914120975], [0.0, 0.0, 0.0, 1.0]], "ParkingLot2_scan_camcoord": [[0.9989532636786039, -0.04044665659892979, 0.021364572447267097, 0.01646827411554571], [-0.026687287930043047, -0.13600581518076985, 0.9903485279940424, 0.030197722289598695], [-0.03715058073335097, -0.9898820567153364, -0.13694286452455984, 1.4372015171546513], [0.0, 0.0, 0.0, 1.0]], "Pavallion_scan_camcoord": [[0.9971864096076799, 0.05693557331723671, -0.048760690979605295, 0.0012478238054067193], [0.05746407703876882, -0.16289761936471214, 0.9849681443861059, -0.006002953831755452], [0.04813672552068054, -0.9849988355812122, -0.16571104235928033, 1.7638454838942128], [0.0, 0.0, 0.0, 1.0]]}
================================================
FILE: eval/GVHMR/hmr4d/dataset/rich/rich_motion_test.py
================================================
from pathlib import Path
import numpy as np
import torch
from torch.utils import data
from hmr4d.utils.pylogger import Log
from .rich_utils import (
get_cam2params,
get_w2az_sahmr,
parse_seqname_info,
get_cam_key_wham_vid,
)
from hmr4d.utils.geo_transform import apply_T_on_points, transform_mat, compute_cam_angvel
from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines
from hmr4d.utils.smplx_utils import make_smplx
from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle
from hmr4d.utils.geo.hmr_cam import resize_K
from hmr4d.configs import MainStore, builds
VID_PRESETS = {
"easytohard": [
"test/Gym_013_burpee4/cam_06",
"test/Gym_011_pushup1/cam_02",
"test/LectureHall_019_wipingchairs1/cam_03",
"test/ParkingLot2_009_overfence1/cam_04",
"test/LectureHall_021_sidebalancerun1/cam_00",
"test/Gym_010_dips2/cam_05",
],
}
class RichSmplFullSeqDataset(data.Dataset):
def __init__(self, vid_presets=None):
"""
Args:
vid_presets is a key in VID_PRESETS
"""
super().__init__()
self.dataset_name = "RICH"
self.dataset_id = "RICH"
Log.info(f"[{self.dataset_name}] Full sequence, Test")
tic = Log.time()
# Load evaluation protocol from WHAM labels
self.rich_dir = Path("inputs/RICH/hmr4d_support")
self.labels = torch.load(self.rich_dir / "rich_test_labels.pt")
self.preproc_data = torch.load(self.rich_dir / "rich_test_preproc.pt")
vids = select_subset(self.labels, vid_presets)
# Setup dataset index
self.idx2meta = []
for vid in vids:
seq_length = len(self.labels[vid]["frame_id"])
self.idx2meta.append((vid, 0, seq_length)) # start=0, end=seq_length
# print(sum([end - start for _, _, start, end in self.idx2meta]))
# Prepare ground truth motion in ay-coordinate
self.w2az = get_w2az_sahmr() # scan_name -> T_w2az, w-coordinate refers to cam-1-coordinate
self.cam2params = get_cam2params() # cam_key -> (T_w2c, K)
seqname_info = parse_seqname_info(skip_multi_persons=True) # {k: (scan_name, subject_id, gender, cam_ids)}
self.seqname_to_scanname = {k: v[0] for k, v in seqname_info.items()}
Log.info(f"[RICH] {len(self.idx2meta)} sequences. Elapsed: {Log.time() - tic:.2f}s")
def __len__(self):
return len(self.idx2meta)
def _load_data(self, idx):
data = {}
# [start, end), when loading data from labels
vid, start, end = self.idx2meta[idx]
label = self.labels[vid]
preproc_data = self.preproc_data[vid]
length = end - start
meta = {"dataset_id": "RICH", "vid": vid, "vid-start-end": (start, end)}
data.update({"meta": meta, "length": length})
# SMPLX
data.update({"gt_smpl_params": label["gt_smplx_params"], "gender": label["gender"]})
# camera
cam_key = get_cam_key_wham_vid(vid)
scan_name = self.seqname_to_scanname[vid.split("/")[1]]
T_w2c, K = self.cam2params[cam_key] # (4, 4) (3, 3)
T_w2az = self.w2az[scan_name]
data.update({"T_w2c": T_w2c, "T_w2az": T_w2az, "K": K})
# image features
data.update(
{
"f_imgseq": preproc_data["f_imgseq"],
"bbx_xys": preproc_data["bbx_xys"],
"img_wh": preproc_data["img_wh"],
"kp2d": preproc_data["kp2d"],
}
)
# to render a video
video_path = self.rich_dir / "video" / vid / "video.mp4"
frame_id = label["frame_id"] # (F,)
width, height = data["img_wh"] / 4 # Video saved has been downsampled 1/4
K_render = resize_K(K, 0.25)
bbx_xys_render = data["bbx_xys"] / 4
data["meta_render"] = {
"name": vid.replace("/", "@"),
"video_path": str(video_path),
"frame_id": frame_id,
"width_height": (width, height),
"K": K_render,
"bbx_xys": bbx_xys_render,
}
return data
def _process_data(self, data):
# T_w2az is pre-computed by using floor clue. az2zy uses a rotation along x-axis.
R_az2ay = axis_angle_to_matrix(torch.tensor([1.0, 0.0, 0.0]) * -torch.pi / 2) # (3, 3)
T_w2ay = transform_mat(R_az2ay, R_az2ay.new([0, 0, 0])) @ data["T_w2az"] # (4, 4)
if False: # Visualize groundtruth and observation
self.rich_smplx = {
"male": make_smplx("rich-smplx", gender="male"),
"female": make_smplx("rich-smplx", gender="female"),
}
wis3d = make_wis3d(name="debug-rich-smpl_dataset")
rich_smplx = make_smplx("rich-smplx", gender=data["gender"])
smplx_out = rich_smplx(**data["gt_smpl_params"])
smplx_verts_ay = apply_T_on_points(smplx_out.vertices, T_w2ay)
for i in range(400):
wis3d.set_scene_id(i)
wis3d.add_mesh(smplx_out.vertices[i], rich_smplx.bm.faces, name=f"gt-smplx")
wis3d.add_mesh(smplx_verts_ay[i], rich_smplx.bm.faces, name=f"gt-smplx-ay")
# process img feature with xys
length = data["length"]
f_imgseq = data["f_imgseq"] # (F, 1024)
R_w2c = data["T_w2c"][:3, :3].repeat(length, 1, 1) # (L, 4, 4)
cam_angvel = compute_cam_angvel(R_w2c) # (L, 6)
# Return
data = {
# --- not batched
"task": "CAP-Seq",
"meta": data["meta"],
"meta_render": data["meta_render"],
# --- we test on single sequence, so set kv manually
"length": length,
"f_imgseq": f_imgseq,
"cam_angvel": cam_angvel,
"bbx_xys": data["bbx_xys"], # (F, 3)
"K_fullimg": data["K"][None].expand(length, -1, -1), # (F, 3, 3)
"kp2d": data["kp2d"], # (F, 17, 3)
# --- dataset specific
"model": "smplx",
"gender": data["gender"],
"gt_smpl_params": data["gt_smpl_params"],
"T_w2ay": T_w2ay, # (4, 4)
"T_w2c": data["T_w2c"], # (4, 4)
}
return data
def __getitem__(self, idx):
data = self._load_data(idx)
data = self._process_data(data)
return data
def select_subset(labels, vid_presets):
vids = list(labels.keys())
if vid_presets != None: # Use a subset of the videos
vids = VID_PRESETS[vid_presets]
return vids
#
group_name = "test_datasets/rich"
base_node = builds(RichSmplFullSeqDataset, vid_presets=None, populate_full_signature=True)
MainStore.store(name="all", node=base_node, group=group_name)
MainStore.store(name="easy_to_hard", node=base_node(vid_presets="easytohard"), group=group_name)
MainStore.store(name="postproc", node=base_node(vid_presets="postproc"), group=group_name)
================================================
FILE: eval/GVHMR/hmr4d/dataset/rich/rich_utils.py
================================================
import torch
import cv2
import numpy as np
from hmr4d.utils.geo_transform import apply_T_on_points, project_p2d
from pathlib import Path
import json
import time
# ----- Meta sample utils ----- #
def sample_idx2meta(idx2meta, sample_interval):
"""
1. remove frames that < 45
2. sample frames by sample_interval
3. sorted
"""
idx2meta = [
v
for k, v in idx2meta.items()
if int(v["frame_name"]) > 45 and (int(v["frame_name"]) + int(v["cam_id"])) % sample_interval == 0
]
idx2meta = sorted(idx2meta, key=lambda meta: meta["img_key"])
return idx2meta
def remove_bbx_invisible_frame(idx2meta, img2gtbbx):
raw_img_lu = np.array([0.0, 0.0])
raw_img_rb_type1 = np.array([4112.0, 3008.0]) - 1 # horizontal
raw_img_rb_type2 = np.array([3008.0, 4112.0]) - 1 # vertical
idx2meta_new = []
for meta in idx2meta:
gtbbx_center = np.array([img2gtbbx[meta["img_key"]][[0, 2]].mean(), img2gtbbx[meta["img_key"]][[1, 3]].mean()])
if (gtbbx_center < raw_img_lu).any():
continue
raw_img_rb = raw_img_rb_type1 if meta["cam_key"] not in ["Pavallion_3", "Pavallion_5"] else raw_img_rb_type2
if (gtbbx_center > raw_img_rb).any():
continue
idx2meta_new.append(meta)
return idx2meta_new
def remove_extra_rules(idx2meta):
multi_person_seqs = ["LectureHall_009_021_reparingprojector1"]
idx2meta = [meta for meta in idx2meta if meta["seq_name"] not in multi_person_seqs]
return idx2meta
# ----- Image utils ----- #
def compute_bbx(dataset, data):
"""
Use gt_smplh_params to compute bbx (w.r.t. original image resolution)
Args:
dataset: rich_pose.RichPose
data: dict
# This function need extra scripts to run
from hmr4d.utils.smplx_utils import make_smplx
self.smplh_male = make_smplx("rich-smplh", gender="male")
self.smplh_female = make_smplx("rich-smplh", gender="female")
self.smplh = {
"male": self.smplh_male,
"female": self.smplh_female,
}
"""
gender = data["meta"]["gender"]
smplh_params = {k: v.reshape(1, -1) for k, v in data["gt_smplh_params"].items()}
smplh_opt = dataset.smplh[gender](**smplh_params)
verts_3d_w = smplh_opt.vertices
T_w2c, K = data["T_w2c"], data["K"]
verts_3d_c = apply_T_on_points(verts_3d_w, T_w2c[None])
verts_2d = project_p2d(verts_3d_c, K[None])[0]
min_2d = verts_2d.T.min(-1)[0]
max_2d = verts_2d.T.max(-1)[0]
bbx = torch.stack([min_2d, max_2d]).reshape(-1).numpy()
return bbx
def get_2d(dataset, data):
gender = data["meta"]["gender"]
smplh_params = {k: v.reshape(1, -1) for k, v in data["gt_smplh_params"].items()}
smplh_opt = dataset.smplh[gender](**smplh_params)
joints_3d_w = smplh_opt.joints
T_w2c, K = data["T_w2c"], data["K"]
joints_3d_c = apply_T_on_points(joints_3d_w, T_w2c[None])
joints_2d = project_p2d(joints_3d_c, K[None])[0]
conf = torch.ones((73, 1))
keypoints = torch.cat([joints_2d, conf], dim=1)
return keypoints
def squared_crop_and_resize(dataset, img, bbx_lurb, dst_size=224, state=None):
if state is not None:
np.random.set_state(state)
center_rand = dataset.BBX_CENTER * (np.random.random(2) * 2 - 1)
center_x = (bbx_lurb[0] + bbx_lurb[2]) / 2 + center_rand[0]
center_y = (bbx_lurb[1] + bbx_lurb[3]) / 2 + center_rand[1]
ori_half_size = max(bbx_lurb[2] - bbx_lurb[0], bbx_lurb[3] - bbx_lurb[1]) / 2
ori_half_size *= 1 + 0.15 + dataset.BBX_ZOOM * np.random.random() # zoom
src = np.array(
[
[center_x - ori_half_size, center_y - ori_half_size],
[center_x + ori_half_size, center_y - ori_half_size],
[center_x, center_y],
],
dtype=np.float32,
)
dst = np.array([[0, 0], [dst_size - 1, 0], [dst_size / 2 - 0.5, dst_size / 2 - 0.5]], dtype=np.float32)
A = cv2.getAffineTransform(src, dst)
img_crop = cv2.warpAffine(img, A, (dst_size, dst_size), flags=cv2.INTER_LINEAR)
bbx_new = np.array(
[center_x - ori_half_size, center_y - ori_half_size, center_x + ori_half_size, center_y + ori_half_size],
dtype=bbx_lurb.dtype,
)
return img_crop, bbx_new, A
# Augment bbx
def get_augmented_square_bbx(bbx_lurb, per_shift=0.1, per_zoomout=0.2, base_zoomout=0.15, state=None):
"""
Args:
per_shift: in percent, maximum random shift
per_zoomout: in percent, maximum random zoom
"""
if state is not None:
np.random.set_state(state)
maxsize_bbx = max(bbx_lurb[2] - bbx_lurb[0], bbx_lurb[3] - bbx_lurb[1])
# shift of center
shift = maxsize_bbx * per_shift * (np.random.random(2) * 2 - 1)
center_x = (bbx_lurb[0] + bbx_lurb[2]) / 2 + shift[0]
center_y = (bbx_lurb[1] + bbx_lurb[3]) / 2 + shift[1]
# zoomout of half-size
halfsize_bbx = maxsize_bbx / 2
halfsize_bbx *= 1 + base_zoomout + per_zoomout * np.random.random()
bbx_lurb = np.array(
[
center_x - halfsize_bbx,
center_y - halfsize_bbx,
center_x + halfsize_bbx,
center_y + halfsize_bbx,
]
)
return bbx_lurb
def get_squared_bbx_region_and_resize(frames, bbx_xys, dst_size=224):
"""
Args:
frames: (F, H, W, 3)
bbx_xys: (F, 3), xys
"""
frames_np = frames.numpy() if isinstance(frames, torch.Tensor) else frames
bbx_xys = bbx_xys if isinstance(bbx_xys, torch.Tensor) else torch.tensor(bbx_xys) # use tensor
srcs = torch.stack(
[
torch.stack([bbx_xys[:, 0] - bbx_xys[:, 2] / 2, bbx_xys[:, 1] - bbx_xys[:, 2] / 2], dim=-1),
torch.stack([bbx_xys[:, 0] + bbx_xys[:, 2] / 2, bbx_xys[:, 1] - bbx_xys[:, 2] / 2], dim=-1),
bbx_xys[:, :2],
],
dim=1,
) # (F, 3, 2)
dst = np.array([[0, 0], [dst_size - 1, 0], [dst_size / 2 - 0.5, dst_size / 2 - 0.5]], dtype=np.float32)
As = np.stack([cv2.getAffineTransform(src, dst) for src in srcs.numpy()])
img_crops = np.stack(
[cv2.warpAffine(frames_np[i], As[i], (dst_size, dst_size), flags=cv2.INTER_LINEAR) for i in range(len(As))]
)
img_crops = torch.from_numpy(img_crops)
As = torch.from_numpy(As)
return img_crops, As
# ----- Camera utils ----- #
def extract_cam_xml(xml_path="", dtype=torch.float32):
import xml.etree.ElementTree as ET
tree = ET.parse(xml_path)
extrinsics_mat = [float(s) for s in tree.find("./CameraMatrix/data").text.split()]
intrinsics_mat = [float(s) for s in tree.find("./Intrinsics/data").text.split()]
distortion_vec = [float(s) for s in tree.find("./Distortion/data").text.split()]
return {
"ext_mat": torch.tensor(extrinsics_mat).float(),
"int_mat": torch.tensor(intrinsics_mat).float(),
"dis_vec": torch.tensor(distortion_vec).float(),
}
def get_cam2params(scene_info_root=None):
"""
Args:
scene_info_root: this could be repalced by path to scan_calibration
"""
if scene_info_root is not None:
cam_params = {}
cam_xml_files = Path(scene_info_root).glob("*/calibration/*.xml")
for cam_xml_file in cam_xml_files:
cam_param = extract_cam_xml(cam_xml_file)
T_w2c = cam_param["ext_mat"].reshape(3, 4)
T_w2c = torch.cat([T_w2c, torch.tensor([[0, 0, 0, 1.0]])], dim=0) # (4, 4)
K = cam_param["int_mat"].reshape(3, 3)
cap_name = cam_xml_file.parts[-3]
cam_id = int(cam_xml_file.stem)
cam_key = f"{cap_name}_{cam_id}"
cam_params[cam_key] = (T_w2c, K)
else:
cam_params = torch.load(Path(__file__).parent / "resource/cam2params.pt")
return cam_params
# ----- Parse Raw Resource ----- #
def get_w2az_sahmr():
"""
Returns:
w2az_sahmr: dict, {scan_name: Tw2az}, Tw2az is a tensor of (4,4)
"""
fn = Path(__file__).parent / "resource/w2az_sahmr.json"
with open(fn, "r") as f:
kvs = json.load(f).items()
w2az_sahmr = {k: torch.tensor(v) for k, v in kvs}
return w2az_sahmr
def has_multi_persons(seq_name):
"""
Args:
seq_name: e.g. LectureHall_009_021_reparingprojector1
"""
return len(seq_name.split("_")) != 3
def parse_seqname_info(skip_multi_persons=True):
"""
This function will skip multi-person sequences.
Returns:
sname_to_info: scan_name, subject_id, gender, cam_ids
"""
fns = [Path(__file__).parent / f"resource/{split}.txt" for split in ["train", "val", "test"]]
# Train / Val&Test Header:
# sequence_name capture_name scan_name id moving_cam gender view_id
# sequence_name capture_name scan_name id moving_cam gender scene action/scene-interaction subjects view_id
sname_to_info = {}
for fn in fns:
with open(fn, "r") as f:
for line in f.readlines()[1:]:
raw_values = line.strip().split()
seq_name = raw_values[0]
if skip_multi_persons and has_multi_persons(seq_name):
continue
scan_name = f"{raw_values[1]}_{raw_values[2]}"
subject_id = int(raw_values[3])
gender = raw_values[5]
cam_ids = [int(c) for c in raw_values[-1].split(",")]
sname_to_info[seq_name] = (scan_name, subject_id, gender, cam_ids)
return sname_to_info
def get_seqnames_of_split(splits=["train"], skip_multi_persons=True):
if not isinstance(splits, list):
splits = [splits]
fns = [Path(__file__).parent / f"resource/{split}.txt" for split in splits]
seqnames = []
for fn in fns:
with open(fn, "r") as f:
for line in f.readlines()[1:]:
seq_name = line.strip().split()[0]
if skip_multi_persons and has_multi_persons(seq_name):
continue
seqnames.append(seq_name)
return seqnames
def get_seqname_to_imgrange():
"""Each sequence has a different range of image ids."""
from tqdm import tqdm
split_seqnames = {split: get_seqnames_of_split(split) for split in ["train", "val", "test"]}
seqname_to_imgrange = {}
for split in ["train", "val", "test"]:
for seqname in tqdm(split_seqnames[split]):
img_root = Path("inputs/RICH") / "images_ds4" / split # compressed (not original)
img_dir = img_root / seqname
img_names = sorted([n.name for n in img_dir.glob("**/*.jpeg")])
if len(img_names) == 0:
img_range = (0, 0)
else:
img_range = (int(img_names[0].split("_")[0]), int(img_names[-1].split("_")[0]))
seqname_to_imgrange[seqname] = img_range
return seqname_to_imgrange
# ----- Compose keys ----- #
def get_img_key(seq_name, cam_id, f_id):
assert len(seq_name.split("_")) == 3
subject_id = int(seq_name.split("_")[1])
return f"{seq_name}_{int(cam_id)}_{int(f_id):05d}_{subject_id}"
def get_seq_cam_fn(img_root, seq_name, cam_id):
"""
Args:
img_root: "inputs/RICH/images_ds4/train"
"""
img_root = Path(img_root)
cam_id = int(cam_id)
return str(img_root / f"{seq_name}/cam_{cam_id:02d}")
def get_img_fn(img_root, seq_name, cam_id, f_id):
"""
Args:
img_root: "inputs/RICH/images_ds4/train"
"""
img_root = Path(img_root)
cam_id = int(cam_id)
f_id = int(f_id)
return str(img_root / f"{seq_name}/cam_{cam_id:02d}" / f"{f_id:05d}_{cam_id:02d}.jpeg")
# ----- WHAM ----- #
def get_cam_key_wham_vid(vid):
_, sname, cname = vid.split("/")
scene = sname.split("_")[0]
cid = int(cname.split("_")[1])
cam_key = f"{scene}_{cid}"
return cam_key
def get_K_wham_vid(vid):
cam_key = get_cam_key_wham_vid(vid)
cam2params = get_cam2params()
K = cam2params[cam_key][1]
return K
class RichVid2Tc2az:
def __init__(self) -> None:
self.w2az = get_w2az_sahmr() # scan_name: tensor 4,4
seqname_info = parse_seqname_info(skip_multi_persons=True) # {k: (scan_name, subject_id, gender, cam_ids)}
self.seqname_to_scanname = {k: v[0] for k, v in seqname_info.items()}
self.cam2params = get_cam2params() # cam_key -> (T_w2c, K)
def __call__(self, vid):
cam_key = get_cam_key_wham_vid(vid)
scan_name = self.seqname_to_scanname[vid.split("/")[1]]
T_w2c, K = self.cam2params[cam_key] # (4, 4) (3, 3)
T_w2az = self.w2az[scan_name]
T_c2az = T_w2az @ T_w2c.inverse()
return T_c2az
def get_T_w2az(self, vid):
cam_key = get_cam_key_wham_vid(vid)
scan_name = self.seqname_to_scanname[vid.split("/")[1]]
T_w2az = self.w2az[scan_name]
return T_w2az
================================================
FILE: eval/GVHMR/hmr4d/dataset/threedpw/threedpw_motion_test.py
================================================
import torch
from torch.utils import data
from pathlib import Path
from hmr4d.utils.pylogger import Log
from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines
from hmr4d.utils.geo_transform import compute_cam_angvel
from hmr4d.utils.geo.hmr_cam import estimate_K, resize_K
from hmr4d.utils.geo.flip_utils import flip_kp2d_coco17
from hmr4d.configs import MainStore, builds
VID_HARD = []
# VID_HARD = ["downtown_bar_00_1"]
class ThreedpwSmplFullSeqDataset(data.Dataset):
def __init__(self, flip_test=False, skip_invalid=False):
super().__init__()
self.dataset_name = "3DPW"
self.skip_invalid = skip_invalid
Log.info(f"[{self.dataset_name}] Full sequence")
# Load evaluation protocol from WHAM labels
self.threedpw_dir = Path("inputs/3DPW/hmr4d_support")
# ['vname', 'K_fullimg', 'T_w2c', 'smpl_params', 'gender', 'mask_raw', 'mask_wham', 'img_wh']
self.labels = torch.load(self.threedpw_dir / "test_3dpw_gt_labels.pt")
self.vid2bbx = torch.load(self.threedpw_dir / "preproc_test_bbx.pt")
self.vid2kp2d = torch.load(self.threedpw_dir / "preproc_test_kp2d_v0.pt")
# Setup dataset index
self.idx2meta = list(self.labels)
if len(VID_HARD) > 0: # Pick subsets for fast testing
self.idx2meta = VID_HARD
Log.info(f"[{self.dataset_name}] {len(self.idx2meta)} sequences.")
# If flip_test is enabled, we will return extra data for flipped test
self.flip_test = flip_test
if self.flip_test:
Log.info(f"[{self.dataset_name}] Flip test enabled")
def __len__(self):
return len(self.idx2meta)
def _load_data(self, idx):
data = {}
vid = self.idx2meta[idx]
meta = {"dataset_id": self.dataset_name, "vid": vid}
data.update({"meta": meta})
# Add useful data
label = self.labels[vid]
mask = label["mask_wham"]
width_height = label["img_wh"]
data.update(
{
"length": len(mask), # F
"smpl_params": label["smpl_params"], # world
"gender": label["gender"], # str
"T_w2c": label["T_w2c"], # (F, 4, 4)
"mask": mask, # (F)
}
)
K_fullimg = label["K_fullimg"] # (3, 3)
if False:
K_fullimg = estimate_K(*width_height)
data["K_fullimg"] = K_fullimg
# Preprocessed: bbx, kp2d, image as feature
bbx_xys = self.vid2bbx[vid]["bbx_xys"] # (F, 3)
kp2d = self.vid2kp2d[vid] # (F, 17, 3)
cam_angvel = compute_cam_angvel(data["T_w2c"][:, :3, :3]) # (L, 6)
data.update({"bbx_xys": bbx_xys, "kp2d": kp2d, "cam_angvel": cam_angvel})
imgfeat_dir = self.threedpw_dir / "imgfeats/3dpw_test"
f_img_dict = torch.load(imgfeat_dir / f"{vid}.pt")
f_imgseq = f_img_dict["features"].float()
data["f_imgseq"] = f_imgseq # (F, 1024)
# to render a video
vname = label["vname"]
video_path = self.threedpw_dir / f"videos/{vname}.mp4"
frame_id = torch.where(mask)[0].long()
ds = 0.5
K_render = resize_K(K_fullimg, ds)
bbx_xys_render = bbx_xys * ds
kp2d_render = kp2d.clone()
kp2d_render[..., :2] *= ds
data["meta_render"] = {
"name": vid,
"video_path": str(video_path),
"ds": ds,
"frame_id": frame_id,
"K": K_render,
"bbx_xys": bbx_xys_render,
"kp2d": kp2d_render,
}
if self.flip_test:
imgfeat_dir = self.threedpw_dir / "imgfeats/3dpw_test_flip"
f_img_dict = torch.load(imgfeat_dir / f"{vid}.pt")
flipped_bbx_xys = f_img_dict["bbx_xys"].float() # (L, 3)
flipped_features = f_img_dict["features"].float() # (L, 1024)
flipped_kp2d = flip_kp2d_coco17(kp2d, width_height[0]) # (L, 17, 3)
R_flip_x = torch.tensor([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]).float()
flipped_R_w2c = R_flip_x @ data["T_w2c"][:, :3, :3].clone()
data_flip = {
"bbx_xys": flipped_bbx_xys,
"f_imgseq": flipped_features,
"kp2d": flipped_kp2d,
"cam_angvel": compute_cam_angvel(flipped_R_w2c),
}
data["flip_test"] = data_flip
return data
def _process_data(self, data):
length = data["length"]
data["K_fullimg"] = data["K_fullimg"][None].repeat(length, 1, 1)
if self.skip_invalid: # Drop all invalid frames
mask = data["mask"].clone()
data["length"] = sum(mask)
data["smpl_params"] = {k: v[mask].clone() for k, v in data["smpl_params"].items()}
data["T_w2c"] = data["T_w2c"][mask].clone()
data["mask"] = data["mask"][mask].clone()
data["K_fullimg"] = data["K_fullimg"][mask].clone()
data["bbx_xys"] = data["bbx_xys"][mask].clone()
data["kp2d"] = data["kp2d"][mask].clone()
data["cam_angvel"] = data["cam_angvel"][mask].clone()
data["f_imgseq"] = data["f_imgseq"][mask].clone()
data["flip_test"] = {k: v[mask].clone() for k, v in data["flip_test"].items()}
return data
def __getitem__(self, idx):
data = self._load_data(idx)
data = self._process_data(data)
return data
# 3DPW
MainStore.store(
name="fliptest",
node=builds(ThreedpwSmplFullSeqDataset, flip_test=True),
group="test_datasets/3dpw",
)
MainStore.store(
name="v1",
node=builds(ThreedpwSmplFullSeqDataset, flip_test=False),
group="test_datasets/3dpw",
)
================================================
FILE: eval/GVHMR/hmr4d/dataset/threedpw/threedpw_motion_train.py
================================================
import torch
from torch.utils import data
from pathlib import Path
import numpy as np
from hmr4d.utils.pylogger import Log
from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines
from hmr4d.utils.geo_transform import compute_cam_angvel
from hmr4d.utils.geo.hmr_cam import estimate_K, resize_K
from hmr4d.utils.geo.flip_utils import flip_kp2d_coco17
from hmr4d.dataset.imgfeat_motion.base_dataset import ImgfeatMotionDatasetBase
from hmr4d.utils.net_utils import get_valid_mask, repeat_to_max_len, repeat_to_max_len_dict
from hmr4d.utils.smplx_utils import make_smplx
from hmr4d.utils.video_io_utils import get_video_lwh, read_video_np, save_video
from hmr4d.utils.vis.renderer_utils import simple_render_mesh_background
from hmr4d.configs import MainStore, builds
class ThreedpwSmplDataset(ImgfeatMotionDatasetBase):
def __init__(self):
# Path
self.hmr4d_support_dir = Path("inputs/3DPW/hmr4d_support")
self.dataset_name = "3DPW"
# Setting
self.min_motion_frames = 60
self.max_motion_frames = 120
super().__init__()
def _load_dataset(self):
self.train_labels = torch.load(self.hmr4d_support_dir / "train_3dpw_gt_labels.pt")
self.refit_smplx = torch.load(self.hmr4d_support_dir / "train_refit_smplx.pt")
if True: # Remove clips that have obvious error
update_list = {
"courtyard_basketball_00_1": [(0, 300), (340, 468)],
"courtyard_laceShoe_00_0": [(0, 620), (780, 931)],
"courtyard_rangeOfMotions_00_1": [(0, 370), (410, 601)],
"courtyard_shakeHands_00_1": [(0, 100), (120, 391)],
}
for k, v in update_list.items():
self.refit_smplx[k]["valid_range_list"] = v
self.f_img_folder = self.hmr4d_support_dir / "imgfeats/3dpw_train_smplx_refit"
Log.info(f"[{self.dataset_name}] Train")
def _get_idx2meta(self):
# We expect to see the entire sequence during one epoch,
# so each sequence will be sampled max(SeqLength // MotionFrames, 1) times
seq_lengths = []
self.idx2meta = []
for vid in self.refit_smplx:
valid_range_list = self.refit_smplx[vid]["valid_range_list"]
for start, end in valid_range_list:
seq_length = end - start
num_samples = max(seq_length // self.max_motion_frames, 1)
seq_lengths.append(seq_length)
self.idx2meta.extend([(vid, start, end)] * num_samples)
minutes = sum(seq_lengths) / 25 / 60
Log.info(
f"[{self.dataset_name}] has {minutes:.1f} minutes motion -> Resampled to {len(self.idx2meta)} samples."
)
def _load_data(self, idx):
data = {}
vid, range1, range2 = self.idx2meta[idx]
# Random select a subset
mlength = range2 - range1
min_motion_len = self.min_motion_frames
max_motion_len = self.max_motion_frames
if mlength < min_motion_len: # this may happen, the minimal mlength is around 30
start = range1
length = mlength
else:
effect_max_motion_len = min(max_motion_len, mlength)
length = np.random.randint(min_motion_len, effect_max_motion_len + 1) # [low, high)
start = np.random.randint(range1, range2 - length + 1)
end = start + length
data["length"] = length
data["meta"] = {"data_name": self.dataset_name, "idx": idx, "vid": vid, "start_end": (start, end)}
# Select motion subset
data["smplx_params_incam"] = {k: v[start:end] for k, v in self.refit_smplx[vid]["smplx_params_incam"].items()}
data["K_fullimg"] = self.train_labels[vid]["K_fullimg"]
data["T_w2c"] = self.train_labels[vid]["T_w2c"][start:end]
# Img (as feature):
f_img_dict = torch.load(self.f_img_folder / f"{vid}.pt")
data["bbx_xys"] = f_img_dict["bbx_xys"][start:end] # (F, 3)
data["f_imgseq"] = f_img_dict["features"][start:end].float() # (F, 3)
data["img_wh"] = f_img_dict["img_wh"] # (2)
data["kp2d"] = torch.zeros((end - start), 17, 3) # (L, 17, 3) # do not provide kp2d
return data
def _process_data(self, data, idx):
length = data["length"]
smpl_params_c = data["smplx_params_incam"]
smpl_params_w_zero = {k: torch.zeros_like(v) for k, v in smpl_params_c.items()}
K_fullimg = data["K_fullimg"][None].repeat(length, 1, 1)
cam_angvel = compute_cam_angvel(data["T_w2c"][:, :3, :3])
max_len = self.max_motion_frames
return_data = {
"meta": data["meta"],
"length": length,
"smpl_params_c": smpl_params_c,
"smpl_params_w": smpl_params_w_zero,
"R_c2gv": torch.zeros(length, 3, 3), # (F, 3, 3)
"gravity_vec": torch.zeros(3), # (3)
"bbx_xys": data["bbx_xys"], # (F, 3)
"K_fullimg": K_fullimg, # (F, 3, 3)
"f_imgseq": data["f_imgseq"], # (F, D)
"kp2d": data["kp2d"], # (F, 17, 3)
"cam_angvel": cam_angvel, # (F, 6)
"mask": {
"valid": get_valid_mask(max_len, length),
"vitpose": False,
"bbx_xys": True,
"f_imgseq": True,
"spv_incam_only": True,
},
}
if False: # Debug, render incam
start, end = data["meta"]["start_end"]
vid = data["meta"]["vid"]
ds = 0.5
faces = smplx.faces
smplx = make_smplx("supermotion")
smplx_c_verts = smplx(**return_data["smpl_params_c"]).vertices
K_render = resize_K(K_fullimg, ds)
video_path = self.hmr4d_support_dir / f"videos/{vid[:-2]}.mp4"
images = read_video_np(video_path, scale=ds, start_frame=start, end_frame=end)
render_dict = {
"K": K_render[:1], # only support batch size 1
"faces": faces,
"verts": smplx_c_verts,
"background": images,
}
img_overlay = simple_render_mesh_background(render_dict, VI=10)
save_video(img_overlay, f"tmp.mp4", crf=28)
# Batchable
return_data["smpl_params_c"] = repeat_to_max_len_dict(return_data["smpl_params_c"], max_len)
return_data["smpl_params_w"] = repeat_to_max_len_dict(return_data["smpl_params_w"], max_len)
return_data["R_c2gv"] = repeat_to_max_len(return_data["R_c2gv"], max_len)
return_data["bbx_xys"] = repeat_to_max_len(return_data["bbx_xys"], max_len)
return_data["K_fullimg"] = repeat_to_max_len(return_data["K_fullimg"], max_len)
return_data["f_imgseq"] = repeat_to_max_len(return_data["f_imgseq"], max_len)
return_data["kp2d"] = repeat_to_max_len(return_data["kp2d"], max_len)
return_data["cam_angvel"] = repeat_to_max_len(return_data["cam_angvel"], max_len)
return return_data
# 3DPW
MainStore.store(name="v1", node=builds(ThreedpwSmplDataset), group="train_datasets/imgfeat_3dpw")
================================================
FILE: eval/GVHMR/hmr4d/dataset/threedpw/utils.py
================================================
import json
import numpy as np
from pathlib import Path
from collections import defaultdict
import pickle
import torch
import joblib
RESOURCE_FOLDER = Path(__file__).resolve().parent / "resource"
def read_raw_pkl(pkl_path):
with open(pkl_path, "rb") as f:
data = pickle.load(f, encoding="bytes")
num_subjects = len(data[b"poses"])
F = data[b"poses"][0].shape[0]
smpl_params = []
for i in range(num_subjects):
smpl_params.append(
{
"body_pose": torch.from_numpy(data[b"poses"][i][:, 3:72]).float(), # (F, 69)
"betas": torch.from_numpy(data[b"betas"][i][:10]).repeat(F, 1).float(), # (F, 10)
"global_orient": torch.from_numpy(data[b"poses"][i][:, :3]).float(), # (F, 3)
"transl": torch.from_numpy(data[b"trans"][i]).float(), # (F, 3)
}
)
genders = ["male" if g == "m" else "female" for g in data[b"genders"]]
campose_valid = [torch.from_numpy(v).bool() for v in data[b"campose_valid"]]
seq_name = data[b"sequence"]
K_fullimg = torch.from_numpy(data[b"cam_intrinsics"]).float()
T_w2c = torch.from_numpy(data[b"cam_poses"]).float()
return_data = {
"sequence": seq_name, # 'courtyard_bodyScannerMotions_00'
"K_fullimg": K_fullimg, # (3, 3), not 55FoV
"T_w2c": T_w2c, # (F, 4, 4)
"smpl_params": smpl_params, # list of dict
"genders": genders, # list of str
"campose_valid": campose_valid, # list of bool-array
# "jointPositions": data[b'jointPositions'], # SMPL, 24x3
# "poses2d": data[b"poses2d"], # COCO, 3x18(?)
}
return return_data
def load_and_convert_wham_pth(pth):
"""
Convert to {vid: DataDict} style, Add smpl_params_incam
"""
# load
wham_labels_raw = joblib.load(pth)
# convert it to {vid: DataDict} style
wham_labels = {}
for i, vid in enumerate(wham_labels_raw["vid"]):
wham_labels[vid] = {k: wham_labels_raw[k][i] for k in wham_labels_raw}
# convert pose and betas as smpl_params_incam (without transl)
for vid in wham_labels:
pose = wham_labels[vid]["pose"]
global_orient = pose[:, :3] # (F, 3)
body_pose = pose[:, 3:] # (F, 69)
betas = wham_labels[vid]["betas"] # (F, 10), all frames are the same
wham_labels[vid]["smpl_params_incam"] = {
"body_pose": body_pose.float(), # (F, 69)
"betas": betas.float(), # (F, 10)
"global_orient": global_orient.float(), # (F, 3)
}
return wham_labels
# Neural-Annot utils
def na_cam_param_to_K_fullimg(cam_param):
K = torch.eye(3)
K[[0, 1], [0, 1]] = torch.tensor(cam_param["focal"])
K[[0, 1], [2, 2]] = torch.tensor(cam_param["princpt"])
return K
================================================
FILE: eval/GVHMR/hmr4d/model/common_utils/optimizer.py
================================================
from torch.optim import AdamW, Adam
from hmr4d.configs import MainStore, builds
optimizer_cfgs = {
"adam_1e-3": builds(Adam, lr=1e-3, zen_partial=True),
"adam_2e-4": builds(Adam, lr=2e-4, zen_partial=True),
"adamw_2e-4": builds(AdamW, lr=2e-4, zen_partial=True),
"adamw_1e-4": builds(AdamW, lr=1e-4, zen_partial=True),
"adamw_5e-5": builds(AdamW, lr=5e-5, zen_partial=True),
"adamw_1e-5": builds(AdamW, lr=1e-5, zen_partial=True),
# zero-shot text-to-image generation
"adamw_1e-3_dalle": builds(AdamW, lr=1e-3, weight_decay=1e-4, zen_partial=True),
}
for name, cfg in optimizer_cfgs.items():
MainStore.store(name=name, node=cfg, group=f"optimizer")
================================================
FILE: eval/GVHMR/hmr4d/model/common_utils/scheduler.py
================================================
import torch
from bisect import bisect_right
class WarmupMultiStepLR(torch.optim.lr_scheduler.LRScheduler):
def __init__(self, optimizer, milestones, warmup=0, gamma=0.1, last_epoch=-1, verbose="deprecated"):
"""Assume optimizer does not change lr; Scheduler is called epoch-based"""
self.milestones = milestones
self.warmup = warmup
assert warmup < milestones[0]
self.gamma = gamma
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
base_lrs = self.base_lrs # base lr for each groups
n_groups = len(base_lrs)
comming_epoch = self.last_epoch # the lr will be set for the comming epoch, starts from 0
# add extra warmup
if comming_epoch < self.warmup:
# e.g. comming_epoch [0, 1, 2] for warmup == 3
# lr should be base_lr * (last_epoch+1) / (warmup + 1), e.g. [0.25, 0.5, 0.75] * base_lr
lr_factor = (self.last_epoch + 1) / (self.warmup + 1)
return [base_lrs[i] * lr_factor for i in range(n_groups)]
else:
# bisect_right([3,5,7], 0) -> 0; bisect_right([3,5,7], 5) -> 2
p = bisect_right(self.milestones, comming_epoch)
lr_factor = self.gamma**p
return [base_lrs[i] * lr_factor for i in range(n_groups)]
================================================
FILE: eval/GVHMR/hmr4d/model/common_utils/scheduler_cfg.py
================================================
from omegaconf import DictConfig, ListConfig
from hmr4d.configs import MainStore, builds
# do not perform scheduling
default = DictConfig({"scheduler": None})
MainStore.store(name="default", node=default, group=f"scheduler_cfg")
# epoch-based
def epoch_half_by(milestones=[100, 200, 300]):
return DictConfig(
{
"scheduler": {
"_target_": "torch.optim.lr_scheduler.MultiStepLR",
"milestones": milestones,
"gamma": 0.5,
},
"interval": "epoch",
"frequency": 1,
}
)
MainStore.store(name="epoch_half_100_200_300", node=epoch_half_by([100, 200, 300]), group=f"scheduler_cfg")
MainStore.store(name="epoch_half_100_200", node=epoch_half_by([100, 200]), group=f"scheduler_cfg")
MainStore.store(name="epoch_half_200_350", node=epoch_half_by([200, 350]), group=f"scheduler_cfg")
MainStore.store(name="epoch_half_300", node=epoch_half_by([300]), group=f"scheduler_cfg")
# epoch-based
def warmup_epoch_half_by(warmup=10, milestones=[100, 200, 300]):
return DictConfig(
{
"scheduler": {
"_target_": "hmr4d.model.common_utils.scheduler.WarmupMultiStepLR",
"milestones": milestones,
"warmup": warmup,
"gamma": 0.5,
},
"interval": "epoch",
"frequency": 1,
}
)
MainStore.store(name="warmup_5_epoch_half_200_350", node=warmup_epoch_half_by(5, [200, 350]), group=f"scheduler_cfg")
MainStore.store(name="warmup_10_epoch_half_200_350", node=warmup_epoch_half_by(10, [200, 350]), group=f"scheduler_cfg")
================================================
FILE: eval/GVHMR/hmr4d/model/gvhmr/callbacks/metric_3dpw.py
================================================
import torch
import pytorch_lightning as pl
import numpy as np
from pathlib import Path
from einops import einsum, rearrange
from hmr4d.configs import MainStore, builds
from hmr4d.utils.pylogger import Log
from hmr4d.utils.comm.gather import all_gather
from hmr4d.utils.eval.eval_utils import compute_camcoord_metrics, as_np_array
from hmr4d.utils.smplx_utils import make_smplx
from hmr4d.utils.vis.cv2_utils import cv2, draw_bbx_xys_on_image_batch, draw_coco17_skeleton_batch
from hmr4d.utils.vis.renderer_utils import simple_render_mesh_background
from hmr4d.utils.video_io_utils import read_video_np, get_video_lwh, save_video
from hmr4d.utils.geo_transform import apply_T_on_points
from hmr4d.utils.seq_utils import rearrange_by_mask
class MetricMocap(pl.Callback):
def __init__(self):
super().__init__()
# vid->result
self.metric_aggregator = {
"pa_mpjpe": {},
"mpjpe": {},
"pve": {},
"accel": {},
}
# SMPLX and SMPL
self.smplx = make_smplx("supermotion_EVAL3DPW")
self.smpl = {"male": make_smplx("smpl", gender="male"), "female": make_smplx("smpl", gender="female")}
self.J_regressor = torch.load("hmr4d/utils/body_model/smpl_3dpw14_J_regressor_sparse.pt").to_dense()
self.J_regressor24 = torch.load("hmr4d/utils/body_model/smpl_neutral_J_regressor.pt")
self.smplx2smpl = torch.load("hmr4d/utils/body_model/smplx2smpl_sparse.pt")
self.faces_smplx = self.smplx.faces
self.faces_smpl = self.smpl["male"].faces
# The metrics are calculated similarly for val/test/predict
self.on_test_batch_end = self.on_validation_batch_end = self.on_predict_batch_end
# Only validation record the metrics with logger
self.on_test_epoch_end = self.on_validation_epoch_end = self.on_predict_epoch_end
# ================== Batch-based Computation ================== #
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
"""The behaviour is the same for val/test/predict"""
assert batch["B"] == 1
dataset_id = batch["meta"][0]["dataset_id"]
if dataset_id != "3DPW":
return
# Move to cuda if not
self.smplx = self.smplx.cuda()
for g in ["male", "female"]:
self.smpl[g] = self.smpl[g].cuda()
self.J_regressor = self.J_regressor.cuda()
self.J_regressor24 = self.J_regressor24.cuda()
self.smplx2smpl = self.smplx2smpl.cuda()
vid = batch["meta"][0]["vid"]
seq_length = batch["length"][0].item()
gender = batch["gender"][0]
T_w2c = batch["T_w2c"][0]
mask = batch["mask"][0]
# Groundtruth (cam)
target_w_params = {k: v[0] for k, v in batch["smpl_params"].items()}
target_w_output = self.smpl[gender](**target_w_params)
target_w_verts = target_w_output.vertices
target_c_verts = apply_T_on_points(target_w_verts, T_w2c)
target_c_j3d = torch.matmul(self.J_regressor, target_c_verts)
# + Prediction -> Metric
smpl_out = self.smplx(**outputs["pred_smpl_params_incam"])
pred_c_verts = torch.stack([torch.matmul(self.smplx2smpl, v_) for v_ in smpl_out.vertices])
pred_c_j3d = einsum(self.J_regressor, pred_c_verts, "j v, l v i -> l j i")
del smpl_out # Prevent OOM
# Metric of current sequence
batch_eval = {
"pred_j3d": pred_c_j3d,
"target_j3d": target_c_j3d,
"pred_verts": pred_c_verts,
"target_verts": target_c_verts,
}
camcoord_metrics = compute_camcoord_metrics(batch_eval, mask=mask, pelvis_idxs=[2, 3])
for k in camcoord_metrics:
self.metric_aggregator[k][vid] = as_np_array(camcoord_metrics[k])
if False: # Render incam (simple)
meta_render = batch["meta_render"][0]
images = read_video_np(meta_render["video_path"], scale=meta_render["ds"])
render_dict = {
"K": meta_render["K"][None], # only support batch size 1
"faces": self.smpl["male"].faces,
"verts": pred_c_verts,
"background": images,
}
img_overlay = simple_render_mesh_background(render_dict)
output_fn = Path("outputs/3DPW_render_pred_flip") / f"{vid}.mp4"
save_video(img_overlay, output_fn, crf=28)
if False: # Render incam (with details)
meta_render = batch["meta_render"][0]
images = read_video_np(meta_render["video_path"], scale=meta_render["ds"])
render_dict = {
"K": meta_render["K"][None], # only support batch size 1
"faces": self.smpl["male"].faces,
"verts": pred_c_verts,
"background": images,
}
img_overlay = simple_render_mesh_background(render_dict)
# Add COCO17 and bbx to image
bbx_xys_render = meta_render["bbx_xys"]
kp2d_render = meta_render["kp2d"]
img_overlay = draw_coco17_skeleton_batch(img_overlay, kp2d_render, conf_thr=0.5)
img_overlay = draw_bbx_xys_on_image_batch(bbx_xys_render, img_overlay, mask)
# Add metric
metric_all = rearrange_by_mask(torch.tensor(camcoord_metrics["pa_mpjpe"]), mask)
for i in range(len(img_overlay)):
m = metric_all[i]
if m == 0: # a not evaluated frame
continue
text = f"PA-MPJPE: {m:.1f}"
color = (244, 10, 20) if m > 45 else (0, 205, 0) # red or green
cv2.putText(img_overlay[i], text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
output_dir = Path("tmp_pred_details")
output_dir.mkdir(exist_ok=True, parents=True)
save_video(img_overlay, output_dir / f"{vid}.mp4", crf=24)
# ================== Epoch Summary ================== #
def on_predict_epoch_end(self, trainer, pl_module):
"""Without logger"""
local_rank, world_size = trainer.local_rank, trainer.world_size
monitor_metric = "pa_mpjpe"
# Reduce metric_aggregator across all processes
metric_keys = list(self.metric_aggregator.keys())
with torch.inference_mode(False): # allow in-place operation of all_gather
metric_aggregator_gathered = all_gather(self.metric_aggregator) # list of dict
for metric_key in metric_keys:
for d in metric_aggregator_gathered:
self.metric_aggregator[metric_key].update(d[metric_key])
if False: # debug to make sure the all_gather is correct
print(f"[RANK {local_rank}/{world_size}]: {self.metric_aggregator[monitor_metric].keys()}")
total = len(self.metric_aggregator[monitor_metric])
Log.info(f"{total} sequences evaluated in {self.__class__.__name__}")
if total == 0:
return
# print monitored metric per sequence
mm_per_seq = {k: v.mean() for k, v in self.metric_aggregator[monitor_metric].items()}
if len(mm_per_seq) > 0:
sorted_mm_per_seq = sorted(mm_per_seq.items(), key=lambda x: x[1], reverse=True)
n_worst = 5 if trainer.state.stage == "validate" else len(sorted_mm_per_seq)
if local_rank == 0:
Log.info(
f"monitored metric {monitor_metric} per sequence\n"
+ "\n".join([f"{m:5.1f} : {s}" for s, m in sorted_mm_per_seq[:n_worst]])
+ "\n------"
)
# average over all batches
metrics_avg = {k: np.concatenate(list(v.values())).mean() for k, v in self.metric_aggregator.items()}
if local_rank == 0:
Log.info(f"[Metrics] 3DPW:\n" + "\n".join(f"{k}: {v:.1f}" for k, v in metrics_avg.items()) + "\n------")
# save to logger if available
if pl_module.logger is not None:
cur_epoch = pl_module.current_epoch
for k, v in metrics_avg.items():
pl_module.logger.log_metrics({f"val_metric_3DPW/{k}": v}, step=cur_epoch)
# reset
for k in self.metric_aggregator:
self.metric_aggregator[k] = {}
node_3dpw = builds(MetricMocap)
MainStore.store(name="metric_3dpw", node=node_3dpw, group="callbacks", package="callbacks.metric_3dpw")
================================================
FILE: eval/GVHMR/hmr4d/model/gvhmr/callbacks/metric_emdb.py
================================================
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
from hmr4d.configs import MainStore, builds
from hmr4d.utils.comm.gather import all_gather
from hmr4d.utils.pylogger import Log
from hmr4d.utils.eval.eval_utils import (
compute_camcoord_metrics,
compute_global_metrics,
compute_camcoord_perjoint_metrics,
rearrange_by_mask,
as_np_array,
)
from hmr4d.utils.geo_transform import apply_T_on_points, compute_T_ayfz2ay
from hmr4d.utils.smplx_utils import make_smplx
from einops import einsum, rearrange
from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines
from hmr4d.utils.vis.renderer import Renderer, get_global_cameras_static
from hmr4d.utils.geo.hmr_cam import estimate_focal_length
from hmr4d.utils.video_io_utils import read_video_np, save_video
import imageio
from tqdm import tqdm
from pathlib import Path
import numpy as np
import cv2
class MetricMocap(pl.Callback):
def __init__(self, emdb_split=1):
"""
Args:
emdb_split: 1 to evaluate incam, 2 to evaluate global
"""
super().__init__()
# vid->result
if emdb_split == 1:
self.target_dataset_id = "EMDB_1"
self.metric_aggregator = {
"pa_mpjpe": {},
"mpjpe": {},
"pve": {},
"accel": {},
}
elif emdb_split == 2:
self.target_dataset_id = "EMDB_2"
self.metric_aggregator = {
"wa2_mpjpe": {},
"waa_mpjpe": {},
"rte": {},
"jitter": {},
"fs": {},
}
else:
raise ValueError(f"Unknown emdb_split: {emdb_split}")
# SMPL
self.smplx = make_smplx("supermotion")
self.smpl_model = {"male": make_smplx("smpl", gender="male"), "female": make_smplx("smpl", gender="female")}
self.J_regressor = torch.load("hmr4d/utils/body_model/smpl_neutral_J_regressor.pt")
self.smplx2smpl = torch.load("hmr4d/utils/body_model/smplx2smpl_sparse.pt")
self.faces_smpl = self.smpl_model["male"].faces
self.faces_smplx = self.smplx.faces
# The metrics are calculated similarly for val/test/predict
self.on_test_batch_end = self.on_validation_batch_end = self.on_predict_batch_end
# Only validation record the metrics with logger
self.on_test_epoch_end = self.on_validation_epoch_end = self.on_predict_epoch_end
# ================== Batch-based Computation ================== #
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
"""The behaviour is the same for val/test/predict"""
assert batch["B"] == 1
dataset_id = batch["meta"][0]["dataset_id"]
if dataset_id != self.target_dataset_id:
return
# Move to cuda if not
self.smplx = self.smplx.cuda()
for g in ["male", "female"]:
self.smpl_model[g] = self.smpl_model[g].cuda()
self.J_regressor = self.J_regressor.cuda()
self.smplx2smpl = self.smplx2smpl.cuda()
vid = batch["meta"][0]["vid"]
seq_length = batch["length"][0].item()
gender = batch["gender"][0]
T_w2c = batch["T_w2c"][0]
mask = batch["mask"][0]
# Groundtruth (world, cam)
target_w_params = {k: v[0] for k, v in batch["smpl_params"].items()}
target_w_output = self.smpl_model[gender](**target_w_params)
target_w_verts = target_w_output.vertices
target_w_j3d = torch.matmul(self.J_regressor, target_w_verts)
target_c_verts = apply_T_on_points(target_w_verts, T_w2c)
target_c_j3d = apply_T_on_points(target_w_j3d, T_w2c)
# + Prediction -> Metric
if self.target_dataset_id == "EMDB_1": # in camera metrics
# 1. cam
pred_smpl_params_incam = outputs["pred_smpl_params_incam"]
smpl_out = self.smplx(**pred_smpl_params_incam)
pred_c_verts = torch.stack([torch.matmul(self.smplx2smpl, v_) for v_ in smpl_out.vertices])
pred_c_j3d = einsum(self.J_regressor, pred_c_verts, "j v, l v i -> l j i")
del smpl_out # Prevent OOM
batch_eval = {
"pred_j3d": pred_c_j3d,
"target_j3d": target_c_j3d,
"pred_verts": pred_c_verts,
"target_verts": target_c_verts,
}
camcoord_metrics = compute_camcoord_metrics(batch_eval, mask=mask)
for k in camcoord_metrics:
self.metric_aggregator[k][vid] = as_np_array(camcoord_metrics[k])
elif self.target_dataset_id == "EMDB_2": # global metrics
# 2. global (align-y axis)
pred_smpl_params_global = outputs["pred_smpl_params_global"]
smpl_out = self.smplx(**pred_smpl_params_global)
pred_ay_verts = torch.stack([torch.matmul(self.smplx2smpl, v_) for v_ in smpl_out.vertices])
pred_ay_j3d = einsum(self.J_regressor, pred_ay_verts, "j v, l v i -> l j i")
del smpl_out # Prevent OOM
batch_eval = {
"pred_j3d_glob": pred_ay_j3d,
"target_j3d_glob": target_w_j3d,
"pred_verts_glob": pred_ay_verts,
"target_verts_glob": target_w_verts,
}
global_metrics = compute_global_metrics(batch_eval, mask=mask)
for k in global_metrics:
self.metric_aggregator[k][vid] = as_np_array(global_metrics[k])
if False: # wis3d debug
wis3d = make_wis3d(name="debug-emdb-incam")
pred_cr_j3d = pred_c_j3d - pred_c_j3d[:, [0]] # (L, J, 3)
target_cr_j3d = target_c_j3d - target_c_j3d[:, [0]] # (L, J, 3)
add_motion_as_lines(pred_cr_j3d, wis3d, name="pred_cr_j3d", const_color="blue")
add_motion_as_lines(target_cr_j3d, wis3d, name="target_cr_j3d", const_color="green")
if False: # Dump wis3d
vid = batch["meta"][0]["vid"]
split = batch["meta_render"][0]["split"]
wis3d = make_wis3d(name=f"dump_emdb{split}-{vid}")
R_cam_type = batch["meta_render"][0]["R_cam_type"]
pred_cr_j3d = pred_c_j3d - pred_c_j3d[:, [0]] # (L, J, 3)
target_cr_j3d = target_c_j3d - target_c_j3d[:, [0]] # (L, J, 3)
add_motion_as_lines(pred_cr_j3d, wis3d, name="pred_cr_j3d", const_color="blue")
add_motion_as_lines(target_cr_j3d, wis3d, name="target_cr_j3d", const_color="green")
add_motion_as_lines(pred_ay_j3d, wis3d, name=f"pred_ay_j3d@{R_cam_type}")
# add_motion_as_lines(target_w_j3d, wis3d, name="target_ay_j3d")
if False: # Render incam
# -- rendering code -- #
vname = batch["meta_render"][0]["name"]
video_path = batch["meta_render"][0]["video_path"]
width, height = batch["meta_render"][0]["width_height"]
K = batch["meta_render"][0]["K"]
faces = self.faces_smpl
split = batch["meta_render"][0]["split"]
out_fn = f"outputs/dump_render_emdb{split}/{vname}.mp4"
Path(out_fn).parent.mkdir(exist_ok=True, parents=True)
# renderer
renderer = Renderer(width, height, device="cuda", faces=faces, K=K)
# not skipping invalid frames
resize_factor = 0.25
images = read_video_np(video_path, scale=resize_factor) # (F, H, W, 3), uint8, numpy
frame_id = batch["meta_render"][0]["frame_id"]
bbx_xys_render = batch["meta_render"][0]["bbx_xys"]
metric_vis = rearrange_by_mask(torch.from_numpy(self.metric_aggregator["mpjpe"][vid]), mask)
# -- render mesh -- #
verts_incam = pred_c_verts
output_images = []
for i in tqdm(range(len(images)), desc=f"Rendering {vname}"):
img = renderer.render_mesh(verts_incam[i].cuda(), images[i], [0.8, 0.8, 0.8])
# bbx
bbx_xys_ = bbx_xys_render[i].cpu().numpy()
lu_point = (bbx_xys_[:2] - bbx_xys_[2:] / 2).astype(int)
rd_point = (bbx_xys_[:2] + bbx_xys_[2:] / 2).astype(int)
img = cv2.rectangle(img, lu_point, rd_point, (255, 178, 102), 2)
if metric_vis[i] > 0:
text = f"pred mpjpe: {metric_vis[i]:.1f}"
text_color = (244, 10, 20) if metric_vis[i] > 80 else (0, 205, 0) # red or green
cv2.putText(img, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.75, text_color, 2)
output_images.append(img)
save_video(output_images, out_fn, quality=5)
if False: # Visualize incam + global results
def move_to_start_point_face_z(verts):
"XZ to origin, Start from the ground, Face-Z"
verts = verts.clone() # (L, V, 3)
xz_mean = verts[0].mean(0)[[0, 2]]
y_min = verts[0, :, [1]].min()
offset = torch.tensor([[[xz_mean[0], y_min, xz_mean[1]]]]).to(verts)
verts = verts - offset
T_ay2ayfz = compute_T_ayfz2ay(einsum(self.J_regressor, verts[[0]], "j v, l v i -> l j i"), inverse=True)
verts = apply_T_on_points(verts, T_ay2ayfz)
return verts
verts_incam = pred_c_verts.clone()
# verts_glob = move_to_start_point_face_z(target_ay_verts) # gt
verts_glob = move_to_start_point_face_z(pred_ay_verts)
global_R, global_T, global_lights = get_global_cameras_static(verts_glob.cpu())
# -- rendering code (global version FOV=55) -- #
vname = batch["meta_render"][0]["name"]
width, height = batch["meta_render"][0]["width_height"]
K = batch["meta_render"][0]["K"]
faces = self.faces_smpl
out_fn = f"outputs/dump_render_global/{vname}.mp4"
Path(out_fn).parent.mkdir(exist_ok=True, parents=True)
writer = imageio.get_writer(out_fn, fps=30, mode="I", format="FFMPEG", macro_block_size=1)
# two renderers
renderer_incam = Renderer(width, height, device="cuda", faces=faces, K=K)
renderer_glob = Renderer(width, height, estimate_focal_length(width, height), device="cuda", faces=faces)
# imgs
video_path = batch["meta_render"][0]["video_path"]
frame_id = batch["meta_render"][0]["frame_id"].cpu().numpy()
images = read_video_np(video_path, frame_id=frame_id) # (F, H/4, W/4, 3), uint8, numpy
# Actual rendering
cx, cz = (verts_glob.mean(1).max(0)[0] + verts_glob.mean(1).min(0)[0])[[0, 2]] / 2.0
scale = (verts_glob.mean(1).max(0)[0] - verts_glob.mean(1).min(0)[0])[[0, 2]].max() * 1.5
renderer_glob.set_ground(scale, cx.item(), cz.item())
color = torch.ones(3).float().cuda() * 0.8
for i in tqdm(range(seq_length), desc=f"Rendering {vname}"):
# incam
img_overlay_pred = renderer_incam.render_mesh(verts_incam[i].cuda(), images[i], [0.8, 0.8, 0.8])
if batch["meta_render"][0].get("bbx_xys", None) is not None: # draw bbox lines
bbx_xys = batch["meta_render"][0]["bbx_xys"][i].cpu().numpy()
lu_point = (bbx_xys[:2] - bbx_xys[2:] / 2).astype(int)
rd_point = (bbx_xys[:2] + bbx_xys[2:] / 2).astype(int)
img_overlay_pred = cv2.rectangle(img_overlay_pred, lu_point, rd_point, (255, 178, 102), 2)
pred_mpjpe_ = self.metric_aggregator["mpjpe"][vid][i]
text = f"pred mpjpe: {pred_mpjpe_:.1f}"
cv2.putText(img_overlay_pred, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (200, 100, 200), 2)
# glob
cameras = renderer_glob.create_camera(global_R[i], global_T[i])
img_glob = renderer_glob.render_with_ground(verts_glob[[i]], color[None], cameras, global_lights)
# write
img = np.concatenate([img_overlay_pred, img_glob], axis=1)
writer.append_data(img)
writer.close()
pass
# ================== Epoch Summary ================== #
def on_predict_epoch_end(self, trainer, pl_module):
"""Without logger"""
local_rank, world_size = trainer.local_rank, trainer.world_size
if "mpjpe" in self.metric_aggregator:
monitor_metric = "mpjpe"
else:
monitor_metric = list(self.metric_aggregator.keys())[0]
# Reduce metric_aggregator across all processes
metric_keys = list(self.metric_aggregator.keys())
with torch.inference_mode(False): # allow in-place operation of all_gather
metric_aggregator_gathered = all_gather(self.metric_aggregator) # list of dict
for metric_key in metric_keys:
for d in metric_aggregator_gathered:
self.metric_aggregator[metric_key].update(d[metric_key])
total = len(self.metric_aggregator[monitor_metric])
Log.info(f"{total} sequences evaluated in {self.__class__.__name__}")
if total == 0:
return
# print monitored metric per sequence
mm_per_seq = {k: v.mean() for k, v in self.metric_aggregator[monitor_metric].items()}
if len(mm_per_seq) > 0:
sorted_mm_per_seq = sorted(mm_per_seq.items(), key=lambda x: x[1], reverse=True)
n_worst = 5 if trainer.state.stage == "validate" else len(sorted_mm_per_seq)
if local_rank == 0:
Log.info(
f"monitored metric {monitor_metric} per sequence\n"
+ "\n".join([f"{m:5.1f} : {s}" for s, m in sorted_mm_per_seq[:n_worst]])
+ "\n------"
)
# average over all batches
metrics_avg = {k: np.concatenate(list(v.values())).mean() for k, v in self.metric_aggregator.items()}
if local_rank == 0:
Log.info(
f"[Metrics] {self.target_dataset_id}:\n"
+ "\n".join(f"{k}: {v:.1f}" for k, v in metrics_avg.items())
+ "\n------"
)
# save to logger if available
if pl_module.logger is not None:
cur_epoch = pl_module.current_epoch
for k, v in metrics_avg.items():
pl_module.logger.log_metrics({f"val_metric_{self.target_dataset_id}/{k}": v}, step=cur_epoch)
# reset
for k in self.metric_aggregator:
self.metric_aggregator[k] = {}
emdb1_node = builds(MetricMocap, emdb_split=1)
emdb2_node = builds(MetricMocap, emdb_split=2)
MainStore.store(name="metric_emdb1", node=emdb1_node, group="callbacks", package="callbacks.metric_emdb1")
MainStore.store(name="metric_emdb2", node=emdb2_node, group="callbacks", package="callbacks.metric_emdb2")
================================================
FILE: eval/GVHMR/hmr4d/model/gvhmr/callbacks/metric_rich.py
================================================
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
from hmr4d.configs import MainStore, builds
from hmr4d.utils.comm.gather import all_gather
from hmr4d.utils.pylogger import Log
from hmr4d.utils.eval.eval_utils import (
compute_camcoord_metrics,
compute_global_metrics,
compute_camcoord_perjoint_metrics,
as_np_array,
)
from hmr4d.utils.geo_transform import apply_T_on_points, compute_T_ayfz2ay
from hmr4d.utils.smplx_utils import make_smplx
from einops import einsum, rearrange
from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle
from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines, get_colors_by_conf
from hmr4d.utils.vis.renderer import Renderer, get_global_cameras_static, get_ground_params_from_points
from hmr4d.utils.geo.hmr_cam import estimate_focal_length
from hmr4d.utils.video_io_utils import read_video_np, save_video, get_writer
import imageio
from tqdm import tqdm
from pathlib import Path
import numpy as np
import cv2
from smplx.joint_names import JOINT_NAMES
from hmr4d.utils.net_utils import repeat_to_max_len, gaussian_smooth
from hmr4d.utils.geo.hmr_global import rollout_vel, get_static_joint_mask
class MetricMocap(pl.Callback):
def __init__(self):
super().__init__()
# vid->result
self.metric_aggregator = {
"pa_mpjpe": {},
"mpjpe": {},
"pve": {},
"accel": {},
"wa2_mpjpe": {},
"waa_mpjpe": {},
"rte": {},
"jitter": {},
"fs": {},
}
self.perjoint_metrics = False
if self.perjoint_metrics:
body_joint_names = JOINT_NAMES[:22] + ["left_hand", "right_hand"]
self.body_joint_names = body_joint_names
self.perjoint_metric_aggregator = {
"mpjpe": {k: {} for k in body_joint_names},
}
self.perjoint_obs_metric_aggregator = {
"mpjpe": {k: {} for k in body_joint_names},
}
# SMPL
self.smplx_model = {
"male": make_smplx("rich-smplx", gender="male"),
"female": make_smplx("rich-smplx", gender="female"),
"neutral": make_smplx("rich-smplx", gender="neutral"),
}
self.J_regressor = torch.load("hmr4d/utils/body_model/smpl_neutral_J_regressor.pt")
self.smplx2smpl = torch.load("hmr4d/utils/body_model/smplx2smpl_sparse.pt")
self.faces_smpl = make_smplx("smpl").faces
self.faces_smplx = self.smplx_model["neutral"].faces
# The metrics are calculated similarly for val/test/predict
self.on_test_batch_end = self.on_validation_batch_end = self.on_predict_batch_end
# Only validation record the metrics with logger
self.on_test_epoch_end = self.on_validation_epoch_end = self.on_predict_epoch_end
# ================== Batch-based Computation ================== #
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
"""The behaviour is the same for val/test/predict"""
assert batch["B"] == 1
dataset_id = batch["meta"][0]["dataset_id"]
if dataset_id != "RICH":
return
# Move to cuda if not
for g in ["male", "female", "neutral"]:
self.smplx_model[g] = self.smplx_model[g].cuda()
self.J_regressor = self.J_regressor.cuda()
self.smplx2smpl = self.smplx2smpl.cuda()
vid = batch["meta"][0]["vid"]
seq_length = batch["length"][0].item()
gender = batch["gender"][0]
T_w2ay = batch["T_w2ay"][0]
T_w2c = batch["T_w2c"][0]
# Groundtruth (world, cam)
target_w_params = {k: v[0] for k, v in batch["gt_smpl_params"].items()}
target_w_output = self.smplx_model[gender](**target_w_params)
target_w_verts = torch.stack([torch.matmul(self.smplx2smpl, v_) for v_ in target_w_output.vertices])
target_c_verts = apply_T_on_points(target_w_verts, T_w2c)
target_c_j3d = torch.matmul(self.J_regressor, target_c_verts)
offset = target_c_j3d[..., [1, 2], :].mean(-2, keepdim=True) # (L, 1, 3)
target_cr_j3d = target_c_j3d - offset
target_cr_verts = target_c_verts - offset
# optional: ay for visual comparison
target_ay_verts = apply_T_on_points(target_w_verts, T_w2ay)
target_ay_j3d = torch.matmul(self.J_regressor, target_ay_verts)
# + Prediction -> Metric
# 1. cam
pred_smpl_params_incam = outputs["pred_smpl_params_incam"]
smpl_out = self.smplx_model["neutral"](**pred_smpl_params_incam)
pred_c_verts = torch.stack([torch.matmul(self.smplx2smpl, v_) for v_ in smpl_out.vertices])
pred_c_j3d = einsum(self.J_regressor, pred_c_verts, "j v, l v i -> l j i")
offset = pred_c_j3d[..., [1, 2], :].mean(-2, keepdim=True) # (L, 1, 3)
# 2. ay
pred_smpl_params_global = outputs["pred_smpl_params_global"]
smpl_out = self.smplx_model["neutral"](**pred_smpl_params_global)
pred_ay_verts = torch.stack([torch.matmul(self.smplx2smpl, v_) for v_ in smpl_out.vertices])
pred_ay_j3d = einsum(self.J_regressor, pred_ay_verts, "j v, l v i -> l j i")
# Metric of current sequence
batch_eval = {
"pred_j3d": pred_c_j3d,
"target_j3d": target_c_j3d,
"pred_verts": pred_c_verts,
"target_verts": target_c_verts,
}
camcoord_metrics = compute_camcoord_metrics(batch_eval)
for k in camcoord_metrics:
self.metric_aggregator[k][vid] = as_np_array(camcoord_metrics[k])
batch_eval = {
"pred_j3d_glob": pred_ay_j3d,
"target_j3d_glob": target_ay_j3d,
"pred_verts_glob": pred_ay_verts,
"target_verts_glob": target_ay_verts,
}
global_metrics = compute_global_metrics(batch_eval)
for k in global_metrics:
self.metric_aggregator[k][vid] = as_np_array(global_metrics[k])
if False: # global wi3d debug
wis3d = make_wis3d(name="debug-metric-global")
add_motion_as_lines(pred_ay_j3d, wis3d, name="pred_ay_j3d")
add_motion_as_lines(target_ay_j3d, wis3d, name="target_ay_j3d")
if False: # incam visualize debug
# Print per-sequence error
Log.info(
f"seq {vid} metrics:\n"
+ "\n".join(
f"{k}: {self.metric_aggregator[k][vid].mean():.1f} (obs:{camcoord_metrics[k].mean():.1f})"
for k in camcoord_metrics.keys()
)
+ "\n------\n"
)
if self.perjoint_metrics:
Log.info(
f"\n".join(
f"{k}-{j}: {self.perjoint_metric_aggregator[k][j][vid].mean():.1f} (obs:{self.perjoint_obs_metric_aggregator[k][j][vid].mean():.1f})"
for j in self.body_joint_names
for k in self.perjoint_obs_metric_aggregator.keys()
)
+ "\n------"
)
# -- metric -- #
pred_mpjpe = self.metric_aggregator["mpjpe"][vid].mean()
obs_mpjpe = camcoord_metrics["mpjpe"].mean()
# -- render mesh -- #
vertices_gt = target_c_verts
vertices_cr_gt = target_cr_verts + target_cr_verts.new([0, 0, 3.0]) # move forward +z
vertices_pred = pred_c_verts
vertices_cr_obs = obs_cr_verts + obs_cr_verts.new([0, 0, 3.0]) # move forward +z
vertices_cr_pred = pred_cr_verts + pred_cr_verts.new([0, 0, 3.0]) # move forward +z
# -- rendering code -- #
vname = batch["meta_render"][0]["name"]
K = batch["meta_render"][0]["K"]
width, height = batch["meta_render"][0]["width_height"]
faces = self.faces_smpl
renderer = Renderer(width, height, device="cuda", faces=faces, K=K)
out_fn = f"outputs/dump_render/{vname}.mp4"
Path(out_fn).parent.mkdir(exist_ok=True, parents=True)
writer = imageio.get_writer(out_fn, fps=30, mode="I", format="FFMPEG", macro_block_size=1)
# imgs
video_path = batch["meta_render"][0]["video_path"]
frame_id = batch["meta_render"][0]["frame_id"].cpu().numpy()
vr = decord.VideoReader(video_path)
images = vr.get_batch(list(frame_id)).numpy() # (F, H/4, W/4, 3), uint8, numpy
for i in tqdm(range(seq_length), desc=f"Rendering {vname}"):
img_overlay_gt = renderer.render_mesh(vertices_gt[i].cuda(), images[i], [39, 194, 128])
if batch["meta_render"][0].get("bbx_xys", None) is not None: # draw bbox lines
bbx_xys = batch["meta_render"][0]["bbx_xys"][i].cpu().numpy()
lu_point = (bbx_xys[:2] - bbx_xys[2:] / 2).astype(int)
rd_point = (bbx_xys[:2] + bbx_xys[2:] / 2).astype(int)
img_overlay_gt = cv2.rectangle(img_overlay_gt, lu_point, rd_point, (255, 178, 102), 2)
img_overlay_pred = renderer.render_mesh(vertices_pred[i].cuda(), images[i])
# img_overlay_pred = renderer.render_mesh(vertices_pred[i].cuda(), np.zeros_like(images[i]))
img = np.concatenate([img_overlay_gt, img_overlay_pred], axis=0)
####### overlay gt cr first, then overlay pred cr with error color ########
# overlay gt cr first with blue color
black_overlay_obs = renderer.render_mesh(
vertices_cr_gt[i].cuda(), np.zeros_like(images[i]), colors=[39, 194, 128]
)
black_overlay_pred = renderer.render_mesh(
vertices_cr_gt[i].cuda(), np.zeros_like(images[i]), colors=[39, 194, 128]
)
# get error color
obs_error = (vertices_cr_gt[i] - vertices_cr_obs[i]).norm(dim=-1)
pred_error = (vertices_cr_gt[i] - vertices_cr_pred[i]).norm(dim=-1)
max_error = max(obs_error.max(), pred_error.max())
obs_error_color = torch.stack(
[obs_error / max_error, torch.ones_like(obs_error) * 0.6, torch.ones_like(obs_error) * 0.6],
dim=-1,
)
obs_error_color = torch.clip(obs_error_color, 0, 1)
pred_error_color = torch.stack(
[pred_error / max_error, torch.ones_like(pred_error) * 0.6, torch.ones_like(pred_error) * 0.6],
dim=-1,
)
pred_error_color = torch.clip(pred_error_color, 0, 1)
# overlay cr with error color
black_overlay_obs = renderer.render_mesh(
vertices_cr_obs[i].cuda(), black_overlay_obs, colors=obs_error_color[None]
)
black_overlay_pred = renderer.render_mesh(
vertices_cr_pred[i].cuda(), black_overlay_pred, colors=pred_error_color[None]
)
# write mpjpe on the img
obs_mpjpe_ = camcoord_metrics["mpjpe"][i]
text = f"obs mpjpe: {obs_mpjpe_:.1f} ({obs_mpjpe:.1f})"
cv2.putText(black_overlay_obs, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (100, 200, 200), 2)
pred_mpjpe_ = self.metric_aggregator["mpjpe"][vid][i]
text = f"pred mpjpe: {pred_mpjpe_:.1f} ({pred_mpjpe:.1f})"
if pred_mpjpe_ > obs_mpjpe_:
# large error -> purple
cv2.putText(black_overlay_pred, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (200, 100, 200), 2)
else:
# small error -> yellow
cv2.putText(black_overlay_pred, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (200, 200, 100), 2)
black = np.concatenate([black_overlay_obs, black_overlay_pred], axis=0)
###########################################
img = np.concatenate([img, black], axis=1)
writer.append_data(img)
writer.close()
if False: # Visualize incam + global results
def move_to_start_point_face_z(verts):
"XZ to origin, Start from the ground, Face-Z"
# position
verts = verts.clone() # (L, V, 3)
offset = einsum(self.J_regressor, verts[0], "j v, v i -> j i")[0] # (3)
offset[1] = verts[:, :, [1]].min()
verts = verts - offset
# face direction
T_ay2ayfz = compute_T_ayfz2ay(einsum(self.J_regressor, verts[[0]], "j v, l v i -> l j i"), inverse=True)
verts = apply_T_on_points(verts, T_ay2ayfz)
return verts
verts_incam = pred_c_verts.clone()
# verts_glob = move_to_start_point_face_z(target_ay_verts) # gt
verts_glob = move_to_start_point_face_z(pred_ay_verts)
joints_glob = einsum(self.J_regressor, verts_glob, "j v, l v i -> l j i") # (L, J, 3)
global_R, global_T, global_lights = get_global_cameras_static(
verts_glob.cpu(),
beta=4.0,
cam_height_degree=20,
target_center_height=1.0,
vec_rot=-45,
)
# -- rendering code (global version FOV=55) -- #
vname = batch["meta_render"][0]["name"]
width, height = batch["meta_render"][0]["width_height"]
K = batch["meta_render"][0]["K"]
faces = self.faces_smpl
out_fn = f"outputs/dump_render_global/{vname}.mp4"
Path(out_fn).parent.mkdir(exist_ok=True, parents=True)
# two renderers
renderer_incam = Renderer(width, height, device="cuda", faces=faces, K=K)
renderer_glob = Renderer(width, height, estimate_focal_length(width, height), device="cuda", faces=faces)
# imgs
video_path = batch["meta_render"][0]["video_path"]
frame_id = batch["meta_render"][0]["frame_id"].cpu().numpy()
images = read_video_np(video_path)[frame_id] # (F, H/4, W/4, 3), uint8, numpy
# Actual rendering
scale, cx, cz = get_ground_params_from_points(joints_glob[:, 0], verts_glob)
renderer_glob.set_ground(scale * 1.5, cx, cz)
color = torch.ones(3).float().cuda() * 0.8
writer = get_writer(out_fn, fps=30, crf=23)
for i in tqdm(range(seq_length), desc=f"Rendering {vname}"):
# incam
img_overlay_pred = renderer_incam.render_mesh(verts_incam[i].cuda(), images[i], [0.8, 0.8, 0.8])
# if batch["meta_render"][0].get("bbx_xys", None) is not None: # draw bbox lines
# bbx_xys = batch["meta_render"][0]["bbx_xys"][i].cpu().numpy()
# lu_point = (bbx_xys[:2] - bbx_xys[2:] / 2).astype(int)
# rd_point = (bbx_xys[:2] + bbx_xys[2:] / 2).astype(int)
# img_overlay_pred = cv2.rectangle(img_overlay_pred, lu_point, rd_point, (255, 178, 102), 2)
# pred_mpjpe_ = self.metric_aggregator["mpjpe"][vid][i]
# text = f"pred mpjpe: {pred_mpjpe_:.1f}"
# cv2.putText(img_overlay_pred, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (200, 100, 200), 2)
# glob
cameras = renderer_glob.create_camera(global_R[i], global_T[i])
# img_glob = renderer_glob.render_with_ground(verts_glob[[i]], color_[None], cameras, global_lights)
img_glob = renderer_glob.render_with_ground(
verts_glob[[i]], color.clone()[None], cameras, global_lights
)
# write
img = np.concatenate([img_overlay_pred, img_glob], axis=1)
writer.write_frame(img)
writer.close()
# ================== Epoch Summary ================== #
def on_predict_epoch_end(self, trainer, pl_module):
"""Without logger"""
local_rank, world_size = trainer.local_rank, trainer.world_size
monitor_metric = "mpjpe"
# Reduce metric_aggregator across all processes
metric_keys = list(self.metric_aggregator.keys())
with torch.inference_mode(False): # allow in-place operation of all_gather
metric_aggregator_gathered = all_gather(self.metric_aggregator) # list of dict
for metric_key in metric_keys:
for d in metric_aggregator_gathered:
self.metric_aggregator[metric_key].update(d[metric_key])
if False: # debug to make sure the all_gather is correct
print(f"[RANK {local_rank}/{world_size}]: {self.metric_aggregator[monitor_metric].keys()}")
total = len(self.metric_aggregator[monitor_metric])
Log.info(f"{total} sequences evaluated in {self.__class__.__name__}")
if total == 0:
return
# print monitored metric per sequence
mm_per_seq = {k: v.mean() for k, v in self.metric_aggregator[monitor_metric].items()}
if len(mm_per_seq) > 0:
sorted_mm_per_seq = sorted(mm_per_seq.items(), key=lambda x: x[1], reverse=True)
n_worst = 5 if trainer.state.stage == "validate" else len(sorted_mm_per_seq)
if local_rank == 0:
Log.info(
f"monitored metric {monitor_metric} per sequence\n"
+ "\n".join([f"{m:5.1f} : {s}" for s, m in sorted_mm_per_seq[:n_worst]])
+ "\n------"
)
# average over all batches
metrics_avg = {k: np.concatenate(list(v.values())).mean() for k, v in self.metric_aggregator.items()}
if local_rank == 0:
Log.info(f"[Metrics] RICH:\n" + "\n".join(f"{k}: {v:.1f}" for k, v in metrics_avg.items()) + "\n------")
# save to logger if available
if pl_module.logger is not None:
cur_epoch = pl_module.current_epoch
for k, v in metrics_avg.items():
pl_module.logger.log_metrics({f"val_metric_RICH/{k}": v}, step=cur_epoch)
# reset
for k in self.metric_aggregator:
self.metric_aggregator[k] = {}
rich_node = builds(MetricMocap)
MainStore.store(name="metric_rich", node=rich_node, group="callbacks", package="callbacks.metric_rich")
================================================
FILE: eval/GVHMR/hmr4d/model/gvhmr/gvhmr_pl.py
================================================
from typing import Any, Dict
import numpy as np
from pathlib import Path
import torch
import pytorch_lightning as pl
from hydra.utils import instantiate
from hmr4d.utils.pylogger import Log
from einops import rearrange, einsum
from hmr4d.configs import MainStore, builds
from hmr4d.utils.geo_transform import compute_T_ayfz2ay, apply_T_on_points
from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines
from hmr4d.utils.smplx_utils import make_smplx
from hmr4d.utils.geo.augment_noisy_pose import (
get_wham_aug_kp3d,
get_visible_mask,
get_invisible_legs_mask,
randomly_occlude_lower_half,
randomly_modify_hands_legs,
)
from hmr4d.utils.geo.hmr_cam import perspective_projection, normalize_kp2d, safely_render_x3d_K, get_bbx_xys
from hmr4d.utils.video_io_utils import save_video
from hmr4d.utils.vis.cv2_utils import draw_bbx_xys_on_image_batch
from hmr4d.utils.geo.flip_utils import flip_smplx_params, avg_smplx_aa
from hmr4d.model.gvhmr.utils.postprocess import pp_static_joint, pp_static_joint_cam, process_ik
class GvhmrPL(pl.LightningModule):
def __init__(
self,
pipeline,
optimizer=None,
scheduler_cfg=None,
ignored_weights_prefix=["smplx", "pipeline.endecoder"],
):
super().__init__()
self.pipeline = instantiate(pipeline, _recursive_=False)
self.optimizer = instantiate(optimizer)
self.scheduler_cfg = scheduler_cfg
# Options
self.ignored_weights_prefix = ignored_weights_prefix
# The test step is the same as validation
self.test_step = self.predict_step = self.validation_step
# SMPLX
self.smplx = make_smplx("supermotion_v437coco17")
def training_step(self, batch, batch_idx):
B, F = batch["smpl_params_c"]["body_pose"].shape[:2]
# Create augmented noisy-obs : gt_j3d(coco17)
with torch.no_grad():
gt_verts437, gt_j3d = self.smplx(**batch["smpl_params_c"])
root_ = gt_j3d[:, :, [11, 12], :].mean(-2, keepdim=True)
batch["gt_j3d"] = gt_j3d
batch["gt_cr_coco17"] = gt_j3d - root_
batch["gt_c_verts437"] = gt_verts437
batch["gt_cr_verts437"] = gt_verts437 - root_
# bbx_xys
i_x2d = safely_render_x3d_K(gt_verts437, batch["K_fullimg"], thr=0.3)
bbx_xys = get_bbx_xys(i_x2d, do_augment=True)
if False: # trust image bbx_xys seems better
batch["bbx_xys"] = bbx_xys
else:
mask_bbx_xys = batch["mask"]["bbx_xys"]
batch["bbx_xys"][~mask_bbx_xys] = bbx_xys[~mask_bbx_xys]
if False: # visualize bbx_xys from an iPhone view
render_w, render_h = 120, 160 # iphone main-lens 24mm 3:4
ratio = render_w / 1528
offset = torch.tensor([764 - 500, 1019 - 500]).to(i_x2d)
i_x2d_render = (i_x2d + offset).clone()
i_x2d_render = (i_x2d_render * ratio).long().clone()
torch.clamp_(i_x2d_render[..., 0], 0, render_w - 1)
torch.clamp_(i_x2d_render[..., 1], 0, render_h - 1)
bbx_xys_render = bbx_xys.clone()
bbx_xys_render[..., :2] += offset
bbx_xys_render *= ratio
output_dir = Path("outputs/simulated_bbx_xys")
output_dir.mkdir(parents=True, exist_ok=True)
video_list = []
for bid in range(B):
images = torch.zeros(F, render_h, render_w, 3, device=i_x2d.device)
for fid in range(F):
images[fid, i_x2d_render[bid, fid, :, 1], i_x2d_render[bid, fid, :, 0]] = 255
images = draw_bbx_xys_on_image_batch(bbx_xys_render[bid].cpu().numpy(), images.cpu().numpy())
images = np.stack(images).astype("uint8") # (L, H, W, 3)
images[:, 0, :] = np.array([255, 255, 255])
images[:, -1, :] = np.array([255, 255, 255])
images[:, :, 0] = np.array([255, 255, 255])
images[:, :, -1] = np.array([255, 255, 255])
video_list.append(images)
# stack videos
video_output = []
for i in range(0, len(video_list), 4):
if i + 4 <= len(video_list):
video_output.append(np.concatenate(video_list[i : i + 4], axis=2))
video_output = np.concatenate(video_output, axis=1)
save_video(video_output, output_dir / f"{batch_idx}.mp4", fps=30, quality=5)
# noisy_j3d -> project to i_j2d -> compute a bbx -> normalized kp2d [-1, 1]
noisy_j3d = gt_j3d + get_wham_aug_kp3d(gt_j3d.shape[:2])
if True:
noisy_j3d = randomly_modify_hands_legs(noisy_j3d)
obs_i_j2d = perspective_projection(noisy_j3d, batch["K_fullimg"]) # (B, L, J, 2)
j2d_visible_mask = get_visible_mask(gt_j3d.shape[:2]).cuda() # (B, L, J)
j2d_visible_mask[noisy_j3d[..., 2] < 0.3] = False # Set close-to-image-plane points as invisible
if True: # Set both legs as invisible for a period
legs_invisible_mask = get_invisible_legs_mask(gt_j3d.shape[:2]).cuda() # (B, L, J)
j2d_visible_mask[legs_invisible_mask] = False
obs_kp2d = torch.cat([obs_i_j2d, j2d_visible_mask[:, :, :, None].float()], dim=-1) # (B, L, J, 3)
obs = normalize_kp2d(obs_kp2d, batch["bbx_xys"]) # (B, L, J, 3)
obs[~j2d_visible_mask] = 0 # if not visible, set to (0,0,0)
batch["obs"] = obs
if True: # Use some detected vitpose (presave data)
prob = 0.5
mask_real_vitpose = (torch.rand(B).to(obs_kp2d) < prob) * batch["mask"]["vitpose"]
batch["obs"][mask_real_vitpose] = normalize_kp2d(batch["kp2d"], batch["bbx_xys"])[mask_real_vitpose]
# Set untrusted frames to False
batch["obs"][~batch["mask"]["valid"]] = 0
if False: # wis3d
wis3d = make_wis3d(name="debug-aug-kp3d")
add_motion_as_lines(gt_j3d[0], wis3d, name="gt_j3d", skeleton_type="coco17")
add_motion_as_lines(noisy_j3d[0], wis3d, name="noisy_j3d", skeleton_type="coco17")
# f_imgseq: apply random aug on offline extracted features
# f_imgseq = batch["f_imgseq"] + torch.randn_like(batch["f_imgseq"]) * 0.1
# f_imgseq[~batch["mask"]["f_imgseq"]] = 0
# batch["f_imgseq"] = f_imgseq.clone()
# Forward and get loss
outputs = self.pipeline.forward(batch, train=True)
# Log
log_kwargs = {
"on_epoch": True,
"prog_bar": True,
"logger": True,
"batch_size": B,
"sync_dist": True,
}
self.log("train/loss", outputs["loss"], **log_kwargs)
for k, v in outputs.items():
if "_loss" in k:
self.log(f"train/{k}", v, **log_kwargs)
return outputs
def validation_step(self, batch, batch_idx, dataloader_idx=0):
# Options & Check
do_postproc = self.trainer.state.stage == "test" # Only apply postproc in test
do_flip_test = "flip_test" in batch
do_postproc_not_flip_test = do_postproc and not do_flip_test # later pp when flip_test
assert batch["B"] == 1, "Only support batch size 1 in evalution."
# ROPE inference
obs = normalize_kp2d(batch["kp2d"], batch["bbx_xys"])
if "mask" in batch:
obs[0, ~batch["mask"][0]] = 0
batch_ = {
"length": batch["length"],
"obs": obs,
"bbx_xys": batch["bbx_xys"],
"K_fullimg": batch["K_fullimg"],
"cam_angvel": batch["cam_angvel"],
"f_imgseq": batch["f_imgseq"],
}
outputs = self.pipeline.forward(batch_, train=False, postproc=do_postproc_not_flip_test)
outputs["pred_smpl_params_global"] = {k: v[0] for k, v in outputs["pred_smpl_params_global"].items()}
outputs["pred_smpl_params_incam"] = {k: v[0] for k, v in outputs["pred_smpl_params_incam"].items()}
if do_flip_test:
flip_test = batch["flip_test"]
obs = normalize_kp2d(flip_test["kp2d"], flip_test["bbx_xys"])
if "mask" in batch:
obs[0, ~batch["mask"][0]] = 0
batch_ = {
"length": batch["length"],
"obs": obs,
"bbx_xys": flip_test["bbx_xys"],
"K_fullimg": batch["K_fullimg"],
"cam_angvel": flip_test["cam_angvel"],
"f_imgseq": flip_test["f_imgseq"],
}
flipped_outputs = self.pipeline.forward(batch_, train=False)
# First update incam results
flipped_outputs["pred_smpl_params_incam"] = {
k: v[0] for k, v in flipped_outputs["pred_smpl_params_incam"].items()
}
smpl_params1 = outputs["pred_smpl_params_incam"]
smpl_params2 = flip_smplx_params(flipped_outputs["pred_smpl_params_incam"])
smpl_params_avg = smpl_params1.copy()
smpl_params_avg["betas"] = (smpl_params1["betas"] + smpl_params2["betas"]) / 2
smpl_params_avg["body_pose"] = avg_smplx_aa(smpl_params1["body_pose"], smpl_params2["body_pose"])
smpl_params_avg["global_orient"] = avg_smplx_aa(
smpl_params1["global_orient"], smpl_params2["global_orient"]
)
outputs["pred_smpl_params_incam"] = smpl_params_avg
# Then update global results
outputs["pred_smpl_params_global"]["betas"] = smpl_params_avg["betas"]
outputs["pred_smpl_params_global"]["body_pose"] = smpl_params_avg["body_pose"]
# Finally, apply postprocess
if do_postproc:
# temporarily recover the original batch-dim
outputs["pred_smpl_params_global"] = {k: v[None] for k, v in outputs["pred_smpl_params_global"].items()}
outputs["pred_smpl_params_global"]["transl"] = pp_static_joint(outputs, self.pipeline.endecoder)
body_pose = process_ik(outputs, self.pipeline.endecoder)
outputs["pred_smpl_params_global"] = {k: v[0] for k, v in outputs["pred_smpl_params_global"].items()}
outputs["pred_smpl_params_global"]["body_pose"] = body_pose[0]
# outputs["pred_smpl_params_incam"]["body_pose"] = body_pose[0]
if False: # wis3d
wis3d = make_wis3d(name="debug-rich-cap")
smplx_model = make_smplx("rich-smplx", gender="neutral").cuda()
gender = batch["gender"][0]
T_w2ay = batch["T_w2ay"][0]
# Prediction
# add_motion_as_lines(outputs_window["pred_ayfz_motion"][bid], wis3d, name="pred_ayfz_motion")
smplx_out = smplx_model(**pred_smpl_params_global)
for i in range(len(smplx_out.vertices)):
wis3d.set_scene_id(i)
wis3d.add_mesh(smplx_out.vertices[i], smplx_model.bm.faces, name=f"pred-smplx-global")
# GT (w)
smplx_models = {
"male": make_smplx("rich-smplx", gender="male").cuda(),
"female": make_smplx("rich-smplx", gender="female").cuda(),
}
gt_smpl_params = {k: v[0, windows[0]] for k, v in batch["gt_smpl_params"].items()}
gt_smplx_out = smplx_models[gender](**gt_smpl_params)
# GT (ayfz)
smplx_verts_ay = apply_T_on_points(gt_smplx_out.vertices, T_w2ay)
smplx_joints_ay = apply_T_on_points(gt_smplx_out.joints, T_w2ay)
T_ay2ayfz = compute_T_ayfz2ay(smplx_joints_ay[:1], inverse=True)[0] # (4, 4)
smplx_verts_ayfz = apply_T_on_points(smplx_verts_ay, T_ay2ayfz) # (F, 22, 3)
for i in range(len(smplx_verts_ayfz)):
wis3d.set_scene_id(i)
wis3d.add_mesh(smplx_verts_ayfz[i], smplx_models[gender].bm.faces, name=f"gt-smplx-ayfz")
breakpoint()
if False: # o3d
prog_keys = [
"pred_smpl_progress",
"pred_localjoints_progress",
"pred_incam_localjoints_progress",
]
for k in prog_keys:
if k in outputs_window:
seq_out = torch.cat(
[v[:, :l] for v, l in zip(outputs_window[k], length)], dim=1
) # (B, P, L, J, 3) -> (P, L, J, 3) -> (P, CL, J, 3)
outputs[k] = seq_out[None]
return outputs
def configure_optimizers(self):
params = []
for k, v in self.pipeline.named_parameters():
if v.requires_grad:
params.append(v)
optimizer = self.optimizer(params=params)
if self.scheduler_cfg["scheduler"] is None:
return optimizer
scheduler_cfg = dict(self.scheduler_cfg)
scheduler_cfg["scheduler"] = instantiate(scheduler_cfg["scheduler"], optimizer=optimizer)
return [optimizer], [scheduler_cfg]
# ============== Utils ================= #
def on_save_checkpoint(self, checkpoint) -> None:
for ig_keys in self.ignored_weights_prefix:
for k in list(checkpoint["state_dict"].keys()):
if k.startswith(ig_keys):
# Log.info(f"Remove key `{ig_keys}' from checkpoint.")
checkpoint["state_dict"].pop(k)
def load_pretrained_model(self, ckpt_path):
"""Load pretrained checkpoint, and assign each weight to the corresponding part."""
Log.info(f"[PL-Trainer] Loading ckpt: {ckpt_path}")
state_dict = torch.load(ckpt_path, "cpu")["state_dict"]
missing, unexpected = self.load_state_dict(state_dict, strict=False)
real_missing = []
for k in missing:
ignored_when_saving = any(k.startswith(ig_keys) for ig_keys in self.ignored_weights_prefix)
if not ignored_when_saving:
real_missing.append(k)
if len(real_missing) > 0:
Log.warn(f"Missing keys: {real_missing}")
if len(unexpected) > 0:
Log.warn(f"Unexpected keys: {unexpected}")
gvhmr_pl = builds(
GvhmrPL,
pipeline="${pipeline}",
optimizer="${optimizer}",
scheduler_cfg="${scheduler_cfg}",
populate_full_signature=True, # Adds all the arguments to the signature
)
MainStore.store(name="gvhmr_pl", node=gvhmr_pl, group="model/gvhmr")
================================================
FILE: eval/GVHMR/hmr4d/model/gvhmr/gvhmr_pl_demo.py
================================================
import torch
import pytorch_lightning as pl
from hydra.utils import instantiate
from hmr4d.utils.pylogger import Log
from hmr4d.configs import MainStore, builds
from hmr4d.utils.geo.hmr_cam import normalize_kp2d
class DemoPL(pl.LightningModule):
def __init__(self, pipeline):
super().__init__()
self.pipeline = instantiate(pipeline, _recursive_=False)
@torch.no_grad()
def predict(self, data, static_cam=False):
"""auto add batch dim
data: {
"length": int, or Torch.Tensor,
"kp2d": (F, 3)
"bbx_xys": (F, 3)
"K_fullimg": (F, 3, 3)
"cam_angvel": (F, 3)
"f_imgseq": (F, 3, 256, 256)
}
"""
# ROPE inference
batch = {
"length": data["length"][None],
"obs": normalize_kp2d(data["kp2d"], data["bbx_xys"])[None],
"bbx_xys": data["bbx_xys"][None],
"K_fullimg": data["K_fullimg"][None],
"cam_angvel": data["cam_angvel"][None],
"f_imgseq": data["f_imgseq"][None],
}
batch = {k: v.cuda() for k, v in batch.items()}
outputs = self.pipeline.forward(batch, train=False, postproc=True, static_cam=static_cam)
pred = {
"smpl_params_global": {k: v[0] for k, v in outputs["pred_smpl_params_global"].items()},
"smpl_params_incam": {k: v[0] for k, v in outputs["pred_smpl_params_incam"].items()},
"K_fullimg": data["K_fullimg"],
"net_outputs": outputs, # intermediate outputs
}
return pred
def load_pretrained_model(self, ckpt_path):
"""Load pretrained checkpoint, and assign each weight to the corresponding part."""
Log.info(f"[PL-Trainer] Loading ckpt type: {ckpt_path}")
state_dict = torch.load(ckpt_path, "cpu")["state_dict"]
missing, unexpected = self.load_state_dict(state_dict, strict=False)
if len(missing) > 0:
Log.warn(f"Missing keys: {missing}")
if len(unexpected) > 0:
Log.warn(f"Unexpected keys: {unexpected}")
MainStore.store(name="gvhmr_pl_demo", node=builds(DemoPL, pipeline="${pipeline}"), group="model/gvhmr")
================================================
FILE: eval/GVHMR/hmr4d/model/gvhmr/pipeline/gvhmr_pipeline.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
import numpy as np
from einops import einsum, rearrange, repeat
from hydra.utils import instantiate
from hmr4d.utils.pylogger import Log
from hmr4d.utils.net_utils import gaussian_smooth
from hmr4d.model.gvhmr.utils.endecoder import EnDecoder
from hmr4d.model.gvhmr.utils.postprocess import (
pp_static_joint,
process_ik,
pp_static_joint_cam,
)
from hmr4d.model.gvhmr.utils import stats_compose
from pytorch3d.transforms import (
matrix_to_rotation_6d,
rotation_6d_to_matrix,
axis_angle_to_matrix,
matrix_to_axis_angle,
)
from hmr4d.utils.geo.hmr_cam import compute_bbox_info_bedlam, compute_transl_full_cam, get_a_pred_cam, project_to_bi01
from hmr4d.utils.geo.hmr_global import (
rollout_local_transl_vel,
get_static_joint_mask,
get_tgtcoord_rootparam,
)
from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines
from hmr4d.utils.smplx_utils import make_smplx
class Pipeline(nn.Module):
def __init__(self, args, args_denoiser3d, **kwargs):
super().__init__()
self.args = args
self.weights = args.weights # loss weights
# Networks
self.denoiser3d = instantiate(args_denoiser3d, _recursive_=False)
# Log.info(self.denoiser3d)
# Normalizer
self.endecoder: EnDecoder = instantiate(args.endecoder_opt, _recursive_=False)
if self.args.normalize_cam_angvel:
cam_angvel_stats = stats_compose.cam_angvel["manual"]
self.register_buffer("cam_angvel_mean", torch.tensor(cam_angvel_stats["mean"]), persistent=False)
self.register_buffer("cam_angvel_std", torch.tensor(cam_angvel_stats["std"]), persistent=False)
# ========== Training ========== #
def forward(self, inputs, train=False, postproc=False, static_cam=False):
outputs = dict()
length = inputs["length"] # (B,) effective length of each sample
# *. Conditions
cliff_cam = compute_bbox_info_bedlam(inputs["bbx_xys"], inputs["K_fullimg"]) # (B, L, 3)
f_cam_angvel = inputs["cam_angvel"]
if self.args.normalize_cam_angvel:
f_cam_angvel = (f_cam_angvel - self.cam_angvel_mean) / self.cam_angvel_std
f_condition = {
"obs": inputs["obs"], # (B, L, J, 3)
"f_cliffcam": cliff_cam, # (B, L, 3)
"f_cam_angvel": f_cam_angvel, # (B, L, C=6)
"f_imgseq": inputs["f_imgseq"], # (B, L, C=1024)
}
if train:
f_condition = randomly_set_null_condition(f_condition, 0.1)
# Forward & output
model_output = self.denoiser3d(length=length, **f_condition) # pred_x, pred_cam, static_conf_logits
decode_dict = self.endecoder.decode(model_output["pred_x"]) # (B, L, C) -> dict
outputs.update({"model_output": model_output, "decode_dict": decode_dict})
# Post-processing
outputs["pred_smpl_params_incam"] = {
"body_pose": decode_dict["body_pose"], # (B, L, 63)
"betas": decode_dict["betas"], # (B, L, 10)
"global_orient": decode_dict["global_orient"], # (B, L, 3)
"transl": compute_transl_full_cam(model_output["pred_cam"], inputs["bbx_xys"], inputs["K_fullimg"]),
}
if not train:
pred_smpl_params_global = get_smpl_params_w_Rt_v2( # This function has for-loop
global_orient_gv=decode_dict["global_orient_gv"],
local_transl_vel=decode_dict["local_transl_vel"],
global_orient_c=decode_dict["global_orient"],
cam_angvel=inputs["cam_angvel"],
)
outputs["pred_smpl_params_global"] = {
"body_pose": decode_dict["body_pose"],
"betas": decode_dict["betas"],
**pred_smpl_params_global,
}
outputs["static_conf_logits"] = model_output["static_conf_logits"]
if postproc: # apply post-processing
if static_cam: # extra post-processing to utilize static camera prior
outputs["pred_smpl_params_global"]["transl"] = pp_static_joint_cam(outputs, self.endecoder)
else:
outputs["pred_smpl_params_global"]["transl"] = pp_static_joint(outputs, self.endecoder)
body_pose = process_ik(outputs, self.endecoder)
decode_dict["body_pose"] = body_pose
outputs["pred_smpl_params_global"]["body_pose"] = body_pose
outputs["pred_smpl_params_incam"]["body_pose"] = body_pose
return outputs
# ========== Compute Loss ========== #
total_loss = 0
mask = inputs["mask"]["valid"] # (B, L)
# 1. Simple loss: MSE
pred_x = model_output["pred_x"] # (B, L, C)
target_x = self.endecoder.encode(inputs) # (B, L, C)
simple_loss = F.mse_loss(pred_x, target_x, reduction="none")
mask_simple = mask[:, :, None].expand(-1, -1, pred_x.size(2)).clone() # (B, L, C)
mask_simple[inputs["mask"]["spv_incam_only"], :, 142:] = False # 3dpw training
simple_loss = (simple_loss * mask_simple).mean()
total_loss += simple_loss
outputs["simple_loss"] = simple_loss
# 2. Extra loss
extra_funcs = [
compute_extra_incam_loss,
compute_extra_global_loss,
]
for extra_func in extra_funcs:
extra_loss, extra_loss_dict = extra_func(inputs, outputs, self)
total_loss += extra_loss
outputs.update(extra_loss_dict)
outputs["loss"] = total_loss
return outputs
def randomly_set_null_condition(f_condition, uncond_prob=0.1):
"""Conditions are in shape (B, L, *)"""
keys = list(f_condition.keys())
for k in keys:
if f_condition[k] is None:
continue
f_condition[k] = f_condition[k].clone()
mask = torch.rand(f_condition[k].shape[:2]) < uncond_prob
f_condition[k][mask] = 0.0
return f_condition
def compute_extra_incam_loss(inputs, outputs, ppl):
model_output = outputs["model_output"]
decode_dict = outputs["decode_dict"]
endecoder = ppl.endecoder
weights = ppl.weights
args = ppl.args
extra_loss_dict = {}
extra_loss = 0
mask = inputs["mask"]["valid"] # effective length mask
mask_reproj = ~inputs["mask"]["spv_incam_only"] # do not supervise reproj for 3DPW
# Incam FK
# prediction
pred_c_j3d = endecoder.fk_v2(**outputs["pred_smpl_params_incam"])
pred_cr_j3d = pred_c_j3d - pred_c_j3d[:, :, :1] # (B, L, J, 3)
# gt
gt_c_j3d = endecoder.fk_v2(**inputs["smpl_params_c"]) # (B, L, J, 3)
gt_cr_j3d = gt_c_j3d - gt_c_j3d[:, :, :1] # (B, L, J, 3)
# Root aligned C-MPJPE Loss
if weights.cr_j3d > 0.0:
cr_j3d_loss = F.mse_loss(pred_cr_j3d, gt_cr_j3d, reduction="none")
cr_j3d_loss = (cr_j3d_loss * mask[..., None, None]).mean()
extra_loss += cr_j3d_loss * weights.cr_j3d
extra_loss_dict["cr_j3d_loss"] = cr_j3d_loss
# Reprojection (to align with image)
if weights.transl_c > 0.0:
# pred_transl = decode_dict["transl"] # (B, L, 3)
# gt_transl = inputs["smpl_params_c"]["transl"]
# transl_c_loss = F.l1_loss(pred_transl, gt_transl, reduction="none")
# transl_c_loss = (transl_c_loss * mask[..., None]).mean()
# Instead of supervising transl, we convert gt to pred_cam (prevent divide 0)
pred_cam = model_output["pred_cam"] # (B, L, 3)
gt_transl = inputs["smpl_params_c"]["transl"] # (B, L, 3)
gt_pred_cam = get_a_pred_cam(gt_transl, inputs["bbx_xys"], inputs["K_fullimg"]) # (B, L, 3)
gt_pred_cam[gt_pred_cam.isinf()] = -1 # this will be handled by valid_mask
# (compute_transl_full_cam(gt_pred_cam, inputs["bbx_xys"], inputs["K_fullimg"]) - gt_transl).abs().max()
# Skip gts that are not good during random construction
gt_j3d_z_min = inputs["gt_j3d"][..., 2].min(dim=-1)[0]
valid_mask = (
(gt_j3d_z_min > 0.3)
* (gt_pred_cam[..., 0] > 0.3)
* (gt_pred_cam[..., 0] < 5.0)
* (gt_pred_cam[..., 1] > -3.0)
* (gt_pred_cam[..., 1] < 3.0)
* (gt_pred_cam[..., 2] > -3.0)
* (gt_pred_cam[..., 2] < 3.0)
* (inputs["bbx_xys"][..., 2] > 0)
)[..., None]
transl_c_loss = F.mse_loss(pred_cam, gt_pred_cam, reduction="none")
transl_c_loss = (transl_c_loss * mask[..., None] * valid_mask).mean()
extra_loss_dict["transl_c_loss"] = transl_c_loss
extra_loss += transl_c_loss * weights.transl_c
if weights.j2d > 0.0:
# prevent divide 0 or small value to overflow(fp16)
reproj_z_thr = 0.3
pred_c_j3d_z0_mask = pred_c_j3d[..., 2].abs() <= reproj_z_thr
pred_c_j3d[pred_c_j3d_z0_mask] = reproj_z_thr
gt_c_j3d_z0_mask = gt_c_j3d[..., 2].abs() <= reproj_z_thr
gt_c_j3d[gt_c_j3d_z0_mask] = reproj_z_thr
pred_j2d_01 = project_to_bi01(pred_c_j3d, inputs["bbx_xys"], inputs["K_fullimg"])
gt_j2d_01 = project_to_bi01(gt_c_j3d, inputs["bbx_xys"], inputs["K_fullimg"]) # (B, L, J, 2)
valid_mask = (
(gt_c_j3d[..., 2] > reproj_z_thr)
* (pred_c_j3d[..., 2] > reproj_z_thr) # Be safe
* (gt_j2d_01[..., 0] > 0.0)
* (gt_j2d_01[..., 0] < 1.0)
* (gt_j2d_01[..., 1] > 0.0)
* (gt_j2d_01[..., 1] < 1.0)
)[..., None]
valid_mask[~mask_reproj] = False # Do not supervise on 3dpw
j2d_loss = F.mse_loss(pred_j2d_01, gt_j2d_01, reduction="none")
j2d_loss = (j2d_loss * mask[..., None, None] * valid_mask).mean()
extra_loss += j2d_loss * weights.j2d
extra_loss_dict["j2d_loss"] = j2d_loss
if weights.cr_verts > 0:
# SMPL forward
pred_c_verts437, pred_c_j17 = endecoder.smplx_model(**outputs["pred_smpl_params_incam"])
root_ = pred_c_j17[:, :, [11, 12], :].mean(-2, keepdim=True)
pred_cr_verts437 = pred_c_verts437 - root_
gt_cr_verts437 = inputs["gt_cr_verts437"] # (B, L, 437, 3)
cr_vert_loss = F.mse_loss(pred_cr_verts437, gt_cr_verts437, reduction="none")
cr_vert_loss = (cr_vert_loss * mask[:, :, None, None]).mean()
extra_loss += cr_vert_loss * weights.cr_verts
extra_loss_dict["cr_vert_loss"] = cr_vert_loss
if weights.verts2d > 0:
gt_c_verts437 = inputs["gt_c_verts437"] # (B, L, 437, 3)
# prevent divide 0 or small value to overflow(fp16)
reproj_z_thr = 0.3
pred_c_verts437_z0_mask = pred_c_verts437[..., 2].abs() <= reproj_z_thr
pred_c_verts437[pred_c_verts437_z0_mask] = reproj_z_thr
gt_c_verts437_z0_mask = gt_c_verts437[..., 2].abs() <= reproj_z_thr
gt_c_verts437[gt_c_verts437_z0_mask] = reproj_z_thr
pred_verts2d_01 = project_to_bi01(pred_c_verts437, inputs["bbx_xys"], inputs["K_fullimg"])
gt_verts2d_01 = project_to_bi01(gt_c_verts437, inputs["bbx_xys"], inputs["K_fullimg"]) # (B, L, 437, 2)
valid_mask = (
(gt_c_verts437[..., 2] > reproj_z_thr)
* (pred_c_verts437[..., 2] > reproj_z_thr) # Be safe
* (gt_verts2d_01[..., 0] > 0.0)
* (gt_verts2d_01[..., 0] < 1.0)
* (gt_verts2d_01[..., 1] > 0.0)
* (gt_verts2d_01[..., 1] < 1.0)
)[..., None]
valid_mask[~mask_reproj] = False # Do not supervise on 3dpw
verts2d_loss = F.mse_loss(pred_verts2d_01, gt_verts2d_01, reduction="none")
verts2d_loss = (verts2d_loss * mask[..., None, None] * valid_mask).mean()
extra_loss += verts2d_loss * weights.verts2d
extra_loss_dict["verts2d_loss"] = verts2d_loss
return extra_loss, extra_loss_dict
def compute_extra_global_loss(inputs, outputs, ppl):
decode_dict = outputs["decode_dict"]
endecoder = ppl.endecoder
weights = ppl.weights
args = ppl.args
extra_loss_dict = {}
extra_loss = 0
mask = inputs["mask"]["valid"].clone() # (B, L)
mask[inputs["mask"]["spv_incam_only"]] = False
if weights.transl_w > 0:
# compute pred_transl_w by rollout
gt_transl_w = inputs["smpl_params_w"]["transl"]
gt_global_orient_w = inputs["smpl_params_w"]["global_orient"]
local_transl_vel = decode_dict["local_transl_vel"]
pred_transl_w = rollout_local_transl_vel(local_transl_vel, gt_global_orient_w, gt_transl_w[:, [0]])
trans_w_loss = F.l1_loss(pred_transl_w, gt_transl_w, reduction="none")
trans_w_loss = (trans_w_loss * mask[..., None]).mean()
extra_loss += trans_w_loss * weights.transl_w
extra_loss_dict["transl_w_loss"] = trans_w_loss
# Static-Conf loss
if weights.static_conf_bce > 0:
# Compute gt by thresholding velocity
vel_thr = args.static_conf.vel_thr
assert vel_thr > 0
joint_ids = [7, 10, 8, 11, 20, 21] # [L_Ankle, L_foot, R_Ankle, R_foot, L_wrist, R_wrist]
gt_w_j3d = endecoder.fk_v2(**inputs["smpl_params_w"]) # (B, L, J=22, 3)
static_gt = get_static_joint_mask(gt_w_j3d, vel_thr=vel_thr, repeat_last=True) # (B, L, J)
static_gt = static_gt[:, :, joint_ids].float() # (B, L, J')
pred_static_conf_logits = outputs["model_output"]["static_conf_logits"]
static_conf_loss = F.binary_cross_entropy_with_logits(pred_static_conf_logits, static_gt, reduction="none")
static_conf_loss = (static_conf_loss * mask[..., None]).mean()
extra_loss += static_conf_loss * weights.static_conf_bce
extra_loss_dict["static_conf_loss"] = static_conf_loss
return extra_loss, extra_loss_dict
@autocast(enabled=False)
def get_smpl_params_w_Rt_v2(
global_orient_gv,
local_transl_vel,
global_orient_c,
cam_angvel,
):
"""Get global R,t in GV0(ay)
Args:
cam_angvel: (B, L, 6), defined as R @ R_{w2c}^{t} = R_{w2c}^{t+1}
"""
# Get R_ct_to_c0 from cam_angvel
def as_identity(R):
is_I = matrix_to_axis_angle(R).norm(dim=-1) < 1e-5
R[is_I] = torch.eye(3)[None].expand(is_I.sum(), -1, -1).to(R)
return R
B = cam_angvel.shape[0]
R_t_to_tp1 = rotation_6d_to_matrix(cam_angvel) # (B, L, 3, 3)
R_t_to_tp1 = as_identity(R_t_to_tp1)
# Get R_c2gv
R_gv = axis_angle_to_matrix(global_orient_gv) # (B, L, 3, 3)
R_c = axis_angle_to_matrix(global_orient_c) # (B, L, 3, 3)
# Camera view direction in GV coordinate: Rc2gv @ [0,0,1]
R_c2gv = R_gv @ R_c.mT
view_axis_gv = R_c2gv[:, :, :, 2] # (B, L, 3) Rc2gv is estimated, so the x-axis is not accurate, i.e. != 0
# Rotate axis use camera relative rotation
R_cnext2gv = R_c2gv @ R_t_to_tp1.mT
view_axis_gv_next = R_cnext2gv[..., 2]
vec1_xyz = view_axis_gv.clone()
vec1_xyz[..., 1] = 0
vec1_xyz = F.normalize(vec1_xyz, dim=-1)
vec2_xyz = view_axis_gv_next.clone()
vec2_xyz[..., 1] = 0
vec2_xyz = F.normalize(vec2_xyz, dim=-1)
aa_tp1_to_t = vec2_xyz.cross(vec1_xyz, dim=-1)
aa_tp1_to_t_angle = torch.acos(torch.clamp((vec1_xyz * vec2_xyz).sum(dim=-1, keepdim=True), -1.0, 1.0))
aa_tp1_to_t = F.normalize(aa_tp1_to_t, dim=-1) * aa_tp1_to_t_angle
aa_tp1_to_t = gaussian_smooth(aa_tp1_to_t, dim=-2) # Smooth
R_tp1_to_t = axis_angle_to_matrix(aa_tp1_to_t).mT # (B, L, 3)
# Get R_t_to_0
R_t_to_0 = [torch.eye(3)[None].expand(B, -1, -1).to(R_t_to_tp1)]
for i in range(1, R_t_to_tp1.shape[1]):
R_t_to_0.append(R_t_to_0[-1] @ R_tp1_to_t[:, i])
R_t_to_0 = torch.stack(R_t_to_0, dim=1) # (B, L, 3, 3)
R_t_to_0 = as_identity(R_t_to_0)
global_orient = matrix_to_axis_angle(R_t_to_0 @ R_gv)
# Rollout to global transl
# Start from transl0, in gv0 -> flip y-axis of gv0
transl = rollout_local_transl_vel(local_transl_vel, global_orient)
global_orient, transl, _ = get_tgtcoord_rootparam(global_orient, transl, tsf="any->ay")
smpl_params_w_Rt = {"global_orient": global_orient, "transl": transl}
return smpl_params_w_Rt
================================================
FILE: eval/GVHMR/hmr4d/model/gvhmr/utils/endecoder.py
================================================
import torch
import torch.nn as nn
from pytorch3d.transforms import (
rotation_6d_to_matrix,
matrix_to_axis_angle,
axis_angle_to_matrix,
matrix_to_rotation_6d,
matrix_to_quaternion,
quaternion_to_matrix,
)
from hmr4d.configs import MainStore, builds
from hmr4d.utils.geo.augment_noisy_pose import gaussian_augment
import hmr4d.utils.matrix as matrix
from hmr4d.utils.pylogger import Log
from hmr4d.utils.geo.hmr_global import get_local_transl_vel, rollout_local_transl_vel
from hmr4d.utils.smplx_utils import make_smplx
from . import stats_compose
class EnDecoder(nn.Module):
def __init__(self, stats_name="DEFAULT_01", noise_pose_k=10):
super().__init__()
# Load mean, std
stats = getattr(stats_compose, stats_name)
Log.info(f"[EnDecoder] Use {stats_name} for statistics!")
self.register_buffer("mean", torch.tensor(stats["mean"]).float(), False)
self.register_buffer("std", torch.tensor(stats["std"]).float(), False)
# option
self.noise_pose_k = noise_pose_k
# smpl
self.smplx_model = make_smplx("supermotion_v437coco17")
parents = self.smplx_model.parents[:22]
self.register_buffer("parents_tensor", parents, False)
self.parents = parents.tolist()
def get_noisyobs(self, data, return_type="r6d"):
"""
Noisy observation contains local pose with noise
Args:
data (dict):
body_pose: (B, L, J*3) or (B, L, J, 3)
Returns:
noisy_bosy_pose: (B, L, J, 6) or (B, L, J, 3) or (B, L, 3, 3) depends on return_type
"""
body_pose = data["body_pose"] # (B, L, 63)
B, L, _ = body_pose.shape
body_pose = body_pose.reshape(B, L, -1, 3)
# (B, L, J, C)
return_mapping = {"R": 0, "r6d": 1, "aa": 2}
return_id = return_mapping[return_type]
noisy_bosy_pose = gaussian_augment(body_pose, self.noise_pose_k, to_R=True)[return_id]
return noisy_bosy_pose
def normalize_body_pose_r6d(self, body_pose_r6d):
"""body_pose_r6d: (B, L, {J*6}/{J, 6}) -> (B, L, J*6)"""
B, L = body_pose_r6d.shape[:2]
body_pose_r6d = body_pose_r6d.reshape(B, L, -1)
if self.mean.shape[-1] == 1: # no mean, std provided
return body_pose_r6d
body_pose_r6d = (body_pose_r6d - self.mean[:126]) / self.std[:126] # (B, L, C)
return body_pose_r6d
def fk_v2(self, body_pose, betas, global_orient=None, transl=None, get_intermediate=False):
"""
Args:
body_pose: (B, L, 63)
betas: (B, L, 10)
global_orient: (B, L, 3)
Returns:
joints: (B, L, 22, 3)
"""
B, L = body_pose.shape[:2]
if global_orient is None:
global_orient = torch.zeros((B, L, 3), device=body_pose.device)
aa = torch.cat([global_orient, body_pose], dim=-1).reshape(B, L, -1, 3)
rotmat = axis_angle_to_matrix(aa) # (B, L, 22, 3, 3)
skeleton = self.smplx_model.get_skeleton(betas)[..., :22, :] # (B, L, 22, 3)
local_skeleton = skeleton - skeleton[:, :, self.parents_tensor]
local_skeleton = torch.cat([skeleton[:, :, :1], local_skeleton[:, :, 1:]], dim=2)
if transl is not None:
local_skeleton[..., 0, :] += transl # B, L, 22, 3
mat = matrix.get_TRS(rotmat, local_skeleton) # B, L, 22, 4, 4
fk_mat = matrix.forward_kinematics(mat, self.parents) # B, L, 22, 4, 4
joints = matrix.get_position(fk_mat) # B, L, 22, 3
if not get_intermediate:
return joints
else:
return joints, mat, fk_mat
def get_local_pos(self, betas):
skeleton = self.smplx_model.get_skeleton(betas)[..., :22, :] # (B, L, 22, 3)
local_skeleton = skeleton - skeleton[:, :, self.parents_tensor]
local_skeleton = torch.cat([skeleton[:, :, :1], local_skeleton[:, :, 1:]], dim=2)
return local_skeleton
def encode(self, inputs):
"""
definition: {
body_pose_r6d, # (B, L, (J-1)*6) -> 0:126
betas, # (B, L, 10) -> 126:136
global_orient_r6d, # (B, L, 6) -> 136:142 incam
global_orient_gv_r6d: # (B, L, 6) -> 142:148 gv
local_transl_vel, # (B, L, 3) -> 148:151, smpl-coord
}
"""
B, L = inputs["smpl_params_c"]["body_pose"].shape[:2]
# cam
smpl_params_c = inputs["smpl_params_c"]
body_pose = smpl_params_c["body_pose"].reshape(B, L, 21, 3)
body_pose_r6d = matrix_to_rotation_6d(axis_angle_to_matrix(body_pose)).flatten(-2)
betas = smpl_params_c["betas"]
global_orient_R = axis_angle_to_matrix(smpl_params_c["global_orient"])
global_orient_r6d = matrix_to_rotation_6d(global_orient_R)
# global
R_c2gv = inputs["R_c2gv"] # (B, L, 3, 3)
global_orient_gv_r6d = matrix_to_rotation_6d(R_c2gv @ global_orient_R)
# local_transl_vel
smpl_params_w = inputs["smpl_params_w"]
local_transl_vel = get_local_transl_vel(smpl_params_w["transl"], smpl_params_w["global_orient"])
if False: # debug
transl_recover = rollout_local_transl_vel(
local_transl_vel, smpl_params_w["global_orient"], smpl_params_w["transl"][:, [0]]
)
print((transl_recover - smpl_params_w["transl"]).abs().max())
# returns
x = torch.cat([body_pose_r6d, betas, global_orient_r6d, global_orient_gv_r6d, local_transl_vel], dim=-1)
x_norm = (x - self.mean) / self.std
return x_norm
def encode_translw(self, inputs):
"""
definition: {
body_pose_r6d, # (B, L, (J-1)*6) -> 0:126
betas, # (B, L, 10) -> 126:136
global_orient_r6d, # (B, L, 6) -> 136:142 incam
global_orient_gv_r6d: # (B, L, 6) -> 142:148 gv
local_transl_vel, # (B, L, 3) -> 148:151, smpl-coord
}
"""
# local_transl_vel
smpl_params_w = inputs["smpl_params_w"]
local_transl_vel = get_local_transl_vel(smpl_params_w["transl"], smpl_params_w["global_orient"])
# returns
x = local_transl_vel
x_norm = (x - self.mean[-3:]) / self.std[-3:]
return x_norm
def decode_translw(self, x_norm):
return x_norm * self.std[-3:] + self.mean[-3:]
def decode(self, x_norm):
"""x_norm: (B, L, C)"""
B, L, C = x_norm.shape
x = (x_norm * self.std) + self.mean
body_pose_r6d = x[:, :, :126]
betas = x[:, :, 126:136]
global_orient_r6d = x[:, :, 136:142]
global_orient_gv_r6d = x[:, :, 142:148]
local_transl_vel = x[:, :, 148:151]
body_pose = matrix_to_axis_angle(rotation_6d_to_matrix(body_pose_r6d.reshape(B, L, -1, 6)))
body_pose = body_pose.flatten(-2)
global_orient_c = matrix_to_axis_angle(rotation_6d_to_matrix(global_orient_r6d))
global_orient_gv = matrix_to_axis_angle(rotation_6d_to_matrix(global_orient_gv_r6d))
output = {
"body_pose": body_pose,
"betas": betas,
"global_orient": global_orient_c,
"global_orient_gv": global_orient_gv,
"local_transl_vel": local_transl_vel,
}
return output
group_name = "endecoder/gvhmr"
cfg_base = builds(EnDecoder, populate_full_signature=True)
MainStore.store(name="v1_no_stdmean", node=cfg_base, group=group_name)
MainStore.store(name="v1", node=cfg_base(stats_name="MM_V1"), group=group_name)
MainStore.store(
name="v1_amass_local_bedlam_cam",
node=cfg_base(stats_name="MM_V1_AMASS_LOCAL_BEDLAM_CAM"),
group=group_name,
)
MainStore.store(name="v2", node=cfg_base(stats_name="MM_V2"), group=group_name)
MainStore.store(name="v2_1", node=cfg_base(stats_name="MM_V2_1"), group=group_name)
================================================
FILE: eval/GVHMR/hmr4d/model/gvhmr/utils/postprocess.py
================================================
import torch
from torch.cuda.amp import autocast
from pytorch3d.transforms import (
matrix_to_rotation_6d,
rotation_6d_to_matrix,
axis_angle_to_matrix,
matrix_to_axis_angle,
)
import hmr4d.utils.matrix as matrix
from hmr4d.utils.ik.ccd_ik import CCD_IK
from hmr4d.utils.geo_transform import get_sequence_cammat, transform_mat, apply_T_on_points
from hmr4d.utils.net_utils import gaussian_smooth
from hmr4d.model.gvhmr.utils.endecoder import EnDecoder
from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines
@autocast(enabled=False)
def pp_static_joint(outputs, endecoder: EnDecoder):
# Global FK
pred_w_j3d = endecoder.fk_v2(**outputs["pred_smpl_params_global"])
L = pred_w_j3d.shape[1]
joint_ids = [7, 10, 8, 11, 20, 21] # [L_Ankle, L_foot, R_Ankle, R_foot, L_wrist, R_wrist]
pred_j3d_static = pred_w_j3d.clone()[:, :, joint_ids] # (B, L, J, 3)
######## update overall movement with static info, and make displacement ~[0,0,0]
pred_j_disp = pred_j3d_static[:, 1:] - pred_j3d_static[:, :-1] # (B, L-1, J, 3)
static_conf_logits = outputs["static_conf_logits"][:, :-1].clone()
static_label_ = static_conf_logits > 0 # (B, L-1, J) # avoid non-contact frame
static_conf_logits = static_conf_logits.float() - (~static_label_ * 1e6) # fp16 cannot go through softmax
is_static = static_label_.sum(dim=-1) > 0 # (B, L-1)
pred_disp = pred_j_disp * static_conf_logits[..., None].softmax(dim=-2) # (B, L-1, J, 3)
pred_disp = pred_disp * is_static[..., None, None] # (B, L-1, J, 3)
pred_disp = pred_disp.sum(-2) # (B, L-1, 3)
####################
# Overwrite results:
if False: # for-loop
post_w_transl = outputs["pred_smpl_params_global"]["transl"].clone() # (B, L, 3)
for i in range(1, L):
post_w_transl[:, i:] -= pred_disp[:, i - 1 : i]
else: # vectorized
pred_w_transl = outputs["pred_smpl_params_global"]["transl"].clone() # (B, L, 3)
pred_w_disp = pred_w_transl[:, 1:] - pred_w_transl[:, :-1] # (B, L-1, 3)
pred_w_disp_new = pred_w_disp - pred_disp
post_w_transl = torch.cumsum(torch.cat([pred_w_transl[:, :1], pred_w_disp_new], dim=1), dim=1)
post_w_transl[..., 0] = gaussian_smooth(post_w_transl[..., 0], dim=-1)
post_w_transl[..., 2] = gaussian_smooth(post_w_transl[..., 2], dim=-1)
# Put the sequence on the ground by -min(y), this does not consider foot height, for o3d vis
post_w_j3d = pred_w_j3d - pred_w_transl.unsqueeze(-2) + post_w_transl.unsqueeze(-2)
ground_y = post_w_j3d[..., 1].flatten(-2).min(dim=-1)[0] # (B,) Minimum y value
post_w_transl[..., 1] -= ground_y
return post_w_transl
@autocast(enabled=False)
def pp_static_joint_cam(outputs, endecoder: EnDecoder):
"""Use static joint and static camera assumption to postprocess the global transl"""
# input
pred_smpl_params_incam = outputs["pred_smpl_params_incam"].copy()
pred_smpl_params_global = outputs["pred_smpl_params_global"]
static_conf_logits = outputs["static_conf_logits"].clone()[:, :-1] # (B, L-1, J)
joint_ids = [7, 10, 8, 11, 20, 21] # [L_Ankle, L_foot, R_Ankle, R_foot, L_wrist, R_wrist]
B, L = pred_smpl_params_incam["transl"].shape[:2]
assert B == 1
# FK
pred_w_j3d = endecoder.fk_v2(**pred_smpl_params_global) # (B, L, J, 3)
# smooth incam results, as this could be noisy
pred_smpl_params_incam["transl"] = gaussian_smooth(pred_smpl_params_incam["transl"], sigma=5, dim=-2)
pred_c_j3d = endecoder.fk_v2(**pred_smpl_params_incam) # (B, L, J, 3)
# compute T_c2w (static) from first frame
R_gv = axis_angle_to_matrix(pred_smpl_params_global["global_orient"][:, 0]) # (B, 3, 3)
R_c = axis_angle_to_matrix(pred_smpl_params_incam["global_orient"][:, 0]) # (B, 3, 3)
R_c2w = R_gv @ R_c.mT # (B, 3, 3)
t_c2w = pred_w_j3d[:, 0, 0] - torch.einsum("bij,bj->bi", R_c2w, pred_c_j3d[:, 0, 0]) # (B, 3)
T_c2w = transform_mat(R_c2w, t_c2w) # (B, 4, 4)
pred_c_j3d_in_w = apply_T_on_points(pred_c_j3d, T_c2w[:, None])
# 1. Make transl similar to incam
post_w_transl = pred_smpl_params_global["transl"].clone() # (B, L, 3)
post_w_j3d = pred_w_j3d.clone() # (B, L, J, 3)
cp_thr = torch.tensor([0.25, 0.25, 0.25]).to(post_w_j3d) # Only update very bad pred
for i in range(1, L):
cp_diff = post_w_j3d[:, i, 0] - pred_c_j3d_in_w[:, i, 0] # (B, 3)
cp_diff = cp_diff * ~((cp_diff > -cp_thr) * (cp_diff < cp_thr))
cp_diff = torch.clamp(cp_diff, -0.02, 0.02)
post_w_transl[:, i:] -= cp_diff
post_w_j3d[:, i:] -= (cp_diff)[:, None, None]
# 1. Make stationary joint stay stationary
# pred_j3d_static = pred_w_j3d.clone()[:, :, joint_ids] # (B, L, J, 3)
pred_j3d_static = post_w_j3d[:, :, joint_ids] # (B, L, J, 3)
pred_j_disp = pred_j3d_static[:, 1:] - pred_j3d_static[:, :-1] # (B, L-1, J, 3)
static_label = static_conf_logits.sigmoid() > 0.8 # (B, L-1, J)
static_label_sumJ = static_label.sum(-1, keepdim=True) # (B, L-1, 1)
static_label_sumJ = torch.clamp_min(static_label_sumJ, 1) # replace 0 with 1
pred_disp_sumJ = (pred_j_disp * static_label[..., None]).sum(-2) # (B, L-1, 3)
pred_disp = pred_disp_sumJ / static_label_sumJ # (B, L-1, 3)
pred_disp[:, :, 1] = 0 # do not modify y
# Overwrite results (for-loop)
for i in range(1, L):
post_w_transl[:, i:] -= pred_disp[:, [i - 1]]
post_w_j3d[:, i:] -= pred_disp[:, [i - 1], None]
# Put the sequence on the ground by -min(y), this does not consider foot height, for o3d vis
ground_y = post_w_j3d[..., 1].flatten(-2).min(dim=-1)[0] # (B,) Minimum y value
post_w_transl[..., 1] -= ground_y
return post_w_transl
@autocast(enabled=False)
def process_ik(outputs, endecoder):
static_conf = outputs["static_conf_logits"].sigmoid() # (B, L, J)
post_w_j3d, local_mat, post_w_mat = endecoder.fk_v2(**outputs["pred_smpl_params_global"], get_intermediate=True)
# sebas rollout merge
joint_ids = [7, 10, 8, 11, 20, 21] # [L_Ankle, L_foot, R_Ankle, R_foot, L_wrist, R_wrist]
post_target_j3d = post_w_j3d.clone()
for i in range(1, post_w_j3d.size(1)):
prev = post_target_j3d[:, i - 1, joint_ids]
this = post_w_j3d[:, i, joint_ids]
c_prev = static_conf[:, i - 1, :, None]
post_target_j3d[:, i, joint_ids] = prev * c_prev + this * (1 - c_prev)
# ik
global_rot = matrix.get_rotation(post_w_mat)
parents = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19]
left_leg_chain = [0, 1, 4, 7, 10]
right_leg_chain = [0, 2, 5, 8, 11]
left_hand_chain = [9, 13, 16, 18, 20]
right_hand_chain = [9, 14, 17, 19, 21]
def ik(local_mat, target_pos, target_rot, target_ind, chain):
local_mat = local_mat.clone()
IK_solver = CCD_IK(
local_mat,
parents,
target_ind,
target_pos,
target_rot,
kinematic_chain=chain,
max_iter=2,
)
chain_local_mat = IK_solver.solve()
chain_rotmat = matrix.get_rotation(chain_local_mat)
local_mat[:, :, chain[1:], :-1, :-1] = chain_rotmat[:, :, 1:] # (B, L, J, 3, 3)
return local_mat
local_mat = ik(local_mat, post_target_j3d[:, :, [7, 10]], global_rot[:, :, [7, 10]], [3, 4], left_leg_chain)
local_mat = ik(local_mat, post_target_j3d[:, :, [8, 11]], global_rot[:, :, [8, 11]], [3, 4], right_leg_chain)
local_mat = ik(local_mat, post_target_j3d[:, :, [20]], global_rot[:, :, [20]], [4], left_hand_chain)
local_mat = ik(local_mat, post_target_j3d[:, :, [21]], global_rot[:, :, [21]], [4], right_hand_chain)
body_pose = matrix_to_axis_angle(matrix.get_rotation(local_mat[:, :, 1:])) # (B, L, J-1, 3, 3)
body_pose = body_pose.flatten(2) # (B, L, (J-1)*3)
return body_pose
================================================
FILE: eval/GVHMR/hmr4d/model/gvhmr/utils/stats_compose.py
================================================
# fmt:off
body_pose_r6d = {
"bedlam": {
"count": 5417929,
"mean": [ 0.9772, -0.0925, 0.0028, 0.1058, 0.9111, 0.1373, 0.9796, 0.0711,
-0.0193, -0.0816, 0.8910, 0.1953, 0.9935, 0.0072, 0.0270, -0.0046,
0.9200, -0.2511, 0.9752, 0.0477, -0.0990, -0.0613, 0.8242, -0.2730,
0.9836, -0.0400, 0.0067, 0.0148, 0.7836, -0.3471, 0.9931, -0.0300,
-0.0469, 0.0244, 0.9825, -0.0513, 0.9777, 0.0206, 0.1444, -0.0470,
0.9603, 0.1521, 0.9804, -0.0362, -0.0902, 0.0500, 0.9546, 0.1337,
0.9969, -0.0105, 0.0076, 0.0090, 0.9914, 0.0150, 0.9953, -0.0607,
0.0089, 0.0602, 0.9942, 0.0146, 0.9934, -0.0682, -0.0171, 0.0680,
0.9932, -0.0017, 0.9790, 0.0294, 0.0065, -0.0338, 0.9706, -0.0456,
0.9056, 0.2457, -0.1029, -0.2279, 0.9262, 0.0145, 0.9233, -0.1301,
0.1550, 0.1140, 0.9476, 0.0534, 0.9769, -0.0572, -0.0095, 0.0569,
0.9690, 0.0472, 0.6782, 0.5746, -0.2378, -0.5546, 0.7212, 0.0917,
0.6489, -0.5955, 0.2424, 0.5821, 0.6797, 0.0563, 0.5562, -0.1252,
-0.5860, 0.0937, 0.9176, -0.1287, 0.4453, 0.1421, 0.6119, -0.1427,
0.8996, -0.1136, 0.9186, -0.0881, -0.1463, 0.1087, 0.8692, 0.0845,
0.9175, 0.0257, 0.0663, -0.0385, 0.8603, 0.1020],
"std": [0.0429, 0.1392, 0.1236, 0.1323, 0.1645, 0.3086, 0.0375, 0.1406, 0.1172,
0.1275, 0.1934, 0.3280, 0.0119, 0.0835, 0.0716, 0.0741, 0.1528, 0.2484,
0.0349, 0.0947, 0.1633, 0.0924, 0.3469, 0.3370, 0.0273, 0.1009, 0.1411,
0.0680, 0.3876, 0.3323, 0.0103, 0.0735, 0.0712, 0.0690, 0.0246, 0.1617,
0.0216, 0.1097, 0.1016, 0.0924, 0.0509, 0.2035, 0.0245, 0.1188, 0.1212,
0.1056, 0.0634, 0.2308, 0.0054, 0.0579, 0.0517, 0.0575, 0.0124, 0.1158,
0.0076, 0.0654, 0.0367, 0.0644, 0.0118, 0.0592, 0.0116, 0.0829, 0.0361,
0.0832, 0.0124, 0.0422, 0.0343, 0.1060, 0.1680, 0.1075, 0.0473, 0.2023,
0.0701, 0.2344, 0.2213, 0.2632, 0.0589, 0.1318, 0.0767, 0.2456, 0.2009,
0.2666, 0.0542, 0.1106, 0.0347, 0.1080, 0.1718, 0.1117, 0.0459, 0.2025,
0.1882, 0.2769, 0.2032, 0.3072, 0.1447, 0.2204, 0.2018, 0.2820, 0.2126,
0.3213, 0.1760, 0.2486, 0.4749, 0.1677, 0.2791, 0.2239, 0.0963, 0.2705,
0.5540, 0.1846, 0.2572, 0.2411, 0.1287, 0.2878, 0.1151, 0.2993, 0.1557,
0.2812, 0.1880, 0.3334, 0.1286, 0.3355, 0.1553, 0.3216, 0.1880, 0.3306]
},
"amass": {
"count": 7114038,
"mean": [ 9.6969e-01, -5.9719e-02, -3.7700e-02, 5.8256e-02, 9.0800e-01,
1.0972e-01, 9.7636e-01, 4.3401e-02, 4.3110e-03, -4.3032e-02,
9.0261e-01, 1.4478e-01, 9.9288e-01, 3.5673e-03, 1.6264e-02,
-2.2260e-03, 9.3470e-01, -2.3495e-01, 9.7147e-01, 5.2553e-02,
-9.3666e-02, -5.4550e-02, 8.3321e-01, -2.4246e-01, 9.7971e-01,
-3.8429e-02, 5.3575e-03, 1.5537e-02, 8.1449e-01, -3.0926e-01,
9.9532e-01, -9.4398e-03, -3.8328e-02, 8.5141e-03, 9.8880e-01,
1.9976e-04, 9.5602e-01, -3.9528e-02, 2.0017e-01, 1.0363e-02,
9.5965e-01, 1.3770e-01, 9.6223e-01, -4.6278e-02, -1.5177e-01,
6.6705e-02, 9.5545e-01, 1.2519e-01, 9.9767e-01, -1.2616e-02,
-2.5442e-04, 1.1661e-02, 9.9376e-01, -3.6222e-02, 9.9511e-01,
-1.0583e-02, 1.2130e-02, 7.6461e-03, 9.9137e-01, 2.0029e-02,
9.9295e-01, 7.2917e-03, 4.9454e-03, -8.0286e-03, 9.9137e-01,
2.3707e-03, 9.7698e-01, 1.9943e-02, 1.3808e-03, -2.2006e-02,
9.7375e-01, -6.7936e-02, 9.2804e-01, 2.5005e-01, -5.7167e-02,
-2.4047e-01, 9.4246e-01, 2.5863e-02, 9.2957e-01, -2.1329e-01,
1.1112e-01, 2.0741e-01, 9.4876e-01, 2.9901e-02, 9.7683e-01,
-4.1210e-02, 2.3248e-03, 4.0967e-02, 9.7365e-01, 5.7309e-03,
6.4513e-01, 6.1999e-01, -2.5469e-01, -6.2342e-01, 6.8177e-01,
3.5524e-02, 6.6192e-01, -5.9341e-01, 2.7136e-01, 5.9269e-01,
6.8966e-01, 3.1309e-02, 6.8946e-01, -1.1676e-01, -4.9859e-01,
4.0969e-02, 9.3656e-01, -1.4875e-01, 6.2787e-01, 1.3793e-01,
5.4289e-01, -9.1946e-02, 9.2868e-01, -1.1927e-01, 9.3012e-01,
-8.3810e-02, -1.1951e-01, 9.7211e-02, 8.9118e-01, 5.9887e-02,
9.3033e-01, 7.1047e-02, 7.5264e-02, -8.0679e-02, 8.8562e-01,
4.8960e-02],
"std": [0.0612, 0.1390, 0.1779, 0.1415, 0.1826, 0.3268, 0.0440, 0.1382, 0.1542,
0.1348, 0.1930, 0.3272, 0.0132, 0.0801, 0.0855, 0.0729, 0.1255, 0.2238,
0.0554, 0.1088, 0.1727, 0.0939, 0.3294, 0.3559, 0.0532, 0.1082, 0.1554,
0.0768, 0.3446, 0.3407, 0.0120, 0.0650, 0.0584, 0.0632, 0.0198, 0.1335,
0.0631, 0.1250, 0.1574, 0.1047, 0.0730, 0.2091, 0.0759, 0.1241, 0.1667,
0.1112, 0.0831, 0.2185, 0.0060, 0.0441, 0.0502, 0.0441, 0.0102, 0.0946,
0.0237, 0.0722, 0.0610, 0.0738, 0.0479, 0.0949, 0.0369, 0.0943, 0.0610,
0.0966, 0.0498, 0.0729, 0.0425, 0.1001, 0.1824, 0.0972, 0.0408, 0.1887,
0.0594, 0.1842, 0.1884, 0.2020, 0.0457, 0.1018, 0.0640, 0.1990, 0.1854,
0.2133, 0.0467, 0.0910, 0.0392, 0.1049, 0.1776, 0.1037, 0.0413, 0.1945,
0.1733, 0.2612, 0.1905, 0.2963, 0.1512, 0.1861, 0.1710, 0.2663, 0.1896,
0.3135, 0.1568, 0.2219, 0.3976, 0.1594, 0.2810, 0.1855, 0.0845, 0.2398,
0.4398, 0.1629, 0.2685, 0.1990, 0.0998, 0.2556, 0.1137, 0.2837, 0.1419,
0.2761, 0.1678, 0.2973, 0.1172, 0.3010, 0.1394, 0.2910, 0.1724, 0.3039]
}
}
betas = {
"bedlam": {
"count": 37855, # so many subjects?
"mean": [ 0.0378, -0.3562, 0.1185, 0.2245, 0.0204, 0.0929, 0.0537, 0.1006,
-0.1180, 0.0936],
"std":[0.8070, 1.3480, 0.8964, 0.7390, 0.6433, 0.6089, 0.5374, 0.6984, 0.7263,
0.5395],
},
"amass": {
"count": 18086,
"mean": [ 0.2310, 0.1750, 0.2931, -0.1859, -1.1163, -1.1028, -0.2573, 0.3555,
0.3732, 0.2852],
"std": [0.8831, 0.7965, 1.0899, 1.1788, 1.2128, 1.1081, 0.9780, 1.1434, 0.8498,
1.1462],
}
}
global_orient_c_r6d = {
"bedlam": {
"count": 5417929,
"mean": [-4.9862e-03, -8.7136e-04, -1.4187e-03, 1.4825e-02, -9.4419e-01,
-5.1653e-02],
"std": [0.7048, 0.1713, 0.6884, 0.1548, 0.1546, 0.2403],
},
}
global_orient_gv_r6d = {
"bedlam": {
"count": 5134187,
"mean": [ 3.6018e-04, -2.2327e-04, 2.2316e-03, -4.4879e-02, -9.7435e-01,
1.0021e-01],
"std": [0.6070, 0.5355, 0.5873, 0.6285, 0.2336, 0.7675],
},
}
local_transl_vel = {
"none":{
"mean": [0., 0., 0.],
"std": [1., 1., 1.]
},
"1e-2":{
"mean": [0., 0., 0.],
"std": [1e-2, 1e-2, 1e-2]
},
"bedlam": {
"count": 5417929,
"mean": [7.3057e-05, -2.2142e-04, 3.2444e-03],
"std": [0.0065, 0.0091, 0.0114],
},
"amass": {
"count": 7113068,
"mean": [-0.0002, -0.0006, 0.0069],
"std": [0.0064, 0.0070, 0.0138],
},
"alignhead":{
"count": 7113068,
"mean":[-2.0822e-04, -1.7966e-06, 6.9816e-03],
"std":[0.0065, 0.0066, 0.0139],
},
"alignhead_absy":{
"count": 7113068,
"mean":[-0.0002, -0.0316, 0.0070],
"std":[0.0065, 0.1351, 0.0139],
},
"alignhead_absgy":{
"count": 7113068,
"mean":[[-2.0822e-04, 1.2627e+00, 6.9816e-03]],
"std":[0.0065, 0.1516, 0.0139],
}
}
pred_cam = {
"bedlam": {
"count": 5096332,
"mean": [1.0606, -0.0027, 0.2702],
"std": [0.1784, 0.0956, 0.0764],
}
}
vitfeat = {
"bedlam": {
"count": 5546332,
"mean": [-1.3772, 0.2490, 0.0602, -0.1834, 0.2458, 0.5372, 0.3343, -0.3476, -0.1017, -0.0362, -0.0678, 0.2150, -0.2534, 0.1029, 0.8199, -0.4676, 0.6259, -0.3350, 0.0549, -0.4469, 0.2751, -0.1763, 0.1114, -0.2115, -0.0264, 0.5294, 0.8212, -0.4562, 0.4147, -0.0256, -0.1019, 0.2798, 0.9284, 0.4652, 0.6365, 0.6785, -0.0765, 0.0337, -0.2566, -0.0335, -0.1799, 0.7426, 0.2810, -0.7121, -0.0893, 0.1608, -0.2483, 1.5094, -1.4395, -0.3682, -0.4157, -0.0032, -0.0376, -0.0043, 0.2092, 0.3038, -0.2077, -0.4868, -0.1534, 0.2668, 1.2773, 0.2838, -0.4863, -1.2300, 0.0581, -0.3041, 0.1518, 0.7955, -0.4293, 1.4666, 0.3077, 0.3918, 0.1418, 0.1590, 0.8671, -0.3527, 0.5629, 0.1414, 0.0964, -0.1094, -0.0211, -0.0937, 0.1606, -0.7900, 0.0397, 0.0570, 0.7083, -0.5732, 0.1430, -0.2571, 0.5275, 0.6603, 0.3265, 0.4574, -0.3361, -0.1267, 0.3841, 0.1758, -0.6207, -0.3673, 0.8914, 0.4297, -0.8118, 0.2229, -0.2876, 0.2460, 0.4856, -0.1446, -0.2416, 0.1229, 0.2865, 0.7023, -0.2883, 0.3940, -1.5496, 0.4456, 0.6445, 0.2058, -0.4265, 0.3724, 0.1557, -1.4208, -0.1246, 0.1237, -0.3965, 0.0105, -0.0780, 0.6448, -0.1132, 0.8500, -0.2828, 0.4447, 0.6257, -0.2664, -0.8384, -1.8091, -0.2769, 0.1866, 0.6051, -0.2548, 0.9823, -0.2985, -0.2773, -0.4383, 0.1886, 0.2411, 0.2546, 0.2195, -0.0041, 0.1038, -0.6804, 1.2364, 0.5393, 0.0351, 0.4537, -0.8044, -0.1993, -2.1097, -0.8458, 0.1497, 1.6042, 0.6458, -0.5455, 0.0778, 0.0504, -0.5242, -0.3215, -0.0199, 1.1461, -0.3355, -0.3421, -0.3951, 0.0184, -0.0261, 0.2048, 0.0080, 0.6553, -1.3221, 0.5140, 0.5958, -0.2523, 0.9434, -0.0727, 0.1978, 1.1105, -0.4992, 0.3990, 0.2074, 0.3843, -0.0444, 0.0624, -0.8442, -0.0724, -0.5328, 1.1723, 0.8043, 0.6674, 1.5283, 4.2502, 0.0935, 0.3733, 0.1569, 0.0154, 0.0674, 0.0862, -0.2744, -0.4537, 0.1588, -1.9156, 0.0149, -1.0498, -0.0790, 0.0851, -0.5007, 0.3323, -0.1065, 0.0782, 0.0725, -0.5921, -0.1876, 0.0094, -0.3631, 0.0951, 0.1318, 0.0936, 0.5668, -0.0875, -0.4576, -0.4306, 0.5458, 1.0761, 1.1740, -0.0337, 1.3718, -0.2913, -0.3433, 0.5338, -0.4577, -0.4966, 0.2704, 0.3236, 0.4053, 0.0360, 1.1616, -0.2012, 0.7373, 0.0779, -0.0280, -0.4426, 0.0450, 0.2923, 0.0161, -0.4788, 0.1924, -0.3012, 0.0298, -0.7776, -0.2215, 0.4494, -0.1677, 0.2214, 0.0762, -0.3088, 0.4230, 0.0673, -1.0233, 0.0748, -0.4358, -0.2497, -0.0066, 0.1679, -0.1077, -0.4290, 2.5254, -0.8819, -0.8073, 0.2535, 2.0680, -0.4715, 0.3614, -2.9281, 3.1536, 0.3118, -0.0239, 0.7064, -0.6935, -1.1070, -0.1715, -0.0920, -0.2133, -1.0173, 0.0084, -0.1721, 0.2605, -0.6607, -0.0788, -0.3479, -0.2187, 1.0605, 0.2857, 0.7464, 0.9612, -1.1332, 1.5708, -1.0264, 0.6070, 0.4103, -0.1950, -0.0629, -0.0958, -0.2199, -0.2198, -0.4019, 0.2478, -0.3576, 0.0191, -5.8435, 0.0145, -0.2312, 0.9872, 1.1159, 0.3775, 0.1960, -0.5968, -0.2611, -0.0634, -0.1003, 0.7411, -0.8298, -0.1743, 1.8418, 0.3692, -0.4321, 0.0613, -1.9046, 0.5812, 0.2805, 0.1703, -0.2212, -0.0740, -0.2737, -0.3084, 2.9787, -0.1392, 0.3347, 0.0866, -0.8654, -0.4564, -0.7839, 0.1033, -0.0204, 0.1558, -0.1469, 0.2850, -0.1139, 0.8253, 0.7352, -0.6132, 0.0566, 0.3087, -0.1189, 0.1640, 0.2511, 0.5230, -0.0972, -0.5621, -2.5404, 0.3529, -0.2543, -0.6757, 0.2045, -0.0511, -0.2204, 0.1023, 0.0143, 0.4191, -0.3946, -1.0912, 0.8555, 1.0751, -0.0184, -0.3162, 0.1910, 0.6522, -0.5801, 0.2091, -0.8254, -0.3425, 0.3368, -0.0384, -0.4570, 2.5288, -0.3513, -0.1630, 0.1096, -0.5936, 1.5303, -0.4135, -0.2418, -0.0564, -2.6344, -0.1054, 0.8866, -0.2946, -0.4564, -0.6220, 0.2672, -0.9012, 0.3535, 0.2344, -0.0718, 0.0782, 0.0133, 0.2032, -1.2768, 0.1271, -0.5114, -0.0584, -0.8219, -0.1069, 1.5577, -0.1432, -0.6794, 0.9101, 0.6390, 0.3547, -0.6126, -0.1885, 0.2462, -1.1864, 0.0653, -0.7940, 0.5204, 0.5372, 0.5353, -0.4268, -0.2003, -0.2496, -0.0405, 0.3615, -0.1635, 0.1908, -0.0467, 0.7167, 0.1465, 0.4621, 0.1190, -1.6899, 0.6512, 1.3150, -0.1273, 0.0507, 0.2058, -0.1855, 0.1316, 0.1280, 0.5049, 0.0262, -0.0329, 2.0327, -0.6410, 0.4536, 0.0609, 0.1883, -0.5454, -0.5247, 0.1856, 0.7238, 1.4886, -0.1068, 1.7239, -0.8228, -0.2155, 0.5159, 0.2941, -0.0782, -0.0159, 0.1844, -0.1808, -0.1132, 0.4861, 4.0106, 0.0130, 0.2455, -0.1101, 0.0792, 0.4720, -0.1022, 2.0154, -0.4013, 0.5604, 1.3600, -0.5614, 0.3793, -0.1245, 0.2444, 0.1657, 1.7616, 0.6198, 0.1761, -0.6036, -0.1931, 0.4449, 0.2574, -0.2360, 1.1118, 0.0804, 1.1533, 0.2549, 0.3386, 0.2463, 0.0930, -0.6093, -0.1464, 0.2889, 0.2294, -0.5943, 0.1323, 0.5119, 0.1093, -1.0178, 0.4735, 0.3068, 0.3213, -0.0585, -0.3682, -0.6105, -0.7776, 0.1999, 0.9439, -0.4209, 0.1488, 1.3119, -0.4679, -0.3882, 0.2677, -0.1673, -0.5921, -1.2811, -1.0972, 0.3873, 0.0798, -0.0538, 0.0659, -0.1439, -1.3106, -0.5175, 0.4538, -1.0376, -0.9015, 0.7454, -0.0714, -0.4641, 0.2083, 0.0596, -2.9637, 0.3057, 0.2121, -0.2399, 0.6963, 0.1400, 1.7446, 0.9707, -0.3118, -0.3371, 0.0130, 1.0006, -0.2740, 0.1100, -0.9666, 0.7636, 1.2002, -0.0018, -0.3380, 0.1262, 0.5829, -0.0374, 0.0689, 0.2022, -2.0056, -0.2051, -0.4549, 0.0519, 0.4217, -0.7413, 0.0601, 0.4385, 2.8503, -2.7656, 1.2281, -0.1280, 0.6028, 0.4995, 0.0638, -0.3376, 0.2527, -0.1572, -0.4385, -0.6372, 0.2569, 0.4115, 0.4507, 0.6063, -0.1051, 1.2529, 0.2453, -0.7905, -0.3797, -0.2674, 0.2662, 1.5347, -0.3908, 0.8839, -0.6054, -0.4827, -0.3495, 1.2107, -0.4419, -0.6177, 0.1054, 1.0132, -0.3246, -0.1776, 1.1740, -0.0252, 0.0368, -0.7937, -0.9988, -0.0228, 0.0742, -2.4925, 0.5785, 2.3900, 1.2726, -0.3682, -0.8625, -0.3299, 0.3934, 1.4045, -0.6200, -0.0024, 0.2348, -0.1827, -0.5913, -0.6982, 0.2648, 0.2601, 0.9986, 0.1636, 0.8982, -0.4269, 1.7454, -1.9136, -0.9865, -0.0451, 0.2851, -0.5938, -0.3066, 0.0910, -0.3150, -0.4002, 0.4789, 0.0337, -0.6997, -0.2555, -0.6602, -3.0103, 0.2491, -1.0346, 0.3651, 0.2319, 1.0224, -0.2613, 1.6970, 0.7515, 2.1477, 0.1310, 0.2060, 0.1372, 1.0049, -0.8758, -0.3804, -2.1513, 0.8010, -0.2271, -0.2108, 0.3728, -1.7321, -1.0250, -0.2584, -0.2513, 0.2418, -0.7641, 0.2084, -1.3560, 0.5803, 0.1556, -0.3612, 1.3099, -0.2673, 0.4371, -0.8022, 0.1776, -0.5019, 0.1880, -0.2093, 0.0750, -0.7228, -1.3950, 0.1944, -1.5994, -0.2832, 0.0507, 0.1917, 1.2954, 0.0471, 0.3115, -2.2382, -0.3891, -0.0704, 0.3897, 0.0347, 0.9186, -0.8407, 0.9456, 0.5629, 0.3474, -0.4869, 0.4696, -0.4438, 0.0860, -0.8313, -0.0383, 0.2055, 0.4822, -0.1455, -0.1719, -0.2346, -0.4606, 0.8018, 0.3767, -0.0613, 1.9429, -0.6558, -0.0772, -0.1592, -0.1413, 0.4759, -0.0686, 0.9243, -0.2413, -0.1084, -0.2248, -0.0776, 1.4193, -0.0605, 0.1305, -0.2055, 0.0917, 0.6884, -0.0152, 0.1215, 0.2920, -0.0781, -0.0256, 0.3789, -0.1933, 0.1759, 2.3899, 1.0915, -0.7082, -0.4519, -0.2648, -1.2404, -0.2485, 1.0713, 0.1662, -0.1268, 0.3338, -0.0319, 0.1692, -0.5161, 0.9351, 0.1996, -0.2743, 0.0492, -0.0171, 0.1546, 0.2533, -0.0102, 0.6147, 0.0035, -0.2468, -0.2116, -1.7912, 0.2735, 0.4147, 0.4458, 0.6123, 0.0860, 0.2098, -0.3691, -0.2297, -0.6086, -1.0407, -0.7736, -0.3087, -0.0900, -0.1007, -0.3801, -0.3408, -0.4853, -0.3101, -0.8812, 0.0187, -0.9697, -0.2393, 0.1129, -0.5682, 0.4349, 0.1017, 0.2173, -0.0644, -0.9307, 0.9754, 0.2189, 0.2966, -0.4089, -0.2471, -0.7549, 0.3300, 0.7856, 0.1262, 0.2097, -0.5872, 0.9896, 0.5100, 1.0608, -0.7974, 0.1549, -0.1020, 0.4286, 0.0603, -0.6836, -0.4662, -1.2350, -0.0858, -0.5552, 0.0383, 0.2145, -0.4324, -0.5896, 0.9709, -0.0827, -0.2574, 0.2436, -0.1460, 0.5862, 0.4329, -1.2421, 0.0497, -0.0034, 0.2385, -0.1346, 2.0652, 0.8790, -0.2033, -2.6427, 0.3654, -0.1929, -0.0753, -0.9107, 0.9437, 0.3717, -0.7058, -0.2487, -1.0937, -0.7612, 0.9516, -0.7426, -0.0736, 1.2167, 0.6336, 0.2707, -0.7666, -0.1272, -0.8960, 0.3748, 0.7344, 0.7257, 0.3686, -0.5036, -0.2829, 0.0548, 0.3034, -0.2335, -0.3215, 0.0566, -0.2733, -0.3644, 0.0467, -0.0924, -0.5145, -1.7089, 0.4896, 0.0074, 0.2840, 0.1140, -0.0409, -0.3251, 1.0805, 3.0856, -0.3409, 1.2684, -0.0245, -0.0636, -0.0090, 0.1293, -0.3410, -0.0482, 0.1482, 0.2027, 0.5623, 0.0566, 0.6453, -0.0126, 0.0720, -0.0277, 0.0531, 0.1860, -0.1044, -0.6973, 0.3026, 0.4733, -0.1590, 0.4727, 0.8486, 0.4478, 0.1814, 1.0862, 0.0478, 0.2437, -0.5269, -0.0796, -0.4291, 0.4937, -0.0407, -0.6961, -0.0412, 0.6865, 0.0457, 0.1085, -0.4717, -0.1339, 0.8600, 0.6718, -0.3542, -0.5655, 1.3711, 0.0034, 0.3077, 0.0903, 0.3618, 0.3287, -0.1007, 0.0332, -0.3841, -0.3981, 0.1079, -0.4399, 0.1836, 0.0939, -0.1425, -0.2531, -1.2103, 0.0234, -1.3023, -0.0570, -0.0587, 1.1733, 0.0079, 1.0809, 0.4697, -0.1427, 3.3793, -0.1503, 0.4354, 0.0274, 0.3112, -0.3816, 0.0187, -0.1282, -0.4136, 0.3684, 0.6930, 1.3605, 0.4949, 0.4162, -2.2398, 0.4104, 0.6839, 0.4519, 0.0546, -0.0816, 0.0357, 0.1977, -0.8450, 0.1481, 0.1588, -0.1392, -0.3304, -0.3499, -0.8669, 0.1510, 0.1127, 0.9853, -0.3019, -0.3493, -0.0783, -0.8491, 0.0696, 0.7295, -1.0612, 0.1232],
"std": [0.9277, 0.7470, 0.6154, 0.8520, 0.8682, 0.7121, 0.7048, 0.6865, 0.7543, 0.6952, 0.6186, 0.4204, 0.4614, 0.4731, 0.4421, 0.4068, 0.6927, 0.6540, 0.4717, 0.4993, 0.5945, 0.5480, 0.4898, 0.6438, 0.5551, 0.5686, 0.7287, 0.6033, 0.5590, 0.3768, 0.5304, 0.6748, 0.5559, 0.5265, 0.6214, 0.6490, 0.4639, 0.6465, 0.5575, 0.6202, 0.5369, 1.2466, 0.7340, 0.5462, 0.6508, 0.5766, 0.5405, 0.5581, 0.5687, 0.7549, 0.5743, 0.4748, 0.6308, 0.6292, 0.6391, 0.6284, 0.4202, 0.5970, 0.5587, 0.5364, 0.4655, 0.5201, 0.7140, 0.6220, 0.4978, 0.4479, 0.5452, 0.7489, 0.5866, 0.4592, 0.7493, 0.6548, 0.5497, 0.4658, 0.8663, 0.4574, 0.5351, 0.5595, 0.4579, 0.5141, 0.4824, 0.5504, 0.5468, 0.5726, 0.5155, 0.6679, 0.8433, 0.5278, 0.5666, 0.7699, 0.5682, 0.9431, 0.5344, 0.6562, 0.4749, 0.5241, 0.6869, 0.4117, 0.5839, 0.5115, 0.8811, 0.5335, 0.6476, 0.4883, 0.6034, 0.5778, 0.4764, 0.8787, 0.8589, 0.5168, 0.4548, 0.8146, 0.5860, 0.6087, 0.6758, 0.7049, 0.8292, 0.6547, 0.6043, 0.7242, 0.6158, 0.6435, 0.5219, 0.6148, 0.7738, 0.4871, 0.7944, 0.7605, 0.6120, 0.5482, 0.6107, 0.6106, 0.4295, 0.4549, 0.4167, 0.6142, 0.6368, 0.5432, 0.5412, 0.6568, 0.9641, 0.6413, 0.6634, 0.4222, 0.6917, 0.5664, 0.5554, 0.4098, 0.6949, 0.5890, 0.4995, 0.5475, 0.6446, 0.5599, 0.6439, 0.6220, 0.5761, 0.5862, 0.5126, 0.6037, 0.5377, 0.5817, 0.6216, 0.5986, 0.4834, 0.6929, 0.5819, 0.6781, 0.6088, 0.5425, 0.7211, 0.6253, 0.5408, 0.6826, 0.5454, 0.7614, 0.9767, 0.8721, 0.7527, 0.4022, 0.5061, 0.5921, 0.5945, 0.6048, 0.7206, 0.5533, 0.5506, 0.6816, 0.6116, 0.6424, 0.7484, 0.6350, 0.5953, 0.4941, 0.7675, 0.8244, 0.6885, 0.5751, 0.9304, 0.5252, 0.5741, 0.4537, 0.5610, 0.9873, 0.5155, 0.7180, 0.4421, 0.5171, 0.5343, 0.5225, 0.7952, 0.6149, 0.6401, 0.5667, 0.6946, 0.8172, 0.5188, 0.5082, 0.6298, 0.6904, 0.4820, 0.5600, 0.5584, 0.5600, 0.4776, 0.5008, 0.7215, 0.6071, 0.5571, 0.6174, 0.4049, 0.7368, 0.5996, 0.7888, 0.7609, 0.5913, 0.8778, 0.4462, 0.7460, 0.7240, 0.5705, 0.6267, 0.5684, 0.5707, 0.6560, 0.5310, 0.5278, 0.6833, 0.6420, 0.6696, 0.8815, 0.4767, 0.7171, 0.4826, 0.6736, 0.5483, 0.4913, 0.5840, 0.5242, 0.4310, 0.5846, 0.4389, 0.5164, 0.6203, 0.5625, 0.8495, 0.5091, 0.6904, 0.5490, 0.5467, 0.4746, 0.8446, 0.6030, 0.6563, 1.0108, 0.5633, 0.6324, 0.6339, 0.6269, 1.2128, 0.6877, 0.5998, 0.4763, 0.4979, 0.7968, 0.6549, 1.0234, 0.5385, 0.6164, 0.5485, 0.8526, 0.5776, 0.5292, 0.5716, 0.5458, 0.5332, 0.5264, 0.6239, 0.6668, 0.7481, 0.3929, 0.5932, 0.5741, 0.4433, 0.7519, 0.4940, 0.7438, 0.5315, 0.3895, 0.5528, 0.6656, 0.6665, 0.9897, 0.8098, 0.6000, 0.5226, 1.2953, 0.5624, 0.6416, 0.5880, 0.5828, 0.4779, 0.6721, 0.6273, 0.7918, 0.5498, 0.5262, 0.6396, 0.6185, 0.6117, 0.8871, 0.5688, 0.5335, 0.6402, 0.5994, 0.9472, 0.5072, 0.7688, 0.6257, 0.6548, 0.6070, 0.7646, 0.5362, 0.5151, 0.6852, 0.4533, 0.6976, 0.6170, 0.5700, 0.5819, 0.4350, 0.5755, 0.4902, 0.9396, 0.5110, 0.5461, 0.6380, 1.0192, 0.5009, 0.8211, 0.6223, 0.5970, 0.5465, 0.8314, 0.4997, 0.5066, 0.5824, 0.6241, 0.4910, 0.4849, 0.5292, 0.5357, 0.4856, 0.6120, 0.4212, 0.6712, 0.4599, 0.4625, 0.7568, 0.8765, 0.8095, 0.7385, 0.5748, 0.7405, 0.6474, 0.6466, 0.6481, 0.5660, 0.6876, 0.9852, 0.5923, 0.6319, 0.6818, 0.4716, 0.6599, 0.5343, 0.5384, 0.9786, 0.4421, 0.5543, 1.0386, 0.5640, 0.5990, 0.5060, 0.6141, 0.3880, 0.6767, 0.5753, 0.4797, 0.4623, 0.5802, 0.6813, 0.5792, 0.4790, 0.6855, 0.5186, 0.4890, 0.5740, 0.6117, 0.5177, 0.5032, 0.6367, 0.4555, 0.6749, 0.6680, 0.6878, 0.7425, 0.8106, 0.5460, 1.0575, 0.5022, 0.7639, 0.5132, 0.5433, 0.7702, 0.4572, 0.4274, 0.6779, 0.5277, 0.5634, 0.4814, 0.5491, 0.5790, 0.5750, 0.5573, 0.4652, 0.5240, 0.6244, 0.6247, 0.7397, 0.7107, 0.5964, 0.4891, 0.7089, 0.6531, 0.6979, 0.4630, 0.5348, 0.4308, 0.8983, 0.5416, 0.4521, 0.6261, 0.4931, 0.7247, 0.5689, 0.5254, 0.4913, 0.6307, 0.5586, 0.5804, 0.5692, 0.5211, 0.6549, 0.6069, 0.5216, 0.4617, 0.7538, 0.4234, 0.4868, 0.7661, 1.1726, 0.8879, 0.4984, 0.6142, 0.4203, 0.5944, 0.6758, 0.5682, 0.6554, 0.7316, 0.5552, 0.7454, 0.3907, 0.7559, 0.4752, 0.5638, 0.7824, 0.7995, 0.5728, 0.8546, 0.5663, 0.5545, 0.4785, 1.0497, 0.7177, 0.5461, 0.5134, 0.5432, 0.5964, 0.5879, 0.7046, 0.7501, 0.5707, 0.9907, 0.9337, 0.5682, 0.4887, 0.5970, 0.6229, 0.6501, 0.7529, 0.7062, 0.6775, 0.7286, 0.6250, 0.4521, 0.5357, 0.5479, 0.7957, 0.4596, 0.6440, 0.8665, 0.6024, 0.7485, 0.6478, 0.6483, 0.5785, 0.5500, 0.4802, 0.4465, 0.6829, 0.6890, 0.6180, 0.8767, 0.7419, 0.6193, 0.3918, 0.5888, 0.5440, 0.5146, 0.4297, 0.4410, 0.4894, 0.4422, 0.9614, 0.6290, 0.6717, 0.5415, 0.5442, 0.5862, 0.4967, 0.7102, 1.1356, 0.4818, 0.4557, 0.6403, 0.4971, 0.7491, 0.8534, 0.8754, 0.5308, 0.5591, 0.6415, 0.7715, 0.8137, 0.4898, 0.5460, 0.5476, 0.9199, 0.6195, 0.5949, 0.7990, 0.4444, 0.6199, 0.5166, 0.4646, 0.9060, 0.6261, 0.5149, 0.6533, 0.7420, 0.4830, 0.5314, 0.5503, 0.5777, 0.6284, 0.7288, 0.5743, 0.6041, 0.5674, 0.4661, 0.6211, 0.6172, 0.4094, 0.5787, 0.8089, 0.6061, 0.5882, 0.5498, 0.7239, 0.6387, 0.7910, 0.5267, 0.5569, 0.6382, 0.5492, 0.5444, 0.6476, 0.8666, 0.9807, 0.5594, 0.6814, 0.5467, 0.8900, 0.5321, 0.5516, 1.0188, 0.7193, 0.5044, 0.5717, 0.9741, 0.7856, 0.6849, 0.5604, 1.0236, 0.8399, 0.5065, 0.6475, 0.4055, 0.7975, 0.4454, 0.5726, 0.4489, 0.6851, 0.6504, 0.4737, 0.5995, 0.6226, 0.5917, 0.5394, 0.5240, 0.7863, 0.6008, 0.5330, 0.4760, 0.6163, 0.4679, 0.5712, 0.7180, 0.4908, 1.0175, 0.5942, 0.5170, 0.7534, 0.5569, 0.8764, 0.7314, 0.5474, 0.9083, 0.6677, 0.6286, 0.6759, 0.5397, 0.5748, 0.6215, 0.4800, 0.5206, 0.5591, 0.5884, 0.6291, 0.6633, 0.7693, 0.5104, 0.6564, 0.5489, 0.6270, 0.5935, 0.6236, 0.6108, 0.4794, 0.5974, 0.7061, 0.6686, 0.6512, 0.4998, 0.5933, 0.4956, 0.6610, 0.7542, 0.5869, 0.8418, 0.9938, 0.9021, 0.6323, 0.5777, 0.4343, 0.6098, 0.5338, 0.5906, 0.7783, 0.7423, 0.6426, 0.6236, 0.9643, 0.5780, 1.0100, 1.1266, 0.7556, 0.5229, 0.8272, 0.6900, 0.5175, 0.4124, 0.5741, 0.4516, 0.6266, 0.5630, 0.5275, 0.5692, 0.5075, 0.7549, 0.6359, 0.5804, 0.6680, 0.7558, 0.6250, 0.4314, 0.6496, 0.5479, 0.7524, 0.7088, 0.6644, 0.7214, 0.6450, 0.4467, 0.7789, 0.5168, 0.6297, 0.6242, 0.4410, 0.8372, 0.5758, 0.4997, 0.8915, 0.6473, 0.5974, 0.5293, 0.7941, 0.4605, 0.9110, 0.5919, 0.5139, 0.5003, 0.4500, 0.6182, 0.5807, 0.4562, 0.5618, 0.6794, 0.7201, 0.6143, 0.8797, 0.8171, 0.6225, 0.7453, 0.7611, 0.4696, 1.0906, 0.8825, 0.7207, 0.5523, 0.7120, 0.5194, 0.5321, 1.0233, 0.5618, 0.5410, 0.4300, 0.7191, 0.5373, 0.4795, 0.4450, 0.6546, 0.7965, 0.7454, 0.6264, 0.5576, 0.7710, 0.5527, 0.6586, 0.5177, 0.4858, 0.5005, 0.5372, 0.5766, 0.4508, 0.5238, 0.8275, 0.4104, 0.5535, 0.8077, 0.4460, 0.7125, 0.7166, 0.6107, 0.4561, 0.6620, 0.4635, 0.6397, 0.4391, 0.6880, 0.6801, 0.5627, 0.8076, 0.7918, 1.0309, 0.5832, 0.6152, 0.7971, 0.4539, 0.5846, 0.7248, 0.4455, 0.6318, 0.6118, 0.4552, 0.6757, 0.5354, 0.6566, 0.6728, 0.4383, 0.6899, 1.0565, 0.6028, 0.6937, 0.5518, 0.8039, 0.4296, 0.6068, 0.5736, 0.4923, 0.7643, 0.7391, 0.4975, 0.5006, 0.5674, 0.5170, 0.4835, 0.4286, 0.5667, 0.6109, 0.6465, 0.6281, 0.7791, 0.5174, 0.5058, 0.6196, 0.6593, 0.5999, 0.5012, 0.5414, 0.7151, 0.6546, 0.6790, 0.5412, 0.4801, 0.6561, 1.0082, 0.5567, 0.6362, 0.4540, 0.8812, 0.6893, 0.6420, 0.6078, 0.5117, 0.7079, 0.8240, 0.7587, 0.6344, 0.6848, 0.4633, 0.5352, 0.6077, 0.5436, 0.7223, 0.5001, 0.9734, 0.5155, 0.5549, 0.4711, 0.9038, 0.5415, 1.0173, 0.5001, 0.5290, 0.5228, 0.5619, 0.9670, 0.7854, 0.5350, 0.5183, 0.9770, 0.5547, 0.9710, 0.5050, 0.4584, 0.6438, 0.4854, 0.5949, 0.6611, 0.4676, 0.4815, 0.8837, 0.6425, 0.6257, 0.6896, 0.4465, 0.7492, 0.6293, 0.7096, 0.5578, 0.5117, 0.4909, 0.5773, 0.4800, 0.5488, 0.6336, 0.6863, 0.5035, 0.6682, 0.7245, 0.5524, 0.4594, 0.5816, 0.5698, 0.6140, 0.5816, 0.5242, 0.4088, 0.4358, 0.6426, 0.4777, 0.6115, 0.4383, 0.5957, 0.8423, 0.5353, 0.5407, 0.8497, 0.6962, 0.7542, 0.5981, 0.5121, 0.6232, 0.5306, 0.5416, 0.5217, 0.5437, 0.5349, 0.5111, 0.8627, 0.6092, 0.5850, 0.5851, 0.7203, 0.3688, 0.5063, 0.5650, 0.5444, 0.5657, 0.7461, 0.4447, 0.7153, 0.4738, 0.5730, 0.4605, 0.4905, 0.6253, 0.8114, 0.8273, 0.5052, 0.6180, 0.6496, 0.4037, 0.5635, 0.5212, 0.7652, 0.4872, 0.5764, 0.7834, 0.6888, 0.5313, 0.5379, 0.5710, 0.7474, 0.6535, 0.9660, 0.5257, 0.7157, 0.7150, 0.5430, 0.5331, 0.6820, 0.6872, 0.4904, 0.6592, 0.6256, 0.6107, 0.4939, 0.5986, 0.5172, 0.4583],
},
"emdb": {
"count": 62707,
"mean": [-1.1869, 0.1485, 0.1933, -0.6247, 0.0793, 0.5762, 0.1835, -0.2564, 0.1285, 0.3221, 0.0577, 0.1154, -0.0818, -0.2512, 0.9673, -0.5680, 0.5968, -0.2124, -0.0112, -0.5576, 0.5339, -0.1490, 0.3102, -0.4012, -0.0570, 0.6416, 0.9359, -0.2932, 0.8544, 0.1719, -0.4534, 0.1316, 0.8625, 0.3806, 0.4884, 1.0853, -0.3872, -0.2403, -0.4274, 0.1319, -0.3334, 0.6352, 0.5748, -0.8850, -0.4331, 0.3662, -0.3324, 1.3993, -1.5142, -0.3082, -0.5491, -0.1847, 0.0145, -0.0726, 0.0015, -0.0358, -0.2815, -0.4356, -0.3842, 0.1150, 1.1513, 0.6343, -0.7336, -1.1613, 0.1020, -0.1291, 0.1560, 0.4854, -0.4191, 1.6794, 0.4274, 0.4792, 0.3570, 0.0811, 1.0886, 0.0670, 0.5227, 0.1891, 0.1121, 0.1495, -0.2090, -0.2156, -0.2512, -0.9291, 0.1287, -0.0481, 0.6701, -0.4579, 0.2352, -0.1056, 0.5551, 0.4357, 0.8168, 0.6344, -0.6445, -0.1965, 0.5587, 0.3860, -0.2466, -0.1542, 0.6825, 0.5875, -0.5208, 0.1500, -0.3980, 0.2157, 0.8368, -0.1356, -0.3387, 0.1747, 0.1467, 0.2282, -0.1412, 0.6216, -1.8406, 0.0150, 0.2891, 0.0280, 0.0461, 0.8558, 0.2929, -1.3753, -0.5792, 0.2089, -0.3524, -0.1849, -0.0157, 0.4454, -0.5306, 0.8238, -0.3160, 0.3760, 0.8978, -0.1943, -0.9474, -1.7321, -0.0149, 0.2338, 0.6087, -0.4851, 0.5210, -0.4042, -0.5368, -0.6220, 0.1245, 0.3112, 0.6360, -0.1522, 0.0540, -0.2380, -0.8354, 1.7591, 0.5687, 0.1732, 0.7923, -0.5383, -0.3271, -2.0050, -0.5563, 0.2979, 1.6609, 0.7108, -1.0155, 0.3591, 0.0136, -0.4743, -0.5401, -0.0176, 1.3333, -0.2973, -0.1114, -0.1616, 0.1160, 0.1152, 0.0057, 0.2067, 0.3876, -1.5311, 0.0636, 0.4566, -0.2653, 1.0534, -0.4638, 0.2166, 0.8686, -0.1447, 0.5605, -0.3841, 0.7015, 0.0418, 0.0811, -0.6406, -0.2929, -0.6821, 1.3678, 0.7574, 0.8315, 2.0377, 4.9034, -0.0097, 0.0165, 0.3248, 0.2994, 0.0210, 0.2276, -0.6580, -0.6899, 0.1981, -2.3205, 0.0059, -0.9412, -0.3191, 0.0389, -0.4170, 0.3391, -0.1346, 0.1567, 0.1838, -0.4176, -0.2758, 0.1495, -0.2977, 0.0929, 0.7186, 0.1230, 0.8780, -0.1240, -0.7370, -0.7551, 0.3830, 1.0824, 1.4500, -0.1040, 1.4225, 0.0929, 0.4612, 0.5167, -0.7093, -0.4729, 0.2321, 0.4156, -0.0696, -0.0626, 1.3341, -0.2398, 0.8453, 0.4048, 0.1690, 0.0074, -0.0474, 0.4134, 0.2043, -0.5962, 0.1643, -0.3821, 0.3012, -0.5690, 0.0133, 0.1876, -0.0727, 0.2896, 0.3253, 0.0313, 0.5141, -0.0055, -1.2889, -0.0983, -0.3212, -0.4173, -0.0804, 0.2591, -0.4160, -0.4815, 2.2822, -1.0033, -0.9814, 0.5290, 1.7943, -0.4217, -0.0373, -3.3970, 3.3067, 0.1174, -0.1369, 0.3847, -0.6960, -0.8867, -0.3825, -0.0134, -0.4367, -1.0273, -0.0623, 0.1520, 0.3816, -0.6543, -0.0118, -0.3019, -0.1190, 1.0490, 0.6255, 0.8503, 0.9500, -1.1942, 1.6886, -1.3958, 0.9389, 0.2318, -0.0460, 0.1140, -0.2352, -0.5648, 0.0363, -0.5636, 0.0661, -0.8680, -0.1223, -6.5336, 0.2139, -0.2734, 1.1739, 0.6003, 0.2183, 0.2154, -0.5902, -0.2916, -0.2748, 0.0787, 0.9065, -0.9764, -0.2278, 1.6248, 0.7941, -0.5014, 0.2422, -2.1474, 0.7818, 0.4370, 0.1361, -0.3936, -0.7724, 0.0941, -0.5762, 3.2182, -0.1101, 0.2677, -0.0101, -1.1798, -0.0122, -0.8163, 0.1115, -0.1697, -0.1466, -0.3549, 0.5360, -0.5183, 0.7519, 0.7093, -0.5946, 0.2787, 0.4822, -0.2680, 0.0934, 0.1483, 0.6706, -0.1150, -0.1945, -2.6643, 0.2194, -0.5014, -0.5869, 0.1022, 0.1988, -0.2558, 0.3732, -0.0644, 0.6440, -0.7403, -1.0228, 0.8158, 0.9543, -0.1226, -0.0929, 0.2716, 0.7962, -0.5293, 0.1538, -1.2074, -0.5093, 0.2037, 0.2156, -0.4407, 2.6976, -0.3653, 0.0458, -0.0899, -0.7584, 1.8329, -0.5082, -0.4776, -0.0265, -2.9437, -0.1675, 1.2358, 0.1571, -0.5022, -0.6370, 0.4087, -0.9664, 0.3533, 0.0928, -0.5308, 0.4462, 0.2476, 0.0976, -1.8347, 0.0468, -0.9309, -0.3712, -0.8578, -0.0568, 1.7377, -0.1299, -0.7187, 0.9764, 0.6858, 0.4272, -0.9588, 0.1038, 0.2520, -1.3775, 0.1491, -0.8507, 0.7052, 0.6483, 0.2818, -0.3305, -0.5913, -0.0907, -0.2438, -0.1932, -0.0564, -0.0777, -0.0748, 0.6530, 0.2393, 0.4476, 0.3941, -1.7061, 0.8876, 1.1888, 0.1423, 0.1737, 0.1330, 0.1115, 0.1525, -0.3715, 0.4657, -0.4010, -0.3089, 2.0455, -0.9555, 0.5093, 0.1502, -0.0865, -0.7851, -0.5175, 0.1613, 0.8113, 1.1943, 0.0612, 1.7087, -1.1616, -0.3204, 0.4428, 0.6120, -0.2282, 0.0174, -0.3141, -0.0045, 0.2204, 0.3966, 4.1174, -0.1531, 0.4325, -0.0245, -0.0310, 0.6541, 0.2904, 1.9309, -0.5405, 0.8576, 1.0352, -0.3592, -0.1056, -0.0047, 0.7218, 0.2350, 1.8817, 0.7558, -0.1575, -0.0544, 0.0234, 0.5841, 0.0996, -0.0503, 1.4150, 0.2260, 0.9152, 0.0688, 0.5286, 0.5885, 0.4606, -0.9186, 0.0441, 0.5233, 0.5305, -0.9086, 0.3728, 0.6752, 0.5453, -1.1360, 0.0613, -0.2365, 0.8856, -0.0512, -0.2589, -0.7055, -0.8111, 0.1787, 1.0393, -0.2469, -0.0922, 1.1790, -0.3284, 0.0402, 0.0746, -0.1033, -0.7248, -1.3859, -1.0511, 0.2797, 0.2777, -0.0877, 0.0271, 0.0740, -1.5863, -0.7014, 0.3677, -1.6786, -1.0769, 0.5594, 0.2428, -0.2664, 0.3454, -0.0490, -3.3762, 0.2004, 0.1913, -0.6461, 0.7643, -0.1239, 1.6487, 0.4942, -0.3305, -0.5069, -0.2183, 1.1533, -0.4380, 0.0219, -0.6319, 0.6743, 1.0648, 0.0587, -0.0989, -0.0995, 0.3757, 0.1813, 0.2854, 0.4345, -2.2154, 0.3601, -0.6406, -0.1099, 0.3583, -0.3726, 0.2892, 0.5897, 3.4282, -2.8781, 0.8985, 0.1550, 0.1102, 0.8008, -0.0811, -0.4199, 0.3145, -0.3236, -0.2425, -0.4502, 0.2431, 0.8504, 0.4597, 0.6396, 0.0902, 1.3885, 0.1297, -1.1721, -0.3227, -0.4472, 0.2575, 1.6201, -0.5444, 0.8665, -0.9622, 0.0035, -0.5908, 1.6270, 0.0351, -0.3419, 0.0039, 1.1001, -0.3767, -0.2270, 1.3332, 0.3555, 0.0667, -0.5392, -1.3500, -0.0842, 0.2591, -2.8862, 0.3166, 2.3757, 1.1254, -0.5208, -0.7074, -0.8110, 0.3715, 1.3720, -0.7236, -0.0665, 0.2772, -0.2840, -0.3515, -0.4777, 0.3030, 0.5417, 0.7752, -0.0182, 1.1569, -0.1614, 1.6521, -2.2844, -0.9332, -0.1472, 0.6151, -0.5020, -0.0719, 0.3361, -0.2722, -0.1500, 0.5092, -0.0348, -0.6530, -0.4159, -0.6603, -3.6738, 0.1421, -1.1267, 0.4267, 0.0699, 1.6415, 0.1451, 1.3309, 0.7792, 2.1801, -0.0886, 0.4233, 0.2828, 1.3708, -1.2021, -0.2627, -2.1505, 0.7701, -0.0167, -0.0247, 0.4665, -1.5951, -0.9997, -0.1568, -0.1108, 0.1543, -1.0055, 0.0001, -1.0355, 0.8421, -0.0485, -0.3064, 1.2358, -0.0448, 0.4038, -0.7671, 0.3624, -0.6197, 0.7966, -0.2266, 0.1130, -0.5302, -1.5468, 0.0700, -1.1711, -0.3307, 0.0086, -0.0416, 1.2763, -0.0574, 0.0121, -2.6334, -0.3180, -0.1954, 0.3944, 0.0076, 1.2025, -0.5634, 0.9271, 0.4198, 0.3251, -0.0041, 0.5236, -0.5314, 0.0639, -0.8840, -0.2680, 0.4958, 0.7804, 0.2942, -0.1935, -0.1405, -0.5670, 0.9489, 0.5726, -0.2529, 1.8878, -0.7204, -0.0050, -0.2448, 0.1725, 0.4253, 0.0058, 1.0247, -0.2908, -0.3978, -0.0963, 0.2107, 1.3576, 0.3074, 0.5527, -0.0927, 0.1521, 0.6300, -0.1377, -0.0497, 0.0425, -0.2248, -0.1534, 0.5778, 0.0033, 0.1789, 2.4935, 1.3225, -0.8038, -0.8864, 0.1176, -1.0532, -0.2375, 1.4582, -0.1168, 0.0548, 0.4221, -0.3585, 0.4043, -0.4371, 1.3289, -0.3674, -0.4286, -0.1730, 0.0535, 0.1441, 0.2703, 0.3826, 0.5123, -0.0401, -0.1230, -0.3143, -1.7583, 0.2582, 0.3484, 0.5722, 0.8621, 0.4420, 0.4442, -0.2445, 0.0532, -0.8102, -1.4058, -0.6382, -0.5799, -0.2456, -0.0906, -0.3191, -0.3395, -0.4364, -0.5810, -0.7970, 0.0831, -1.1570, -0.2573, -0.0644, -0.7106, 0.1313, 0.1944, -0.2329, 0.1409, -1.2096, 1.0822, 0.5523, 0.2151, -0.1106, -0.1034, -0.4873, 0.6932, 1.0196, -0.0521, 0.0569, -0.8759, 1.0084, 0.6800, 1.0768, -1.2878, -0.1161, 0.0447, 0.1888, -0.2371, -1.0470, -0.4027, -1.4363, 0.1606, -0.8026, -0.0244, -0.2893, -0.4938, -0.6921, 1.0140, -0.4158, -0.5957, 0.3313, -0.2462, 0.7703, 0.3403, -1.5113, -0.1231, -0.3776, 0.3326, 0.1634, 2.1520, 0.7302, -0.0300, -2.8234, 0.4553, -0.4652, -0.3331, -1.0286, 1.2882, -0.2797, -0.4759, 0.1470, -1.0253, -0.8175, 0.6936, -0.3728, -0.4594, 1.0876, 0.6229, -0.0461, -0.4342, -0.1686, -1.3960, 0.5283, 0.4002, 0.8179, 0.4787, -0.7147, -0.5052, -0.2552, 0.2817, -0.4022, -0.5289, 0.0815, -0.4814, -0.5451, -0.1384, -0.4303, -0.4506, -1.9036, 0.6884, 0.1361, 0.2678, -0.0052, 0.0119, -0.1882, 1.0507, 3.1094, -0.5746, 1.3087, -0.1831, -0.1917, 0.0633, 0.5083, -0.1448, -0.0134, 0.5002, 0.2579, 0.7755, 0.1579, 0.4157, -0.2610, -0.4953, 0.1709, 0.4063, 0.2068, 0.2666, -0.7872, 0.5325, 0.4910, -0.1599, 0.4387, 0.9262, 0.9245, 0.5763, 0.9292, -0.4531, -0.5367, -0.4911, 0.2302, -0.4182, 0.7188, 0.0342, -0.2079, 0.1310, 0.5718, -0.0331, 0.1861, -0.1287, -0.0427, 0.8478, 0.7278, -0.5664, -0.5335, 1.3976, 0.1697, 0.6063, -0.0220, 0.4921, -0.1349, -0.0531, -0.2408, -0.3858, -0.2741, 0.2285, -0.5532, 0.2704, -0.2687, -0.2161, -0.1179, -1.5228, -0.3683, -1.3004, 0.2431, -0.3305, 1.6118, -0.0328, 1.1503, 0.5712, -0.0423, 3.4830, -0.2760, 0.6307, -0.0419, 0.1553, -0.5602, 0.2106, -0.2213, -0.4543, 0.3034, 0.9189, 1.5738, 0.5071, 0.2238, -2.2069, 0.4104, 0.6224, 0.2836, -0.1620, -0.3043, -0.4012, 0.2410, -0.6261, -0.2435, 0.0211, -0.2227, -0.2392, -0.3634, -0.9207, 0.2260, 0.0929, 0.8206, -0.3214, -0.2296, 0.1274, -0.8615, 0.2329, 1.1085, -1.0565, 0.2258],
"std": [0.9963, 0.6391, 0.4956, 0.6280, 0.7591, 0.5610, 0.8236, 0.7139, 0.7494, 0.5686, 0.5042, 0.3464, 0.4228, 0.4171, 0.3526, 0.3710, 0.6288, 0.4674, 0.4413, 0.4741, 0.6553, 0.4882, 0.3697, 0.5507, 0.4961, 0.3683, 0.5604, 0.5302, 0.6027, 0.3023, 0.4882, 0.5746, 0.5314, 0.5031, 0.6145, 0.5994, 0.4285, 0.6399, 0.5362, 0.5403, 0.4677, 1.2902, 0.6126, 0.4145, 0.5068, 0.4667, 0.4825, 0.4275, 0.4381, 0.6758, 0.4866, 0.4136, 0.5262, 0.5698, 0.6550, 0.6492, 0.3450, 0.5948, 0.4219, 0.4973, 0.4483, 0.4336, 0.7440, 0.4595, 0.4366, 0.3634, 0.4430, 0.6587, 0.5073, 0.3533, 0.7036, 0.7039, 0.5312, 0.4701, 0.7512, 0.4102, 0.4227, 0.4488, 0.4158, 0.4676, 0.4521, 0.4560, 0.3917, 0.4757, 0.4348, 0.6013, 0.6715, 0.5179, 0.4834, 0.7451, 0.4845, 0.8893, 0.4188, 0.5963, 0.4306, 0.4551, 0.6417, 0.2886, 0.5378, 0.4316, 0.7568, 0.4818, 0.5494, 0.4736, 0.5841, 0.5043, 0.4265, 0.6994, 0.7652, 0.4344, 0.3931, 0.7198, 0.4169, 0.5794, 0.6720, 0.5694, 0.8603, 0.5307, 0.5893, 0.5763, 0.5292, 0.5228, 0.4156, 0.4901, 0.8334, 0.4574, 0.7241, 0.5346, 0.4063, 0.4147, 0.4979, 0.6599, 0.4173, 0.3715, 0.3828, 0.4492, 0.5576, 0.4060, 0.4353, 0.5315, 0.9834, 0.5548, 0.5679, 0.3506, 0.5419, 0.4256, 0.4187, 0.3570, 0.6316, 0.5870, 0.4832, 0.4862, 0.6072, 0.6781, 0.6152, 0.6708, 0.5008, 0.4435, 0.4229, 0.4973, 0.4301, 0.5363, 0.5478, 0.5388, 0.3952, 0.5961, 0.4721, 0.6389, 0.4450, 0.4841, 0.5594, 0.5234, 0.5224, 0.6326, 0.4469, 0.7397, 0.9551, 0.8426, 0.7576, 0.3893, 0.4382, 0.5222, 0.5234, 0.6035, 0.5764, 0.4043, 0.4741, 0.5471, 0.4229, 0.5962, 0.7127, 0.6205, 0.5671, 0.3766, 0.7455, 0.7315, 0.5891, 0.5372, 0.5957, 0.5342, 0.4010, 0.4453, 0.4609, 0.8789, 0.4353, 0.6297, 0.4126, 0.4149, 0.4597, 0.4859, 0.6733, 0.6096, 0.5719, 0.4494, 0.6353, 0.7537, 0.4643, 0.4577, 0.6485, 0.6069, 0.3603, 0.5821, 0.4807, 0.5192, 0.5329, 0.4153, 0.7329, 0.5444, 0.5742, 0.4593, 0.4003, 0.6770, 0.5428, 0.6781, 0.7920, 0.5037, 0.7615, 0.4537, 0.5931, 0.7333, 0.4880, 0.5469, 0.4698, 0.4917, 0.6256, 0.4947, 0.3974, 0.7559, 0.5916, 0.6547, 0.7502, 0.4682, 0.4517, 0.4888, 0.6472, 0.4755, 0.3927, 0.5845, 0.4135, 0.4091, 0.5860, 0.4544, 0.4051, 0.5547, 0.5322, 0.7200, 0.4595, 0.5484, 0.4758, 0.5259, 0.4137, 0.7149, 0.5638, 0.6221, 0.9309, 0.5637, 0.5657, 0.5711, 0.5651, 1.0484, 0.4435, 0.4587, 0.3716, 0.4108, 0.8114, 0.5531, 1.0675, 0.5825, 0.3841, 0.4500, 0.7335, 0.4767, 0.4162, 0.5679, 0.4880, 0.4614, 0.5118, 0.5198, 0.5619, 0.6869, 0.3536, 0.5128, 0.4722, 0.3722, 0.7705, 0.4556, 0.5365, 0.4999, 0.3254, 0.5268, 0.7580, 0.5932, 0.9908, 0.6171, 0.4912, 0.4439, 0.9135, 0.4658, 0.6566, 0.5500, 0.5423, 0.4725, 0.5415, 0.5550, 0.7519, 0.4220, 0.6024, 0.4821, 0.5268, 0.4583, 0.7421, 0.5200, 0.4541, 0.5197, 0.4562, 0.8381, 0.4423, 0.7400, 0.6578, 0.6459, 0.5316, 0.6877, 0.5362, 0.4215, 0.6455, 0.4363, 0.6716, 0.5795, 0.5587, 0.5234, 0.4456, 0.4991, 0.4244, 0.8959, 0.4744, 0.4440, 0.4437, 0.8485, 0.4237, 0.6907, 0.5582, 0.4315, 0.5458, 0.7341, 0.4731, 0.5065, 0.6181, 0.5643, 0.4407, 0.4353, 0.4732, 0.3769, 0.4162, 0.5028, 0.3689, 0.6656, 0.4598, 0.3735, 0.6801, 0.7902, 0.7101, 0.6292, 0.5732, 0.7452, 0.6803, 0.5065, 0.5261, 0.4644, 0.5021, 0.6714, 0.5226, 0.4455, 0.7599, 0.4380, 0.5468, 0.4595, 0.5308, 0.8445, 0.4413, 0.5196, 0.9241, 0.5414, 0.5018, 0.3832, 0.4950, 0.3185, 0.5330, 0.4844, 0.4481, 0.4517, 0.5104, 0.6092, 0.5712, 0.4164, 0.6590, 0.4888, 0.3930, 0.5419, 0.5486, 0.5165, 0.4390, 0.5542, 0.3883, 0.4074, 0.6213, 0.6185, 0.7711, 0.6565, 0.4925, 1.0624, 0.4690, 0.7498, 0.5333, 0.5290, 0.6258, 0.4473, 0.3862, 0.6571, 0.4873, 0.5240, 0.4127, 0.4445, 0.5094, 0.4754, 0.5769, 0.4786, 0.4510, 0.5130, 0.4897, 0.7568, 0.7398, 0.5718, 0.4229, 0.4929, 0.7470, 0.5901, 0.3772, 0.4914, 0.4074, 0.9471, 0.4967, 0.4323, 0.5259, 0.3591, 0.7202, 0.6012, 0.4573, 0.4296, 0.5578, 0.5218, 0.4640, 0.4522, 0.4029, 0.8071, 0.6086, 0.4832, 0.4202, 0.6781, 0.3862, 0.3920, 0.7543, 1.0257, 0.8849, 0.4181, 0.4722, 0.4069, 0.4854, 0.5405, 0.4676, 0.5547, 0.6282, 0.4275, 0.8011, 0.3308, 0.7135, 0.4315, 0.4915, 0.6616, 0.7376, 0.5742, 0.7461, 0.5443, 0.4749, 0.4906, 1.0020, 0.6306, 0.4435, 0.4559, 0.4360, 0.4047, 0.5802, 0.6109, 0.7836, 0.5163, 0.9777, 0.9272, 0.4618, 0.3534, 0.5218, 0.4479, 0.6498, 0.7145, 0.6224, 0.5671, 0.5042, 0.3885, 0.4079, 0.4481, 0.5406, 0.6944, 0.3744, 0.5942, 0.6770, 0.5934, 0.7417, 0.5662, 0.4753, 0.5063, 0.5003, 0.4510, 0.4358, 0.6455, 0.7740, 0.4780, 0.8687, 0.5533, 0.5700, 0.3518, 0.4868, 0.4154, 0.4798, 0.3266, 0.3536, 0.3789, 0.3805, 0.7909, 0.5760, 0.5784, 0.4993, 0.5787, 0.5324, 0.4496, 0.8483, 1.0794, 0.4820, 0.4135, 0.6231, 0.4668, 0.6684, 0.7052, 0.7616, 0.4881, 0.4150, 0.5793, 0.8068, 0.7793, 0.4721, 0.5230, 0.4810, 0.9577, 0.5537, 0.5583, 0.6645, 0.4334, 0.6398, 0.5011, 0.4081, 0.6255, 0.5372, 0.4846, 0.6125, 0.6509, 0.4413, 0.4762, 0.4917, 0.5940, 0.4950, 0.6753, 0.6653, 0.5210, 0.5599, 0.4678, 0.4868, 0.5985, 0.4160, 0.4874, 0.8380, 0.5382, 0.5701, 0.5448, 0.6131, 0.5674, 0.7120, 0.4070, 0.4434, 0.5725, 0.4919, 0.4805, 0.5997, 0.7108, 0.9824, 0.4765, 0.7575, 0.4452, 0.8892, 0.4639, 0.4962, 1.0346, 0.7584, 0.4312, 0.4835, 0.8968, 0.4799, 0.6864, 0.5641, 1.0694, 0.6750, 0.4288, 0.5159, 0.3649, 0.7699, 0.4386, 0.4449, 0.3923, 0.6499, 0.5612, 0.4541, 0.6261, 0.5444, 0.4369, 0.4124, 0.4174, 0.6129, 0.5005, 0.4779, 0.3929, 0.4865, 0.4338, 0.4114, 0.6266, 0.3669, 1.0147, 0.4856, 0.4867, 0.6250, 0.5368, 0.6699, 0.6411, 0.5296, 0.7614, 0.5643, 0.5843, 0.6846, 0.3923, 0.3928, 0.4964, 0.4490, 0.4755, 0.4104, 0.5468, 0.6040, 0.5808, 0.6283, 0.4316, 0.6127, 0.4635, 0.5303, 0.4261, 0.4668, 0.6121, 0.4063, 0.5571, 0.6130, 0.5874, 0.4987, 0.4113, 0.5401, 0.4028, 0.6598, 0.7740, 0.5384, 0.7890, 0.9379, 0.8801, 0.6222, 0.5356, 0.3990, 0.4802, 0.4107, 0.5475, 0.6936, 0.6865, 0.4776, 0.5211, 0.8844, 0.6517, 1.0729, 0.9252, 0.6953, 0.4177, 0.7587, 0.6628, 0.3629, 0.3685, 0.3758, 0.4439, 0.5236, 0.4905, 0.5290, 0.4184, 0.3940, 0.6498, 0.5411, 0.5662, 0.5519, 0.6107, 0.6385, 0.4127, 0.6277, 0.5255, 0.5926, 0.5653, 0.6570, 0.6034, 0.5312, 0.4128, 0.7292, 0.3620, 0.5067, 0.5314, 0.3908, 0.7561, 0.4494, 0.4501, 0.7682, 0.4939, 0.4198, 0.5256, 0.6339, 0.5123, 0.9018, 0.5054, 0.4879, 0.4567, 0.4145, 0.6046, 0.3835, 0.4289, 0.5254, 0.6191, 0.6610, 0.5933, 0.7890, 0.7817, 0.6299, 0.5977, 0.7094, 0.3737, 1.0318, 0.7045, 0.7785, 0.5376, 0.5861, 0.4233, 0.5538, 1.0604, 0.5690, 0.5249, 0.3747, 0.6036, 0.4707, 0.3617, 0.3665, 0.6184, 0.4878, 0.6193, 0.5311, 0.6187, 0.6748, 0.4493, 0.6137, 0.4601, 0.3855, 0.4183, 0.4986, 0.4832, 0.4192, 0.4416, 0.7202, 0.3724, 0.4899, 0.6939, 0.4272, 0.7122, 0.6950, 0.5565, 0.4417, 0.6186, 0.4753, 0.5919, 0.3763, 0.5643, 0.5347, 0.5454, 0.9336, 0.6594, 0.9747, 0.4970, 0.4725, 0.7820, 0.4113, 0.4942, 0.6699, 0.4159, 0.6766, 0.6564, 0.3947, 0.5381, 0.3874, 0.6686, 0.5628, 0.3904, 0.6647, 0.9821, 0.4343, 0.5455, 0.4879, 0.8165, 0.4153, 0.5544, 0.5179, 0.3821, 0.6678, 0.7883, 0.3372, 0.4702, 0.5044, 0.4584, 0.4769, 0.3787, 0.4377, 0.5435, 0.5899, 0.5378, 0.5986, 0.4887, 0.5390, 0.5464, 0.6330, 0.5010, 0.4244, 0.5249, 0.6770, 0.6314, 0.6404, 0.4605, 0.3649, 0.6489, 1.0657, 0.5497, 0.5357, 0.3651, 0.8484, 0.8126, 0.4873, 0.6711, 0.4401, 0.6181, 0.8585, 0.6000, 0.5654, 0.5416, 0.3504, 0.4671, 0.5499, 0.4409, 0.7650, 0.4980, 0.9734, 0.3568, 0.6037, 0.4361, 0.7880, 0.4726, 0.9902, 0.5020, 0.5178, 0.5065, 0.4543, 0.9039, 0.8296, 0.4451, 0.4436, 0.8518, 0.5201, 0.8668, 0.5122, 0.3412, 0.5849, 0.4815, 0.5795, 0.5664, 0.4384, 0.4593, 0.7974, 0.6570, 0.6522, 0.5490, 0.4195, 0.6821, 0.6133, 0.5692, 0.4780, 0.4574, 0.5090, 0.4488, 0.4269, 0.4153, 0.5143, 0.6560, 0.4480, 0.5482, 0.6997, 0.4377, 0.4166, 0.6103, 0.4671, 0.4449, 0.5672, 0.3296, 0.3898, 0.3778, 0.6572, 0.5555, 0.4047, 0.3720, 0.5728, 0.6867, 0.5435, 0.5001, 0.6808, 0.6373, 0.6849, 0.4826, 0.4767, 0.3736, 0.5070, 0.4442, 0.4302, 0.4339, 0.4614, 0.4735, 0.7977, 0.5657, 0.4047, 0.5261, 0.6204, 0.3413, 0.3996, 0.4236, 0.3303, 0.4193, 0.6074, 0.3941, 0.4802, 0.4114, 0.3880, 0.3460, 0.3767, 0.6491, 0.6893, 0.8560, 0.4244, 0.4307, 0.5702, 0.3635, 0.5170, 0.3975, 0.6187, 0.5012, 0.4976, 0.7149, 0.7001, 0.4834, 0.3844, 0.5179, 0.6909, 0.5862, 1.0062, 0.5099, 0.6410, 0.7432, 0.4219, 0.4655, 0.6067, 0.6674, 0.4618, 0.7115, 0.5300, 0.5284, 0.4208, 0.4955, 0.4561, 0.3723],
}
}
cam_angvel = {
"emdb_none_test": {
"count": 42622,
"mean": [1., 0., 0., 0., 1., 0.],
"std": [5.5702e-05, 3.2200e-03, 5.6530e-03, 3.2191e-03, 2.4738e-05, 3.3406e-03],
},
"manual": {
"mean": [1., 0., 0., 0., 1., 0.],
"std": [0.001, 0.1, 0.1, 0.1, 0.001, 0.1], # manually
}
}
# fmt:on
# ====== Compose ====== #
def compose(targets, sources):
if len(sources) == 1:
sources = sources * len(targets)
mean = []
std = []
for t, s in zip(targets, sources):
mean.extend(t[s]["mean"])
std.extend(t[s]["std"])
return {"mean": mean, "std": std}
DEFAULT_01 = {"mean": [0.0], "std": [1.0]}
MM_V1 = compose(
[body_pose_r6d, betas, global_orient_c_r6d, global_orient_gv_r6d, local_transl_vel],
["bedlam"] * 5,
)
MM_V1_AMASS_LOCAL_BEDLAM_CAM = compose(
[body_pose_r6d, betas, global_orient_c_r6d, global_orient_gv_r6d, local_transl_vel],
["amass", "amass", "bedlam", "bedlam", "amass"],
)
MM_V2 = compose(
[body_pose_r6d, betas, global_orient_c_r6d, global_orient_gv_r6d, local_transl_vel],
["bedlam", "bedlam", "bedlam", "bedlam", "none"],
)
MM_V2_1 = compose(
[body_pose_r6d, betas, global_orient_c_r6d, global_orient_gv_r6d, local_transl_vel],
["bedlam", "bedlam", "bedlam", "bedlam", "1e-2"],
)
================================================
FILE: eval/GVHMR/hmr4d/network/base_arch/embeddings/rotary_embedding.py
================================================
import torch
import torch.nn as nn
from einops import repeat, rearrange
from torch.cuda.amp import autocast
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
@autocast(enabled=False)
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
if t.ndim == 3:
seq_len = t.shape[seq_dim]
freqs = freqs[-seq_len:].to(t)
rot_dim = freqs.shape[-1]
end_index = start_index + rot_dim
assert (
rot_dim <= t.shape[-1]
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
return torch.cat((t_left, t, t_right), dim=-1)
def get_encoding(d_model, max_seq_len=4096):
"""Return: (L, D)"""
t = torch.arange(max_seq_len).float()
freqs = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
freqs = torch.einsum("i, j -> i j", t, freqs)
freqs = repeat(freqs, "i j -> i (j r)", r=2)
return freqs
class ROPE(nn.Module):
"""Minimal impl of a lang-style positional encoding."""
def __init__(self, d_model, max_seq_len=4096):
super().__init__()
self.d_model = d_model
self.max_seq_len = max_seq_len
# Pre-cache a freqs tensor
encoding = get_encoding(d_model, max_seq_len)
self.register_buffer("encoding", encoding, False)
def rotate_queries_or_keys(self, x):
"""
Args:
x : (B, H, L, D)
Returns:
rotated_x: (B, H, L, D)
"""
seq_len, d_model = x.shape[-2:]
assert d_model == self.d_model
# encoding: (L, D)s
if seq_len > self.max_seq_len:
encoding = get_encoding(d_model, seq_len).to(x)
else:
encoding = self.encoding[:seq_len]
# encoding: (L, D)
# x: (B, H, L, D)
rotated_x = apply_rotary_emb(encoding, x, seq_dim=-2)
return rotated_x
================================================
FILE: eval/GVHMR/hmr4d/network/base_arch/transformer/encoder_rope.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from timm.models.vision_transformer import Mlp
from typing import Optional, Tuple
from einops import einsum, rearrange, repeat
from hmr4d.network.base_arch.embeddings.rotary_embedding import ROPE
class RoPEAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.rope = ROPE(self.head_dim, max_seq_len=4096)
self.query = nn.Linear(embed_dim, embed_dim)
self.key = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
self.proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x, attn_mask=None, key_padding_mask=None):
# x: (B, L, C)
# attn_mask: (L, L)
# key_padding_mask: (B, L)
B, L, _ = x.shape
xq, xk, xv = self.query(x), self.key(x), self.value(x)
xq = xq.reshape(B, L, self.num_heads, -1).transpose(1, 2)
xk = xk.reshape(B, L, self.num_heads, -1).transpose(1, 2)
xv = xv.reshape(B, L, self.num_heads, -1).transpose(1, 2)
xq = self.rope.rotate_queries_or_keys(xq) # B, N, L, C
xk = self.rope.rotate_queries_or_keys(xk) # B, N, L, C
attn_score = einsum(xq, xk, "b n i c, b n j c -> b n i j") / math.sqrt(self.head_dim)
if attn_mask is not None:
attn_mask = attn_mask.reshape(1, 1, L, L).expand(B, self.num_heads, -1, -1)
attn_score = attn_score.masked_fill(attn_mask, float("-inf"))
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.reshape(B, 1, 1, L).expand(-1, self.num_heads, L, -1)
attn_score = attn_score.masked_fill(key_padding_mask, float("-inf"))
attn_score = torch.softmax(attn_score, dim=-1)
attn_score = self.dropout(attn_score)
output = einsum(attn_score, xv, "b n i j, b n j c -> b n i c") # B, N, L, C
output = output.transpose(1, 2).reshape(B, L, -1) # B, L, C
output = self.proj(output) # B, L, C
return output
class EncoderRoPEBlock(nn.Module):
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, dropout=0.1, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
self.attn = RoPEAttention(hidden_size, num_heads, dropout)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=dropout)
self.gate_msa = nn.Parameter(torch.zeros(1, 1, hidden_size))
self.gate_mlp = nn.Parameter(torch.zeros(1, 1, hidden_size))
# Zero-out adaLN modulation layers
nn.init.constant_(self.gate_msa, 0)
nn.init.constant_(self.gate_mlp, 0)
def forward(self, x, attn_mask=None, tgt_key_padding_mask=None):
x = x + self.gate_msa * self._sa_block(
self.norm1(x), attn_mask=attn_mask, key_padding_mask=tgt_key_padding_mask
)
x = x + self.gate_mlp * self.mlp(self.norm2(x))
return x
def _sa_block(self, x, attn_mask=None, key_padding_mask=None):
# x: (B, L, C)
x = self.attn(x, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
return x
================================================
FILE: eval/GVHMR/hmr4d/network/base_arch/transformer/layer.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
================================================
FILE: eval/GVHMR/hmr4d/network/gvhmr/relative_transformer.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import einsum, rearrange, repeat
from hmr4d.configs import MainStore, builds
from hmr4d.network.base_arch.transformer.encoder_rope import EncoderRoPEBlock
from hmr4d.network.base_arch.transformer.layer import zero_module
from hmr4d.utils.net_utils import length_to_mask
from timm.models.vision_transformer import Mlp
class NetworkEncoderRoPE(nn.Module):
def __init__(
self,
# x
output_dim=151,
max_len=120,
# condition
cliffcam_dim=3,
cam_angvel_dim=6,
imgseq_dim=1024,
# intermediate
latent_dim=512,
num_layers=12,
num_heads=8,
mlp_ratio=4.0,
# output
pred_cam_dim=3,
static_conf_dim=6,
# training
dropout=0.1,
# other
avgbeta=True,
):
super().__init__()
# input
self.output_dim = output_dim
self.max_len = max_len
# condition
self.cliffcam_dim = cliffcam_dim
self.cam_angvel_dim = cam_angvel_dim
self.imgseq_dim = imgseq_dim
# intermediate
self.latent_dim = latent_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.dropout = dropout
# ===== build model ===== #
# Input (Kp2d)
# Main token: map d_obs 2 to 32
self.learned_pos_linear = nn.Linear(2, 32)
self.learned_pos_params = nn.Parameter(torch.randn(17, 32), requires_grad=True)
self.embed_noisyobs = Mlp(
17 * 32, hidden_features=self.latent_dim * 2, out_features=self.latent_dim, drop=dropout
)
self._build_condition_embedder()
# Transformer
self.blocks = nn.ModuleList(
[
EncoderRoPEBlock(self.latent_dim, self.num_heads, mlp_ratio=mlp_ratio, dropout=dropout)
for _ in range(self.num_layers)
]
)
# Output heads
self.final_layer = Mlp(self.latent_dim, out_features=self.output_dim)
self.pred_cam_head = pred_cam_dim > 0 # keep extra_output for easy-loading old ckpt
if self.pred_cam_head:
self.pred_cam_head = Mlp(self.latent_dim, out_features=pred_cam_dim)
self.register_buffer("pred_cam_mean", torch.tensor([1.0606, -0.0027, 0.2702]), False)
self.register_buffer("pred_cam_std", torch.tensor([0.1784, 0.0956, 0.0764]), False)
self.static_conf_head = static_conf_dim > 0
if self.static_conf_head:
self.static_conf_head = Mlp(self.latent_dim, out_features=static_conf_dim)
self.avgbeta = avgbeta
def _build_condition_embedder(self):
latent_dim = self.latent_dim
dropout = self.dropout
self.cliffcam_embedder = nn.Sequential(
nn.Linear(self.cliffcam_dim, latent_dim),
nn.SiLU(),
nn.Dropout(dropout),
zero_module(nn.Linear(latent_dim, latent_dim)),
)
if self.cam_angvel_dim > 0:
self.cam_angvel_embedder = nn.Sequential(
nn.Linear(self.cam_angvel_dim, latent_dim),
nn.SiLU(),
nn.Dropout(dropout),
zero_module(nn.Linear(latent_dim, latent_dim)),
)
if self.imgseq_dim > 0:
self.imgseq_embedder = nn.Sequential(
nn.LayerNorm(self.imgseq_dim),
zero_module(nn.Linear(self.imgseq_dim, latent_dim)),
)
def forward(self, length, obs=None, f_cliffcam=None, f_cam_angvel=None, f_imgseq=None):
"""
Args:
x: None we do not use it
timesteps: (B,)
length: (B), valid length of x, if None then use x.shape[2]
f_imgseq: (B, L, C)
f_cliffcam: (B, L, 3), CLIFF-Cam parameters (bbx-detection in the full-image)
f_noisyobs: (B, L, C), nosiy pose observation
f_cam_angvel: (B, L, 6), Camera angular velocity
"""
B, L, J, C = obs.shape
assert J == 17 and C == 3
# Main token from observation (2D pose)
obs = obs.clone()
visible_mask = obs[..., [2]] > 0.5 # (B, L, J, 1)
obs[~visible_mask[..., 0]] = 0 # set low-conf to all zeros
f_obs = self.learned_pos_linear(obs[..., :2]) # (B, L, J, 32)
f_obs = f_obs * visible_mask + self.learned_pos_params.repeat(B, L, 1, 1) * ~visible_mask
x = self.embed_noisyobs(f_obs.view(B, L, -1)) # (B, L, J*32) -> (B, L, C)
# Condition
f_to_add = []
f_to_add.append(self.cliffcam_embedder(f_cliffcam))
if hasattr(self, "cam_angvel_embedder"):
f_to_add.append(self.cam_angvel_embedder(f_cam_angvel))
if f_imgseq is not None and hasattr(self, "imgseq_embedder"):
f_to_add.append(self.imgseq_embedder(f_imgseq))
for f_delta in f_to_add:
x = x + f_delta
# Setup length and make padding mask
assert B == length.size(0)
pmask = ~length_to_mask(length, L) # (B, L)
if L > self.max_len:
attnmask = torch.ones((L, L), device=x.device, dtype=torch.bool)
for i in range(L):
min_ind = max(0, i - self.max_len // 2)
max_ind = min(L, i + self.max_len // 2)
max_ind = max(self.max_len, max_ind)
min_ind = min(L - self.max_len, min_ind)
attnmask[i, min_ind:max_ind] = False
else:
attnmask = None
# Transformer
for block in self.blocks:
x = block(x, attn_mask=attnmask, tgt_key_padding_mask=pmask)
# Output
sample = self.final_layer(x) # (B, L, C)
if self.avgbeta:
betas = (sample[..., 126:136] * (~pmask[..., None])).sum(1) / length[:, None] # (B, C)
betas = repeat(betas, "b c -> b l c", l=L)
sample = torch.cat([sample[..., :126], betas, sample[..., 136:]], dim=-1)
# Output (extra)
pred_cam = None
if self.pred_cam_head:
pred_cam = self.pred_cam_head(x)
pred_cam = pred_cam * self.pred_cam_std + self.pred_cam_mean
torch.clamp_min_(pred_cam[..., 0], 0.25) # min_clamp s to 0.25 (prevent negative prediction)
static_conf_logits = None
if self.static_conf_head:
static_conf_logits = self.static_conf_head(x) # (B, L, C')
output = {
"pred_context": x,
"pred_x": sample,
"pred_cam": pred_cam,
"static_conf_logits": static_conf_logits,
}
return output
# Add to MainStore
group_name = "network/gvhmr"
MainStore.store(
name="relative_transformer",
node=builds(NetworkEncoderRoPE, populate_full_signature=True),
group=group_name,
)
================================================
FILE: eval/GVHMR/hmr4d/network/hmr2/__init__.py
================================================
import torch
from .hmr2 import HMR2
from pathlib import Path
from .configs import get_config
from hmr4d import PROJ_ROOT
HMR2A_CKPT = PROJ_ROOT / f"inputs/checkpoints/hmr2/epoch=10-step=25000.ckpt" # this is HMR2.0a, follow WHAM
def load_hmr2(checkpoint_path=HMR2A_CKPT):
model_cfg = str((Path(__file__).parent / "configs/model_config.yaml").resolve())
model_cfg = get_config(model_cfg)
# Override some config values, to crop bbox correctly
if (model_cfg.MODEL.BACKBONE.TYPE == "vit") and ("BBOX_SHAPE" not in model_cfg.MODEL):
model_cfg.defrost()
assert (
model_cfg.MODEL.IMAGE_SIZE == 256
), f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone"
model_cfg.MODEL.BBOX_SHAPE = [192, 256] # (W, H)
model_cfg.freeze()
# Setup model and Load weights.
# model = HMR2.load_from_checkpoint(checkpoint_path, strict=False, cfg=model_cfg)
model = HMR2(model_cfg)
state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
keys = [k for k in state_dict.keys() if k.split(".")[0] in ["backbone", "smpl_head"]]
state_dict = {k: v for k, v in state_dict.items() if k in keys}
model.load_state_dict(state_dict, strict=True)
return model
================================================
FILE: eval/GVHMR/hmr4d/network/hmr2/components/__init__.py
================================================
================================================
FILE: eval/GVHMR/hmr4d/network/hmr2/components/pose_transformer.py
================================================
from inspect import isfunction
from typing import Callable, Optional
import torch
from einops import rearrange
from einops.layers.torch import Rearrange
from torch import nn
from .t_cond_mlp import (
AdaptiveLayerNorm1D,
FrequencyEmbedder,
normalization_layer,
)
# from .vit import Attention, FeedForward
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
class PreNorm(nn.Module):
def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
super().__init__()
self.norm = normalization_layer(norm, dim, norm_cond_dim)
self.fn = fn
def forward(self, x: torch.Tensor, *args, **kwargs):
if isinstance(self.norm, AdaptiveLayerNorm1D):
return self.fn(self.norm(x, *args), **kwargs)
else:
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head**-0.5
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
if project_out
else nn.Identity()
)
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class CrossAttention(nn.Module):
def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head**-0.5
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
context_dim = default(context_dim, dim)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
if project_out
else nn.Identity()
)
def forward(self, x, context=None):
context = default(context, x)
k, v = self.to_kv(context).chunk(2, dim=-1)
q = self.to_q(x)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class Transformer(nn.Module):
def __init__(
self,
dim: int,
depth: int,
heads: int,
dim_head: int,
mlp_dim: int,
dropout: float = 0.0,
norm: str = "layer",
norm_cond_dim: int = -1,
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
ff = FeedForward(dim, mlp_dim, dropout=dropout)
self.layers.append(
nn.ModuleList(
[
PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
]
)
)
def forward(self, x: torch.Tensor, *args):
for attn, ff in self.layers:
x = attn(x, *args) + x
x = ff(x, *args) + x
return x
class TransformerCrossAttn(nn.Module):
def __init__(
self,
dim: int,
depth: int,
heads: int,
dim_head: int,
mlp_dim: int,
dropout: float = 0.0,
norm: str = "layer",
norm_cond_dim: int = -1,
context_dim: Optional[int] = None,
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
ca = CrossAttention(
dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
)
ff = FeedForward(dim, mlp_dim, dropout=dropout)
self.layers.append(
nn.ModuleList(
[
PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
]
)
)
def forward(self, x: torch.Tensor, *args, context=None, context_list=None):
if context_list is None:
context_list = [context] * len(self.layers)
if len(context_list) != len(self.layers):
raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
x = self_attn(x, *args) + x
x = cross_attn(x, *args, context=context_list[i]) + x
x = ff(x, *args) + x
return x
class DropTokenDropout(nn.Module):
def __init__(self, p: float = 0.1):
super().__init__()
if p < 0 or p > 1:
raise ValueError(
"dropout probability has to be between 0 and 1, " "but got {}".format(p)
)
self.p = p
def forward(self, x: torch.Tensor):
# x: (batch_size, seq_len, dim)
if self.training and self.p > 0:
zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
# TODO: permutation idx for each batch using torch.argsort
if zero_mask.any():
x = x[:, ~zero_mask, :]
return x
class ZeroTokenDropout(nn.Module):
def __init__(self, p: float = 0.1):
super().__init__()
if p < 0 or p > 1:
raise ValueError(
"dropout probability has to be between 0 and 1, " "but got {}".format(p)
)
self.p = p
def forward(self, x: torch.Tensor):
# x: (batch_size, seq_len, dim)
if self.training and self.p > 0:
zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
# Zero-out the masked tokens
x[zero_mask, :] = 0
return x
class TransformerEncoder(nn.Module):
def __init__(
self,
num_tokens: int,
token_dim: int,
dim: int,
depth: int,
heads: int,
mlp_dim: int,
dim_head: int = 64,
dropout: float = 0.0,
emb_dropout: float = 0.0,
emb_dropout_type: str = "drop",
emb_dropout_loc: str = "token",
norm: str = "layer",
norm_cond_dim: int = -1,
token_pe_numfreq: int = -1,
):
super().__init__()
if token_pe_numfreq > 0:
token_dim_new = token_dim * (2 * token_pe_numfreq + 1)
self.to_token_embedding = nn.Sequential(
Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim),
FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1),
Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new),
nn.Linear(token_dim_new, dim),
)
else:
self.to_token_embedding = nn.Linear(token_dim, dim)
self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
if emb_dropout_type == "drop":
self.dropout = DropTokenDropout(emb_dropout)
elif emb_dropout_type == "zero":
self.dropout = ZeroTokenDropout(emb_dropout)
else:
raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}")
self.emb_dropout_loc = emb_dropout_loc
self.transformer = Transformer(
dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim
)
def forward(self, inp: torch.Tensor, *args, **kwargs):
x = inp
if self.emb_dropout_loc == "input":
x = self.dropout(x)
x = self.to_token_embedding(x)
if self.emb_dropout_loc == "token":
x = self.dropout(x)
b, n, _ = x.shape
x += self.pos_embedding[:, :n]
if self.emb_dropout_loc == "token_afterpos":
x = self.dropout(x)
x = self.transformer(x, *args)
return x
class TransformerDecoder(nn.Module):
def __init__(
self,
num_tokens: int,
token_dim: int,
dim: int,
depth: int,
heads: int,
mlp_dim: int,
dim_head: int = 64,
dropout: float = 0.0,
emb_dropout: float = 0.0,
emb_dropout_type: str = 'drop',
norm: str = "layer",
norm_cond_dim: int = -1,
context_dim: Optional[int] = None,
skip_token_embedding: bool = False,
):
super().__init__()
if not skip_token_embedding:
self.to_token_embedding = nn.Linear(token_dim, dim)
else:
self.to_token_embedding = nn.Identity()
if token_dim != dim:
raise ValueError(
f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
if emb_dropout_type == "drop":
self.dropout = DropTokenDropout(emb_dropout)
elif emb_dropout_type == "zero":
self.dropout = ZeroTokenDropout(emb_dropout)
elif emb_dropout_type == "normal":
self.dropout = nn.Dropout(emb_dropout)
self.transformer = TransformerCrossAttn(
dim,
depth,
heads,
dim_head,
mlp_dim,
dropout,
norm=norm,
norm_cond_dim=norm_cond_dim,
context_dim=context_dim,
)
def forward(self, inp: torch.Tensor, *args, context=None, context_list=None):
x = self.to_token_embedding(inp)
b, n, _ = x.shape
x = self.dropout(x)
x += self.pos_embedding[:, :n]
x = self.transformer(x, *args, context=context, context_list=context_list)
return x
================================================
FILE: eval/GVHMR/hmr4d/network/hmr2/components/t_cond_mlp.py
================================================
import copy
from typing import List, Optional
import torch
class AdaptiveLayerNorm1D(torch.nn.Module):
def __init__(self, data_dim: int, norm_cond_dim: int):
super().__init__()
if data_dim <= 0:
raise ValueError(f"data_dim must be positive, but got {data_dim}")
if norm_cond_dim <= 0:
raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}")
self.norm = torch.nn.LayerNorm(
data_dim
) # TODO: Check if elementwise_affine=True is correct
self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim)
torch.nn.init.zeros_(self.linear.weight)
torch.nn.init.zeros_(self.linear.bias)
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
# x: (batch, ..., data_dim)
# t: (batch, norm_cond_dim)
# return: (batch, data_dim)
x = self.norm(x)
alpha, beta = self.linear(t).chunk(2, dim=-1)
# Add singleton dimensions to alpha and beta
if x.dim() > 2:
alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1])
beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1])
return x * (1 + alpha) + beta
class SequentialCond(torch.nn.Sequential):
def forward(self, input, *args, **kwargs):
for module in self:
if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)):
# print(f'Passing on args to {module}', [a.shape for a in args])
input = module(input, *args, **kwargs)
else:
# print(f'Skipping passing args to {module}', [a.shape for a in args])
input = module(input)
return input
def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1):
if norm == "batch":
return torch.nn.BatchNorm1d(dim)
elif norm == "layer":
return torch.nn.LayerNorm(dim)
elif norm == "ada":
assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}"
return AdaptiveLayerNorm1D(dim, norm_cond_dim)
elif norm is None:
return torch.nn.Identity()
else:
raise ValueError(f"Unknown norm: {norm}")
def linear_norm_activ_dropout(
input_dim: int,
output_dim: int,
activation: torch.nn.Module = torch.nn.ReLU(),
bias: bool = True,
norm: Optional[str] = "layer", # Options: ada/batch/layer
dropout: float = 0.0,
norm_cond_dim: int = -1,
) -> SequentialCond:
layers = []
layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias))
if norm is not None:
layers.append(normalization_layer(norm, output_dim, norm_cond_dim))
layers.append(copy.deepcopy(activation))
if dropout > 0.0:
layers.append(torch.nn.Dropout(dropout))
return SequentialCond(*layers)
def create_simple_mlp(
input_dim: int,
hidden_dims: List[int],
output_dim: int,
activation: torch.nn.Module = torch.nn.ReLU(),
bias: bool = True,
norm: Optional[str] = "layer", # Options: ada/batch/layer
dropout: float = 0.0,
norm_cond_dim: int = -1,
) -> SequentialCond:
layers = []
prev_dim = input_dim
for hidden_dim in hidden_dims:
layers.extend(
linear_norm_activ_dropout(
prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
)
)
prev_dim = hidden_dim
layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias))
return SequentialCond(*layers)
class ResidualMLPBlock(torch.nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
num_hidden_layers: int,
output_dim: int,
activation: torch.nn.Module = torch.nn.ReLU(),
bias: bool = True,
norm: Optional[str] = "layer", # Options: ada/batch/layer
dropout: float = 0.0,
norm_cond_dim: int = -1,
):
super().__init__()
if not (input_dim == output_dim == hidden_dim):
raise NotImplementedError(
f"input_dim {input_dim} != output_dim {output_dim} is not implemented"
)
layers = []
prev_dim = input_dim
for i in range(num_hidden_layers):
layers.append(
linear_norm_activ_dropout(
prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
)
)
prev_dim = hidden_dim
self.model = SequentialCond(*layers)
self.skip = torch.nn.Identity()
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
return x + self.model(x, *args, **kwargs)
class ResidualMLP(torch.nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
num_hidden_layers: int,
output_dim: int,
activation: torch.nn.Module = torch.nn.ReLU(),
bias: bool = True,
norm: Optional[str] = "layer", # Options: ada/batch/layer
dropout: float = 0.0,
num_blocks: int = 1,
norm_cond_dim: int = -1,
):
super().__init__()
self.input_dim = input_dim
self.model = SequentialCond(
linear_norm_activ_dropout(
input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
),
*[
ResidualMLPBlock(
hidden_dim,
hidden_dim,
num_hidden_layers,
hidden_dim,
activation,
bias,
norm,
dropout,
norm_cond_dim,
)
for _ in range(num_blocks)
],
torch.nn.Linear(hidden_dim, output_dim, bias=bias),
)
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
return self.model(x, *args, **kwargs)
class FrequencyEmbedder(torch.nn.Module):
def __init__(self, num_frequencies, max_freq_log2):
super().__init__()
frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies)
self.register_buffer("frequencies", frequencies)
def forward(self, x):
# x should be of size (N,) or (N, D)
N = x.size(0)
if x.dim() == 1: # (N,)
x = x.unsqueeze(1) # (N, D) where D=1
x_unsqueezed = x.unsqueeze(-1) # (N, D, 1)
scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies)
s = torch.sin(scaled)
c = torch.cos(scaled)
embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view(
N, -1
) # (N, D * 2 * num_frequencies + D)
return embedded
================================================
FILE: eval/GVHMR/hmr4d/network/hmr2/configs/__init__.py
================================================
import os
from typing import Dict
from yacs.config import CfgNode as CN
from pathlib import Path
# CACHE_DIR = os.path.join(os.environ.get("HOME"), "Code/4D-Humans/cache")
# CACHE_DIR_4DHUMANS = os.path.join(CACHE_DIR, "4DHumans")
def to_lower(x: Dict) -> Dict:
"""
Convert all dictionary keys to lowercase
Args:
x (dict): Input dictionary
Returns:
dict: Output dictionary with all keys converted to lowercase
"""
return {k.lower(): v for k, v in x.items()}
_C = CN(new_allowed=True)
_C.GENERAL = CN(new_allowed=True)
_C.GENERAL.RESUME = True
_C.GENERAL.TIME_TO_RUN = 3300
_C.GENERAL.VAL_STEPS = 100
_C.GENERAL.LOG_STEPS = 100
_C.GENERAL.CHECKPOINT_STEPS = 20000
_C.GENERAL.CHECKPOINT_DIR = "checkpoints"
_C.GENERAL.SUMMARY_DIR = "tensorboard"
_C.GENERAL.NUM_GPUS = 1
_C.GENERAL.NUM_WORKERS = 4
_C.GENERAL.MIXED_PRECISION = True
_C.GENERAL.ALLOW_CUDA = True
_C.GENERAL.PIN_MEMORY = False
_C.GENERAL.DISTRIBUTED = False
_C.GENERAL.LOCAL_RANK = 0
_C.GENERAL.USE_SYNCBN = False
_C.GENERAL.WORLD_SIZE = 1
_C.TRAIN = CN(new_allowed=True)
_C.TRAIN.NUM_EPOCHS = 100
_C.TRAIN.BATCH_SIZE = 32
_C.TRAIN.SHUFFLE = True
_C.TRAIN.WARMUP = False
_C.TRAIN.NORMALIZE_PER_IMAGE = False
_C.TRAIN.CLIP_GRAD = False
_C.TRAIN.CLIP_GRAD_VALUE = 1.0
_C.LOSS_WEIGHTS = CN(new_allowed=True)
_C.DATASETS = CN(new_allowed=True)
_C.MODEL = CN(new_allowed=True)
_C.MODEL.IMAGE_SIZE = 224
_C.EXTRA = CN(new_allowed=True)
_C.EXTRA.FOCAL_LENGTH = 5000
_C.DATASETS.CONFIG = CN(new_allowed=True)
_C.DATASETS.CONFIG.SCALE_FACTOR = 0.3
_C.DATASETS.CONFIG.ROT_FACTOR = 30
_C.DATASETS.CONFIG.TRANS_FACTOR = 0.02
_C.DATASETS.CONFIG.COLOR_SCALE = 0.2
_C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6
_C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5
_C.DATASETS.CONFIG.DO_FLIP = True
_C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5
_C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10
def default_config() -> CN:
"""
Get a yacs CfgNode object with the default config values.
"""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
return _C.clone()
def dataset_config(name="datasets_tar.yaml") -> CN:
"""
Get dataset config file
Returns:
CfgNode: Dataset config as a yacs CfgNode object.
"""
cfg = CN(new_allowed=True)
config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), name)
cfg.merge_from_file(config_file)
cfg.freeze()
return cfg
def dataset_eval_config() -> CN:
return dataset_config("datasets_eval.yaml")
def get_config(config_file: str, merge: bool = True) -> CN:
"""
Read a config file and optionally merge it with the default config file.
Args:
config_file (str): Path to config file.
merge (bool): Whether to merge with the default config or not.
Returns:
CfgNode: Config as a yacs CfgNode object.
"""
if merge:
cfg = default_config()
else:
cfg = CN(new_allowed=True)
cfg.merge_from_file(config_file)
# ---- Update ---- #
cfg.SMPL.MODEL_PATH = cfg.SMPL.MODEL_PATH # Not used
cfg.SMPL.JOINT_REGRESSOR_EXTRA = cfg.SMPL.JOINT_REGRESSOR_EXTRA # Not Used
cfg.SMPL.MEAN_PARAMS = str(Path(__file__).parent / "smpl_mean_params.npz")
# ---------------- #
cfg.freeze()
return cfg
================================================
FILE: eval/GVHMR/hmr4d/network/hmr2/configs/model_config.yaml
================================================
task_name: train
tags:
- dev
train: true
test: false
ckpt_path: null
seed: null
DATASETS:
TRAIN:
H36M-TRAIN:
WEIGHT: 0.3
MPII-TRAIN:
WEIGHT: 0.1
COCO-TRAIN-2014:
WEIGHT: 0.4
MPI-INF-TRAIN:
WEIGHT: 0.2
VAL:
COCO-VAL:
WEIGHT: 1.0
MOCAP: CMU-MOCAP
CONFIG:
SCALE_FACTOR: 0.3
ROT_FACTOR: 30
TRANS_FACTOR: 0.02
COLOR_SCALE: 0.2
ROT_AUG_RATE: 0.6
TRANS_AUG_RATE: 0.5
DO_FLIP: true
FLIP_AUG_RATE: 0.5
EXTREME_CROP_AUG_RATE: 0.1
trainer:
_target_: pytorch_lightning.Trainer
default_root_dir: ${paths.output_dir}
accelerator: gpu
devices: 8
deterministic: false
num_sanity_val_steps: 0
log_every_n_steps: ${GENERAL.LOG_STEPS}
val_check_interval: ${GENERAL.VAL_STEPS}
precision: 16
max_steps: ${GENERAL.TOTAL_STEPS}
move_metrics_to_cpu: true
limit_val_batches: 1
track_grad_norm: 2
strategy: ddp
num_nodes: 1
sync_batchnorm: true
paths:
root_dir: ${oc.env:PROJECT_ROOT}
data_dir: ${paths.root_dir}/data/
log_dir: /fsx/shubham/code/hmr2023/logs_hydra/
output_dir: ${hydra:runtime.output_dir}
work_dir: ${hydra:runtime.cwd}
extras:
ignore_warnings: false
enforce_tags: true
print_config: true
exp_name: 3001d
SMPL:
MODEL_PATH: data/smpl
GENDER: neutral
NUM_BODY_JOINTS: 23
JOINT_REGRESSOR_EXTRA: data/SMPL_to_J19.pkl
MEAN_PARAMS: data/smpl_mean_params.npz
EXTRA:
FOCAL_LENGTH: 5000
NUM_LOG_IMAGES: 4
NUM_LOG_SAMPLES_PER_IMAGE: 8
PELVIS_IND: 39
MODEL:
IMAGE_SIZE: 256
IMAGE_MEAN:
- 0.485
- 0.456
- 0.406
IMAGE_STD:
- 0.229
- 0.224
- 0.225
BACKBONE:
TYPE: vit
FREEZE: true
NUM_LAYERS: 50
OUT_CHANNELS: 2048
ADD_NECK: false
FLOW:
DIM: 144
NUM_LAYERS: 4
CONTEXT_FEATURES: 2048
LAYER_HIDDEN_FEATURES: 1024
LAYER_DEPTH: 2
FC_HEAD:
NUM_FEATURES: 1024
SMPL_HEAD:
TYPE: transformer_decoder
IN_CHANNELS: 2048
TRANSFORMER_DECODER:
depth: 6
heads: 8
mlp_dim: 1024
dim_head: 64
dropout: 0.0
emb_dropout: 0.0
norm: layer
context_dim: 1280
GENERAL:
TOTAL_STEPS: 100000
LOG_STEPS: 100
VAL_STEPS: 100
CHECKPOINT_STEPS: 1000
CHECKPOINT_SAVE_TOP_K: -1
NUM_WORKERS: 6
PREFETCH_FACTOR: 2
TRAIN:
LR: 0.0001
WEIGHT_DECAY: 0.0001
BATCH_SIZE: 512
LOSS_REDUCTION: mean
NUM_TRAIN_SAMPLES: 2
NUM_TEST_SAMPLES: 64
POSE_2D_NOISE_RATIO: 0.01
SMPL_PARAM_NOISE_RATIO: 0.005
LOSS_WEIGHTS:
KEYPOINTS_3D: 0.05
KEYPOINTS_2D: 0.01
GLOBAL_ORIENT: 0.001
BODY_POSE: 0.001
BETAS: 0.0005
ADVERSARIAL: 0.0005
local: {}
================================================
FILE: eval/GVHMR/hmr4d/network/hmr2/hmr2.py
================================================
import torch
import pytorch_lightning as pl
from yacs.config import CfgNode
from .vit import ViT
from .smpl_head import SMPLTransformerDecoderHead
from pytorch3d.transforms import matrix_to_axis_angle
from hmr4d.utils.geo.hmr_cam import compute_transl_full_cam
class HMR2(pl.LightningModule):
def __init__(self, cfg: CfgNode):
super().__init__()
self.cfg = cfg
self.backbone = ViT(
img_size=(256, 192),
patch_size=16,
embed_dim=1280,
depth=32,
num_heads=16,
ratio=1,
use_checkpoint=False,
mlp_ratio=4,
qkv_bias=True,
drop_path_rate=0.55,
)
self.smpl_head = SMPLTransformerDecoderHead(cfg)
def forward(self, batch, feat_mode=True):
"""this file has been modified
Args:
feat_mode: default True, as we only need the feature token output for the HMR4D project;
when False, the full process of HMR2 will be executed.
"""
# Backbone
x = batch["img"][:, :, :, 32:-32]
vit_feats = self.backbone(x)
# Output head
if feat_mode:
token_out = self.smpl_head(vit_feats, only_return_token_out=True) # (B, 1024)
return token_out
# return full process
pred_smpl_params, pred_cam, _, token_out = self.smpl_head(vit_feats, only_return_token_out=False)
output = {}
output["token_out"] = token_out
output["smpl_params"] = {
"body_pose": matrix_to_axis_angle(pred_smpl_params["body_pose"]).flatten(-2), # (B, 23, 3)
"betas": pred_smpl_params["betas"], # (B, 10)
"global_orient": matrix_to_axis_angle(pred_smpl_params["global_orient"])[:, 0], # (B, 3)
"transl": compute_transl_full_cam(pred_cam, batch["bbx_xys"], batch["K_fullimg"]), # (B, 3)
}
return output
================================================
FILE: eval/GVHMR/hmr4d/network/hmr2/smpl_head.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import einops
from .utils.geometry import rot6d_to_rotmat, aa_to_rotmat
from .components.pose_transformer import TransformerDecoder
class SMPLTransformerDecoderHead(nn.Module):
"""Cross-attention based SMPL Transformer decoder"""
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.joint_rep_type = cfg.MODEL.SMPL_HEAD.get("JOINT_REP", "6d")
self.joint_rep_dim = {"6d": 6, "aa": 3}[self.joint_rep_type]
npose = self.joint_rep_dim * (cfg.SMPL.NUM_BODY_JOINTS + 1)
self.npose = npose
self.input_is_mean_shape = cfg.MODEL.SMPL_HEAD.get("TRANSFORMER_INPUT", "zero") == "mean_shape"
transformer_args = dict(
num_tokens=1,
token_dim=(npose + 10 + 3) if self.input_is_mean_shape else 1,
dim=1024,
)
transformer_args.update(**dict(cfg.MODEL.SMPL_HEAD.TRANSFORMER_DECODER))
self.transformer = TransformerDecoder(**transformer_args)
dim = transformer_args["dim"]
self.decpose = nn.Linear(dim, npose)
self.decshape = nn.Linear(dim, 10)
self.deccam = nn.Linear(dim, 3)
if cfg.MODEL.SMPL_HEAD.get("INIT_DECODER_XAVIER", False):
# True by default in MLP. False by default in Transformer
nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
mean_params = np.load(cfg.SMPL.MEAN_PARAMS)
init_body_pose = torch.from_numpy(mean_params["pose"].astype(np.float32)).unsqueeze(0)
init_betas = torch.from_numpy(mean_params["shape"].astype("float32")).unsqueeze(0)
init_cam = torch.from_numpy(mean_params["cam"].astype(np.float32)).unsqueeze(0)
self.register_buffer("init_body_pose", init_body_pose)
self.register_buffer("init_betas", init_betas)
self.register_buffer("init_cam", init_cam)
def forward(self, x, only_return_token_out=False):
batch_size = x.shape[0]
# vit pretrained backbone is channel-first. Change to token-first
x = einops.rearrange(x, "b c h w -> b (h w) c")
init_body_pose = self.init_body_pose.expand(batch_size, -1)
init_betas = self.init_betas.expand(batch_size, -1)
init_cam = self.init_cam.expand(batch_size, -1)
# TODO: Convert init_body_pose to aa rep if needed
if self.joint_rep_type == "aa":
raise NotImplementedError
pred_body_pose = init_body_pose
pred_betas = init_betas
pred_cam = init_cam
pred_body_pose_list = []
pred_betas_list = []
pred_cam_list = []
for i in range(self.cfg.MODEL.SMPL_HEAD.get("IEF_ITERS", 1)):
assert i == 0, "Only support 1 iteration for now"
# Input token to transformer is zero token
if self.input_is_mean_shape:
token = torch.cat([pred_body_pose, pred_betas, pred_cam], dim=1)[:, None, :]
else:
token = torch.zeros(batch_size, 1, 1).to(x.device)
# Pass through transformer
token_out = self.transformer(token, context=x)
token_out = token_out.squeeze(1) # (B, C)
if only_return_token_out:
return token_out
else:
# Readout from token_out
pred_body_pose = self.decpose(token_out) + pred_body_pose
pred_betas = self.decshape(token_out) + pred_betas
pred_cam = self.deccam(token_out) + pred_cam
pred_body_pose_list.append(pred_body_pose)
pred_betas_list.append(pred_betas)
pred_cam_list.append(pred_cam)
# Convert self.joint_rep_type -> rotmat
joint_conversion_fn = {"6d": rot6d_to_rotmat, "aa": lambda x: aa_to_rotmat(x.view(-1, 3).contiguous())}[
self.joint_rep_type
]
pred_smpl_params_list = {}
pred_smpl_params_list["body_pose"] = torch.cat(
[joint_conversion_fn(pbp).view(batch_size, -1, 3, 3)[:, 1:, :, :] for pbp in pred_body_pose_list], dim=0
)
pred_smpl_params_list["betas"] = torch.cat(pred_betas_list, dim=0)
pred_smpl_params_list["cam"] = torch.cat(pred_cam_list, dim=0)
pred_body_pose = joint_conversion_fn(pred_body_pose).view(batch_size, self.cfg.SMPL.NUM_BODY_JOINTS + 1, 3, 3)
pred_smpl_params = {
"global_orient": pred_body_pose[:, [0]],
"body_pose": pred_body_pose[:, 1:],
"betas": pred_betas,
}
return pred_smpl_params, pred_cam, pred_smpl_params_list, token_out
================================================
FILE: eval/GVHMR/hmr4d/network/hmr2/utils/geometry.py
================================================
from typing import Optional
import torch
from torch.nn import functional as F
def aa_to_rotmat(theta: torch.Tensor):
"""
Convert axis-angle representation to rotation matrix.
Works by first converting it to a quaternion.
Args:
theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations.
Returns:
torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
"""
norm = torch.norm(theta + 1e-8, p=2, dim=1)
angle = torch.unsqueeze(norm, -1)
normalized = torch.div(theta, angle)
angle = angle * 0.5
v_cos = torch.cos(angle)
v_sin = torch.sin(angle)
quat = torch.cat([v_cos, v_sin * normalized], dim=1)
return quat_to_rotmat(quat)
def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
"""
Convert quaternion representation to rotation matrix.
Args:
quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z).
Returns:
torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
"""
norm_quat = quat
norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
B = quat.size(0)
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
wx, wy, wz = w * x, w * y, w * z
xy, xz, yz = x * y, x * z, y * z
rotMat = torch.stack(
[
w2 + x2 - y2 - z2,
2 * xy - 2 * wz,
2 * wy + 2 * xz,
2 * wz + 2 * xy,
w2 - x2 + y2 - z2,
2 * yz - 2 * wx,
2 * xz - 2 * wy,
2 * wx + 2 * yz,
w2 - x2 - y2 + z2,
],
dim=1,
).view(B, 3, 3)
return rotMat
def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
"""
Convert 6D rotation representation to 3x3 rotation matrix.
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
Args:
x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
Returns:
torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
"""
x = x.reshape(-1, 2, 3).permute(0, 2, 1).contiguous()
a1 = x[:, :, 0]
a2 = x[:, :, 1]
b1 = F.normalize(a1)
b2 = F.normalize(a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1)
b3 = torch.cross(b1, b2)
return torch.stack((b1, b2, b3), dim=-1)
def perspective_projection(
points: torch.Tensor,
translation: torch.Tensor,
focal_length: torch.Tensor,
camera_center: Optional[torch.Tensor] = None,
rotation: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Computes the perspective projection of a set of 3D points.
Args:
points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels.
camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels.
rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
Returns:
torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points.
"""
batch_size = points.shape[0]
if rotation is None:
rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1)
if camera_center is None:
camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype)
# Populate intrinsic camera matrix K.
K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype)
K[:, 0, 0] = focal_length[:, 0]
K[:, 1, 1] = focal_length[:, 1]
K[:, 2, 2] = 1.0
K[:, :-1, -1] = camera_center
# Transform points
points = torch.einsum("bij,bkj->bki", rotation, points)
points = points + translation.unsqueeze(1)
# Apply perspective distortion
projected_points = points / points[:, :, -1].unsqueeze(-1)
# Apply camera intrinsics
projected_points = torch.einsum("bij,bkj->bki", K, projected_points)
return projected_points[:, :, :-1]
================================================
FILE: eval/GVHMR/hmr4d/network/hmr2/utils/preproc.py
================================================
import cv2
import numpy as np
import torch
from pathlib import Path
IMAGE_MEAN = torch.tensor([0.485, 0.456, 0.406])
IMAGE_STD = torch.tensor([0.229, 0.224, 0.225])
def expand_to_aspect_ratio(input_shape, target_aspect_ratio=[192, 256]):
"""Increase the size of the bounding box to match the target shape."""
if target_aspect_ratio is None:
return input_shape
try:
w, h = input_shape
except (ValueError, TypeError):
return input_shape
w_t, h_t = target_aspect_ratio
if h / w < h_t / w_t:
h_new = max(w * h_t / w_t, h)
w_new = w
else:
h_new = h
w_new = max(h * w_t / h_t, w)
if h_new < h or w_new < w:
breakpoint()
return np.array([w_new, h_new])
def crop_and_resize(img, bbx_xy, bbx_s, dst_size=256, enlarge_ratio=1.2):
"""
Args:
img: (H, W, 3)
bbx_xy: (2,)
bbx_s: scalar
"""
hs = bbx_s * enlarge_ratio / 2
src = np.stack(
[
bbx_xy - hs, # left-up corner
bbx_xy + np.array([hs, -hs]), # right-up corner
bbx_xy, # center
]
).astype(np.float32)
dst = np.array([[0, 0], [dst_size - 1, 0], [dst_size / 2 - 0.5, dst_size / 2 - 0.5]], dtype=np.float32)
A = cv2.getAffineTransform(src, dst)
img_crop = cv2.warpAffine(img, A, (dst_size, dst_size), flags=cv2.INTER_LINEAR)
bbx_xys_final = np.array([*bbx_xy, bbx_s * enlarge_ratio])
return img_crop, bbx_xys_final
================================================
FILE: eval/GVHMR/hmr4d/network/hmr2/utils/smpl_wrapper.py
================================================
import torch
import numpy as np
import pickle
from typing import Optional
import smplx
from smplx.lbs import vertices2joints
from smplx.utils import SMPLOutput
class SMPL(smplx.SMPLLayer):
def __init__(self, *args, joint_regressor_extra: Optional[str] = None, update_hips: bool = False, **kwargs):
"""
Extension of the official SMPL implementation to support more joints.
Args:
Same as SMPLLayer.
joint_regressor_extra (str): Path to extra joint regressor.
"""
super(SMPL, self).__init__(*args, **kwargs)
smpl_to_openpose = [24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]
if joint_regressor_extra is not None:
self.register_buffer(
"joint_regressor_extra",
torch.tensor(pickle.load(open(joint_regressor_extra, "rb"), encoding="latin1"), dtype=torch.float32),
)
self.register_buffer("joint_map", torch.tensor(smpl_to_openpose, dtype=torch.long))
self.update_hips = update_hips
def forward(self, *args, **kwargs) -> SMPLOutput:
"""
Run forward pass. Same as SMPL and also append an extra set of joints if joint_regressor_extra is specified.
"""
smpl_output = super(SMPL, self).forward(*args, **kwargs)
joints = smpl_output.joints[:, self.joint_map, :]
if self.update_hips:
joints[:, [9, 12]] = (
joints[:, [9, 12]]
+ 0.25 * (joints[:, [9, 12]] - joints[:, [12, 9]])
+ 0.5 * (joints[:, [8]] - 0.5 * (joints[:, [9, 12]] + joints[:, [12, 9]]))
)
if hasattr(self, "joint_regressor_extra"):
extra_joints = vertices2joints(self.joint_regressor_extra, smpl_output.vertices)
joints = torch.cat([joints, extra_joints], dim=1)
smpl_output.joints = joints
return smpl_output
================================================
FILE: eval/GVHMR/hmr4d/network/hmr2/vit.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
from functools import partial
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
def vit(cfg):
return ViT(
img_size=(256, 192),
patch_size=16,
embed_dim=1280,
depth=32,
num_heads=16,
ratio=1,
use_checkpoint=False,
mlp_ratio=4,
qkv_bias=True,
drop_path_rate=0.55,
)
def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
"""
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
dimension for the original embeddings.
Args:
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
hw (Tuple): size of input image tokens.
Returns:
Absolute positional embeddings after processing with shape (1, H, W, C)
"""
cls_token = None
B, L, C = abs_pos.shape
if has_cls_token:
cls_token = abs_pos[:, 0:1]
abs_pos = abs_pos[:, 1:]
if ori_h != h or ori_w != w:
new_abs_pos = F.interpolate(
abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
size=(h, w),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).reshape(B, -1, C)
else:
new_abs_pos = abs_pos
if cls_token is not None:
new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
return new_abs_pos
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self):
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., attn_head_dim=None,):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.dim = dim
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
norm_layer=nn.LayerNorm, attn_head_dim=None
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1))
def forward(self, x, **kwargs):
B, C, H, W = x.shape
x = self.proj(x)
Hp, Wp = x.shape[2], x.shape[3]
x = x.flatten(2).transpose(1, 2)
return x, (Hp, Wp)
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
feature_dim = self.backbone.feature_info.channels()[-1]
self.num_patches = feature_size[0] * feature_size[1]
self.proj = nn.Linear(feature_dim, embed_dim)
def forward(self, x):
x = self.backbone(x)[-1]
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class ViT(nn.Module):
def __init__(self,
img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
frozen_stages=-1, ratio=1, last_norm=True,
patch_padding='pad', freeze_attn=False, freeze_ffn=False,
):
# Protect mutable default arguments
super(ViT, self).__init__()
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.frozen_stages = frozen_stages
self.use_checkpoint = use_checkpoint
self.patch_padding = patch_padding
self.freeze_attn = freeze_attn
self.freeze_ffn = freeze_ffn
self.depth = depth
if hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
num_patches = self.patch_embed.num_patches
# since the pretraining model has class token
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
)
for i in range(depth)])
self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
self._freeze_stages()
def _freeze_stages(self):
"""Freeze parameters."""
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
m = self.blocks[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
if self.freeze_attn:
for i in range(0, self.depth):
m = self.blocks[i]
m.attn.eval()
m.norm1.eval()
for param in m.attn.parameters():
param.requires_grad = False
for param in m.norm1.parameters():
param.requires_grad = False
if self.freeze_ffn:
self.pos_embed.requires_grad = False
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
for i in range(0, self.depth):
m = self.blocks[i]
m.mlp.eval()
m.norm2.eval()
for param in m.mlp.parameters():
param.requires_grad = False
for param in m.norm2.parameters():
param.requires_grad = False
def init_weights(self):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
self.apply(_init_weights)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward_features(self, x):
B, C, H, W = x.shape
x, (Hp, Wp) = self.patch_embed(x)
if self.pos_embed is not None:
# fit for multiple GPU training
# since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
x = self.last_norm(x)
xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
return xp
def forward(self, x):
x = self.forward_features(x)
return x
def train(self, mode=True):
"""Convert the model into training mode."""
super().train(mode)
self._freeze_stages()
================================================
FILE: eval/GVHMR/hmr4d/utils/body_model/README.md
================================================
# README
Contents of this folder are modified from HuMoR repository.
================================================
FILE: eval/GVHMR/hmr4d/utils/body_model/__init__.py
================================================
from .body_model import BodyModel
from .body_model_smplh import BodyModelSMPLH
from .body_model_smplx import BodyModelSMPLX
================================================
FILE: eval/GVHMR/hmr4d/utils/body_model/body_model.py
================================================
from turtle import forward
import numpy as np
import torch
import torch.nn as nn
from smplx import SMPL, SMPLH, SMPLX
from smplx.vertex_ids import vertex_ids
from smplx.utils import Struct
class BodyModel(nn.Module):
"""
Wrapper around SMPLX body model class.
modified by Zehong Shen
"""
def __init__(self,
bm_path,
num_betas=16,
use_vtx_selector=False,
model_type='smplh'):
super().__init__()
'''
Creates the body model object at the given path.
:param bm_path: path to the body model pkl file
:param model_type: one of [smpl, smplh, smplx]
:param use_vtx_selector: if true, returns additional vertices as joints that correspond to OpenPose joints
'''
self.use_vtx_selector = use_vtx_selector
cur_vertex_ids = None
if self.use_vtx_selector:
cur_vertex_ids = vertex_ids[model_type]
data_struct = None
if '.npz' in bm_path:
# smplx does not support .npz by default, so have to load in manually
smpl_dict = np.load(bm_path, encoding='latin1')
data_struct = Struct(**smpl_dict)
# print(smpl_dict.files)
if model_type == 'smplh':
data_struct.hands_componentsl = np.zeros((0))
data_struct.hands_componentsr = np.zeros((0))
data_struct.hands_meanl = np.zeros((15 * 3))
data_struct.hands_meanr = np.zeros((15 * 3))
V, D, B = data_struct.shapedirs.shape
data_struct.shapedirs = np.concatenate([data_struct.shapedirs, np.zeros(
(V, D, SMPL.SHAPE_SPACE_DIM-B))], axis=-1) # super hacky way to let smplh use 16-size beta
kwargs = {
'model_type': model_type,
'data_struct': data_struct,
'num_betas': num_betas,
'vertex_ids': cur_vertex_ids,
'use_pca': False,
'flat_hand_mean': True,
# - enable variable batchsize, since we don't need module variable - #
'create_body_pose': False,
'create_betas': False,
'create_global_orient': False,
'create_transl': False,
'create_left_hand_pose': False,
'create_right_hand_pose': False,
}
assert(model_type in ['smpl', 'smplh', 'smplx'])
if model_type == 'smpl':
self.bm = SMPL(bm_path, **kwargs)
self.num_joints = SMPL.NUM_JOINTS
elif model_type == 'smplh':
self.bm = SMPLH(bm_path, **kwargs)
self.num_joints = SMPLH.NUM_JOINTS
elif model_type == 'smplx':
self.bm = SMPLX(bm_path, **kwargs)
self.num_joints = SMPLX.NUM_JOINTS
self.model_type = model_type
def forward(self, root_orient=None, pose_body=None, pose_hand=None, pose_jaw=None, pose_eye=None, betas=None,
trans=None, dmpls=None, expression=None, return_dict=False, **kwargs):
'''
Note dmpls are not supported.
'''
assert(dmpls is None)
B = pose_body.shape[0]
if pose_hand is None:
pose_hand = torch.zeros((B, 2*SMPLH.NUM_HAND_JOINTS*3), device=pose_body.device)
if len(betas.shape) == 1:
betas = betas.reshape((1, -1)).expand(B, -1)
out_obj = self.bm(
betas=betas,
global_orient=root_orient,
body_pose=pose_body,
left_hand_pose=pose_hand[:, :(SMPLH.NUM_HAND_JOINTS*3)],
right_hand_pose=pose_hand[:, (SMPLH.NUM_HAND_JOINTS*3):],
transl=trans,
expression=expression,
jaw_pose=pose_jaw,
leye_pose=None if pose_eye is None else pose_eye[:, :3],
reye_pose=None if pose_eye is None else pose_eye[:, 3:],
return_full_pose=True,
**kwargs
)
out = {
'v': out_obj.vertices,
'f': self.bm.faces_tensor,
'Jtr': out_obj.joints,
}
if not self.use_vtx_selector:
# don't need extra joints
out['Jtr'] = out['Jtr'][:, :self.num_joints+1] # add one for the root
if not return_dict:
out = Struct(**out)
return out
def forward_motion(self, **kwargs):
B, W, _ = kwargs['pose_body'].shape
kwargs = {k: v.reshape(B*W, v.shape[-1]) for k, v in kwargs.items()}
smpl_opt = self.forward(**kwargs)
smpl_opt.v = smpl_opt.v.reshape(B, W, -1, 3)
smpl_opt.Jtr = smpl_opt.Jtr.reshape(B, W, -1, 3)
return smpl_opt
================================================
FILE: eval/GVHMR/hmr4d/utils/body_model/body_model_smplh.py
================================================
import torch
import torch.nn as nn
import smplx
kwargs_disable_member_var = {
"create_body_pose": False,
"create_betas": False,
"create_global_orient": False,
"create_transl": False,
"create_left_hand_pose": False,
"create_right_hand_pose": False,
}
class BodyModelSMPLH(nn.Module):
"""Support Batch inference"""
def __init__(self, model_path, **kwargs):
super().__init__()
# enable flexible batchsize, handle missing variable at forward()
kwargs.update(kwargs_disable_member_var)
self.bm = smplx.create(model_path=model_path, **kwargs)
self.faces = self.bm.faces
self.is_smpl = kwargs.get("model_type", "smpl") == "smpl"
if not self.is_smpl:
self.hand_pose_dim = self.bm.num_pca_comps if self.bm.use_pca else 3 * self.bm.NUM_HAND_JOINTS
# For fast computing of skeleton under beta
shapedirs = self.bm.shapedirs # (V, 3, 10)
J_regressor = self.bm.J_regressor[:22, :] # (22, V)
v_template = self.bm.v_template # (V, 3)
J_template = J_regressor @ v_template # (22, 3)
J_shapedirs = torch.einsum("jv, vcd -> jcd", J_regressor, shapedirs) # (22, 3, 10)
self.register_buffer("J_template", J_template, False)
self.register_buffer("J_shapedirs", J_shapedirs, False)
def forward(
self,
betas=None,
global_orient=None,
transl=None,
body_pose=None,
left_hand_pose=None,
right_hand_pose=None,
**kwargs
):
device, dtype = self.bm.shapedirs.device, self.bm.shapedirs.dtype
model_vars = [betas, global_orient, body_pose, transl, left_hand_pose, right_hand_pose]
batch_size = 1
for var in model_vars:
if var is None:
continue
batch_size = max(batch_size, len(var))
if global_orient is None:
global_orient = torch.zeros([batch_size, 3], dtype=dtype, device=device)
if body_pose is None:
body_pose = (
torch.zeros(3 * self.bm.NUM_BODY_JOINTS, device=device, dtype=dtype)[None]
.expand(batch_size, -1)
.contiguous()
)
if not self.is_smpl:
if left_hand_pose is None:
left_hand_pose = (
torch.zeros(self.hand_pose_dim, device=device, dtype=dtype)[None]
.expand(batch_size, -1)
.contiguous()
)
if right_hand_pose is None:
right_hand_pose = (
torch.zeros(self.hand_pose_dim, device=device, dtype=dtype)[None]
.expand(batch_size, -1)
.contiguous()
)
if betas is None:
betas = torch.zeros([batch_size, self.bm.num_betas], dtype=dtype, device=device)
if transl is None:
transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)
bm_out = self.bm(
betas=betas,
global_orient=global_orient,
body_pose=body_pose,
left_hand_pose=left_hand_pose,
right_hand_pose=right_hand_pose,
transl=transl,
**kwargs
)
return bm_out
def get_skeleton(self, betas):
"""betas: (*, 10) -> skeleton_beta: (*, 22, 3)"""
skeleton_beta = self.J_template + torch.einsum("...d, jcd -> ...jc", betas, self.J_shapedirs) # (22, 3)
return skeleton_beta
================================================
FILE: eval/GVHMR/hmr4d/utils/body_model/body_model_smplx.py
================================================
import torch
import torch.nn as nn
import smplx
kwargs_disable_member_var = {
"create_body_pose": False,
"create_betas": False,
"create_global_orient": False,
"create_transl": False,
"create_left_hand_pose": False,
"create_right_hand_pose": False,
"create_expression": False,
"create_jaw_pose": False,
"create_leye_pose": False,
"create_reye_pose": False,
}
class BodyModelSMPLX(nn.Module):
"""Support Batch inference"""
def __init__(self, model_path, **kwargs):
super().__init__()
# enable flexible batchsize, handle missing variable at forward()
kwargs.update(kwargs_disable_member_var)
self.bm = smplx.create(model_path=model_path, **kwargs)
self.faces = self.bm.faces
self.hand_pose_dim = self.bm.num_pca_comps if self.bm.use_pca else 3 * self.bm.NUM_HAND_JOINTS
# For fast computing of skeleton under beta
shapedirs = self.bm.shapedirs # (V, 3, 10)
J_regressor = self.bm.J_regressor[:22, :] # (22, V)
v_template = self.bm.v_template # (V, 3)
J_template = J_regressor @ v_template # (22, 3)
J_shapedirs = torch.einsum("jv, vcd -> jcd", J_regressor, shapedirs) # (22, 3, 10)
self.register_buffer("J_template", J_template, False)
self.register_buffer("J_shapedirs", J_shapedirs, False)
def forward(
self,
betas=None,
global_orient=None,
transl=None,
body_pose=None,
left_hand_pose=None,
right_hand_pose=None,
expression=None,
jaw_pose=None,
leye_pose=None,
reye_pose=None,
**kwargs
):
device, dtype = self.bm.shapedirs.device, self.bm.shapedirs.dtype
model_vars = [
betas,
global_orient,
body_pose,
transl,
expression,
left_hand_pose,
right_hand_pose,
jaw_pose,
leye_pose,
reye_pose,
]
batch_size = 1
for var in model_vars:
if var is None:
continue
batch_size = max(batch_size, len(var))
if global_orient is None:
global_orient = torch.zeros([batch_size, 3], dtype=dtype, device=device)
if body_pose is None:
body_pose = (
torch.zeros(3 * self.bm.NUM_BODY_JOINTS, device=device, dtype=dtype)[None]
.expand(batch_size, -1)
.contiguous()
)
if left_hand_pose is None:
left_hand_pose = (
torch.zeros(self.hand_pose_dim, device=device, dtype=dtype)[None].expand(batch_size, -1).contiguous()
)
if right_hand_pose is None:
right_hand_pose = (
torch.zeros(self.hand_pose_dim, device=device, dtype=dtype)[None].expand(batch_size, -1).contiguous()
)
if jaw_pose is None:
jaw_pose = torch.zeros([batch_size, 3], dtype=dtype, device=device)
if leye_pose is None:
leye_pose = torch.zeros([batch_size, 3], dtype=dtype, device=device)
if reye_pose is None:
reye_pose = torch.zeros([batch_size, 3], dtype=dtype, device=device)
if expression is None:
expression = torch.zeros([batch_size, self.bm.num_expression_coeffs], dtype=dtype, device=device)
if betas is None:
betas = torch.zeros([batch_size, self.bm.num_betas], dtype=dtype, device=device)
if transl is None:
transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)
bm_out = self.bm(
betas=betas,
global_orient=global_orient,
body_pose=body_pose,
left_hand_pose=left_hand_pose,
right_hand_pose=right_hand_pose,
transl=transl,
expression=expression,
jaw_pose=jaw_pose,
leye_pose=leye_pose,
reye_pose=reye_pose,
**kwargs
)
return bm_out
def get_skeleton(self, betas):
"""betas: (*, 10) -> skeleton_beta: (*, 22, 3)"""
skeleton_beta = self.J_template + torch.einsum("...d, jcd -> ...jc", betas, self.J_shapedirs) # (22, 3)
return skeleton_beta
def forward_bfc(self, **kwargs):
"""Wrap (B, F, C) to (B*F, C) and unwrap (B*F, C) to (B, F, C)"""
for k in kwargs:
assert len(kwargs[k].shape) == 3
B, F = kwargs["body_pose"].shape[:2]
smplx_out = self.forward(**{k: v.reshape(B * F, -1) for k, v in kwargs.items()})
smplx_out.vertices = smplx_out.vertices.reshape(B, F, -1, 3)
smplx_out.joints = smplx_out.joints.reshape(B, F, -1, 3)
return smplx_out
================================================
FILE: eval/GVHMR/hmr4d/utils/body_model/min_lbs.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch3d.transforms import axis_angle_to_matrix
from smplx.utils import Struct, to_np, to_tensor
from hmr4d.utils.smplx_utils import forward_kinematics_motion
class MinimalLBS(nn.Module):
def __init__(self, sp_ids, bm_dir='models/smplh', num_betas=16, model_type='smplh', **kwargs):
super().__init__()
self.num_betas = num_betas
self.sensor_point_vid = torch.tensor(sp_ids)
# load struct data on predefined sensor-point
self.load_struct_on_sp(f'{bm_dir}/male/model.npz', prefix='male')
self.load_struct_on_sp(f'{bm_dir}/female/model.npz', prefix='female')
def load_struct_on_sp(self, bm_path, prefix='m'):
"""
Load 4 weights from body-model-struct.
Keep the sensor points only. Use prefix to label different bm.
"""
num_betas = self.num_betas
sp_vid = self.sensor_point_vid
# load data
data_struct = Struct(**np.load(bm_path, encoding='latin1'))
# v-template
v_template = to_tensor(to_np(data_struct.v_template)) # (V, 3)
v_template_sp = v_template[sp_vid] # (N, 3)
self.register_buffer(f'{prefix}_v_template_sp', v_template_sp, False)
# shapedirs
shapedirs = to_tensor(to_np(data_struct.shapedirs[:, :, :num_betas])) # (V, 3, NB)
shapedirs_sp = shapedirs[sp_vid]
self.register_buffer(f'{prefix}_shapedirs_sp', shapedirs_sp, False)
# posedirs
posedirs = to_tensor(to_np(data_struct.posedirs)) # (V, 3, 51*9)
posedirs_sp = posedirs[sp_vid]
posedirs_sp = posedirs_sp.reshape(len(sp_vid)*3, -1).T # (51*9, N*3)
self.register_buffer(f'{prefix}_posedirs_sp', posedirs_sp, False)
# lbs_weights
lbs_weights = to_tensor(to_np(data_struct.weights)) # (V, J+1)
lbs_weights_sp = lbs_weights[sp_vid]
self.register_buffer(f'{prefix}_lbs_weights_sp', lbs_weights_sp, False)
def forward(self, root_orient=None, pose_body=None, trans=None, betas=None, A=None, recompute_A=False, genders=None,
joints_zero=None):
"""
Args:
root_orient, Optional: (B, T, 3)
pose_body: (B, T, J*3)
trans: (B, T, 3)
betas: (B, T, 16)
A, Optional: (B, T, J+1, 4, 4)
recompute_A: if True, root_orient should be given, otherwise use A
genders, List: ['male', 'female', ...]
joints_zero: (B, J+1, 3), required when recompute_A is True
Returns:
sensor_verts: (B, T, N, 3)
"""
B, T = pose_body.shape[:2]
v_template = torch.stack([getattr(self, f'{g}_v_template_sp') for g in genders]) # (B, N, 3)
shapedirs = torch.stack([getattr(self, f'{g}_shapedirs_sp') for g in genders]) # (B, N, 3, NB)
posedirs = torch.stack([getattr(self, f'{g}_posedirs_sp') for g in genders]) # (B, 51*9, N*3)
lbs_weights = torch.stack([getattr(self, f'{g}_lbs_weights_sp') for g in genders]) # (B, N, J+1)
# ===== LBS, handle T ===== #
# 2. Add shape contribution
if betas.shape[1] == 1:
betas = betas.expand(-1, T, -1)
blend_shape = torch.einsum('btl,bmkl->btmk', [betas, shapedirs])
v_shaped = v_template[:, None] + blend_shape
# 3. Add pose blend shapes
ident = torch.eye(3).to(pose_body)
aa = pose_body.reshape(B, T, -1, 3)
R = axis_angle_to_matrix(aa)
pose_feature = (R - ident).view(B, T, -1)
dim_pf = pose_feature.shape[-1]
# (B, T, P) @ (B, P, N*3) -> (B, T, N, 3)
pose_offsets = torch.matmul(pose_feature, posedirs[:, :dim_pf]).view(B, T, -1, 3)
v_posed = pose_offsets + v_shaped
# 4. Compute A
if recompute_A:
_, _, A = forward_kinematics_motion(root_orient, pose_body, trans, joints_zero)
# 5. Skinning
W = lbs_weights
# (B, 1, N, J+1)) @ (B, T, J+1, 16)
num_joints = A.shape[-3] # 22
Ts = torch.matmul(W[:, None, :, :num_joints], A.view(B, T, num_joints, 16))
Ts = Ts.view(B, T, -1, 4, 4) # (B, T, N, 4, 4)
v_posed_homo = F.pad(v_posed, (0, 1), value=1) # (B, T, N, 4)
v_homo = torch.matmul(Ts, torch.unsqueeze(v_posed_homo, dim=-1))
# 6. translate
sensor_verts = v_homo[:, :, :, :3, 0] + trans[:, :, None]
return sensor_verts
================================================
FILE: eval/GVHMR/hmr4d/utils/body_model/smpl_lite.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from pytorch3d.transforms import axis_angle_to_matrix
from smplx.utils import Struct, to_np, to_tensor
from einops import einsum, rearrange
from time import time
import pickle
from .smplx_lite import batch_rigid_transform_v2
class SmplLite(nn.Module):
def __init__(
self,
model_path="inputs/checkpoints/body_models/smpl",
gender="neutral",
num_betas=10,
):
super().__init__()
# Load the model
model_path = Path(model_path)
if model_path.is_dir():
smpl_path = Path(model_path) / f"SMPL_{gender.upper()}.pkl"
else:
smpl_path = model_path
assert smpl_path.exists()
with open(smpl_path, "rb") as smpl_file:
data_struct = Struct(**pickle.load(smpl_file, encoding="latin1"))
self.faces = data_struct.f # (F, 3)
self.register_smpl_buffers(data_struct, num_betas)
self.register_fast_skeleton_computing_buffers()
def register_smpl_buffers(self, data_struct, num_betas):
# shapedirs, (V, 3, N_betas), V=10475 for SMPLX
shapedirs = to_tensor(to_np(data_struct.shapedirs[:, :, :num_betas])).float()
self.register_buffer("shapedirs", shapedirs, False)
# v_template, (V, 3)
v_template = to_tensor(to_np(data_struct.v_template)).float()
self.register_buffer("v_template", v_template, False)
# J_regressor, (J, V), J=55 for SMPLX
J_regressor = to_tensor(to_np(data_struct.J_regressor)).float()
self.register_buffer("J_regressor", J_regressor, False)
# posedirs, (54*9, V, 3), note that the first global_orient is not included
posedirs = to_tensor(to_np(data_struct.posedirs)).float() # (V, 3, 54*9)
posedirs = rearrange(posedirs, "v c n -> n v c")
self.register_buffer("posedirs", posedirs, False)
# lbs_weights, (V, J), J=55
lbs_weights = to_tensor(to_np(data_struct.weights)).float()
self.register_buffer("lbs_weights", lbs_weights, False)
# parents, (J), long
parents = to_tensor(to_np(data_struct.kintree_table[0])).long()
parents[0] = -1
self.register_buffer("parents", parents, False)
def register_fast_skeleton_computing_buffers(self):
# For fast computing of skeleton under beta
J_template = self.J_regressor @ self.v_template # (J, 3)
J_shapedirs = torch.einsum("jv, vcd -> jcd", self.J_regressor, self.shapedirs) # (J, 3, 10)
self.register_buffer("J_template", J_template, False)
self.register_buffer("J_shapedirs", J_shapedirs, False)
def get_skeleton(self, betas):
return self.J_template + einsum(betas, self.J_shapedirs, "... k, j c k -> ... j c")
def forward(
self,
body_pose,
betas,
global_orient,
transl,
):
"""
Args:
body_pose: (B, L, 63)
betas: (B, L, 10)
global_orient: (B, L, 3)
transl: (B, L, 3)
Returns:
vertices: (B, L, V, 3)
"""
# 1. Convert [global_orient, body_pose] to rot_mats
full_pose = torch.cat([global_orient, body_pose], dim=-1)
rot_mats = axis_angle_to_matrix(full_pose.reshape(*full_pose.shape[:-1], full_pose.shape[-1] // 3, 3))
# 2. Forward Kinematics
J = self.get_skeleton(betas) # (*, 55, 3)
A = batch_rigid_transform_v2(rot_mats, J, self.parents)[1]
# 3. Canonical v_posed = v_template + shaped_offsets + pose_offsets
pose_feature = rot_mats[..., 1:, :, :] - rot_mats.new([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
pose_feature = pose_feature.view(*pose_feature.shape[:-3], -1) # (*, 55*3*3)
v_posed = (
self.v_template
+ einsum(betas, self.shapedirs, "... k, v c k -> ... v c")
+ einsum(pose_feature, self.posedirs, "... k, k v c -> ... v c")
)
del pose_feature, rot_mats, full_pose
# 4. Skinning
T = einsum(self.lbs_weights, A, "v j, ... j c d -> ... v c d")
verts = einsum(T[..., :3, :3], v_posed, "... v c d, ... v d -> ... v c") + T[..., :3, 3]
# 5. Translation
verts = verts + transl[..., None, :]
return verts
class SmplxLiteJ24(SmplLite):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Compute mapping
smpl2j24 = self.J_regressor # (24, 6890)
jids, smplx_vids = torch.where(smpl2j24 != 0)
interestd = torch.zeros([len(smplx_vids), 24])
for idx, (jid, smplx_vid) in enumerate(zip(jids, smplx_vids)):
interestd[idx, jid] = smpl2j24[jid, smplx_vid]
self.register_buffer("interestd", interestd, False) # (236, 24)
# Update to vertices of interest
self.v_template = self.v_template[smplx_vids].clone() # (V', 3)
self.shapedirs = self.shapedirs[smplx_vids].clone() # (V', 3, K)
self.posedirs = self.posedirs[:, smplx_vids].clone() # (K, V', 3)
self.lbs_weights = self.lbs_weights[smplx_vids].clone() # (V', J)
def forward(self, body_pose, betas, global_orient, transl):
"""Returns: joints (*, J, 3). (B, L) or (B,) are both supported."""
# Use super class's forward to get verts
verts = super().forward(body_pose, betas, global_orient, transl) # (*, 236, 3)
joints = einsum(self.interestd, verts, "v j, ... v c -> ... j c")
return joints
================================================
FILE: eval/GVHMR/hmr4d/utils/body_model/smpl_vert_segmentation.json
================================================
{
"rightHand": [
5442,
5443,
5444,
5445,
5446,
5447,
5448,
5449,
5450,
5451,
5452,
5453,
5454,
5455,
5456,
5457,
5458,
5459,
5460,
5461,
5462,
5463,
5464,
5465,
5466,
5467,
5468,
5469,
5470,
5471,
5472,
5473,
5474,
5475,
5476,
5477,
5478,
5479,
5480,
5481,
5482,
5483,
5484,
5485,
5486,
5487,
5492,
5493,
5494,
5495,
5496,
5497,
5502,
5503,
5504,
5505,
5506,
5507,
5508,
5509,
5510,
5511,
5512,
5513,
5514,
5515,
5516,
5517,
5518,
5519,
5520,
5521,
5522,
5523,
5524,
5525,
5526,
5527,
5530,
5531,
5532,
5533,
5534,
5535,
5536,
5537,
5538,
5539,
5540,
5541,
5542,
5543,
5544,
5545,
5546,
5547,
5548,
5549,
5550,
5551,
5552,
5553,
5554,
5555,
5556,
5557,
5558,
5559,
5560,
5561,
5562,
5569,
5571,
5574,
5575,
5576,
5577,
5578,
5579,
5580,
5581,
5582,
5583,
5588,
5589,
5592,
5593,
5594,
5595,
5596,
5597,
5598,
5599,
5600,
5601,
5602,
5603,
5604,
5605,
5610,
5611,
5612,
5613,
5614,
5621,
5622,
5625,
5631,
5632,
5633,
5634,
5635,
5636,
5637,
5638,
5639,
5640,
5641,
5643,
5644,
5645,
5646,
5649,
5650,
5652,
5653,
5654,
5655,
5656,
5657,
5658,
5659,
5660,
5661,
5662,
5663,
5664,
5667,
5670,
5671,
5672,
5673,
5674,
5675,
5682,
5683,
5684,
5685,
5686,
5687,
5688,
5689,
5690,
5692,
5695,
5697,
5698,
5699,
5700,
5701,
5707,
5708,
5709,
5710,
5711,
5712,
5713,
5714,
5715,
5716,
5717,
5718,
5719,
5720,
5721,
5723,
5724,
5725,
5726,
5727,
5728,
5729,
5730,
5731,
5732,
5735,
5736,
5737,
5738,
5739,
5740,
5745,
5746,
5748,
5749,
5750,
5751,
5752,
6056,
6057,
6066,
6067,
6158,
6159,
6160,
6161,
6162,
6163,
6164,
6165,
6166,
6167,
6168,
6169,
6170,
6171,
6172,
6173,
6174,
6175,
6176,
6177,
6178,
6179,
6180,
6181,
6182,
6183,
6184,
6185,
6186,
6187,
6188,
6189,
6190,
6191,
6192,
6193,
6194,
6195,
6196,
6197,
6198,
6199,
6200,
6201,
6202,
6203,
6204,
6205,
6206,
6207,
6208,
6209,
6210,
6211,
6212,
6213,
6214,
6215,
6216,
6217,
6218,
6219,
6220,
6221,
6222,
6223,
6224,
6225,
6226,
6227,
6228,
6229,
6230,
6231,
6232,
6233,
6234,
6235,
6236,
6237,
6238,
6239
],
"rightUpLeg": [
4320,
4321,
4323,
4324,
4333,
4334,
4335,
4336,
4337,
4338,
4339,
4340,
4356,
4357,
4358,
4359,
4360,
4361,
4362,
4363,
4364,
4365,
4366,
4367,
4383,
4384,
4385,
4386,
4387,
4388,
4389,
4390,
4391,
4392,
4393,
4394,
4395,
4396,
4397,
4398,
4399,
4400,
4401,
4419,
4420,
4421,
4422,
4430,
4431,
4432,
4433,
4434,
4435,
4436,
4437,
4438,
4439,
4440,
4441,
4442,
4443,
4444,
4445,
4446,
4447,
4448,
4449,
4450,
4451,
4452,
4453,
4454,
4455,
4456,
4457,
4458,
4459,
4460,
4461,
4462,
4463,
4464,
4465,
4466,
4467,
4468,
4469,
4470,
4471,
4472,
4473,
4474,
4475,
4476,
4477,
4478,
4479,
4480,
4481,
4482,
4483,
4484,
4485,
4486,
4487,
4488,
4489,
4490,
4491,
4492,
4493,
4494,
4495,
4496,
4497,
4498,
4499,
4500,
4501,
4502,
4503,
4504,
4505,
4506,
4507,
4508,
4509,
4510,
4511,
4512,
4513,
4514,
4515,
4516,
4517,
4518,
4519,
4520,
4521,
4522,
4523,
4524,
4525,
4526,
4527,
4528,
4529,
4530,
4531,
4532,
4623,
4624,
4625,
4626,
4627,
4628,
4629,
4630,
4631,
4632,
4633,
4634,
4645,
4646,
4647,
4648,
4649,
4650,
4651,
4652,
4653,
4654,
4655,
4656,
4657,
4658,
4659,
4660,
4670,
4671,
4672,
4673,
4704,
4705,
4706,
4707,
4708,
4709,
4710,
4711,
4712,
4713,
4745,
4746,
4757,
4758,
4759,
4760,
4801,
4802,
4829,
4834,
4835,
4836,
4837,
4838,
4839,
4840,
4841,
4924,
4925,
4926,
4928,
4929,
4930,
4931,
4932,
4933,
4934,
4935,
4936,
4948,
4949,
4950,
4951,
4952,
4970,
4971,
4972,
4973,
4983,
4984,
4985,
4986,
4987,
4988,
4989,
4990,
4991,
4992,
4993,
5004,
5005,
6546,
6547,
6548,
6549,
6552,
6553,
6554,
6555,
6556,
6873,
6877
],
"leftArm": [
626,
627,
628,
629,
634,
635,
680,
681,
716,
717,
718,
719,
769,
770,
771,
772,
773,
774,
775,
776,
777,
778,
779,
780,
784,
785,
786,
787,
788,
789,
790,
791,
792,
793,
1231,
1232,
1233,
1234,
1258,
1259,
1260,
1261,
1271,
1281,
1282,
1310,
1311,
1314,
1315,
1340,
1341,
1342,
1343,
1355,
1356,
1357,
1358,
1376,
1377,
1378,
1379,
1380,
1381,
1382,
1383,
1384,
1385,
1386,
1387,
1388,
1389,
1390,
1391,
1392,
1393,
1394,
1395,
1396,
1397,
1398,
1399,
1400,
1402,
1403,
1405,
1406,
1407,
1408,
1409,
1410,
1411,
1412,
1413,
1414,
1415,
1416,
1428,
1429,
1430,
1431,
1432,
1433,
1438,
1439,
1440,
1441,
1442,
1443,
1444,
1445,
1502,
1505,
1506,
1507,
1508,
1509,
1510,
1538,
1541,
1542,
1543,
1545,
1619,
1620,
1621,
1622,
1631,
1632,
1633,
1634,
1635,
1636,
1637,
1638,
1639,
1640,
1641,
1642,
1645,
1646,
1647,
1648,
1649,
1650,
1651,
1652,
1653,
1654,
1655,
1656,
1658,
1659,
1661,
1662,
1664,
1666,
1667,
1668,
1669,
1670,
1671,
1672,
1673,
1674,
1675,
1676,
1677,
1678,
1679,
1680,
1681,
1682,
1683,
1684,
1696,
1697,
1698,
1703,
1704,
1705,
1706,
1707,
1708,
1709,
1710,
1711,
1712,
1713,
1714,
1715,
1716,
1717,
1718,
1719,
1720,
1725,
1731,
1732,
1733,
1734,
1735,
1737,
1739,
1740,
1745,
1746,
1747,
1748,
1749,
1751,
1761,
1830,
1831,
1844,
1845,
1846,
1850,
1851,
1854,
1855,
1858,
1860,
1865,
1866,
1867,
1869,
1870,
1871,
1874,
1875,
1876,
1877,
1878,
1882,
1883,
1888,
1889,
1892,
1900,
1901,
1902,
1903,
1904,
1909,
2819,
2820,
2821,
2822,
2895,
2896,
2897,
2898,
2899,
2900,
2901,
2902,
2903,
2945,
2946,
2974,
2975,
2976,
2977,
2978,
2979,
2980,
2981,
2982,
2983,
2984,
2985,
2986,
2987,
2988,
2989,
2990,
2991,
2992,
2993,
2994,
2995,
2996,
3002,
3013
],
"leftLeg": [
995,
998,
999,
1002,
1004,
1005,
1008,
1010,
1012,
1015,
1016,
1018,
1019,
1043,
1044,
1047,
1048,
1049,
1050,
1051,
1052,
1053,
1054,
1055,
1056,
1057,
1058,
1059,
1060,
1061,
1062,
1063,
1064,
1065,
1066,
1067,
1068,
1069,
1070,
1071,
1072,
1073,
1074,
1075,
1076,
1077,
1078,
1079,
1080,
1081,
1082,
1083,
1084,
1085,
1086,
1087,
1088,
1089,
1090,
1091,
1092,
1093,
1094,
1095,
1096,
1097,
1098,
1099,
1100,
1101,
1102,
1103,
1104,
1105,
1106,
1107,
1108,
1109,
1110,
1111,
1112,
1113,
1114,
1115,
1116,
1117,
1118,
1119,
1120,
1121,
1122,
1123,
1124,
1125,
1126,
1127,
1128,
1129,
1130,
1131,
1132,
1133,
1134,
1135,
1136,
1148,
1149,
1150,
1151,
1152,
1153,
1154,
1155,
1156,
1157,
1158,
1175,
1176,
1177,
1178,
1179,
1180,
1181,
1182,
1183,
1369,
1370,
1371,
1372,
1373,
1374,
1375,
1464,
1465,
1466,
1467,
1468,
1469,
1470,
1471,
1472,
1473,
1474,
1522,
1523,
1524,
1525,
1526,
1527,
1528,
1529,
1530,
1531,
1532,
3174,
3175,
3176,
3177,
3178,
3179,
3180,
3181,
3182,
3183,
3184,
3185,
3186,
3187,
3188,
3189,
3190,
3191,
3192,
3193,
3194,
3195,
3196,
3197,
3198,
3199,
3200,
3201,
3202,
3203,
3204,
3205,
3206,
3207,
3208,
3209,
3210,
3319,
3320,
3321,
3322,
3323,
3324,
3325,
3326,
3327,
3328,
3329,
3330,
3331,
3332,
3333,
3334,
3335,
3432,
3433,
3434,
3435,
3436,
3469,
3472,
3473,
3474
],
"leftToeBase": [
3211,
3212,
3213,
3214,
3215,
3216,
3217,
3218,
3219,
3220,
3221,
3222,
3223,
3224,
3225,
3226,
3227,
3228,
3229,
3230,
3231,
3232,
3233,
3234,
3235,
3236,
3237,
3238,
3239,
3240,
3241,
3242,
3243,
3244,
3245,
3246,
3247,
3248,
3249,
3250,
3251,
3252,
3253,
3254,
3255,
3256,
3257,
3258,
3259,
3260,
3261,
3262,
3263,
3264,
3265,
3266,
3267,
3268,
3269,
3270,
3271,
3272,
3273,
3274,
3275,
3276,
3277,
3278,
3279,
3280,
3281,
3282,
3283,
3284,
3285,
3286,
3287,
3288,
3289,
3290,
3291,
3292,
3293,
3294,
3295,
3296,
3297,
3298,
3299,
3300,
3301,
3302,
3303,
3304,
3305,
3306,
3307,
3308,
3309,
3310,
3311,
3312,
3313,
3314,
3315,
3316,
3317,
3318,
3336,
3337,
3340,
3342,
3344,
3346,
3348,
3350,
3352,
3354,
3357,
3358,
3360,
3362
],
"leftFoot": [
3327,
3328,
3329,
3330,
3331,
3332,
3333,
3334,
3335,
3336,
3337,
3338,
3339,
3340,
3341,
3342,
3343,
3344,
3345,
3346,
3347,
3348,
3349,
3350,
3351,
3352,
3353,
3354,
3355,
3356,
3357,
3358,
3359,
3360,
3361,
3362,
3363,
3364,
3365,
3366,
3367,
3368,
3369,
3370,
3371,
3372,
3373,
3374,
3375,
3376,
3377,
3378,
3379,
3380,
3381,
3382,
3383,
3384,
3385,
3386,
3387,
3388,
3389,
3390,
3391,
3392,
3393,
3394,
3395,
3396,
3397,
3398,
3399,
3400,
3401,
3402,
3403,
3404,
3405,
3406,
3407,
3408,
3409,
3410,
3411,
3412,
3413,
3414,
3415,
3416,
3417,
3418,
3419,
3420,
3421,
3422,
3423,
3424,
3425,
3426,
3427,
3428,
3429,
3430,
3431,
3432,
3433,
3434,
3435,
3436,
3437,
3438,
3439,
3440,
3441,
3442,
3443,
3444,
3445,
3446,
3447,
3448,
3449,
3450,
3451,
3452,
3453,
3454,
3455,
3456,
3457,
3458,
3459,
3460,
3461,
3462,
3463,
3464,
3465,
3466,
3467,
3468,
3469
],
"spine1": [
598,
599,
600,
601,
610,
611,
612,
613,
614,
615,
616,
617,
618,
619,
620,
621,
642,
645,
646,
647,
652,
653,
658,
659,
660,
661,
668,
669,
670,
671,
684,
685,
686,
687,
688,
689,
690,
691,
692,
722,
723,
724,
725,
736,
750,
751,
761,
764,
766,
767,
794,
795,
891,
892,
893,
894,
925,
926,
927,
928,
929,
940,
941,
942,
943,
1190,
1191,
1192,
1193,
1194,
1195,
1196,
1197,
1200,
1201,
1202,
1212,
1236,
1252,
1253,
1254,
1255,
1268,
1269,
1270,
1329,
1330,
1348,
1349,
1351,
1420,
1421,
1423,
1424,
1425,
1426,
1436,
1437,
1756,
1757,
1758,
2839,
2840,
2841,
2842,
2843,
2844,
2845,
2846,
2847,
2848,
2849,
2850,
2851,
2870,
2871,
2883,
2906,
2908,
3014,
3017,
3025,
3030,
3033,
3034,
3037,
3039,
3040,
3041,
3042,
3043,
3044,
3076,
3077,
3079,
3480,
3505,
3511,
4086,
4087,
4088,
4089,
4098,
4099,
4100,
4101,
4102,
4103,
4104,
4105,
4106,
4107,
4108,
4109,
4130,
4131,
4134,
4135,
4140,
4141,
4146,
4147,
4148,
4149,
4156,
4157,
4158,
4159,
4172,
4173,
4174,
4175,
4176,
4177,
4178,
4179,
4180,
4210,
4211,
4212,
4213,
4225,
4239,
4240,
4249,
4250,
4255,
4256,
4282,
4283,
4377,
4378,
4379,
4380,
4411,
4412,
4413,
4414,
4415,
4426,
4427,
4428,
4429,
4676,
4677,
4678,
4679,
4680,
4681,
4682,
4683,
4686,
4687,
4688,
4695,
4719,
4735,
4736,
4737,
4740,
4751,
4752,
4753,
4824,
4825,
4828,
4893,
4894,
4895,
4897,
4898,
4899,
4908,
4909,
5223,
5224,
5225,
6300,
6301,
6302,
6303,
6304,
6305,
6306,
6307,
6308,
6309,
6310,
6311,
6312,
6331,
6332,
6342,
6366,
6367,
6475,
6477,
6478,
6481,
6482,
6485,
6487,
6488,
6489,
6490,
6491,
6878
],
"spine2": [
570,
571,
572,
573,
584,
585,
586,
587,
588,
589,
590,
591,
592,
593,
594,
595,
596,
597,
602,
603,
604,
605,
606,
607,
608,
609,
622,
623,
624,
625,
638,
639,
640,
641,
643,
644,
648,
649,
650,
651,
666,
667,
672,
673,
674,
675,
680,
681,
682,
683,
693,
694,
695,
696,
697,
698,
699,
700,
701,
702,
703,
704,
713,
714,
715,
716,
717,
726,
727,
728,
729,
730,
731,
732,
733,
735,
737,
738,
739,
740,
741,
742,
743,
744,
745,
746,
747,
748,
749,
752,
753,
754,
755,
756,
757,
758,
759,
760,
762,
763,
803,
804,
805,
806,
811,
812,
813,
814,
817,
818,
819,
820,
821,
824,
825,
826,
827,
828,
895,
896,
930,
931,
1198,
1199,
1213,
1214,
1215,
1216,
1217,
1218,
1219,
1220,
1235,
1237,
1256,
1257,
1271,
1272,
1273,
1279,
1280,
1283,
1284,
1285,
1286,
1287,
1288,
1289,
1290,
1291,
1292,
1293,
1294,
1295,
1296,
1297,
1298,
1299,
1300,
1301,
1302,
1303,
1304,
1305,
1306,
1307,
1308,
1309,
1312,
1313,
1319,
1320,
1346,
1347,
1350,
1352,
1401,
1417,
1418,
1419,
1422,
1427,
1434,
1435,
1503,
1504,
1536,
1537,
1544,
1545,
1753,
1754,
1755,
1759,
1760,
1761,
1762,
1763,
1808,
1809,
1810,
1811,
1816,
1817,
1818,
1819,
1820,
1834,
1835,
1836,
1837,
1838,
1839,
1868,
1879,
1880,
2812,
2813,
2852,
2853,
2854,
2855,
2856,
2857,
2858,
2859,
2860,
2861,
2862,
2863,
2864,
2865,
2866,
2867,
2868,
2869,
2872,
2875,
2876,
2877,
2878,
2881,
2882,
2884,
2885,
2886,
2904,
2905,
2907,
2931,
2932,
2933,
2934,
2935,
2936,
2937,
2941,
2950,
2951,
2952,
2953,
2954,
2955,
2956,
2957,
2958,
2959,
2960,
2961,
2962,
2963,
2964,
2965,
2966,
2967,
2968,
2969,
2970,
2971,
2972,
2973,
2997,
2998,
3006,
3007,
3012,
3015,
3026,
3027,
3028,
3029,
3031,
3032,
3035,
3036,
3038,
3059,
3060,
3061,
3062,
3063,
3064,
3065,
3066,
3067,
3073,
3074,
3075,
3078,
3168,
3169,
3171,
3470,
3471,
3482,
3483,
3495,
3496,
3497,
3498,
3506,
3508,
4058,
4059,
4060,
4061,
4072,
4073,
4074,
4075,
4076,
4077,
4078,
4079,
4080,
4081,
4082,
4083,
4084,
4085,
4090,
4091,
4092,
4093,
4094,
4095,
4096,
4097,
4110,
4111,
4112,
4113,
4126,
4127,
4128,
4129,
4132,
4133,
4136,
4137,
4138,
4139,
4154,
4155,
4160,
4161,
4162,
4163,
4168,
4169,
4170,
4171,
4181,
4182,
4183,
4184,
4185,
4186,
4187,
4188,
4189,
4190,
4191,
4192,
4201,
4202,
4203,
4204,
4207,
4214,
4215,
4216,
4217,
4218,
4219,
4220,
4221,
4223,
4224,
4226,
4227,
4228,
4229,
4230,
4231,
4232,
4233,
4234,
4235,
4236,
4237,
4238,
4241,
4242,
4243,
4244,
4245,
4246,
4247,
4248,
4251,
4252,
4291,
4292,
4293,
4294,
4299,
4300,
4301,
4302,
4305,
4306,
4307,
4308,
4309,
4312,
4313,
4314,
4315,
4381,
4382,
4416,
4417,
4684,
4685,
4696,
4697,
4698,
4699,
4700,
4701,
4702,
4703,
4718,
4720,
4738,
4739,
4754,
4755,
4756,
4761,
4762,
4765,
4766,
4767,
4768,
4769,
4770,
4771,
4772,
4773,
4774,
4775,
4776,
4777,
4778,
4779,
4780,
4781,
4782,
4783,
4784,
4785,
4786,
4787,
4788,
4789,
4792,
4793,
4799,
4800,
4822,
4823,
4826,
4827,
4874,
4890,
4891,
4892,
4896,
4900,
4907,
4910,
4975,
4976,
5007,
5008,
5013,
5014,
5222,
5226,
5227,
5228,
5229,
5230,
5269,
5270,
5271,
5272,
5277,
5278,
5279,
5280,
5281,
5295,
5296,
5297,
5298,
5299,
5300,
5329,
5340,
5341,
6273,
6274,
6313,
6314,
6315,
6316,
6317,
6318,
6319,
6320,
6321,
6322,
6323,
6324,
6325,
6326,
6327,
6328,
6329,
6330,
6333,
6336,
6337,
6340,
6341,
6343,
6344,
6345,
6363,
6364,
6365,
6390,
6391,
6392,
6393,
6394,
6395,
6396,
6398,
6409,
6410,
6411,
6412,
6413,
6414,
6415,
6416,
6417,
6418,
6419,
6420,
6421,
6422,
6423,
6424,
6425,
6426,
6427,
6428,
6429,
6430,
6431,
6432,
6456,
6457,
6465,
6466,
6476,
6479,
6480,
6483,
6484,
6486,
6496,
6497,
6498,
6499,
6500,
6501,
6502,
6503,
6879
],
"leftShoulder": [
591,
604,
605,
606,
609,
634,
635,
636,
637,
674,
706,
707,
708,
709,
710,
711,
712,
713,
715,
717,
730,
733,
734,
735,
781,
782,
783,
1238,
1239,
1240,
1241,
1242,
1243,
1244,
1245,
1290,
1291,
1294,
1316,
1317,
1318,
1401,
1402,
1403,
1404,
1509,
1535,
1545,
1808,
1810,
1811,
1812,
1813,
1814,
1815,
1818,
1819,
1821,
1822,
1823,
1824,
1825,
1826,
1827,
1828,
1829,
1830,
1831,
1832,
1833,
1837,
1840,
1841,
1842,
1843,
1844,
1845,
1846,
1847,
1848,
1849,
1850,
1851,
1852,
1853,
1854,
1855,
1856,
1857,
1858,
1859,
1861,
1862,
1863,
1864,
1872,
1873,
1880,
1881,
1884,
1885,
1886,
1887,
1890,
1891,
1893,
1894,
1895,
1896,
1897,
1898,
1899,
2879,
2880,
2881,
2886,
2887,
2888,
2889,
2890,
2891,
2892,
2893,
2894,
2903,
2938,
2939,
2940,
2941,
2942,
2943,
2944,
2945,
2946,
2947,
2948,
2949,
2965,
2967,
2969,
2999,
3000,
3001,
3002,
3003,
3004,
3005,
3008,
3009,
3010,
3011
],
"rightShoulder": [
4077,
4091,
4092,
4094,
4095,
4122,
4123,
4124,
4125,
4162,
4194,
4195,
4196,
4197,
4198,
4199,
4200,
4201,
4203,
4207,
4218,
4219,
4222,
4223,
4269,
4270,
4271,
4721,
4722,
4723,
4724,
4725,
4726,
4727,
4728,
4773,
4774,
4778,
4796,
4797,
4798,
4874,
4875,
4876,
4877,
4982,
5006,
5014,
5269,
5271,
5272,
5273,
5274,
5275,
5276,
5279,
5281,
5282,
5283,
5284,
5285,
5286,
5287,
5288,
5289,
5290,
5291,
5292,
5293,
5294,
5298,
5301,
5302,
5303,
5304,
5305,
5306,
5307,
5308,
5309,
5310,
5311,
5312,
5313,
5314,
5315,
5316,
5317,
5318,
5319,
5320,
5322,
5323,
5324,
5325,
5333,
5334,
5341,
5342,
5345,
5346,
5347,
5348,
5351,
5352,
5354,
5355,
5356,
5357,
5358,
5359,
5360,
6338,
6339,
6340,
6345,
6346,
6347,
6348,
6349,
6350,
6351,
6352,
6353,
6362,
6397,
6398,
6399,
6400,
6401,
6402,
6403,
6404,
6405,
6406,
6407,
6408,
6424,
6425,
6428,
6458,
6459,
6460,
6461,
6462,
6463,
6464,
6467,
6468,
6469,
6470
],
"rightFoot": [
6727,
6728,
6729,
6730,
6731,
6732,
6733,
6734,
6735,
6736,
6737,
6738,
6739,
6740,
6741,
6742,
6743,
6744,
6745,
6746,
6747,
6748,
6749,
6750,
6751,
6752,
6753,
6754,
6755,
6756,
6757,
6758,
6759,
6760,
6761,
6762,
6763,
6764,
6765,
6766,
6767,
6768,
6769,
6770,
6771,
6772,
6773,
6774,
6775,
6776,
6777,
6778,
6779,
6780,
6781,
6782,
6783,
6784,
6785,
6786,
6787,
6788,
6789,
6790,
6791,
6792,
6793,
6794,
6795,
6796,
6797,
6798,
6799,
6800,
6801,
6802,
6803,
6804,
6805,
6806,
6807,
6808,
6809,
6810,
6811,
6812,
6813,
6814,
6815,
6816,
6817,
6818,
6819,
6820,
6821,
6822,
6823,
6824,
6825,
6826,
6827,
6828,
6829,
6830,
6831,
6832,
6833,
6834,
6835,
6836,
6837,
6838,
6839,
6840,
6841,
6842,
6843,
6844,
6845,
6846,
6847,
6848,
6849,
6850,
6851,
6852,
6853,
6854,
6855,
6856,
6857,
6858,
6859,
6860,
6861,
6862,
6863,
6864,
6865,
6866,
6867,
6868,
6869
],
"head": [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
60,
61,
62,
63,
64,
65,
66,
67,
68,
69,
70,
71,
72,
73,
74,
75,
76,
77,
78,
79,
80,
81,
82,
83,
84,
85,
86,
87,
88,
89,
90,
91,
92,
93,
94,
95,
96,
97,
98,
99,
100,
101,
102,
103,
104,
105,
106,
107,
108,
109,
110,
111,
112,
113,
114,
115,
116,
117,
118,
119,
120,
121,
122,
123,
124,
125,
126,
127,
128,
129,
130,
131,
132,
133,
134,
135,
136,
137,
138,
139,
140,
141,
142,
143,
144,
145,
146,
147,
148,
149,
154,
155,
156,
157,
158,
159,
160,
161,
162,
163,
164,
165,
166,
167,
168,
169,
170,
171,
172,
173,
176,
177,
178,
179,
180,
181,
182,
183,
184,
185,
186,
187,
188,
189,
190,
191,
192,
193,
194,
195,
196,
197,
198,
199,
200,
201,
202,
203,
204,
205,
220,
221,
225,
226,
227,
228,
229,
230,
231,
232,
233,
234,
235,
236,
237,
238,
239,
240,
241,
242,
243,
244,
245,
246,
247,
248,
249,
250,
251,
252,
253,
254,
255,
258,
259,
260,
261,
262,
263,
264,
265,
266,
267,
268,
269,
270,
271,
272,
273,
274,
275,
276,
277,
278,
279,
280,
281,
282,
283,
286,
287,
288,
289,
290,
291,
292,
293,
294,
295,
303,
304,
306,
307,
310,
311,
312,
313,
314,
315,
316,
317,
318,
319,
320,
321,
322,
323,
324,
325,
326,
327,
328,
329,
330,
331,
332,
335,
336,
337,
338,
339,
340,
341,
342,
343,
344,
345,
346,
347,
348,
349,
350,
351,
352,
353,
354,
355,
356,
357,
358,
359,
360,
361,
362,
363,
364,
365,
366,
367,
368,
369,
370,
371,
372,
373,
374,
375,
376,
377,
378,
379,
380,
381,
382,
383,
384,
385,
386,
387,
388,
389,
390,
391,
392,
393,
394,
395,
396,
397,
398,
399,
400,
401,
402,
403,
404,
405,
406,
407,
408,
409,
410,
411,
412,
413,
414,
415,
416,
417,
418,
419,
420,
421,
422,
427,
428,
429,
430,
431,
432,
433,
434,
435,
436,
437,
438,
439,
442,
443,
444,
445,
446,
447,
448,
449,
450,
454,
455,
456,
457,
458,
459,
461,
462,
463,
464,
465,
466,
467,
468,
469,
470,
471,
472,
473,
474,
475,
476,
477,
478,
479,
480,
481,
482,
483,
484,
485,
486,
487,
488,
489,
490,
491,
492,
493,
494,
495,
496,
497,
498,
499,
500,
501,
502,
503,
504,
505,
506,
507,
508,
509,
510,
511,
512,
513,
514,
515,
516,
517,
518,
519,
520,
521,
522,
523,
524,
525,
526,
527,
528,
529,
530,
531,
532,
533,
534,
535,
536,
537,
538,
539,
540,
541,
542,
543,
544,
545,
546,
547,
548,
549,
550,
551,
552,
553,
554,
555,
556,
557,
558,
559,
560,
561,
562,
563,
564,
565,
566,
567,
568,
569,
574,
575,
576,
577,
578,
579,
580,
581,
582,
583,
1764,
1765,
1766,
1770,
1771,
1772,
1773,
1774,
1775,
1776,
1777,
1778,
1905,
1906,
1907,
1908,
2779,
2780,
2781,
2782,
2783,
2784,
2785,
2786,
2787,
2788,
2789,
2790,
2791,
2792,
2793,
2794,
2795,
2796,
2797,
2798,
2799,
2800,
2801,
2802,
2803,
2804,
2805,
2806,
2807,
2808,
2809,
2810,
2811,
2814,
2815,
2816,
2817,
2818,
3045,
3046,
3047,
3048,
3051,
3052,
3053,
3054,
3055,
3056,
3058,
3069,
3070,
3071,
3072,
3161,
3162,
3163,
3165,
3166,
3167,
3485,
3486,
3487,
3488,
3489,
3490,
3491,
3492,
3493,
3494,
3499,
3512,
3513,
3514,
3515,
3516,
3517,
3518,
3519,
3520,
3521,
3522,
3523,
3524,
3525,
3526,
3527,
3528,
3529,
3530,
3531,
3532,
3533,
3534,
3535,
3536,
3537,
3538,
3539,
3540,
3541,
3542,
3543,
3544,
3545,
3546,
3547,
3548,
3549,
3550,
3551,
3552,
3553,
3554,
3555,
3556,
3557,
3558,
3559,
3560,
3561,
3562,
3563,
3564,
3565,
3566,
3567,
3568,
3569,
3570,
3571,
3572,
3573,
3574,
3575,
3576,
3577,
3578,
3579,
3580,
3581,
3582,
3583,
3584,
3585,
3586,
3587,
3588,
3589,
3590,
3591,
3592,
3593,
3594,
3595,
3596,
3597,
3598,
3599,
3600,
3601,
3602,
3603,
3604,
3605,
3606,
3607,
3608,
3609,
3610,
3611,
3612,
3613,
3614,
3615,
3616,
3617,
3618,
3619,
3620,
3621,
3622,
3623,
3624,
3625,
3626,
3627,
3628,
3629,
3630,
3631,
3632,
3633,
3634,
3635,
3636,
3637,
3638,
3639,
3640,
3641,
3642,
3643,
3644,
3645,
3646,
3647,
3648,
3649,
3650,
3651,
3652,
3653,
3654,
3655,
3656,
3657,
3658,
3659,
3660,
3661,
3666,
3667,
3668,
3669,
3670,
3671,
3672,
3673,
3674,
3675,
3676,
3677,
3678,
3679,
3680,
3681,
3682,
3683,
3684,
3685,
3688,
3689,
3690,
3691,
3692,
3693,
3694,
3695,
3696,
3697,
3698,
3699,
3700,
3701,
3702,
3703,
3704,
3705,
3706,
3707,
3708,
3709,
3710,
3711,
3712,
3713,
3714,
3715,
3716,
3717,
3732,
3733,
3737,
3738,
3739,
3740,
3741,
3742,
3743,
3744,
3745,
3746,
3747,
3748,
3749,
3750,
3751,
3752,
3753,
3754,
3755,
3756,
3757,
3758,
3759,
3760,
3761,
3762,
3763,
3764,
3765,
3766,
3767,
3770,
3771,
3772,
3773,
3774,
3775,
3776,
3777,
3778,
3779,
3780,
3781,
3782,
3783,
3784,
3785,
3786,
3787,
3788,
3789,
3790,
3791,
3792,
3793,
3794,
3795,
3798,
3799,
3800,
3801,
3802,
3803,
3804,
3805,
3806,
3807,
3815,
3816,
3819,
3820,
3821,
3822,
3823,
3824,
3825,
3826,
3827,
3828,
3829,
3830,
3831,
3832,
3833,
3834,
3835,
3836,
3837,
3838,
3841,
3842,
3843,
3844,
3845,
3846,
3847,
3848,
3849,
3850,
3851,
3852,
3853,
3854,
3855,
3856,
3857,
3858,
3859,
3860,
3861,
3862,
3863,
3864,
3865,
3866,
3867,
3868,
3869,
3870,
3871,
3872,
3873,
3874,
3875,
3876,
3877,
3878,
3879,
3880,
3881,
3882,
3883,
3884,
3885,
3886,
3887,
3888,
3889,
3890,
3891,
3892,
3893,
3894,
3895,
3896,
3897,
3898,
3899,
3900,
3901,
3902,
3903,
3904,
3905,
3906,
3907,
3908,
3909,
3910,
3911,
3912,
3913,
3914,
3915,
3916,
3917,
3922,
3923,
3924,
3925,
3926,
3927,
3928,
3929,
3930,
3931,
3932,
3933,
3936,
3937,
3938,
3939,
3940,
3941,
3945,
3946,
3947,
3948,
3949,
3950,
3951,
3952,
3953,
3954,
3955,
3956,
3957,
3958,
3959,
3960,
3961,
3962,
3963,
3964,
3965,
3966,
3967,
3968,
3969,
3970,
3971,
3972,
3973,
3974,
3975,
3976,
3977,
3978,
3979,
3980,
3981,
3982,
3983,
3984,
3985,
3986,
3987,
3988,
3989,
3990,
3991,
3992,
3993,
3994,
3995,
3996,
3997,
3998,
3999,
4000,
4001,
4002,
4003,
4004,
4005,
4006,
4007,
4008,
4009,
4010,
4011,
4012,
4013,
4014,
4015,
4016,
4017,
4018,
4019,
4020,
4021,
4022,
4023,
4024,
4025,
4026,
4027,
4028,
4029,
4030,
4031,
4032,
4033,
4034,
4035,
4036,
4037,
4038,
4039,
4040,
4041,
4042,
4043,
4044,
4045,
4046,
4047,
4048,
4049,
4050,
4051,
4052,
4053,
4054,
4055,
4056,
4057,
4062,
4063,
4064,
4065,
4066,
4067,
4068,
4069,
4070,
4071,
5231,
5232,
5233,
5235,
5236,
5237,
5238,
5239,
5240,
5241,
5242,
5243,
5366,
5367,
5368,
5369,
6240,
6241,
6242,
6243,
6244,
6245,
6246,
6247,
6248,
6249,
6250,
6251,
6252,
6253,
6254,
6255,
6256,
6257,
6258,
6259,
6260,
6261,
6262,
6263,
6264,
6265,
6266,
6267,
6268,
6269,
6270,
6271,
6272,
6275,
6276,
6277,
6278,
6279,
6492,
6493,
6494,
6495,
6880,
6881,
6882,
6883,
6884,
6885,
6886,
6887,
6888,
6889
],
"rightArm": [
4114,
4115,
4116,
4117,
4122,
4125,
4168,
4171,
4204,
4205,
4206,
4207,
4257,
4258,
4259,
4260,
4261,
4262,
4263,
4264,
4265,
4266,
4267,
4268,
4272,
4273,
4274,
4275,
4276,
4277,
4278,
4279,
4280,
4281,
4714,
4715,
4716,
4717,
4741,
4742,
4743,
4744,
4756,
4763,
4764,
4790,
4791,
4794,
4795,
4816,
4817,
4818,
4819,
4830,
4831,
4832,
4833,
4849,
4850,
4851,
4852,
4853,
4854,
4855,
4856,
4857,
4858,
4859,
4860,
4861,
4862,
4863,
4864,
4865,
4866,
4867,
4868,
4869,
4870,
4871,
4872,
4873,
4876,
4877,
4878,
4879,
4880,
4881,
4882,
4883,
4884,
4885,
4886,
4887,
4888,
4889,
4901,
4902,
4903,
4904,
4905,
4906,
4911,
4912,
4913,
4914,
4915,
4916,
4917,
4918,
4974,
4977,
4978,
4979,
4980,
4981,
4982,
5009,
5010,
5011,
5012,
5014,
5088,
5089,
5090,
5091,
5100,
5101,
5102,
5103,
5104,
5105,
5106,
5107,
5108,
5109,
5110,
5111,
5114,
5115,
5116,
5117,
5118,
5119,
5120,
5121,
5122,
5123,
5124,
5125,
5128,
5129,
5130,
5131,
5134,
5135,
5136,
5137,
5138,
5139,
5140,
5141,
5142,
5143,
5144,
5145,
5146,
5147,
5148,
5149,
5150,
5151,
5152,
5153,
5165,
5166,
5167,
5172,
5173,
5174,
5175,
5176,
5177,
5178,
5179,
5180,
5181,
5182,
5183,
5184,
5185,
5186,
5187,
5188,
5189,
5194,
5200,
5201,
5202,
5203,
5204,
5206,
5208,
5209,
5214,
5215,
5216,
5217,
5218,
5220,
5229,
5292,
5293,
5303,
5306,
5309,
5311,
5314,
5315,
5318,
5319,
5321,
5326,
5327,
5328,
5330,
5331,
5332,
5335,
5336,
5337,
5338,
5339,
5343,
5344,
5349,
5350,
5353,
5361,
5362,
5363,
5364,
5365,
5370,
6280,
6281,
6282,
6283,
6354,
6355,
6356,
6357,
6358,
6359,
6360,
6361,
6362,
6404,
6405,
6433,
6434,
6435,
6436,
6437,
6438,
6439,
6440,
6441,
6442,
6443,
6444,
6445,
6446,
6447,
6448,
6449,
6450,
6451,
6452,
6453,
6454,
6455,
6461,
6471
],
"leftHandIndex1": [
2027,
2028,
2029,
2030,
2037,
2038,
2039,
2040,
2057,
2067,
2068,
2123,
2124,
2125,
2126,
2127,
2128,
2129,
2130,
2132,
2145,
2146,
2152,
2153,
2154,
2156,
2157,
2158,
2159,
2160,
2161,
2162,
2163,
2164,
2165,
2166,
2167,
2168,
2169,
2177,
2178,
2179,
2181,
2186,
2187,
2190,
2191,
2204,
2205,
2215,
2216,
2217,
2218,
2219,
2220,
2232,
2233,
2245,
2246,
2247,
2258,
2259,
2261,
2262,
2263,
2269,
2270,
2272,
2273,
2274,
2276,
2277,
2280,
2281,
2282,
2283,
2291,
2292,
2293,
2294,
2295,
2296,
2297,
2298,
2299,
2300,
2301,
2302,
2303,
2304,
2305,
2306,
2307,
2308,
2309,
2310,
2311,
2312,
2313,
2314,
2315,
2316,
2317,
2318,
2319,
2320,
2321,
2322,
2323,
2324,
2325,
2326,
2327,
2328,
2329,
2330,
2331,
2332,
2333,
2334,
2335,
2336,
2337,
2338,
2339,
2340,
2341,
2342,
2343,
2344,
2345,
2346,
2347,
2348,
2349,
2350,
2351,
2352,
2353,
2354,
2355,
2356,
2357,
2358,
2359,
2360,
2361,
2362,
2363,
2364,
2365,
2366,
2367,
2368,
2369,
2370,
2371,
2372,
2373,
2374,
2375,
2376,
2377,
2378,
2379,
2380,
2381,
2382,
2383,
2384,
2385,
2386,
2387,
2388,
2389,
2390,
2391,
2392,
2393,
2394,
2395,
2396,
2397,
2398,
2399,
2400,
2401,
2402,
2403,
2404,
2405,
2406,
2407,
2408,
2409,
2410,
2411,
2412,
2413,
2414,
2415,
2416,
2417,
2418,
2419,
2420,
2421,
2422,
2423,
2424,
2425,
2426,
2427,
2428,
2429,
2430,
2431,
2432,
2433,
2434,
2435,
2436,
2437,
2438,
2439,
2440,
2441,
2442,
2443,
2444,
2445,
2446,
2447,
2448,
2449,
2450,
2451,
2452,
2453,
2454,
2455,
2456,
2457,
2458,
2459,
2460,
2461,
2462,
2463,
2464,
2465,
2466,
2467,
2468,
2469,
2470,
2471,
2472,
2473,
2474,
2475,
2476,
2477,
2478,
2479,
2480,
2481,
2482,
2483,
2484,
2485,
2486,
2487,
2488,
2489,
2490,
2491,
2492,
2493,
2494,
2495,
2496,
2497,
2498,
2499,
2500,
2501,
2502,
2503,
2504,
2505,
2506,
2507,
2508,
2509,
2510,
2511,
2512,
2513,
2514,
2515,
2516,
2517,
2518,
2519,
2520,
2521,
2522,
2523,
2524,
2525,
2526,
2527,
2528,
2529,
2530,
2531,
2532,
2533,
2534,
2535,
2536,
2537,
2538,
2539,
2540,
2541,
2542,
2543,
2544,
2545,
2546,
2547,
2548,
2549,
2550,
2551,
2552,
2553,
2554,
2555,
2556,
2557,
2558,
2559,
2560,
2561,
2562,
2563,
2564,
2565,
2566,
2567,
2568,
2569,
2570,
2571,
2572,
2573,
2574,
2575,
2576,
2577,
2578,
2579,
2580,
2581,
2582,
2583,
2584,
2585,
2586,
2587,
2588,
2589,
2590,
2591,
2592,
2593,
2594,
2596,
2597,
2599,
2600,
2601,
2602,
2603,
2604,
2606,
2607,
2609,
2610,
2611,
2612,
2613,
2614,
2615,
2616,
2617,
2618,
2619,
2620,
2621,
2622,
2623,
2624,
2625,
2626,
2627,
2628,
2629,
2630,
2631,
2632,
2633,
2634,
2635,
2636,
2637,
2638,
2639,
2640,
2641,
2642,
2643,
2644,
2645,
2646,
2647,
2648,
2649,
2650,
2651,
2652,
2653,
2654,
2655,
2656,
2657,
2658,
2659,
2660,
2661,
2662,
2663,
2664,
2665,
2666,
2667,
2668,
2669,
2670,
2671,
2672,
2673,
2674,
2675,
2676,
2677,
2678,
2679,
2680,
2681,
2682,
2683,
2684,
2685,
2686,
2687,
2688,
2689,
2690,
2691,
2692,
2693,
2694,
2695,
2696
],
"rightLeg": [
4481,
4482,
4485,
4486,
4491,
4492,
4493,
4495,
4498,
4500,
4501,
4505,
4506,
4529,
4532,
4533,
4534,
4535,
4536,
4537,
4538,
4539,
4540,
4541,
4542,
4543,
4544,
4545,
4546,
4547,
4548,
4549,
4550,
4551,
4552,
4553,
4554,
4555,
4556,
4557,
4558,
4559,
4560,
4561,
4562,
4563,
4564,
4565,
4566,
4567,
4568,
4569,
4570,
4571,
4572,
4573,
4574,
4575,
4576,
4577,
4578,
4579,
4580,
4581,
4582,
4583,
4584,
4585,
4586,
4587,
4588,
4589,
4590,
4591,
4592,
4593,
4594,
4595,
4596,
4597,
4598,
4599,
4600,
4601,
4602,
4603,
4604,
4605,
4606,
4607,
4608,
4609,
4610,
4611,
4612,
4613,
4614,
4615,
4616,
4617,
4618,
4619,
4620,
4621,
4622,
4634,
4635,
4636,
4637,
4638,
4639,
4640,
4641,
4642,
4643,
4644,
4661,
4662,
4663,
4664,
4665,
4666,
4667,
4668,
4669,
4842,
4843,
4844,
4845,
4846,
4847,
4848,
4937,
4938,
4939,
4940,
4941,
4942,
4943,
4944,
4945,
4946,
4947,
4993,
4994,
4995,
4996,
4997,
4998,
4999,
5000,
5001,
5002,
5003,
6574,
6575,
6576,
6577,
6578,
6579,
6580,
6581,
6582,
6583,
6584,
6585,
6586,
6587,
6588,
6589,
6590,
6591,
6592,
6593,
6594,
6595,
6596,
6597,
6598,
6599,
6600,
6601,
6602,
6603,
6604,
6605,
6606,
6607,
6608,
6609,
6610,
6719,
6720,
6721,
6722,
6723,
6724,
6725,
6726,
6727,
6728,
6729,
6730,
6731,
6732,
6733,
6734,
6735,
6832,
6833,
6834,
6835,
6836,
6869,
6870,
6871,
6872
],
"rightHandIndex1": [
5488,
5489,
5490,
5491,
5498,
5499,
5500,
5501,
5518,
5528,
5529,
5584,
5585,
5586,
5587,
5588,
5589,
5590,
5591,
5592,
5606,
5607,
5613,
5615,
5616,
5617,
5618,
5619,
5620,
5621,
5622,
5623,
5624,
5625,
5626,
5627,
5628,
5629,
5630,
5638,
5639,
5640,
5642,
5647,
5648,
5650,
5651,
5665,
5666,
5676,
5677,
5678,
5679,
5680,
5681,
5693,
5694,
5706,
5707,
5708,
5719,
5721,
5722,
5723,
5724,
5730,
5731,
5733,
5734,
5735,
5737,
5738,
5741,
5742,
5743,
5744,
5752,
5753,
5754,
5755,
5756,
5757,
5758,
5759,
5760,
5761,
5762,
5763,
5764,
5765,
5766,
5767,
5768,
5769,
5770,
5771,
5772,
5773,
5774,
5775,
5776,
5777,
5778,
5779,
5780,
5781,
5782,
5783,
5784,
5785,
5786,
5787,
5788,
5789,
5790,
5791,
5792,
5793,
5794,
5795,
5796,
5797,
5798,
5799,
5800,
5801,
5802,
5803,
5804,
5805,
5806,
5807,
5808,
5809,
5810,
5811,
5812,
5813,
5814,
5815,
5816,
5817,
5818,
5819,
5820,
5821,
5822,
5823,
5824,
5825,
5826,
5827,
5828,
5829,
5830,
5831,
5832,
5833,
5834,
5835,
5836,
5837,
5838,
5839,
5840,
5841,
5842,
5843,
5844,
5845,
5846,
5847,
5848,
5849,
5850,
5851,
5852,
5853,
5854,
5855,
5856,
5857,
5858,
5859,
5860,
5861,
5862,
5863,
5864,
5865,
5866,
5867,
5868,
5869,
5870,
5871,
5872,
5873,
5874,
5875,
5876,
5877,
5878,
5879,
5880,
5881,
5882,
5883,
5884,
5885,
5886,
5887,
5888,
5889,
5890,
5891,
5892,
5893,
5894,
5895,
5896,
5897,
5898,
5899,
5900,
5901,
5902,
5903,
5904,
5905,
5906,
5907,
5908,
5909,
5910,
5911,
5912,
5913,
5914,
5915,
5916,
5917,
5918,
5919,
5920,
5921,
5922,
5923,
5924,
5925,
5926,
5927,
5928,
5929,
5930,
5931,
5932,
5933,
5934,
5935,
5936,
5937,
5938,
5939,
5940,
5941,
5942,
5943,
5944,
5945,
5946,
5947,
5948,
5949,
5950,
5951,
5952,
5953,
5954,
5955,
5956,
5957,
5958,
5959,
5960,
5961,
5962,
5963,
5964,
5965,
5966,
5967,
5968,
5969,
5970,
5971,
5972,
5973,
5974,
5975,
5976,
5977,
5978,
5979,
5980,
5981,
5982,
5983,
5984,
5985,
5986,
5987,
5988,
5989,
5990,
5991,
5992,
5993,
5994,
5995,
5996,
5997,
5998,
5999,
6000,
6001,
6002,
6003,
6004,
6005,
6006,
6007,
6008,
6009,
6010,
6011,
6012,
6013,
6014,
6015,
6016,
6017,
6018,
6019,
6020,
6021,
6022,
6023,
6024,
6025,
6026,
6027,
6028,
6029,
6030,
6031,
6032,
6033,
6034,
6035,
6036,
6037,
6038,
6039,
6040,
6041,
6042,
6043,
6044,
6045,
6046,
6047,
6048,
6049,
6050,
6051,
6052,
6053,
6054,
6055,
6058,
6059,
6060,
6061,
6062,
6063,
6064,
6065,
6068,
6069,
6070,
6071,
6072,
6073,
6074,
6075,
6076,
6077,
6078,
6079,
6080,
6081,
6082,
6083,
6084,
6085,
6086,
6087,
6088,
6089,
6090,
6091,
6092,
6093,
6094,
6095,
6096,
6097,
6098,
6099,
6100,
6101,
6102,
6103,
6104,
6105,
6106,
6107,
6108,
6109,
6110,
6111,
6112,
6113,
6114,
6115,
6116,
6117,
6118,
6119,
6120,
6121,
6122,
6123,
6124,
6125,
6126,
6127,
6128,
6129,
6130,
6131,
6132,
6133,
6134,
6135,
6136,
6137,
6138,
6139,
6140,
6141,
6142,
6143,
6144,
6145,
6146,
6147,
6148,
6149,
6150,
6151,
6152,
6153,
6154,
6155,
6156,
6157
],
"leftForeArm": [
1546,
1547,
1548,
1549,
1550,
1551,
1552,
1553,
1554,
1555,
1556,
1557,
1558,
1559,
1560,
1561,
1562,
1563,
1564,
1565,
1566,
1567,
1568,
1569,
1570,
1571,
1572,
1573,
1574,
1575,
1576,
1577,
1578,
1579,
1580,
1581,
1582,
1583,
1584,
1585,
1586,
1587,
1588,
1589,
1590,
1591,
1592,
1593,
1594,
1595,
1596,
1597,
1598,
1599,
1600,
1601,
1602,
1603,
1604,
1605,
1606,
1607,
1608,
1609,
1610,
1611,
1612,
1613,
1614,
1615,
1616,
1617,
1618,
1620,
1621,
1623,
1624,
1625,
1626,
1627,
1628,
1629,
1630,
1643,
1644,
1646,
1647,
1650,
1651,
1654,
1655,
1657,
1658,
1659,
1660,
1661,
1662,
1663,
1664,
1665,
1666,
1685,
1686,
1687,
1688,
1689,
1690,
1691,
1692,
1693,
1694,
1695,
1699,
1700,
1701,
1702,
1721,
1722,
1723,
1724,
1725,
1726,
1727,
1728,
1729,
1730,
1732,
1736,
1738,
1741,
1742,
1743,
1744,
1750,
1752,
1900,
1909,
1910,
1911,
1912,
1913,
1914,
1915,
1916,
1917,
1918,
1919,
1920,
1921,
1922,
1923,
1924,
1925,
1926,
1927,
1928,
1929,
1930,
1931,
1932,
1933,
1934,
1935,
1936,
1937,
1938,
1939,
1940,
1941,
1942,
1943,
1944,
1945,
1946,
1947,
1948,
1949,
1950,
1951,
1952,
1953,
1954,
1955,
1956,
1957,
1958,
1959,
1960,
1961,
1962,
1963,
1964,
1965,
1966,
1967,
1968,
1969,
1970,
1971,
1972,
1973,
1974,
1975,
1976,
1977,
1978,
1979,
1980,
2019,
2059,
2060,
2073,
2089,
2098,
2099,
2100,
2101,
2102,
2103,
2104,
2105,
2106,
2107,
2108,
2109,
2110,
2111,
2112,
2147,
2148,
2206,
2207,
2208,
2209,
2228,
2230,
2234,
2235,
2241,
2242,
2243,
2244,
2279,
2286,
2873,
2874
],
"rightForeArm": [
5015,
5016,
5017,
5018,
5019,
5020,
5021,
5022,
5023,
5024,
5025,
5026,
5027,
5028,
5029,
5030,
5031,
5032,
5033,
5034,
5035,
5036,
5037,
5038,
5039,
5040,
5041,
5042,
5043,
5044,
5045,
5046,
5047,
5048,
5049,
5050,
5051,
5052,
5053,
5054,
5055,
5056,
5057,
5058,
5059,
5060,
5061,
5062,
5063,
5064,
5065,
5066,
5067,
5068,
5069,
5070,
5071,
5072,
5073,
5074,
5075,
5076,
5077,
5078,
5079,
5080,
5081,
5082,
5083,
5084,
5085,
5086,
5087,
5090,
5091,
5092,
5093,
5094,
5095,
5096,
5097,
5098,
5099,
5112,
5113,
5116,
5117,
5120,
5121,
5124,
5125,
5126,
5127,
5128,
5129,
5130,
5131,
5132,
5133,
5134,
5135,
5154,
5155,
5156,
5157,
5158,
5159,
5160,
5161,
5162,
5163,
5164,
5168,
5169,
5170,
5171,
5190,
5191,
5192,
5193,
5194,
5195,
5196,
5197,
5198,
5199,
5202,
5205,
5207,
5210,
5211,
5212,
5213,
5219,
5221,
5361,
5370,
5371,
5372,
5373,
5374,
5375,
5376,
5377,
5378,
5379,
5380,
5381,
5382,
5383,
5384,
5385,
5386,
5387,
5388,
5389,
5390,
5391,
5392,
5393,
5394,
5395,
5396,
5397,
5398,
5399,
5400,
5401,
5402,
5403,
5404,
5405,
5406,
5407,
5408,
5409,
5410,
5411,
5412,
5413,
5414,
5415,
5416,
5417,
5418,
5419,
5420,
5421,
5422,
5423,
5424,
5425,
5426,
5427,
5428,
5429,
5430,
5431,
5432,
5433,
5434,
5435,
5436,
5437,
5438,
5439,
5440,
5441,
5480,
5520,
5521,
5534,
5550,
5559,
5560,
5561,
5562,
5563,
5564,
5565,
5566,
5567,
5568,
5569,
5570,
5571,
5572,
5573,
5608,
5609,
5667,
5668,
5669,
5670,
5689,
5691,
5695,
5696,
5702,
5703,
5704,
5705,
5740,
5747,
6334,
6335
],
"neck": [
148,
150,
151,
152,
153,
172,
174,
175,
201,
202,
204,
205,
206,
207,
208,
209,
210,
211,
212,
213,
214,
215,
216,
217,
218,
219,
222,
223,
224,
225,
256,
257,
284,
285,
295,
296,
297,
298,
299,
300,
301,
302,
303,
304,
305,
306,
307,
308,
309,
333,
334,
423,
424,
425,
426,
440,
441,
451,
452,
453,
460,
461,
571,
572,
824,
825,
826,
827,
828,
829,
1279,
1280,
1312,
1313,
1319,
1320,
1331,
3049,
3050,
3057,
3058,
3059,
3068,
3164,
3661,
3662,
3663,
3664,
3665,
3685,
3686,
3687,
3714,
3715,
3716,
3717,
3718,
3719,
3720,
3721,
3722,
3723,
3724,
3725,
3726,
3727,
3728,
3729,
3730,
3731,
3734,
3735,
3736,
3737,
3768,
3769,
3796,
3797,
3807,
3808,
3809,
3810,
3811,
3812,
3813,
3814,
3815,
3816,
3817,
3818,
3819,
3839,
3840,
3918,
3919,
3920,
3921,
3934,
3935,
3942,
3943,
3944,
3950,
4060,
4061,
4312,
4313,
4314,
4315,
4761,
4762,
4792,
4793,
4799,
4800,
4807
],
"rightToeBase": [
6611,
6612,
6613,
6614,
6615,
6616,
6617,
6618,
6619,
6620,
6621,
6622,
6623,
6624,
6625,
6626,
6627,
6628,
6629,
6630,
6631,
6632,
6633,
6634,
6635,
6636,
6637,
6638,
6639,
6640,
6641,
6642,
6643,
6644,
6645,
6646,
6647,
6648,
6649,
6650,
6651,
6652,
6653,
6654,
6655,
6656,
6657,
6658,
6659,
6660,
6661,
6662,
6663,
6664,
6665,
6666,
6667,
6668,
6669,
6670,
6671,
6672,
6673,
6674,
6675,
6676,
6677,
6678,
6679,
6680,
6681,
6682,
6683,
6684,
6685,
6686,
6687,
6688,
6689,
6690,
6691,
6692,
6693,
6694,
6695,
6696,
6697,
6698,
6699,
6700,
6701,
6702,
6703,
6704,
6705,
6706,
6707,
6708,
6709,
6710,
6711,
6712,
6713,
6714,
6715,
6716,
6717,
6718,
6736,
6739,
6741,
6743,
6745,
6747,
6749,
6750,
6752,
6754,
6757,
6758,
6760,
6762
],
"spine": [
616,
617,
630,
631,
632,
633,
654,
655,
656,
657,
662,
663,
664,
665,
720,
721,
765,
766,
767,
768,
796,
797,
798,
799,
889,
890,
916,
917,
918,
919,
921,
922,
923,
924,
925,
926,
1188,
1189,
1211,
1212,
1248,
1249,
1250,
1251,
1264,
1265,
1266,
1267,
1323,
1324,
1325,
1326,
1327,
1328,
1332,
1333,
1334,
1335,
1336,
1344,
1345,
1481,
1482,
1483,
1484,
1485,
1486,
1487,
1488,
1489,
1490,
1491,
1492,
1493,
1494,
1495,
1496,
1767,
2823,
2824,
2825,
2826,
2827,
2828,
2829,
2830,
2831,
2832,
2833,
2834,
2835,
2836,
2837,
2838,
2839,
2840,
2841,
2842,
2843,
2844,
2845,
2847,
2848,
2851,
3016,
3017,
3018,
3019,
3020,
3023,
3024,
3124,
3173,
3476,
3477,
3478,
3480,
3500,
3501,
3502,
3504,
3509,
3511,
4103,
4104,
4118,
4119,
4120,
4121,
4142,
4143,
4144,
4145,
4150,
4151,
4152,
4153,
4208,
4209,
4253,
4254,
4255,
4256,
4284,
4285,
4286,
4287,
4375,
4376,
4402,
4403,
4405,
4406,
4407,
4408,
4409,
4410,
4411,
4412,
4674,
4675,
4694,
4695,
4731,
4732,
4733,
4734,
4747,
4748,
4749,
4750,
4803,
4804,
4805,
4806,
4808,
4809,
4810,
4811,
4812,
4820,
4821,
4953,
4954,
4955,
4956,
4957,
4958,
4959,
4960,
4961,
4962,
4963,
4964,
4965,
4966,
4967,
4968,
5234,
6284,
6285,
6286,
6287,
6288,
6289,
6290,
6291,
6292,
6293,
6294,
6295,
6296,
6297,
6298,
6299,
6300,
6301,
6302,
6303,
6304,
6305,
6306,
6308,
6309,
6312,
6472,
6473,
6474,
6545,
6874,
6875,
6876,
6878
],
"leftUpLeg": [
833,
834,
838,
839,
847,
848,
849,
850,
851,
852,
853,
854,
870,
871,
872,
873,
874,
875,
876,
877,
878,
879,
880,
881,
897,
898,
899,
900,
901,
902,
903,
904,
905,
906,
907,
908,
909,
910,
911,
912,
913,
914,
915,
933,
934,
935,
936,
944,
945,
946,
947,
948,
949,
950,
951,
952,
953,
954,
955,
956,
957,
958,
959,
960,
961,
962,
963,
964,
965,
966,
967,
968,
969,
970,
971,
972,
973,
974,
975,
976,
977,
978,
979,
980,
981,
982,
983,
984,
985,
986,
987,
988,
989,
990,
991,
992,
993,
994,
995,
996,
997,
998,
999,
1000,
1001,
1002,
1003,
1004,
1005,
1006,
1007,
1008,
1009,
1010,
1011,
1012,
1013,
1014,
1015,
1016,
1017,
1018,
1019,
1020,
1021,
1022,
1023,
1024,
1025,
1026,
1027,
1028,
1029,
1030,
1031,
1032,
1033,
1034,
1035,
1036,
1037,
1038,
1039,
1040,
1041,
1042,
1043,
1044,
1045,
1046,
1137,
1138,
1139,
1140,
1141,
1142,
1143,
1144,
1145,
1146,
1147,
1148,
1159,
1160,
1161,
1162,
1163,
1164,
1165,
1166,
1167,
1168,
1169,
1170,
1171,
1172,
1173,
1174,
1184,
1185,
1186,
1187,
1221,
1222,
1223,
1224,
1225,
1226,
1227,
1228,
1229,
1230,
1262,
1263,
1274,
1275,
1276,
1277,
1321,
1322,
1354,
1359,
1360,
1361,
1362,
1365,
1366,
1367,
1368,
1451,
1452,
1453,
1455,
1456,
1457,
1458,
1459,
1460,
1461,
1462,
1463,
1475,
1477,
1478,
1479,
1480,
1498,
1499,
1500,
1501,
1511,
1512,
1513,
1514,
1516,
1517,
1518,
1519,
1520,
1521,
1522,
1533,
1534,
3125,
3126,
3127,
3128,
3131,
3132,
3133,
3134,
3135,
3475,
3479
],
"leftHand": [
1981,
1982,
1983,
1984,
1985,
1986,
1987,
1988,
1989,
1990,
1991,
1992,
1993,
1994,
1995,
1996,
1997,
1998,
1999,
2000,
2001,
2002,
2003,
2004,
2005,
2006,
2007,
2008,
2009,
2010,
2011,
2012,
2013,
2014,
2015,
2016,
2017,
2018,
2019,
2020,
2021,
2022,
2023,
2024,
2025,
2026,
2031,
2032,
2033,
2034,
2035,
2036,
2041,
2042,
2043,
2044,
2045,
2046,
2047,
2048,
2049,
2050,
2051,
2052,
2053,
2054,
2055,
2056,
2057,
2058,
2059,
2060,
2061,
2062,
2063,
2064,
2065,
2066,
2069,
2070,
2071,
2072,
2073,
2074,
2075,
2076,
2077,
2078,
2079,
2080,
2081,
2082,
2083,
2084,
2085,
2086,
2087,
2088,
2089,
2090,
2091,
2092,
2093,
2094,
2095,
2096,
2097,
2098,
2099,
2100,
2101,
2107,
2111,
2113,
2114,
2115,
2116,
2117,
2118,
2119,
2120,
2121,
2122,
2127,
2130,
2131,
2132,
2133,
2134,
2135,
2136,
2137,
2138,
2139,
2140,
2141,
2142,
2143,
2144,
2149,
2150,
2151,
2152,
2155,
2160,
2163,
2164,
2170,
2171,
2172,
2173,
2174,
2175,
2176,
2177,
2178,
2179,
2180,
2182,
2183,
2184,
2185,
2188,
2189,
2191,
2192,
2193,
2194,
2195,
2196,
2197,
2198,
2199,
2200,
2201,
2202,
2203,
2207,
2209,
2210,
2211,
2212,
2213,
2214,
2221,
2222,
2223,
2224,
2225,
2226,
2227,
2228,
2229,
2231,
2234,
2236,
2237,
2238,
2239,
2240,
2246,
2247,
2248,
2249,
2250,
2251,
2252,
2253,
2254,
2255,
2256,
2257,
2258,
2259,
2260,
2262,
2263,
2264,
2265,
2266,
2267,
2268,
2269,
2270,
2271,
2274,
2275,
2276,
2277,
2278,
2279,
2284,
2285,
2287,
2288,
2289,
2290,
2293,
2595,
2598,
2605,
2608,
2697,
2698,
2699,
2700,
2701,
2702,
2703,
2704,
2705,
2706,
2707,
2708,
2709,
2710,
2711,
2712,
2713,
2714,
2715,
2716,
2717,
2718,
2719,
2720,
2721,
2722,
2723,
2724,
2725,
2726,
2727,
2728,
2729,
2730,
2731,
2732,
2733,
2734,
2735,
2736,
2737,
2738,
2739,
2740,
2741,
2742,
2743,
2744,
2745,
2746,
2747,
2748,
2749,
2750,
2751,
2752,
2753,
2754,
2755,
2756,
2757,
2758,
2759,
2760,
2761,
2762,
2763,
2764,
2765,
2766,
2767,
2768,
2769,
2770,
2771,
2772,
2773,
2774,
2775,
2776,
2777,
2778
],
"hips": [
631,
632,
654,
657,
662,
665,
676,
677,
678,
679,
705,
720,
796,
799,
800,
801,
802,
807,
808,
809,
810,
815,
816,
822,
823,
830,
831,
832,
833,
834,
835,
836,
837,
838,
839,
840,
841,
842,
843,
844,
845,
846,
855,
856,
857,
858,
859,
860,
861,
862,
863,
864,
865,
866,
867,
868,
869,
871,
878,
881,
882,
883,
884,
885,
886,
887,
888,
889,
890,
912,
915,
916,
917,
918,
919,
920,
932,
937,
938,
939,
1163,
1166,
1203,
1204,
1205,
1206,
1207,
1208,
1209,
1210,
1246,
1247,
1262,
1263,
1276,
1277,
1278,
1321,
1336,
1337,
1338,
1339,
1353,
1354,
1361,
1362,
1363,
1364,
1446,
1447,
1448,
1449,
1450,
1454,
1476,
1497,
1511,
1513,
1514,
1515,
1533,
1534,
1539,
1540,
1768,
1769,
1779,
1780,
1781,
1782,
1783,
1784,
1785,
1786,
1787,
1788,
1789,
1790,
1791,
1792,
1793,
1794,
1795,
1796,
1797,
1798,
1799,
1800,
1801,
1802,
1803,
1804,
1805,
1806,
1807,
2909,
2910,
2911,
2912,
2913,
2914,
2915,
2916,
2917,
2918,
2919,
2920,
2921,
2922,
2923,
2924,
2925,
2926,
2927,
2928,
2929,
2930,
3018,
3019,
3021,
3022,
3080,
3081,
3082,
3083,
3084,
3085,
3086,
3087,
3088,
3089,
3090,
3091,
3092,
3093,
3094,
3095,
3096,
3097,
3098,
3099,
3100,
3101,
3102,
3103,
3104,
3105,
3106,
3107,
3108,
3109,
3110,
3111,
3112,
3113,
3114,
3115,
3116,
3117,
3118,
3119,
3120,
3121,
3122,
3123,
3124,
3128,
3129,
3130,
3136,
3137,
3138,
3139,
3140,
3141,
3142,
3143,
3144,
3145,
3146,
3147,
3148,
3149,
3150,
3151,
3152,
3153,
3154,
3155,
3156,
3157,
3158,
3159,
3160,
3170,
3172,
3481,
3484,
3500,
3502,
3503,
3507,
3510,
4120,
4121,
4142,
4143,
4150,
4151,
4164,
4165,
4166,
4167,
4193,
4208,
4284,
4285,
4288,
4289,
4290,
4295,
4296,
4297,
4298,
4303,
4304,
4310,
4311,
4316,
4317,
4318,
4319,
4320,
4321,
4322,
4323,
4324,
4325,
4326,
4327,
4328,
4329,
4330,
4331,
4332,
4341,
4342,
4343,
4344,
4345,
4346,
4347,
4348,
4349,
4350,
4351,
4352,
4353,
4354,
4355,
4356,
4364,
4365,
4368,
4369,
4370,
4371,
4372,
4373,
4374,
4375,
4376,
4398,
4399,
4402,
4403,
4404,
4405,
4406,
4418,
4423,
4424,
4425,
4649,
4650,
4689,
4690,
4691,
4692,
4693,
4729,
4730,
4745,
4746,
4759,
4760,
4801,
4812,
4813,
4814,
4815,
4829,
4836,
4837,
4919,
4920,
4921,
4922,
4923,
4927,
4969,
4983,
4984,
4986,
5004,
5005,
5244,
5245,
5246,
5247,
5248,
5249,
5250,
5251,
5252,
5253,
5254,
5255,
5256,
5257,
5258,
5259,
5260,
5261,
5262,
5263,
5264,
5265,
5266,
5267,
5268,
6368,
6369,
6370,
6371,
6372,
6373,
6374,
6375,
6376,
6377,
6378,
6379,
6380,
6381,
6382,
6383,
6384,
6385,
6386,
6387,
6388,
6389,
6473,
6474,
6504,
6505,
6506,
6507,
6508,
6509,
6510,
6511,
6512,
6513,
6514,
6515,
6516,
6517,
6518,
6519,
6520,
6521,
6522,
6523,
6524,
6525,
6526,
6527,
6528,
6529,
6530,
6531,
6532,
6533,
6534,
6535,
6536,
6537,
6538,
6539,
6540,
6541,
6542,
6543,
6544,
6545,
6549,
6550,
6551,
6557,
6558,
6559,
6560,
6561,
6562,
6563,
6564,
6565,
6566,
6567,
6568,
6569,
6570,
6571,
6572,
6573
]
}
================================================
FILE: eval/GVHMR/hmr4d/utils/body_model/smplx_lite.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from pytorch3d.transforms import axis_angle_to_matrix, rotation_6d_to_matrix
from smplx.utils import Struct, to_np, to_tensor
from einops import einsum, rearrange
from time import time
from hmr4d import PROJ_ROOT
class SmplxLite(nn.Module):
def __init__(
self,
model_path=PROJ_ROOT / "inputs/checkpoints/body_models/smplx",
gender="neutral",
num_betas=10,
):
super().__init__()
# Load the model
model_path = Path(model_path)
if model_path.is_dir():
smplx_path = Path(model_path) / f"SMPLX_{gender.upper()}.npz"
else:
smplx_path = model_path
assert smplx_path.exists()
model_data = np.load(smplx_path, allow_pickle=True)
data_struct = Struct(**model_data)
self.faces = data_struct.f # (F, 3)
self.register_smpl_buffers(data_struct, num_betas)
# self.register_smplh_buffers(data_struct, num_pca_comps, flat_hand_mean)
# self.register_smplx_buffers(data_struct)
self.register_fast_skeleton_computing_buffers()
# default_pose (99,) for torch.cat([global_orient, body_pose, default_pose])
other_default_pose = torch.cat(
[
torch.zeros(9),
to_tensor(data_struct.hands_meanl).float(),
to_tensor(data_struct.hands_meanr).float(),
]
)
self.register_buffer("other_default_pose", other_default_pose, False)
def register_smpl_buffers(self, data_struct, num_betas):
# shapedirs, (V, 3, N_betas), V=10475 for SMPLX
shapedirs = to_tensor(to_np(data_struct.shapedirs[:, :, :num_betas])).float()
self.register_buffer("shapedirs", shapedirs, False)
# v_template, (V, 3)
v_template = to_tensor(to_np(data_struct.v_template)).float()
self.register_buffer("v_template", v_template, False)
# J_regressor, (J, V), J=55 for SMPLX
J_regressor = to_tensor(to_np(data_struct.J_regressor)).float()
self.register_buffer("J_regressor", J_regressor, False)
# posedirs, (54*9, V, 3), note that the first global_orient is not included
posedirs = to_tensor(to_np(data_struct.posedirs)).float() # (V, 3, 54*9)
posedirs = rearrange(posedirs, "v c n -> n v c")
self.register_buffer("posedirs", posedirs, False)
# lbs_weights, (V, J), J=55
lbs_weights = to_tensor(to_np(data_struct.weights)).float()
self.register_buffer("lbs_weights", lbs_weights, False)
# parents, (J), long
parents = to_tensor(to_np(data_struct.kintree_table[0])).long()
parents[0] = -1
self.register_buffer("parents", parents, False)
def register_smplh_buffers(self, data_struct, num_pca_comps, flat_hand_mean):
# hand_pca, (N_pca, 45)
left_hand_components = to_tensor(data_struct.hands_componentsl[:num_pca_comps]).float()
right_hand_components = to_tensor(data_struct.hands_componentsr[:num_pca_comps]).float()
self.register_buffer("left_hand_components", left_hand_components, False)
self.register_buffer("right_hand_components", right_hand_components, False)
# hand_mean, (45,)
left_hand_mean = to_tensor(data_struct.hands_meanl).float()
right_hand_mean = to_tensor(data_struct.hands_meanr).float()
if not flat_hand_mean:
left_hand_mean = torch.zeros_like(left_hand_mean)
right_hand_mean = torch.zeros_like(right_hand_mean)
self.register_buffer("left_hand_mean", left_hand_mean, False)
self.register_buffer("right_hand_mean", right_hand_mean, False)
def register_smplx_buffers(self, data_struct):
# expr_dirs, (V, 3, N_expr)
expr_dirs = to_tensor(to_np(data_struct.shapedirs[:, :, 300:310])).float()
self.register_buffer("expr_dirs", expr_dirs, False)
def register_fast_skeleton_computing_buffers(self):
# For fast computing of skeleton under beta
J_template = self.J_regressor @ self.v_template # (J, 3)
J_shapedirs = torch.einsum("jv, vcd -> jcd", self.J_regressor, self.shapedirs) # (J, 3, 10)
self.register_buffer("J_template", J_template, False)
self.register_buffer("J_shapedirs", J_shapedirs, False)
def get_skeleton(self, betas):
return self.J_template + einsum(betas, self.J_shapedirs, "... k, j c k -> ... j c")
def forward(
self,
body_pose,
betas,
global_orient,
transl=None,
rotation_type="aa",
):
"""
Args:
body_pose: (B, L, 63)
betas: (B, L, 10)
global_orient: (B, L, 3)
transl: (B, L, 3)
Returns:
vertices: (B, L, V, 3)
"""
# 1. Convert [global_orient, body_pose, other_default_pose] to rot_mats
other_default_pose = self.other_default_pose # (99,)
if rotation_type == "aa":
other_default_pose = other_default_pose.expand(*body_pose.shape[:-1], -1)
full_pose = torch.cat([global_orient, body_pose, other_default_pose], dim=-1)
rot_mats = axis_angle_to_matrix(full_pose.reshape(*full_pose.shape[:-1], 55, 3))
del full_pose, other_default_pose
else:
assert rotation_type == "r6d" # useful when doing smplify
other_default_pose = axis_angle_to_matrix(other_default_pose.view(33, 3))
part_full_pose = torch.cat([global_orient, body_pose], dim=-1)
rot_mats = rotation_6d_to_matrix(part_full_pose.view(*part_full_pose.shape[:-1], 22, 6))
other_default_pose = other_default_pose.expand(*rot_mats.shape[:-3], -1, -1, -1)
rot_mats = torch.cat([rot_mats, other_default_pose], dim=-3)
del part_full_pose, other_default_pose
# 2. Forward Kinematics
J = self.get_skeleton(betas) # (*, 55, 3)
A = batch_rigid_transform_v2(rot_mats, J, self.parents)[1]
# 3. Canonical v_posed = v_template + shaped_offsets + pose_offsets
pose_feature = rot_mats[..., 1:, :, :] - rot_mats.new([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
pose_feature = pose_feature.view(*pose_feature.shape[:-3], -1) # (*, 55*3*3)
v_posed = (
self.v_template
+ einsum(betas, self.shapedirs, "... k, v c k -> ... v c")
+ einsum(pose_feature, self.posedirs, "... k, k v c -> ... v c")
)
del pose_feature, rot_mats
# 4. Skinning
T = einsum(self.lbs_weights, A, "v j, ... j c d -> ... v c d")
verts = einsum(T[..., :3, :3], v_posed, "... v c d, ... v d -> ... v c") + T[..., :3, 3]
# 5. Translation
if transl is not None:
verts = verts + transl[..., None, :]
return verts
class SmplxLiteCoco17(SmplxLite):
"""Output COCO17 joints (Faster, but cannot output vertices)"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Compute mapping
smplx2smpl = torch.load(PROJ_ROOT / "hmr4d/utils/body_model/smplx2smpl_sparse.pt")
COCO17_regressor = torch.load(PROJ_ROOT / "hmr4d/utils/body_model/smpl_coco17_J_regressor.pt")
smplx2coco17 = torch.matmul(COCO17_regressor, smplx2smpl.to_dense())
jids, smplx_vids = torch.where(smplx2coco17 != 0)
smplx2coco17_interestd = torch.zeros([len(smplx_vids), 17])
for idx, (jid, smplx_vid) in enumerate(zip(jids, smplx_vids)):
smplx2coco17_interestd[idx, jid] = smplx2coco17[jid, smplx_vid]
self.register_buffer("smplx2coco17_interestd", smplx2coco17_interestd, False) # (132, 17)
# Update to vertices of interest
self.v_template = self.v_template[smplx_vids].clone() # (V', 3)
self.shapedirs = self.shapedirs[smplx_vids].clone() # (V', 3, K)
self.posedirs = self.posedirs[:, smplx_vids].clone() # (K, V', 3)
self.lbs_weights = self.lbs_weights[smplx_vids].clone() # (V', J)
def forward(self, body_pose, betas, global_orient, transl):
"""Returns: joints (*, 17, 3). (B, L) or (B,) are both supported."""
# Use super class's forward to get verts
verts = super().forward(body_pose, betas, global_orient, transl) # (*, 132, 3)
joints = einsum(self.smplx2coco17_interestd, verts, "v j, ... v c -> ... j c")
return joints
class SmplxLiteV437Coco17(SmplxLite):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Compute mapping (COCO17)
smplx2smpl = torch.load(PROJ_ROOT / "hmr4d/utils/body_model/smplx2smpl_sparse.pt")
COCO17_regressor = torch.load(PROJ_ROOT / "hmr4d/utils/body_model/smpl_coco17_J_regressor.pt")
smplx2coco17 = torch.matmul(COCO17_regressor, smplx2smpl.to_dense())
jids, smplx_vids = torch.where(smplx2coco17 != 0)
smplx2coco17_interestd = torch.zeros([len(smplx_vids), 17])
for idx, (jid, smplx_vid) in enumerate(zip(jids, smplx_vids)):
smplx2coco17_interestd[idx, jid] = smplx2coco17[jid, smplx_vid]
self.register_buffer("smplx2coco17_interestd", smplx2coco17_interestd, False) # (132, 17)
assert len(smplx_vids) == 132
# Verts437
smplx_vids2 = torch.load(PROJ_ROOT / "hmr4d/utils/body_model/smplx_verts437.pt")
smplx_vids = torch.cat([smplx_vids, smplx_vids2])
# Update to vertices of interest
self.v_template = self.v_template[smplx_vids].clone() # (V', 3)
self.shapedirs = self.shapedirs[smplx_vids].clone() # (V', 3, K)
self.posedirs = self.posedirs[:, smplx_vids].clone() # (K, V', 3)
self.lbs_weights = self.lbs_weights[smplx_vids].clone() # (V', J)
def forward(self, body_pose, betas, global_orient, transl):
"""
Returns:
verts_437: (*, 437, 3)
joints (*, 17, 3). (B, L) or (B,) are both supported.
"""
# Use super class's forward to get verts
verts = super().forward(body_pose, betas, global_orient, transl) # (*, 132+437, 3)
verts_437 = verts[..., 132:, :].clone()
joints = einsum(self.smplx2coco17_interestd, verts[..., :132, :], "v j, ... v c -> ... j c")
return verts_437, joints
class SmplxLiteSmplN24(SmplxLite):
"""Output SMPL(not smplx)-Neutral 24 joints (Faster, but cannot output vertices)"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Compute mapping
smplx2smpl = torch.load(PROJ_ROOT / "hmr4d/utils/body_model/smplx2smpl_sparse.pt")
smpl2joints = torch.load(PROJ_ROOT / "hmr4d/utils/body_model/smpl_neutral_J_regressor.pt")
smplx2joints = torch.matmul(smpl2joints, smplx2smpl.to_dense())
jids, smplx_vids = torch.where(smplx2joints != 0)
smplx2joints_interested = torch.zeros([len(smplx_vids), smplx2joints.size(0)])
for idx, (jid, smplx_vid) in enumerate(zip(jids, smplx_vids)):
smplx2joints_interested[idx, jid] = smplx2joints[jid, smplx_vid]
self.register_buffer("smplx2joints_interested", smplx2joints_interested, False) # (V', J)
# Update to vertices of interest
self.v_template = self.v_template[smplx_vids].clone() # (V', 3)
self.shapedirs = self.shapedirs[smplx_vids].clone() # (V', 3, K)
self.posedirs = self.posedirs[:, smplx_vids].clone() # (K, V', 3)
self.lbs_weights = self.lbs_weights[smplx_vids].clone() # (V', J)
def forward(self, body_pose, betas, global_orient, transl):
"""Returns: joints (*, J, 3). (B, L) or (B,) are both supported."""
# Use super class's forward to get verts
verts = super().forward(body_pose, betas, global_orient, transl) # (*, V', 3)
joints = einsum(self.smplx2joints_interested, verts, "v j, ... v c -> ... j c")
return joints
def batch_rigid_transform_v2(rot_mats, joints, parents):
"""
Args:
rot_mats: (*, J, 3, 3)
joints: (*, J, 3)
"""
# check shape, since sometimes beta has shape=1
rot_mats_shape_prefix = rot_mats.shape[:-3]
if rot_mats_shape_prefix != joints.shape[:-2]:
joints = joints.expand(*rot_mats_shape_prefix, -1, -1)
rel_joints = joints.clone()
rel_joints[..., 1:, :] -= joints[..., parents[1:], :]
transforms_mat = torch.cat([rot_mats, rel_joints[..., :, None]], dim=-1) # (*, J, 3, 4)
transforms_mat = F.pad(transforms_mat, [0, 0, 0, 1], value=0.0)
transforms_mat[..., 3, 3] = 1.0 # (*, J, 4, 4)
transform_chain = [transforms_mat[..., 0, :, :]]
for i in range(1, parents.shape[0]):
# Subtract the joint location at the rest pose
# No need for rotation, since it's identity when at rest
curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[..., i, :, :])
transform_chain.append(curr_res)
transforms = torch.stack(transform_chain, dim=-3) # (*, J, 4, 4)
# The last column of the transformations contains the posed joints
posed_joints = transforms[..., :3, 3].clone()
rel_transforms = transforms.clone()
rel_transforms[..., :3, 3] -= einsum(transforms[..., :3, :3], joints, "... j c d, ... j d -> ... j c")
return posed_joints, rel_transforms
def sync_time():
torch.cuda.synchronize()
return time()
================================================
FILE: eval/GVHMR/hmr4d/utils/body_model/utils.py
================================================
import os
import numpy as np
import torch
SMPLH_JOINT_NAMES = [
'pelvis',
'left_hip',
'right_hip',
'spine1',
'left_knee',
'right_knee',
'spine2',
'left_ankle',
'right_ankle',
'spine3',
'left_foot',
'right_foot',
'neck',
'left_collar',
'right_collar',
'head',
'left_shoulder',
'right_shoulder',
'left_elbow',
'right_elbow',
'left_wrist',
'right_wrist',
'left_index1',
'left_index2',
'left_index3',
'left_middle1',
'left_middle2',
'left_middle3',
'left_pinky1',
'left_pinky2',
'left_pinky3',
'left_ring1',
'left_ring2',
'left_ring3',
'left_thumb1',
'left_thumb2',
'left_thumb3',
'right_index1',
'right_index2',
'right_index3',
'right_middle1',
'right_middle2',
'right_middle3',
'right_pinky1',
'right_pinky2',
'right_pinky3',
'right_ring1',
'right_ring2',
'right_ring3',
'right_thumb1',
'right_thumb2',
'right_thumb3',
'nose',
'right_eye',
'left_eye',
'right_ear',
'left_ear',
'left_big_toe',
'left_small_toe',
'left_heel',
'right_big_toe',
'right_small_toe',
'right_heel',
'left_thumb',
'left_index',
'left_middle',
'left_ring',
'left_pinky',
'right_thumb',
'right_index',
'right_middle',
'right_ring',
'right_pinky',
]
SMPLH_LEFT_LEG = ['left_hip', 'left_knee', 'left_ankle', 'left_foot']
SMPLH_RIGHT_LEG = ['right_hip', 'right_knee', 'right_ankle', 'right_foot']
SMPLH_LEFT_ARM = ['left_collar', 'left_shoulder', 'left_elbow', 'left_wrist']
SMPLH_RIGHT_ARM = ['right_collar', 'right_shoulder', 'right_elbow', 'right_wrist']
SMPLH_HEAD = ['neck', 'head']
SMPLH_SPINE = ['spine1', 'spine2', 'spine3']
# name to 21 index (without pelvis, hand, and extra)
_name_2_idx = {j: i for i, j in enumerate(SMPLH_JOINT_NAMES[1:22])}
SMPLH_PART_IDX = {
'left_leg': [_name_2_idx[x] for x in SMPLH_LEFT_LEG],
'right_leg': [_name_2_idx[x] for x in SMPLH_RIGHT_LEG],
'left_arm': [_name_2_idx[x] for x in SMPLH_LEFT_ARM],
'right_arm': [_name_2_idx[x] for x in SMPLH_RIGHT_ARM],
'two_legs': [_name_2_idx[x] for x in SMPLH_LEFT_LEG + SMPLH_RIGHT_LEG],
'left_arm_and_leg': [_name_2_idx[x] for x in SMPLH_LEFT_ARM + SMPLH_LEFT_LEG],
'right_arm_and_leg': [_name_2_idx[x] for x in SMPLH_RIGHT_ARM + SMPLH_RIGHT_LEG],
}
# name to full index
_name_2_idx_full = {j: i for i, j in enumerate(SMPLH_JOINT_NAMES)}
SMPLH_PART_IDX_FULL = {
'lower_body': [_name_2_idx_full[x] for x in ['pelvis'] + SMPLH_LEFT_LEG + SMPLH_RIGHT_LEG]
}
# ===== ⬇️ Fitting optimizer ⬇️ ===== #
SMPL_JOINTS = {'hips': 0, 'leftUpLeg': 1, 'rightUpLeg': 2, 'spine': 3, 'leftLeg': 4, 'rightLeg': 5,
'spine1': 6, 'leftFoot': 7, 'rightFoot': 8, 'spine2': 9, 'leftToeBase': 10, 'rightToeBase': 11,
'neck': 12, 'leftShoulder': 13, 'rightShoulder': 14, 'head': 15, 'leftArm': 16, 'rightArm': 17,
'leftForeArm': 18, 'rightForeArm': 19, 'leftHand': 20, 'rightHand': 21}
# chosen virtual mocap markers that are "keypoints" to work with
KEYPT_VERTS = [4404, 920, 3076, 3169, 823, 4310, 1010, 1085, 4495, 4569, 6615, 3217, 3313, 6713,
6785, 3383, 6607, 3207, 1241, 1508, 4797, 4122, 1618, 1569, 5135, 5040, 5691, 5636,
5404, 2230, 2173, 2108, 134, 3645, 6543, 3123, 3024, 4194, 1306, 182, 3694, 4294, 744]
# From https://github.com/vchoutas/smplify-x/blob/master/smplifyx/utils.py
# Please see license for usage restrictions.
def smpl_to_openpose(model_type='smplx', use_hands=True, use_face=True,
use_face_contour=False, openpose_format='coco25'):
''' Returns the indices of the permutation that maps SMPL to OpenPose
Parameters
----------
model_type: str, optional
The type of SMPL-like model that is used. The default mapping
returned is for the SMPLX model
use_hands: bool, optional
Flag for adding to the returned permutation the mapping for the
hand keypoints. Defaults to True
use_face: bool, optional
Flag for adding to the returned permutation the mapping for the
face keypoints. Defaults to True
use_face_contour: bool, optional
Flag for appending the facial contour keypoints. Defaults to False
openpose_format: bool, optional
The output format of OpenPose. For now only COCO-25 and COCO-19 is
supported. Defaults to 'coco25'
'''
if openpose_format.lower() == 'coco25':
if model_type == 'smpl':
return np.array([24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4,
7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
dtype=np.int32)
elif model_type == 'smplh':
body_mapping = np.array([52, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5,
8, 1, 4, 7, 53, 54, 55, 56, 57, 58, 59,
60, 61, 62], dtype=np.int32)
mapping = [body_mapping]
if use_hands:
lhand_mapping = np.array([20, 34, 35, 36, 63, 22, 23, 24, 64,
25, 26, 27, 65, 31, 32, 33, 66, 28,
29, 30, 67], dtype=np.int32)
rhand_mapping = np.array([21, 49, 50, 51, 68, 37, 38, 39, 69,
40, 41, 42, 70, 46, 47, 48, 71, 43,
44, 45, 72], dtype=np.int32)
mapping += [lhand_mapping, rhand_mapping]
return np.concatenate(mapping)
# SMPLX
elif model_type == 'smplx':
body_mapping = np.array([55, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5,
8, 1, 4, 7, 56, 57, 58, 59, 60, 61, 62,
63, 64, 65], dtype=np.int32)
mapping = [body_mapping]
if use_hands:
lhand_mapping = np.array([20, 37, 38, 39, 66, 25, 26, 27,
67, 28, 29, 30, 68, 34, 35, 36, 69,
31, 32, 33, 70], dtype=np.int32)
rhand_mapping = np.array([21, 52, 53, 54, 71, 40, 41, 42, 72,
43, 44, 45, 73, 49, 50, 51, 74, 46,
47, 48, 75], dtype=np.int32)
mapping += [lhand_mapping, rhand_mapping]
if use_face:
# end_idx = 127 + 17 * use_face_contour
face_mapping = np.arange(76, 127 + 17 * use_face_contour,
dtype=np.int32)
mapping += [face_mapping]
return np.concatenate(mapping)
else:
raise ValueError('Unknown model type: {}'.format(model_type))
elif openpose_format == 'coco19':
if model_type == 'smpl':
return np.array([24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8,
1, 4, 7, 25, 26, 27, 28],
dtype=np.int32)
elif model_type == 'smplh':
body_mapping = np.array([52, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5,
8, 1, 4, 7, 53, 54, 55, 56],
dtype=np.int32)
mapping = [body_mapping]
if use_hands:
lhand_mapping = np.array([20, 34, 35, 36, 57, 22, 23, 24, 58,
25, 26, 27, 59, 31, 32, 33, 60, 28,
29, 30, 61], dtype=np.int32)
rhand_mapping = np.array([21, 49, 50, 51, 62, 37, 38, 39, 63,
40, 41, 42, 64, 46, 47, 48, 65, 43,
44, 45, 66], dtype=np.int32)
mapping += [lhand_mapping, rhand_mapping]
return np.concatenate(mapping)
# SMPLX
elif model_type == 'smplx':
body_mapping = np.array([55, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5,
8, 1, 4, 7, 56, 57, 58, 59],
dtype=np.int32)
mapping = [body_mapping]
if use_hands:
lhand_mapping = np.array([20, 37, 38, 39, 60, 25, 26, 27,
61, 28, 29, 30, 62, 34, 35, 36, 63,
31, 32, 33, 64], dtype=np.int32)
rhand_mapping = np.array([21, 52, 53, 54, 65, 40, 41, 42, 66,
43, 44, 45, 67, 49, 50, 51, 68, 46,
47, 48, 69], dtype=np.int32)
mapping += [lhand_mapping, rhand_mapping]
if use_face:
face_mapping = np.arange(70, 70 + 51 +
17 * use_face_contour,
dtype=np.int32)
mapping += [face_mapping]
return np.concatenate(mapping)
else:
raise ValueError('Unknown model type: {}'.format(model_type))
else:
raise ValueError('Unknown joint format: {}'.format(openpose_format))
================================================
FILE: eval/GVHMR/hmr4d/utils/callbacks/lr_monitor.py
================================================
from pytorch_lightning.callbacks import LearningRateMonitor
from hmr4d.configs import builds, MainStore
MainStore.store(name="pl", node=builds(LearningRateMonitor), group="callbacks/lr_monitor")
================================================
FILE: eval/GVHMR/hmr4d/utils/callbacks/prog_bar.py
================================================
from collections import OrderedDict
from numbers import Number
from datetime import datetime, timedelta
from typing import Any, Dict, Union
from pytorch_lightning.utilities.types import STEP_OUTPUT
import torch
from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar, Tqdm, convert_inf
from pytorch_lightning.callbacks.progress import ProgressBar
from pytorch_lightning.utilities import rank_zero_only
import pytorch_lightning as pl
from hmr4d.utils.pylogger import Log
from time import time
from collections import deque
import sys
from hmr4d.configs import MainStore, builds
# ========== Helper functions ========== #
def format_num(n):
f = "{0:.3g}".format(n).replace("+0", "+").replace("-0", "-")
n = str(n)
return f if len(f) < len(n) else n
def convert_kwargs_to_str(**kwargs):
# Sort in alphabetical order to be more deterministic
postfix = OrderedDict([])
for key in sorted(kwargs.keys()):
new_key = key.split("/")[-1]
postfix[new_key] = kwargs[key]
# Preprocess stats according to datatype
for key in postfix.keys():
# Number: limit the length of the string
if isinstance(postfix[key], Number):
postfix[key] = format_num(postfix[key])
# Else for any other type, try to get the string conversion
elif not isinstance(postfix[key], str):
postfix[key] = str(postfix[key])
# Else if it's a string, don't need to preprocess anything
# Stitch together to get the final postfix
postfix = ", ".join(key + "=" + postfix[key].strip() for key in postfix.keys())
return postfix
def convert_t_to_str(t):
"""Convert time in second to string in format hour:minute:second.
If hour is 0, don't show it. Always show minute and second.
"""
t_str = timedelta(seconds=t) # e.g. 0:00:00.704186
t_str = str(t_str).split(".")[0] # e.g. 0:00:00
if t_str[:2] == "0:":
t_str = t_str[2:]
return t_str
class MyTQDMProgressBar(TQDMProgressBar, pl.Callback):
def init_train_tqdm(self):
bar = Tqdm(
desc="Training", # this will be overwritten anyway
bar_format="{desc}{percentage:3.0f}%[{bar:10}][{n_fmt}/{total_fmt}, {elapsed}→{remaining},{rate_fmt}]{postfix}",
position=(2 * self.process_position),
disable=self.is_disabled,
leave=False,
smoothing=0,
dynamic_ncols=False,
)
return bar
@rank_zero_only
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
# this function also updates the main progress bar
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
# in this function, we only set the postfix of the main progress bar
n = batch_idx + 1
if self._should_update(n, self.train_progress_bar.total):
# Set post-fix string
# 1. maximum GPU usage
max_mem = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 / 1024.0
post_fix_str = f"maxGPU={max_mem:.1f}GB"
# 2. training metrics
training_metrics = self.get_metrics(trainer, pl_module)
training_metrics.pop("v_num", None)
post_fix_str += ", " + convert_kwargs_to_str(**training_metrics)
# extra message if applicable
if "message" in outputs:
post_fix_str += ", " + outputs["message"]
self.train_progress_bar.set_postfix_str(post_fix_str)
class ProgressReporter(ProgressBar, pl.Callback):
def __init__(
self,
log_every_percent: float = 0.1, # report interval
exp_name=None, # if None, use pl_module.exp_name or "Unnamed Experiment"
data_name=None, # if None, use pl_module.exp_name or "Unknown Data"
**kwargs,
):
super().__init__()
self.enable = True
# 1. Store experiment meta data.
self.log_every_percent = log_every_percent
self.exp_name = exp_name
self.data_name = data_name
self.batch_time_queue = deque(maxlen=5)
self.start_prompt = "🚀"
self.finish_prompt = "✅"
# 2. Utils for evaluation
self.n_finished = 0
def disable(self):
self.enable = False
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None:
# Connect to the trainer object.
super().setup(trainer, pl_module, stage)
self.stage = stage
self.time_exp_start = time()
self.epoch_exp_start = trainer.current_epoch
if self.exp_name is None:
if hasattr(pl_module, "exp_name"):
self.exp_name = pl_module.exp_name
else:
self.exp_name = "Unnamed Experiment"
if self.data_name is None:
if hasattr(pl_module, "data_name"):
self.data_name = pl_module.data_name
else:
self.data_name = "Unknown Data"
def print(self, *args: Any, **kwargs: Any) -> None:
print(*args)
def get_metrics(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> Dict[str, Union[str, float]]:
"""Get metrics from trainer for progress bar."""
items = super().get_metrics(trainer, pl_module)
items.pop("v_num", None)
return items
def _should_update(self, n_finished: int, total: int) -> bool:
"""
Rule: Log every `log_every_percent` percent, or the last batch.
"""
log_interval = max(int(total * self.log_every_percent), 1)
able = n_finished % log_interval == 0 or n_finished == total
if log_interval > 10:
able = able or n_finished in [5, 10] # always log
able = able and self.enable
return able
@rank_zero_only
def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
self.print("=" * 80)
Log.info(
f"{self.start_prompt}[FIT][Epoch {trainer.current_epoch}] Data: {self.data_name} Experiment: {self.exp_name}"
)
self.time_train_epoch_start = time()
@rank_zero_only
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) # don't forget this :)
total = self.total_train_batches
# Speed
n_finished = batch_idx + 1
percent = 100 * n_finished / total
time_current = time()
self.batch_time_queue.append(time_current)
time_elapsed = time_current - self.time_train_epoch_start # second
time_remaining = time_elapsed * (total - n_finished) / n_finished # second
if len(self.batch_time_queue) == 1: # cannot compute speed
speed = 1 / time_elapsed
else:
speed = (len(self.batch_time_queue) - 1) / (self.batch_time_queue[-1] - self.batch_time_queue[0])
# Skip if not update
if not self._should_update(n_finished, total):
return
# ===== Set Prefix string ===== #
# General
desc = f"[Train]"
# Speed: Get elapsed time and estimated remaining time
time_elapsed_str = convert_t_to_str(time_elapsed)
time_remaining_str = convert_t_to_str(time_remaining)
speed_str = f"{speed:.2f}it/s" if speed > 1 else f"{1/speed:.1f}s/it"
n_digit = len(str(total))
desc_speed = (
f"[{n_finished:{n_digit}d}/{total}={percent:3.0f}%, {time_elapsed_str} → {time_remaining_str}, {speed_str}]"
)
# ===== Set postfix string ===== #
# 1. maximum GPU usage
max_mem = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 / 1024.0
post_fix_str = f"maxGPU={max_mem:.1f}GB"
# 2. training step metrics
train_metrics = self.get_metrics(trainer, pl_module)
train_metrics = {k: v for k, v in train_metrics.items() if ("train" in k and "epoch" not in k)}
post_fix_str += ", " + convert_kwargs_to_str(**train_metrics)
# extra message if applicable
if "message" in outputs:
post_fix_str += ", " + outputs["message"]
post_fix_str = f"[{post_fix_str}]"
# ===== Output ===== #
bar_output = f"{desc}{desc_speed}{post_fix_str}"
self.print(bar_output)
@rank_zero_only
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
super().on_train_epoch_end(trainer, pl_module)
# Clear
self.batch_time_queue.clear()
# Estimate Epoch time
n_finished = trainer.current_epoch + 1 - self.epoch_exp_start
n_to_finish = trainer.max_epochs - trainer.current_epoch - 1
time_current = time()
time_elapsed = time_current - self.time_exp_start
time_remaining = time_elapsed * n_to_finish / n_finished
time_elapsed_str = convert_t_to_str(time_elapsed)
time_remaining_str = convert_t_to_str(time_remaining)
# Metrics
# training epoch metrics
train_metrics = self.get_metrics(trainer, pl_module)
train_metrics = {k: v for k, v in train_metrics.items() if ("train" in k and "epoch" in k)}
train_metrics_str = convert_kwargs_to_str(**train_metrics)
Log.info(
f"{self.finish_prompt}[FIT][Epoch {trainer.current_epoch}] finished! {time_elapsed_str}→{time_remaining_str} | {train_metrics_str}"
)
# ===== Validation/Test/Prediction ===== #
@rank_zero_only
def on_validation_epoch_start(self, trainer, pl_module):
self.time_val_epoch_start = time()
@rank_zero_only
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
self.n_finished += 1
n_finished = self.n_finished
total = self.total_val_batches
if not self._should_update(n_finished, total):
return
# General
desc = f"[Val]"
# Speed
percent = 100 * n_finished / total
time_current = time()
time_elapsed = time_current - self.time_val_epoch_start # second
time_remaining = time_elapsed * (total - n_finished) / n_finished # second
time_elapsed_str = convert_t_to_str(time_elapsed)
time_remaining_str = convert_t_to_str(time_remaining)
desc_speed = f"[{n_finished}/{total} ={percent:3.0f}%, {time_elapsed_str}→{time_remaining_str}]"
# Output
bar_output = f"{desc} {desc_speed}"
self.print(bar_output)
def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
# Reset
self.n_finished = 0
class EmojiProgressReporter(ProgressBar, pl.Callback):
def __init__(
self,
refresh_rate_batch: Union[int, None] = 1, # report interval of batch, set None to disable it
refresh_rate_epoch: int = 1, # report interval of epoch
**kwargs,
):
super().__init__()
self.enable = True
# Store experiment meta data.
self.refresh_rate_batch = refresh_rate_batch
self.refresh_rate_epoch = refresh_rate_epoch
# Style of the progress bar.
self.title_prompt = "📝"
self.prog_prompt = "🚀"
self.timer_prompt = "⌛️"
self.metric_prompt = "📌"
self.finish_prompt = "✅"
def disable(self):
self.enable = False
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str):
# Connect to the trainer object.
super().setup(trainer, pl_module, stage)
self.stage = stage
self.time_start_batch = None
self.time_start_epoch = None
if hasattr(pl_module, "exp_name"):
self.exp_name = pl_module.exp_name
else:
self.exp_name = "Unnamed Experiment"
Log.warn("Experiment name not found, please set it to `pl_module.exp_name`!")
def print(self, *args: Any, **kwargs: Any):
print(*args)
def get_metrics(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> Dict[str, Union[str, float]]:
"""Get metrics from trainer for progress bar."""
items = super().get_metrics(trainer, pl_module)
items.pop("v_num", None)
return dict(sorted(items.items()))
def _should_log_batch(self, n: int) -> bool:
# Disable batch log.
if self.refresh_rate_batch is None:
return False
# Log at the first & last batch, and every `self.refresh_rate_batch` batches.
able = n % self.refresh_rate_batch == 0 or n == self.total_train_batches - 1
able = able and self.enable
return able
def _should_log_epoch(self, n: int) -> bool:
# Log at the first & last epoch, and every `self.refresh_rate_epoch` epochs.
able = n % self.refresh_rate_epoch == 0 or n == self.trainer.max_epochs - 1
able = able and self.enable
return able
def timestamp_delta_to_str(self, timestamp_delta: float):
"""Convert delta timestamp to string."""
time_rest = timedelta(seconds=timestamp_delta)
hours, remainder = divmod(time_rest.seconds, 3600)
minutes, seconds = divmod(remainder, 60)
time_str = ""
# Check if the time is valid. Note that, if `hours` is visible, then `minutes` must be visible.
if hours <= 0:
hours = None
if minutes <= 0:
minutes = None
if seconds <= 0:
seconds = None
time_str += f"{hours}h " if hours is not None else ""
time_str += f"{minutes}m " if minutes is not None else ""
time_str += f"{seconds}s" if seconds is not None else ""
return time_str
@rank_zero_only
def on_train_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int):
super().on_train_batch_start(trainer, pl_module, batch, batch_idx)
# Initialize some meta data.
if self.time_start_batch is None:
self.time_start_batch = datetime.now().timestamp()
@rank_zero_only
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) # don't forget this :)
# Get some meta data.
epoch_idx = trainer.current_epoch
percent = 100 * (batch_idx + 1) / (self.total_train_batches + 1)
metrics = self.get_metrics(trainer, pl_module)
# Current time.
time_cur_stamp = datetime.now().timestamp()
time_cur_str = datetime.fromtimestamp(time_cur_stamp).strftime("%m-%d %H:%M:%S")
# Rest time.
time_rest_stamp = (time_cur_stamp - self.time_start_batch) * (100 - percent) / percent
time_rest_str = self.timestamp_delta_to_str(time_rest_stamp)
if not self._should_log_batch(batch_idx):
return
# Print the logs.
self.print(f"{self.title_prompt} [{self.stage.upper()}] Exp: {self.exp_name}...")
self.print(
f"{self.prog_prompt} Ep {epoch_idx}: {int(percent):02d}% <= [{batch_idx}/{self.total_train_batches}]"
)
self.print(f"{self.timer_prompt} Time: {time_cur_str} | Ep Rest: {time_rest_str}")
for k, v in metrics.items():
self.print(f"{self.metric_prompt} {k}: {v}")
self.print("") # Add a blank line.
def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
super().on_train_epoch_start(trainer, pl_module)
# Initialize some meta data.
self.time_start_batch = None
if self.time_start_epoch is None:
self.time_start_epoch = datetime.now().timestamp()
@rank_zero_only
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
super().on_train_epoch_end(trainer, pl_module)
# Get some meta data.
epoch_idx = trainer.current_epoch
percent = 100 * (epoch_idx + 1) / (self.trainer.max_epochs + 1)
metrics = self.get_metrics(trainer, pl_module)
# Current time.
time_cur = datetime.now().timestamp()
time_str = datetime.fromtimestamp(time_cur).strftime("%m-%d %H: %M:%S")
# Rest time.
time_rest_stamp = (time_cur - self.time_start_epoch) * (100 - percent) / percent
time_rest_str = self.timestamp_delta_to_str(time_rest_stamp)
if not self._should_log_batch(epoch_idx):
return
# Print the logs.
self.print(f">> >> >> >>")
self.print(f"{self.title_prompt} [{self.stage.upper()}] Exp: {self.exp_name}")
self.print(f"{self.finish_prompt} Ep {epoch_idx} finished!")
self.print(f"{self.timer_prompt} Time: {time_str} | Rest: {time_rest_str}")
for k, v in metrics.items():
self.print(f"{self.metric_prompt} {k}: {v}")
self.print(f"<< << << <<")
self.print("") # Add a blank line.
group_name = "callbacks/prog_bar"
prog_reporter_base = builds(
ProgressReporter,
log_every_percent=0.1,
exp_name="${exp_name}",
data_name="${data_name}",
populate_full_signature=True,
)
MainStore.store(name="prog_reporter_every0.1", node=prog_reporter_base, group=group_name)
MainStore.store(name="prog_reporter_every0.2", node=prog_reporter_base(log_every_percent=0.2), group=group_name)
================================================
FILE: eval/GVHMR/hmr4d/utils/callbacks/simple_ckpt_saver.py
================================================
from pathlib import Path
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks.checkpoint import Checkpoint
from pytorch_lightning.utilities import rank_zero_only
from hmr4d.utils.pylogger import Log
from hmr4d.configs import MainStore, builds
class SimpleCkptSaver(Checkpoint):
"""
This callback runs at the end of each training epoch.
Check {every_n_epochs} and save at most {save_top_k} model if it is time.
"""
def __init__(
self,
output_dir,
filename="e{epoch:03d}-s{step:06d}.ckpt",
save_top_k=1,
every_n_epochs=1,
save_last=None,
save_weights_only=True,
):
super().__init__()
self.output_dir = Path(output_dir)
self.filename = filename
self.save_top_k = save_top_k
self.every_n_epochs = every_n_epochs
self.save_last = save_last
self.save_weights_only = save_weights_only
# Setup output dir
if rank_zero_only.rank == 0:
self.output_dir.mkdir(parents=True, exist_ok=True)
Log.info(f"[Simple Ckpt Saver]: Save to `{self.output_dir}'")
@rank_zero_only
def on_train_epoch_end(self, trainer, pl_module):
"""Save a checkpoint at the end of the training epoch."""
if self.every_n_epochs >= 1 and (trainer.current_epoch + 1) % self.every_n_epochs == 0:
if self.save_top_k == 0:
return
# Current saved ckpts in the output_dir
model_paths = []
for p in sorted(list(self.output_dir.glob("*.ckpt"))):
model_paths.append(p)
model_to_remove = model_paths[0] if len(model_paths) >= self.save_top_k else None
# Save cureent checkpoint
filepath = self.output_dir / self.filename.format(epoch=trainer.current_epoch, step=trainer.global_step)
checkpoint = {
"epoch": trainer.current_epoch,
"global_step": trainer.global_step,
"pytorch-lightning_version": pl.__version__,
"state_dict": pl_module.state_dict(),
}
pl_module.on_save_checkpoint(checkpoint)
if not self.save_weights_only:
# optimizer
optimizer_states = []
for i, optimizer in enumerate(trainer.optimizers):
# Rely on accelerator to dump optimizer state
optimizer_state = trainer.strategy.optimizer_state(optimizer)
optimizer_states.append(optimizer_state)
checkpoint["optimizer_states"] = optimizer_states
# lr_scheduler
lr_schedulers = []
for config in trainer.lr_scheduler_configs:
lr_schedulers.append(config.scheduler.state_dict())
checkpoint["lr_schedulers"] = lr_schedulers
# trainer.strategy.checkpoint_io.save_checkpoint(checkpoint, filepath)
torch.save(checkpoint, filepath)
# Remove the earliest checkpoint
if model_to_remove:
trainer.strategy.remove_checkpoint(model_paths[0])
group_name = "callbacks/simple_ckpt_saver"
base = builds(SimpleCkptSaver, output_dir="${output_dir}/checkpoints/", populate_full_signature=True)
MainStore.store(name="base", node=base, group=group_name)
MainStore.store(name="every1e", node=base, group=group_name)
MainStore.store(name="every2e", node=base(every_n_epochs=2), group=group_name)
MainStore.store(name="every5e", node=base(every_n_epochs=5), group=group_name)
MainStore.store(name="every5e_top100", node=base(every_n_epochs=5, save_top_k=100), group=group_name)
MainStore.store(name="every10e", node=base(every_n_epochs=10), group=group_name)
MainStore.store(name="every10e_top100", node=base(every_n_epochs=10, save_top_k=100), group=group_name)
================================================
FILE: eval/GVHMR/hmr4d/utils/callbacks/train_speed_timer.py
================================================
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
from time import time
from collections import deque
from hmr4d.configs import MainStore, builds
class TrainSpeedTimer(pl.Callback):
def __init__(self, N_avg=5):
"""
This callback times the training speed (averge over recent 5 iterations)
1. Data waiting time: this should be small, otherwise the data loading should be improved
2. Single batch time: this is the time for one batch of training (excluding data waiting)
"""
super().__init__()
self.last_batch_end = None
self.this_batch_start = None
# time queues for averaging
self.data_waiting_time_queue = deque(maxlen=N_avg)
self.single_batch_time_queue = deque(maxlen=N_avg)
@rank_zero_only
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
"""Count the time of data waiting"""
if self.last_batch_end is not None:
# This should be small, otherwise the data loading should be improved
data_waiting = time() - self.last_batch_end
# Average the time
self.data_waiting_time_queue.append(data_waiting)
average_time = sum(self.data_waiting_time_queue) / len(self.data_waiting_time_queue)
# Log to prog-bar
pl_module.log(
"train_timer/data_waiting", average_time, on_step=True, on_epoch=False, prog_bar=True, logger=True
)
self.this_batch_start = time()
@rank_zero_only
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
# Effective training time elapsed (excluding data waiting)
single_batch = time() - self.this_batch_start
# Average the time
self.single_batch_time_queue.append(single_batch)
average_time = sum(self.single_batch_time_queue) / len(self.single_batch_time_queue)
# Log iter time
pl_module.log(
"train_timer/single_batch", average_time, on_step=True, on_epoch=False, prog_bar=False, logger=True
)
# Set timer for counting data waiting
self.last_batch_end = time()
@rank_zero_only
def on_train_epoch_end(self, trainer, pl_module):
# Reset the timer
self.last_batch_end = None
self.this_batch_start = None
# Clear the queue
self.data_waiting_time_queue.clear()
self.single_batch_time_queue.clear()
group_name = "callbacks/train_speed_timer"
base = builds(TrainSpeedTimer, populate_full_signature=True)
MainStore.store(name="base", node=base, group=group_name)
================================================
FILE: eval/GVHMR/hmr4d/utils/comm/gather.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
[Copied from detectron2]
This file contains primitives for multi-gpu communication.
This is useful when doing distributed training.
"""
import functools
import logging
import numpy as np
import pickle
import torch
import torch.distributed as dist
_LOCAL_PROCESS_GROUP = None
"""
A torch process group which only includes processes that on the same machine as the current process.
This variable is set when processes are spawned by `launch()` in "engine/launch.py".
"""
def get_world_size() -> int:
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def get_rank() -> int:
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def get_local_rank() -> int:
"""
Returns:
The rank of the current process within the local (per-machine) process group.
"""
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
assert _LOCAL_PROCESS_GROUP is not None
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
def get_local_size() -> int:
"""
Returns:
The size of the per-machine process group,
i.e. the number of processes per machine.
"""
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
def is_main_process() -> bool:
return get_rank() == 0
def synchronize():
"""
Helper function to synchronize (barrier) among all processes when
using distributed training
"""
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
@functools.lru_cache()
def _get_global_gloo_group():
"""
Return a process group based on gloo backend, containing all the ranks
The result is cached.
"""
if dist.get_backend() == "nccl":
return dist.new_group(backend="gloo")
else:
return dist.group.WORLD
def _serialize_to_tensor(data, group):
backend = dist.get_backend(group)
assert backend in ["gloo", "nccl"]
device = torch.device("cpu" if backend == "gloo" else "cuda")
buffer = pickle.dumps(data)
if len(buffer) > 1024**3:
logger = logging.getLogger(__name__)
logger.warning(
"Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
get_rank(), len(buffer) / (1024**3), device
)
)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to(device=device)
return tensor
def _pad_to_largest_tensor(tensor, group):
"""
Returns:
list[int]: size of the tensor, on each rank
Tensor: padded tensor that has the max size
"""
world_size = dist.get_world_size(group=group)
assert world_size >= 1, "comm.gather/all_gather must be called from ranks within the given group!"
local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
size_list = [torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)]
dist.all_gather(size_list, local_size, group=group)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
if local_size != max_size:
padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
tensor = torch.cat((tensor, padding), dim=0)
return size_list, tensor
def all_gather(data, group=None):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors).
Args:
data: any picklable object
group: a torch process group. By default, will use a group which
contains all ranks on gloo backend.
Returns:
list[data]: list of data gathered from each rank
"""
if get_world_size() == 1:
return [data]
if group is None:
group = _get_global_gloo_group()
if dist.get_world_size(group) == 1:
return [data]
tensor = _serialize_to_tensor(data, group)
size_list, tensor = _pad_to_largest_tensor(tensor, group)
max_size = max(size_list)
# receiving Tensor from all ranks
tensor_list = [torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list]
dist.all_gather(tensor_list, tensor, group=group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def gather(data, dst=0, group=None):
"""
Run gather on arbitrary picklable data (not necessarily tensors).
Args:
data: any picklable object
dst (int): destination rank
group: a torch process group. By default, will use a group which
contains all ranks on gloo backend.
Returns:
list[data]: on dst, a list of data gathered from each rank. Otherwise,
an empty list.
"""
if get_world_size() == 1:
return [data]
if group is None:
group = _get_global_gloo_group()
if dist.get_world_size(group=group) == 1:
return [data]
rank = dist.get_rank(group=group)
tensor = _serialize_to_tensor(data, group)
size_list, tensor = _pad_to_largest_tensor(tensor, group)
# receiving Tensor from all ranks
if rank == dst:
max_size = max(size_list)
tensor_list = [torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list]
dist.gather(tensor, tensor_list, dst=dst, group=group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
else:
dist.gather(tensor, [], dst=dst, group=group)
return []
def shared_random_seed():
"""
Returns:
int: a random number that is the same across all workers.
If workers need a shared RNG, they can use this shared seed to
create one.
All workers must call this function, otherwise it will deadlock.
"""
ints = np.random.randint(2**31)
all_ints = all_gather(ints)
return all_ints[0]
def reduce_dict(input_dict, average=True):
"""
Reduce the values in the dictionary from all processes so that process with rank
0 has the reduced results.
Args:
input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
average (bool): whether to do average or sum
Returns:
a dict with the same keys as input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.reduce(values, dst=0)
if dist.get_rank() == 0 and average:
# only main process gets accumulated, so only divide by
# world_size in this case
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict
================================================
FILE: eval/GVHMR/hmr4d/utils/eval/eval_utils.py
================================================
import torch
import numpy as np
@torch.no_grad()
def compute_camcoord_metrics(batch, pelvis_idxs=[1, 2], fps=30, mask=None):
"""
Args:
batch (dict): {
"pred_j3d": (..., J, 3) tensor
"target_j3d":
"pred_verts":
"target_verts":
}
Returns:
cam_coord_metrics (dict): {
"pa_mpjpe": (..., ) numpy array
"mpjpe":
"pve":
"accel":
}
"""
# All data is in camera coordinates
pred_j3d = batch["pred_j3d"].cpu() # (..., J, 3)
target_j3d = batch["target_j3d"].cpu()
pred_verts = batch["pred_verts"].cpu()
target_verts = batch["target_verts"].cpu()
if mask is not None:
mask = mask.cpu()
pred_j3d = pred_j3d[mask].clone()
target_j3d = target_j3d[mask].clone()
pred_verts = pred_verts[mask].clone()
target_verts = target_verts[mask].clone()
assert "mask" not in batch
# Align by pelvis
pred_j3d, target_j3d, pred_verts, target_verts = batch_align_by_pelvis(
[pred_j3d, target_j3d, pred_verts, target_verts], pelvis_idxs=pelvis_idxs
)
# Metrics
m2mm = 1000
S1_hat = batch_compute_similarity_transform_torch(pred_j3d, target_j3d)
pa_mpjpe = compute_jpe(S1_hat, target_j3d) * m2mm
mpjpe = compute_jpe(pred_j3d, target_j3d) * m2mm
pve = compute_jpe(pred_verts, target_verts) * m2mm
accel = compute_error_accel(joints_pred=pred_j3d, joints_gt=target_j3d, fps=fps)
camcoord_metrics = {
"pa_mpjpe": pa_mpjpe,
"mpjpe": mpjpe,
"pve": pve,
"accel": accel,
}
return camcoord_metrics
@torch.no_grad()
def compute_global_metrics(batch, mask=None):
"""Follow WHAM, the input has skipped invalid frames
Args:
batch (dict): {
"pred_j3d_glob": (F, J, 3) tensor
"target_j3d_glob":
"pred_verts_glob":
"target_verts_glob":
}
Returns:
global_metrics (dict): {
"wa2_mpjpe": (F, ) numpy array
"waa_mpjpe":
"rte":
"jitter":
"fs":
}
"""
# All data is in global coordinates
pred_j3d_glob = batch["pred_j3d_glob"].cpu() # (..., J, 3)
target_j3d_glob = batch["target_j3d_glob"].cpu()
pred_verts_glob = batch["pred_verts_glob"].cpu()
target_verts_glob = batch["target_verts_glob"].cpu()
if mask is not None:
mask = mask.cpu()
pred_j3d_glob = pred_j3d_glob[mask].clone()
target_j3d_glob = target_j3d_glob[mask].clone()
pred_verts_glob = pred_verts_glob[mask].clone()
target_verts_glob = target_verts_glob[mask].clone()
assert "mask" not in batch
seq_length = pred_j3d_glob.shape[0]
# Use chunk to compare
chunk_length = 100
wa2_mpjpe, waa_mpjpe = [], []
for start in range(0, seq_length, chunk_length):
end = min(seq_length, start + chunk_length)
target_j3d = target_j3d_glob[start:end].clone().cpu()
pred_j3d = pred_j3d_glob[start:end].clone().cpu()
w_j3d = first_align_joints(target_j3d, pred_j3d)
wa_j3d = global_align_joints(target_j3d, pred_j3d)
if False:
from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines
wis3d = make_wis3d(name="debug-metric_utils")
add_motion_as_lines(target_j3d, wis3d, name="target_j3d")
add_motion_as_lines(pred_j3d, wis3d, name="pred_j3d")
add_motion_as_lines(w_j3d, wis3d, name="pred_w2_j3d")
add_motion_as_lines(wa_j3d, wis3d, name="pred_wa_j3d")
wa2_mpjpe.append(compute_jpe(target_j3d, w_j3d))
waa_mpjpe.append(compute_jpe(target_j3d, wa_j3d))
# Metrics
m2mm = 1000
wa2_mpjpe = np.concatenate(wa2_mpjpe) * m2mm
waa_mpjpe = np.concatenate(waa_mpjpe) * m2mm
# Additional Metrics
rte = compute_rte(target_j3d_glob[:, 0].cpu(), pred_j3d_glob[:, 0].cpu()) * 1e2
jitter = compute_jitter(pred_j3d_glob, fps=30)
foot_sliding = compute_foot_sliding(target_verts_glob, pred_verts_glob) * m2mm
global_metrics = {
"wa2_mpjpe": wa2_mpjpe,
"waa_mpjpe": waa_mpjpe,
"rte": rte,
"jitter": jitter,
"fs": foot_sliding,
}
return global_metrics
@torch.no_grad()
def compute_camcoord_perjoint_metrics(batch, pelvis_idxs=[1, 2]):
"""
Args:
batch (dict): {
"pred_j3d": (..., J, 3) tensor
"target_j3d":
}
Returns:
cam_coord_metrics (dict): {
"pa_mpjpe": (..., ) numpy array
"mpjpe":
"pve":
"accel":
}
"""
# All data is in camera coordinates
pred_j3d = batch["pred_j3d"].cpu() # (..., J, 3)
target_j3d = batch["target_j3d"].cpu()
pred_verts = batch["pred_verts"].cpu()
target_verts = batch["target_verts"].cpu()
# Align by pelvis
pred_j3d, target_j3d, pred_verts, target_verts = batch_align_by_pelvis(
[pred_j3d, target_j3d, pred_verts, target_verts], pelvis_idxs=pelvis_idxs
)
# Metrics
m2mm = 1000
perjoint_mpjpe = compute_perjoint_jpe(pred_j3d, target_j3d) * m2mm
camcoord_perjoint_metrics = {
"mpjpe": perjoint_mpjpe,
}
return camcoord_perjoint_metrics
# ===== Utilities =====
def compute_jpe(S1, S2):
return torch.sqrt(((S1 - S2) ** 2).sum(dim=-1)).mean(dim=-1).numpy()
def compute_perjoint_jpe(S1, S2):
return torch.sqrt(((S1 - S2) ** 2).sum(dim=-1)).numpy()
def batch_align_by_pelvis(data_list, pelvis_idxs=[1, 2]):
"""
Assumes data is given as [pred_j3d, target_j3d, pred_verts, target_verts].
Each data is in shape of (frames, num_points, 3)
Pelvis is notated as one / two joints indices.
Align all data to the corresponding pelvis location.
"""
pred_j3d, target_j3d, pred_verts, target_verts = data_list
pred_pelvis = pred_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone()
target_pelvis = target_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone()
# Align to the pelvis
pred_j3d = pred_j3d - pred_pelvis
target_j3d = target_j3d - target_pelvis
pred_verts = pred_verts - pred_pelvis
target_verts = target_verts - target_pelvis
return (pred_j3d, target_j3d, pred_verts, target_verts)
def batch_compute_similarity_transform_torch(S1, S2):
"""
Computes a similarity transform (sR, t) that takes
a set of 3D points S1 (3 x N) closest to a set of 3D points S2,
where R is an 3x3 rotation matrix, t 3x1 translation, s scale.
i.e. solves the orthogonal Procrutes problem.
"""
transposed = False
if S1.shape[0] != 3 and S1.shape[0] != 2:
S1 = S1.permute(0, 2, 1)
S2 = S2.permute(0, 2, 1)
transposed = True
assert S2.shape[1] == S1.shape[1]
# 1. Remove mean.
mu1 = S1.mean(axis=-1, keepdims=True)
mu2 = S2.mean(axis=-1, keepdims=True)
X1 = S1 - mu1
X2 = S2 - mu2
# 2. Compute variance of X1 used for scale.
var1 = torch.sum(X1**2, dim=1).sum(dim=1)
# 3. The outer product of X1 and X2.
K = X1.bmm(X2.permute(0, 2, 1))
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
# singular vectors of K.
U, s, V = torch.svd(K)
# Construct Z that fixes the orientation of R to get det(R)=1.
Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0)
Z = Z.repeat(U.shape[0], 1, 1)
Z[:, -1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0, 2, 1))))
# Construct R.
R = V.bmm(Z.bmm(U.permute(0, 2, 1)))
# 5. Recover scale.
scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1
# 6. Recover translation.
t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1)))
# 7. Error:
S1_hat = scale.unsqueeze(-1).unsqueeze(-1) * R.bmm(S1) + t
if transposed:
S1_hat = S1_hat.permute(0, 2, 1)
return S1_hat
def compute_error_accel(joints_gt, joints_pred, valid_mask=None, fps=None):
"""
Use [i-1, i, i+1] to compute acc at frame_i. The acceleration error:
1/(n-2) \sum_{i=1}^{n-1} X_{i-1} - 2X_i + X_{i+1}
Note that for each frame that is not visible, three entries(-1, 0, +1) in the
acceleration error will be zero'd out.
Args:
joints_gt : (F, J, 3)
joints_pred : (F, J, 3)
valid_mask : (F)
Returns:
error_accel (F-2) when valid_mask is None, else (F'), F' <= F-2
"""
# (F, J, 3) -> (F-2) per-joint
accel_gt = joints_gt[:-2] - 2 * joints_gt[1:-1] + joints_gt[2:]
accel_pred = joints_pred[:-2] - 2 * joints_pred[1:-1] + joints_pred[2:]
normed = np.linalg.norm(accel_pred - accel_gt, axis=-1).mean(axis=-1)
if fps is not None:
normed = normed * fps**2
if valid_mask is None:
new_vis = np.ones(len(normed), dtype=bool)
else:
invis = np.logical_not(valid_mask)
invis1 = np.roll(invis, -1)
invis2 = np.roll(invis, -2)
new_invis = np.logical_or(invis, np.logical_or(invis1, invis2))[:-2]
new_vis = np.logical_not(new_invis)
if new_vis.sum() == 0:
print("Warning!!! no valid acceleration error to compute.")
return normed[new_vis]
def compute_rte(target_trans, pred_trans):
# Compute the global alignment
_, rot, trans = align_pcl(target_trans[None, :], pred_trans[None, :], fixed_scale=True)
pred_trans_hat = (torch.einsum("tij,tnj->tni", rot, pred_trans[None, :]) + trans[None, :])[0]
# Compute the entire displacement of ground truth trajectory
disps, disp = [], 0
for p1, p2 in zip(target_trans, target_trans[1:]):
delta = (p2 - p1).norm(2, dim=-1)
disp += delta
disps.append(disp)
# Compute absolute root-translation-error (RTE)
rte = torch.norm(target_trans - pred_trans_hat, 2, dim=-1)
# Normalize it to the displacement
return (rte / disp).numpy()
def compute_jitter(joints, fps=30):
"""compute jitter of the motion
Args:
joints (N, J, 3).
fps (float).
Returns:
jitter (N-3).
"""
pred_jitter = torch.norm(
(joints[3:] - 3 * joints[2:-1] + 3 * joints[1:-2] - joints[:-3]) * (fps**3),
dim=2,
).mean(dim=-1)
return pred_jitter.cpu().numpy() / 10.0
def compute_foot_sliding(target_verts, pred_verts, thr=1e-2):
"""compute foot sliding error
The foot ground contact label is computed by the threshold of 1 cm/frame
Args:
target_verts (N, 6890, 3).
pred_verts (N, 6890, 3).
Returns:
error (N frames in contact).
"""
assert target_verts.shape == pred_verts.shape
assert target_verts.shape[-2] == 6890
# Foot vertices idxs
foot_idxs = [3216, 3387, 6617, 6787]
# Compute contact label
foot_loc = target_verts[:, foot_idxs]
foot_disp = (foot_loc[1:] - foot_loc[:-1]).norm(2, dim=-1)
contact = foot_disp[:] < thr
pred_feet_loc = pred_verts[:, foot_idxs]
pred_disp = (pred_feet_loc[1:] - pred_feet_loc[:-1]).norm(2, dim=-1)
error = pred_disp[contact]
return error.cpu().numpy()
def convert_joints22_to_24(joints22, ratio2220=0.3438, ratio2321=0.3345):
joints24 = torch.zeros(*joints22.shape[:-2], 24, 3).to(joints22.device)
joints24[..., :22, :] = joints22
joints24[..., 22, :] = joints22[..., 20, :] + ratio2220 * (joints22[..., 20, :] - joints22[..., 18, :])
joints24[..., 23, :] = joints22[..., 21, :] + ratio2321 * (joints22[..., 21, :] - joints22[..., 19, :])
return joints24
def align_pcl(Y, X, weight=None, fixed_scale=False):
"""align similarity transform to align X with Y using umeyama method
X' = s * R * X + t is aligned with Y
:param Y (*, N, 3) first trajectory
:param X (*, N, 3) second trajectory
:param weight (*, N, 1) optional weight of valid correspondences
:returns s (*, 1), R (*, 3, 3), t (*, 3)
"""
*dims, N, _ = Y.shape
N = torch.ones(*dims, 1, 1) * N
if weight is not None:
Y = Y * weight
X = X * weight
N = weight.sum(dim=-2, keepdim=True) # (*, 1, 1)
# subtract mean
my = Y.sum(dim=-2) / N[..., 0] # (*, 3)
mx = X.sum(dim=-2) / N[..., 0]
y0 = Y - my[..., None, :] # (*, N, 3)
x0 = X - mx[..., None, :]
if weight is not None:
y0 = y0 * weight
x0 = x0 * weight
# correlation
C = torch.matmul(y0.transpose(-1, -2), x0) / N # (*, 3, 3)
U, D, Vh = torch.linalg.svd(C) # (*, 3, 3), (*, 3), (*, 3, 3)
S = torch.eye(3).reshape(*(1,) * (len(dims)), 3, 3).repeat(*dims, 1, 1)
neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0
S[neg, 2, 2] = -1
R = torch.matmul(U, torch.matmul(S, Vh)) # (*, 3, 3)
D = torch.diag_embed(D) # (*, 3, 3)
if fixed_scale:
s = torch.ones(*dims, 1, device=Y.device, dtype=torch.float32)
else:
var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N # (*, 1, 1)
s = torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum(dim=-1, keepdim=True) / var[..., 0] # (*, 1)
t = my - s * torch.matmul(R, mx[..., None])[..., 0] # (*, 3)
return s, R, t
def global_align_joints(gt_joints, pred_joints):
"""
:param gt_joints (T, J, 3)
:param pred_joints (T, J, 3)
"""
s_glob, R_glob, t_glob = align_pcl(gt_joints.reshape(-1, 3), pred_joints.reshape(-1, 3))
pred_glob = s_glob * torch.einsum("ij,tnj->tni", R_glob, pred_joints) + t_glob[None, None]
return pred_glob
def first_align_joints(gt_joints, pred_joints):
"""
align the first two frames
:param gt_joints (T, J, 3)
:param pred_joints (T, J, 3)
"""
# (1, 1), (1, 3, 3), (1, 3)
s_first, R_first, t_first = align_pcl(gt_joints[:2].reshape(1, -1, 3), pred_joints[:2].reshape(1, -1, 3))
pred_first = s_first * torch.einsum("tij,tnj->tni", R_first, pred_joints) + t_first[:, None]
return pred_first
def rearrange_by_mask(x, mask):
"""
x (L, *)
mask (M,), M >= L
"""
M = mask.size(0)
L = x.size(0)
if M == L:
return x
assert M > L
assert mask.sum() == L
x_rearranged = torch.zeros((M, *x.size()[1:]), dtype=x.dtype, device=x.device)
x_rearranged[mask] = x
return x_rearranged
def as_np_array(d):
if isinstance(d, torch.Tensor):
return d.cpu().numpy()
elif isinstance(d, np.ndarray):
return d
else:
return np.array(d)
================================================
FILE: eval/GVHMR/hmr4d/utils/geo/augment_noisy_pose.py
================================================
import torch
from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle, matrix_to_rotation_6d
import hmr4d.utils.matrix as matrix
from hmr4d import PROJ_ROOT
COCO17_AUG = {k: v.flatten() for k, v in torch.load(PROJ_ROOT / "hmr4d/utils/body_model/coco_aug_dict.pth").items()}
COCO17_AUG_CUDA = {}
COCO17_TREE = [[5, 6], 0, 0, 1, 2, -1, -1, 5, 6, 7, 8, -1, -1, 11, 12, 13, 14, 15, 15, 15, 16, 16, 16]
def gaussian_augment(body_pose, std_angle=10.0, to_R=True):
"""
Args:
body_pose torch.Tensor: (..., J, 3) axis-angle if to_R is True, else rotmat (..., J, 3, 3)
std_angle: scalar or list, in degree
"""
body_pose = body_pose.clone()
if to_R:
body_pose_R = axis_angle_to_matrix(body_pose) # (B, L, J, 3, 3)
else:
body_pose_R = body_pose
shape = body_pose_R.shape[:-2]
device = body_pose.device
# 1. Simulate noise
# angle:
std_angle = torch.tensor(std_angle).to(device).reshape(-1) # allow scalar or list
noise_angle = torch.randn(shape, device=device) * std_angle * torch.pi / 180
# axis: avoid zero vector
noise_axis = torch.rand((*shape, 3), device=device)
mask_ = torch.norm(noise_axis, dim=-1) < 1e-6
noise_axis[mask_] = 1
noise_axis = noise_axis / torch.norm(noise_axis, dim=-1, keepdim=True)
noise_aa = noise_angle[..., None] * noise_axis # (B, L, J, 3)
noise_R = axis_angle_to_matrix(noise_aa) # (B, L, J, 3, 3)
# 2. Add noise to body pose
new_body_pose_R = matrix.get_mat_BfromA(body_pose_R, noise_R) # (B, L, J, 3, 3)
# new_body_pose_R = torch.matmul(noise_R, body_pose_R)
new_body_pose_r6d = matrix_to_rotation_6d(new_body_pose_R) # (B, L, J, 6)
new_body_pose_aa = matrix_to_axis_angle(new_body_pose_R) # (B, L, J, 3)
return new_body_pose_R, new_body_pose_r6d, new_body_pose_aa
# ========= Augment Joint 3D ======== #
def get_jitter(shape=(8, 120), s_jittering=5e-2):
"""Guassian jitter modeling."""
jittering_noise = (
torch.normal(
mean=torch.zeros((*shape, 17, 3)),
std=COCO17_AUG["jittering"].reshape(1, 1, 17, 1).expand(*shape, -1, 3),
)
* s_jittering
)
return jittering_noise
def get_jitter_cuda(shape=(8, 120), s_jittering=5e-2):
if "jittering" not in COCO17_AUG_CUDA:
COCO17_AUG_CUDA["jittering"] = COCO17_AUG["jittering"].cuda().reshape(1, 1, 17, 1)
jittering = COCO17_AUG_CUDA["jittering"]
jittering_noise = torch.randn((*shape, 17, 3), device="cuda") * jittering * s_jittering
return jittering_noise
def get_lfhp(shape=(8, 120), s_peak=3e-1, s_peak_mask=5e-3):
"""Low-frequency high-peak noise modeling."""
def get_peak_noise_mask():
peak_noise_mask = torch.rand(*shape, 17) * COCO17_AUG["pmask"]
peak_noise_mask = peak_noise_mask < s_peak_mask
return peak_noise_mask
peak_noise_mask = get_peak_noise_mask() # (B, L, 17)
peak_noise = peak_noise_mask.float().unsqueeze(-1).repeat(1, 1, 1, 3)
peak_noise = peak_noise * torch.randn(3) * COCO17_AUG["peak"].reshape(17, 1) * s_peak
return peak_noise
def get_lfhp_cuda(shape=(8, 120), s_peak=3e-1, s_peak_mask=5e-3):
if "peak" not in COCO17_AUG_CUDA:
COCO17_AUG_CUDA["pmask"] = COCO17_AUG["pmask"].cuda()
COCO17_AUG_CUDA["peak"] = COCO17_AUG["peak"].cuda().reshape(17, 1)
pmask = COCO17_AUG_CUDA["pmask"]
peak = COCO17_AUG_CUDA["peak"]
peak_noise_mask = torch.rand(*shape, 17, device="cuda") * pmask < s_peak_mask
peak_noise = (
peak_noise_mask.float().unsqueeze(-1).expand(-1, -1, -1, 3) * torch.randn(3, device="cuda") * peak * s_peak
)
return peak_noise
def get_bias(shape=(8, 120), s_bias=1e-1):
"""Bias noise modeling."""
b, l = shape
bias_noise = torch.normal(mean=torch.zeros((b, 17, 3)), std=COCO17_AUG["bias"].reshape(1, 17, 1)) * s_bias
bias_noise = bias_noise[:, None].expand(-1, l, -1, -1) # (B, L, J, 3), the whole sequence is moved by the same bias
return bias_noise
def get_bias_cuda(shape=(8, 120), s_bias=1e-1):
if "bias" not in COCO17_AUG_CUDA:
COCO17_AUG_CUDA["bias"] = COCO17_AUG["bias"].cuda().reshape(1, 17, 1)
bias = COCO17_AUG_CUDA["bias"]
bias_noise = torch.randn((shape[0], 17, 3), device="cuda") * bias * s_bias
bias_noise = bias_noise[:, None].expand(-1, shape[1], -1, -1)
return bias_noise
def get_wham_aug_kp3d(shape=(8, 120)):
# aug = get_bias(shape).cuda() + get_lfhp(shape).cuda() + get_jitter(shape).cuda()
aug = get_bias_cuda(shape) + get_lfhp_cuda(shape) + get_jitter_cuda(shape)
return aug
def get_visible_mask(shape=(8, 120), s_mask=0.03):
"""Mask modeling."""
# Per-frame and joint
mask = torch.rand(*shape, 17) < s_mask
visible = (~mask).clone() # (B, L, 17)
visible = visible.reshape(-1, 17) # (BL, 17)
for child in range(17):
parent = COCO17_TREE[child]
if parent == -1:
continue
if isinstance(parent, list):
visible[:, child] *= visible[:, parent[0]] * visible[:, parent[1]]
else:
visible[:, child] *= visible[:, parent]
visible = visible.reshape(*shape, 17).clone() # (B, L, J)
return visible
def get_invisible_legs_mask(shape, s_mask=0.03):
"""
Both legs are invisible for a random duration.
"""
B, L = shape
starts = torch.randint(0, L - 90, (B,))
ends = starts + torch.randint(30, 90, (B,))
mask_range = torch.arange(L).unsqueeze(0).expand(B, -1)
mask_to_apply = (mask_range >= starts.unsqueeze(1)) & (mask_range < ends.unsqueeze(1))
mask_to_apply = mask_to_apply.unsqueeze(2).expand(-1, -1, 17).clone()
mask_to_apply[:, :, :11] = False # only both legs are invisible
mask_to_apply = mask_to_apply & (torch.rand(B, 1, 1) < s_mask)
return mask_to_apply
def randomly_occlude_lower_half(i_x2d, s_mask=0.03):
"""
Randomly occlude the lower half of the image.
"""
raise NotImplementedError
B, L, N, _ = i_x2d.shape
i_x2d = i_x2d.clone()
# a period of time when the lower half of the image is invisible
starts = torch.randint(0, L - 90, (B,))
ends = starts + torch.randint(30, 90, (B,))
mask_range = torch.arange(L).unsqueeze(0).expand(B, -1)
mask_to_apply = (mask_range >= starts.unsqueeze(1)) & (mask_range < ends.unsqueeze(1))
mask_to_apply = mask_to_apply.unsqueeze(2).expand(-1, -1, N) # (B, L, N)
# only the lower half of the image is invisible
i_x2d
i_x2d[..., 1] / 2
mask_to_apply = mask_to_apply & (torch.rand(B, 1, 1) < s_mask)
return mask_to_apply
def randomly_modify_hands_legs(j3d):
hands = [9, 10]
legs = [15, 16]
B, L, J, _ = j3d.shape
p_switch_hand = 0.001
p_switch_leg = 0.001
p_wrong_hand0 = 0.001
p_wrong_hand1 = 0.001
p_wrong_leg0 = 0.001
p_wrong_leg1 = 0.001
mask = torch.rand(B, L) < p_switch_hand
j3d[mask][:, hands] = j3d[mask][:, hands[::-1]]
mask = torch.rand(B, L) < p_switch_leg
j3d[mask][:, legs] = j3d[mask][:, legs[::-1]]
mask = torch.rand(B, L) < p_wrong_hand0
j3d[mask][:, 9] = j3d[mask][:, 10]
mask = torch.rand(B, L) < p_wrong_hand1
j3d[mask][:, 10] = j3d[mask][:, 9]
mask = torch.rand(B, L) < p_wrong_leg0
j3d[mask][:, 15] = j3d[mask][:, 16]
mask = torch.rand(B, L) < p_wrong_leg1
j3d[mask][:, 16] = j3d[mask][:, 15]
return j3d
================================================
FILE: eval/GVHMR/hmr4d/utils/geo/flip_utils.py
================================================
import torch
from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle
def flip_heatmap_coco17(output_flipped):
assert output_flipped.ndim == 4, "output_flipped should be [B, J, H, W]"
shape_ori = output_flipped.shape
channels = 1
output_flipped = output_flipped.reshape(shape_ori[0], -1, channels, shape_ori[2], shape_ori[3])
output_flipped_back = output_flipped.clone()
# Swap left-right parts
for left, right in [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]:
output_flipped_back[:, left, ...] = output_flipped[:, right, ...]
output_flipped_back[:, right, ...] = output_flipped[:, left, ...]
output_flipped_back = output_flipped_back.reshape(shape_ori)
# Flip horizontally
output_flipped_back = output_flipped_back.flip(3)
return output_flipped_back
def flip_bbx_xys(bbx_xys, w):
"""
bbx_xys: (F, 3)
"""
bbx_xys_flip = bbx_xys.clone()
bbx_xys_flip[:, 0] = w - bbx_xys_flip[:, 0]
return bbx_xys_flip
def flip_kp2d_coco17(kp2d, w):
"""Flip keypoints."""
kp2d = kp2d.clone()
flipped_parts = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
kp2d = kp2d[..., flipped_parts, :]
kp2d[..., 0] = w - kp2d[..., 0]
return kp2d
def flip_smplx_params(smplx_params):
"""Flip pose.
The flipping is based on SMPLX parameters.
"""
rotation = torch.cat([smplx_params["global_orient"], smplx_params["body_pose"]], dim=1)
BN = rotation.shape[0]
pose = rotation.reshape(BN, -1).transpose(0, 1)
SMPL_JOINTS_FLIP_PERM = [0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, 20] # , 23, 22]
SMPL_POSE_FLIP_PERM = []
for i in SMPL_JOINTS_FLIP_PERM:
SMPL_POSE_FLIP_PERM.append(3 * i)
SMPL_POSE_FLIP_PERM.append(3 * i + 1)
SMPL_POSE_FLIP_PERM.append(3 * i + 2)
pose = pose[SMPL_POSE_FLIP_PERM]
# we also negate the second and the third dimension of the axis-angle
pose[1::3] = -pose[1::3]
pose[2::3] = -pose[2::3]
pose = pose.transpose(0, 1).reshape(BN, -1, 3)
smplx_params_flipped = smplx_params.copy()
smplx_params_flipped["global_orient"] = pose[:, :1]
smplx_params_flipped["body_pose"] = pose[:, 1:]
return smplx_params_flipped
def avg_smplx_aa(aa1, aa2):
def avg_rot(rot):
# input [B,...,3,3] --> output [...,3,3]
rot = rot.mean(dim=0)
U, _, V = torch.svd(rot)
rot = U @ V.transpose(-1, -2)
return rot
B, J3 = aa1.shape
aa1 = aa1.reshape(B, -1, 3)
aa2 = aa2.reshape(B, -1, 3)
R1 = axis_angle_to_matrix(aa1)
R2 = axis_angle_to_matrix(aa2)
R_avg = avg_rot(torch.stack([R1, R2]))
aa_avg = matrix_to_axis_angle(R_avg).reshape(B, -1)
return aa_avg
================================================
FILE: eval/GVHMR/hmr4d/utils/geo/hmr_cam.py
================================================
import torch
import numpy as np
from hmr4d.utils.geo_transform import project_p2d, convert_bbx_xys_to_lurb, cvt_to_bi01_p2d
def estimate_focal_length(img_w, img_h):
return (img_w**2 + img_h**2) ** 0.5 # Diagonal FOV = 2*arctan(0.5) * 180/pi = 53
def estimate_K(img_w, img_h):
focal_length = estimate_focal_length(img_w, img_h)
K = torch.eye(3).float()
K[0, 0] = focal_length
K[1, 1] = focal_length
K[0, 2] = img_w / 2.0
K[1, 2] = img_h / 2.0
return K
def convert_K_to_K4(K):
K4 = torch.stack([K[0, 0], K[1, 1], K[0, 2], K[1, 2]]).float()
return K4
def convert_f_to_K(focal_length, img_w, img_h):
K = torch.eye(3).float()
K[0, 0] = focal_length
K[1, 1] = focal_length
K[0, 2] = img_w / 2.0
K[1, 2] = img_h / 2.0
return K
def resize_K(K, f=0.5):
K = K.clone() * f
K[..., 2, 2] = 1.0
return K
def create_camera_sensor(width=None, height=None, f_fullframe=None):
if width is None or height is None:
# The 4:3 aspect ratio is widely adopted by image sensors in mobile phones.
if np.random.rand() < 0.5:
width, height = 1200, 1600
else:
width, height = 1600, 1200
# Sample FOV from common options:
# 1. wide-angle lenses are common in mobile phones,
# 2. telephoto lenses has less perspective effect, which should makes it easy to learn
if f_fullframe is None:
f_fullframe_options = [24, 26, 28, 30, 35, 40, 50, 60, 70]
f_fullframe = np.random.choice(f_fullframe_options)
# We use diag to map focal-length: https://www.nikonians.org/reviews/fov-tables
diag_fullframe = (24**2 + 36**2) ** 0.5
diag_img = (width**2 + height**2) ** 0.5
focal_length = diag_img / diag_fullframe * f_fullframe
K_fullimg = torch.eye(3)
K_fullimg[0, 0] = focal_length
K_fullimg[1, 1] = focal_length
K_fullimg[0, 2] = width / 2
K_fullimg[1, 2] = height / 2
return width, height, K_fullimg
# ====== Compute cliffcam ===== #
def convert_xys_to_cliff_cam_wham(xys, res):
"""
Args:
xys: (N, 3) in pixel. Note s should not be touched by 200
res: (2), e.g. [4112., 3008.] (w,h)
Returns:
cliff_cam: (N, 3), normalized representation
"""
def normalize_keypoints_to_image(x, res):
"""
Args:
x: (N, 2), centers
res: (2), e.g. [4112., 3008.]
Returns:
x_normalized: (N, 2)
"""
res = res.to(x.device)
scale = res.max(-1)[0].reshape(-1)
mean = torch.stack([res[..., 0] / scale, res[..., 1] / scale], dim=-1).to(x.device)
x = 2 * x / scale.reshape(*[1 for i in range(len(x.shape[1:]))]) - mean.reshape(
*[1 for i in range(len(x.shape[1:-1]))], -1
)
return x
centers = normalize_keypoints_to_image(xys[:, :2], res) # (N, 2)
scale = xys[:, 2:] / res.max()
location = torch.cat((centers, scale), dim=-1)
return location
def compute_bbox_info_bedlam(bbx_xys, K_fullimg):
"""impl as in BEDLAM
Args:
bbx_xys: ((B), N, 3), in pixel space described by K_fullimg
K_fullimg: ((B), (N), 3, 3)
Returns:
bbox_info: ((B), N, 3)
"""
fl = K_fullimg[..., 0, 0].unsqueeze(-1)
icx = K_fullimg[..., 0, 2]
icy = K_fullimg[..., 1, 2]
cx, cy, b = bbx_xys[..., 0], bbx_xys[..., 1], bbx_xys[..., 2]
bbox_info = torch.stack([cx - icx, cy - icy, b], dim=-1)
bbox_info = bbox_info / fl
return bbox_info
# ====== Convert Prediction to Cam-t ===== #
def compute_transl_full_cam(pred_cam, bbx_xys, K_fullimg):
s, tx, ty = pred_cam[..., 0], pred_cam[..., 1], pred_cam[..., 2]
focal_length = K_fullimg[..., 0, 0]
icx = K_fullimg[..., 0, 2]
icy = K_fullimg[..., 1, 2]
sb = s * bbx_xys[..., 2]
cx = 2 * (bbx_xys[..., 0] - icx) / (sb + 1e-9)
cy = 2 * (bbx_xys[..., 1] - icy) / (sb + 1e-9)
tz = 2 * focal_length / (sb + 1e-9)
cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1)
return cam_t
def get_a_pred_cam(transl, bbx_xys, K_fullimg):
"""Inverse operation of compute_transl_full_cam"""
assert transl.ndim == bbx_xys.ndim # (*, L, 3)
assert K_fullimg.ndim == (bbx_xys.ndim + 1) # (*, L, 3, 3)
f = K_fullimg[..., 0, 0]
cx = K_fullimg[..., 0, 2]
cy = K_fullimg[..., 1, 2]
gt_s = 2 * f / (transl[..., 2] * bbx_xys[..., 2]) # (B, L)
gt_x = transl[..., 0] - transl[..., 2] / f * (bbx_xys[..., 0] - cx)
gt_y = transl[..., 1] - transl[..., 2] / f * (bbx_xys[..., 1] - cy)
gt_pred_cam = torch.stack([gt_s, gt_x, gt_y], dim=-1)
return gt_pred_cam
# ====== 3D to 2D ===== #
def project_to_bi01(points, bbx_xys, K_fullimg):
"""
points: (B, L, J, 3)
bbx_xys: (B, L, 3)
K_fullimg: (B, L, 3, 3)
"""
# p2d = project_p2d(points, K_fullimg)
p2d = perspective_projection(points, K_fullimg)
bbx_lurb = convert_bbx_xys_to_lurb(bbx_xys)
p2d_bi01 = cvt_to_bi01_p2d(p2d, bbx_lurb)
return p2d_bi01
def perspective_projection(points, K):
# points: (B, L, J, 3)
# K: (B, L, 3, 3)
projected_points = points / points[..., -1].unsqueeze(-1)
projected_points = torch.einsum("...ij,...kj->...ki", K, projected_points.float())
return projected_points[..., :-1]
# ====== 2D (bbx from j2d) ===== #
def normalize_kp2d(obs_kp2d, bbx_xys, clamp_scale_min=False):
"""
Args:
obs_kp2d: (B, L, J, 3) [x, y, c]
bbx_xys: (B, L, 3)
Returns:
obs: (B, L, J, 3) [x, y, c]
"""
obs_xy = obs_kp2d[..., :2] # (B, L, J, 2)
obs_conf = obs_kp2d[..., 2] # (B, L, J)
center = bbx_xys[..., :2]
scale = bbx_xys[..., [2]]
# Mark keypoints outside the bounding box as invisible
xy_max = center + scale / 2
xy_min = center - scale / 2
invisible_mask = (
(obs_xy[..., 0] < xy_min[..., None, 0])
+ (obs_xy[..., 0] > xy_max[..., None, 0])
+ (obs_xy[..., 1] < xy_min[..., None, 1])
+ (obs_xy[..., 1] > xy_max[..., None, 1])
)
obs_conf = obs_conf * ~invisible_mask
if clamp_scale_min:
scale = scale.clamp(min=1e-5)
normalized_obs_xy = 2 * (obs_xy - center.unsqueeze(-2)) / scale.unsqueeze(-2)
return torch.cat([normalized_obs_xy, obs_conf[..., None]], dim=-1)
def get_bbx_xys(i_j2d, bbx_ratio=[192, 256], do_augment=False, base_enlarge=1.2):
"""Args: (B, L, J, 3) [x,y,c] -> Returns: (B, L, 3)"""
# Center
min_x = i_j2d[..., 0].min(-1)[0]
max_x = i_j2d[..., 0].max(-1)[0]
min_y = i_j2d[..., 1].min(-1)[0]
max_y = i_j2d[..., 1].max(-1)[0]
center_x = (min_x + max_x) / 2
center_y = (min_y + max_y) / 2
# Size
h = max_y - min_y # (B, L)
w = max_x - min_x # (B, L)
if True: # fit w and h into aspect-ratio
aspect_ratio = bbx_ratio[0] / bbx_ratio[1]
mask1 = w > aspect_ratio * h
h[mask1] = w[mask1] / aspect_ratio
mask2 = w < aspect_ratio * h
w[mask2] = h[mask2] * aspect_ratio
# apply a common factor to enlarge the bounding box
bbx_size = torch.max(h, w) * base_enlarge
if do_augment:
B, L = bbx_size.shape[:2]
device = bbx_size.device
if True:
scaleFactor = torch.rand((B, L), device=device) * 0.3 + 1.05 # 1.05~1.35
txFactor = torch.rand((B, L), device=device) * 1.6 - 0.8 # -0.8~0.8
tyFactor = torch.rand((B, L), device=device) * 1.6 - 0.8 # -0.8~0.8
else:
scaleFactor = torch.rand((B, 1), device=device) * 0.3 + 1.05 # 1.05~1.35
txFactor = torch.rand((B, 1), device=device) * 1.6 - 0.8 # -0.8~0.8
tyFactor = torch.rand((B, 1), device=device) * 1.6 - 0.8 # -0.8~0.8
raw_bbx_size = bbx_size / base_enlarge
bbx_size = raw_bbx_size * scaleFactor
center_x += raw_bbx_size / 2 * ((scaleFactor - 1) * txFactor)
center_y += raw_bbx_size / 2 * ((scaleFactor - 1) * tyFactor)
return torch.stack([center_x, center_y, bbx_size], dim=-1)
def safely_render_x3d_K(x3d, K_fullimg, thr):
"""
Args:
x3d: (B, L, V, 3), should as least have a safe points (not examined here)
K_fullimg: (B, L, 3, 3)
Returns:
bbx_xys: (B, L, 3)
i_x2d: (B, L, V, 2)
"""
# For each frame, update unsafe z ( 0:
x3d[..., 2][x3d_unsafe_mask] = thr
if False:
from hmr4d.utils.wis3d_utils import make_wis3d
wis3d = make_wis3d(name="debug-update-z")
bs, ls, vs = torch.where(x3d_unsafe_mask)
bs = torch.unique(bs)
for b in bs:
for f in range(x3d.size(1)):
wis3d.set_scene_id(f)
wis3d.add_point_cloud(x3d[b, f], name="unsafe")
pass
# renfer
i_x2d = perspective_projection(x3d, K_fullimg) # (B, L, V, 2)
return i_x2d
def get_bbx_xys_from_xyxy(bbx_xyxy, base_enlarge=1.2):
"""
Args:
bbx_xyxy: (N, 4) [x1, y1, x2, y2]
Returns:
bbx_xys: (N, 3) [center_x, center_y, size]
"""
i_p2d = torch.stack([bbx_xyxy[:, [0, 1]], bbx_xyxy[:, [2, 3]]], dim=1) # (L, 2, 2)
bbx_xys = get_bbx_xys(i_p2d[None], base_enlarge=base_enlarge)[0]
return bbx_xys
def bbx_xyxy_from_x(p2d):
"""
Args:
p2d: (*, V, 2) - Tensor containing 2D points.
Returns:
bbx_xyxy: (*, 4) - Bounding box coordinates in the format (xmin, ymin, xmax, ymax).
"""
# Compute the minimum and maximum coordinates for the bounding box
xy_min = p2d.min(dim=-2).values # (*, 2)
xy_max = p2d.max(dim=-2).values # (*, 2)
# Concatenate min and max coordinates to form the bounding box
bbx_xyxy = torch.cat([xy_min, xy_max], dim=-1) # (*, 4)
return bbx_xyxy
def bbx_xyxy_from_masked_x(p2d, mask):
"""
Args:
p2d: (*, V, 2) - Tensor containing 2D points.
mask: (*, V) - Boolean tensor indicating valid points.
Returns:
bbx_xyxy: (*, 4) - Bounding box coordinates in the format (xmin, ymin, xmax, ymax).
"""
# Ensure the shapes of p2d and mask are compatible
assert p2d.shape[:-1] == mask.shape, "The shape of p2d and mask are not compatible."
# Flatten the input tensors for batch processing
p2d_flat = p2d.view(-1, p2d.shape[-2], p2d.shape[-1])
mask_flat = mask.view(-1, mask.shape[-1])
# Set masked out values to a large positive and negative value respectively
p2d_min = torch.where(mask_flat.unsqueeze(-1), p2d_flat, torch.tensor(float("inf")).to(p2d_flat))
p2d_max = torch.where(mask_flat.unsqueeze(-1), p2d_flat, torch.tensor(float("-inf")).to(p2d_flat))
# Compute the minimum and maximum coordinates for the bounding box
xy_min = p2d_min.min(dim=1).values # (BL, 2)
xy_max = p2d_max.max(dim=1).values # (BL, 2)
# Concatenate min and max coordinates to form the bounding box
bbx_xyxy = torch.cat([xy_min, xy_max], dim=-1) # (BL, 4)
# Reshape back to the original shape prefix
bbx_xyxy = bbx_xyxy.view(*p2d.shape[:-2], 4)
return bbx_xyxy
def bbx_xyxy_ratio(xyxy1, xyxy2):
"""Designed for fov/unbounded
Args:
xyxy1: (*, 4)
xyxy2: (*, 4)
Return:
ratio: (*), squared_area(xyxy1) / squared_area(xyxy2)
"""
area1 = (xyxy1[..., 2] - xyxy1[..., 0]) * (xyxy1[..., 3] - xyxy1[..., 1])
area2 = (xyxy2[..., 2] - xyxy2[..., 0]) * (xyxy2[..., 3] - xyxy2[..., 1])
# Check
area1[~torch.isfinite(area1)] = 0 # replace inf in area1 with 0
assert (area2 > 0).all(), "area2 should be positive"
return area1 / area2
def get_mesh_in_fov_category(mask):
"""mask: (L, V)
The definition:
1. FullyVisible: The mesh in every frame is entirely within the field of view (FOV).
2. PartiallyVisible: In some frames, parts of the mesh are outside the FOV, while other parts are within the FOV.
3. PartiallyOut: In some frames, the mesh is completely outside the FOV, while in others, it is visible.
4. FullyOut: The mesh is completely outside the FOV in every frame.
"""
mask = mask.clone().cpu()
is_class1 = mask.all() # FullyVisible
is_class2 = mask.any(1).all() * ~is_class1 # PartiallyVisible
is_class4 = ~(mask.any()) # PartiallyOut
is_class3 = ~is_class1 * ~is_class2 * ~is_class4 # FullyOut
mask_frame_any_verts = mask.any(1)
assert is_class1.int() + is_class2.int() + is_class3.int() + is_class4.int() == 1
class_type = is_class1.int() + 2 * is_class2.int() + 3 * is_class3.int() + 4 * is_class4.int()
return class_type.item(), mask_frame_any_verts
def get_infov_mask(p2d, w_real, h_real):
"""
Args:
p2d: (B, L, V, 2)
w_real, h_real: (B, L) or int
Returns:
mask: (B, L, V)
"""
x, y = p2d[..., 0], p2d[..., 1]
if isinstance(w_real, int):
mask = (x >= 0) * (x < w_real) * (y >= 0) * (y < h_real)
else:
mask = (x >= 0) * (x < w_real[..., None]) * (y >= 0) * (y < h_real[..., None])
return mask
================================================
FILE: eval/GVHMR/hmr4d/utils/geo/hmr_global.py
================================================
import torch
from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle, matrix_to_quaternion, quaternion_to_matrix
import hmr4d.utils.matrix as matrix
from hmr4d.utils.net_utils import gaussian_smooth
def get_R_c2gv(R_w2c, axis_gravity_in_w=[0, 0, -1]):
"""
Args:
R_w2c: (*, 3, 3)
Returns:
R_c2gv: (*, 3, 3)
"""
if isinstance(axis_gravity_in_w, list):
axis_gravity_in_w = torch.tensor(axis_gravity_in_w).float() # gravity direction in world coord
axis_z_in_c = torch.tensor([0, 0, 1]).float()
# get gv-coord axes in in c-coord
axis_y_of_gv = R_w2c @ axis_gravity_in_w # (*, 3)
axis_x_of_gv = axis_y_of_gv.cross(axis_z_in_c.expand_as(axis_y_of_gv), dim=-1)
# normalize
axis_x_of_gv_norm = axis_x_of_gv.norm(dim=-1, keepdim=True)
axis_x_of_gv = axis_x_of_gv / (axis_x_of_gv_norm + 1e-5)
axis_x_of_gv[axis_x_of_gv_norm.squeeze(-1) < 1e-5] = torch.tensor([1.0, 0.0, 0.0]) # use cam x-axis as axis_x_of_gv
axis_z_of_gv = axis_x_of_gv.cross(axis_y_of_gv, dim=-1)
R_gv2c = torch.stack([axis_x_of_gv, axis_y_of_gv, axis_z_of_gv], dim=-1) # (*, 3, 3)
R_c2gv = R_gv2c.transpose(-1, -2) # (*, 3, 3)
return R_c2gv
tsf_axisangle = {
"ay->ay": [0, 0, 0],
"any->ay": [0, 0, torch.pi],
"az->ay": [-torch.pi / 2, 0, 0],
"ay->any": [0, 0, torch.pi],
}
def get_tgtcoord_rootparam(global_orient, transl, gravity_vec=None, tgt_gravity_vec=None, tsf="ay->ay"):
"""Rotate around the origin center, to match the new gravity direction
Args:
global_orient: torch.tensor, (*, 3)
transl: torch.tensor, (*, 3)
gravity_vec: torch.tensor, (3,)
tgt_gravity_vec: torch.tensor, (3,)
Returns:
tgt_global_orient: torch.tensor, (*, 3)
tgt_transl: torch.tensor, (*, 3)
R_g2tg: (3, 3)
"""
# get rotation matrix
device = global_orient.device
if gravity_vec is None and tgt_gravity_vec is None:
aa = torch.tensor(tsf_axisangle[tsf]).to(device)
R_g2tg = axis_angle_to_matrix(aa) # (3, 3)
else:
raise NotImplementedError
# TODO: Impl this function
gravity_vec = torch.tensor(gravity_vec).float().to(device)
gravity_vec = gravity_vec / gravity_vec.norm()
tgt_gravity_vec = torch.tensor(tgt_gravity_vec).float().to(device)
tgt_gravity_vec = tgt_gravity_vec / tgt_gravity_vec.norm()
# pick one identity axis
axis_identity = torch.tensor([0, 0, 0]).float().to(device)
for i in (gravity_vec == 0) & (tgt_global_orient == 0):
if i:
axis_identity[i] = 1
break
# rotate
global_orient_R = axis_angle_to_matrix(global_orient) # (*, 3, 3)
tgt_global_orient = matrix_to_axis_angle(R_g2tg @ global_orient_R) # (*, 3, 3)
tgt_transl = torch.einsum("...ij,...j->...i", R_g2tg, transl)
return tgt_global_orient, tgt_transl, R_g2tg
def get_c_rootparam(global_orient, transl, T_w2c, offset):
"""
Args:
global_orient: torch.tensor, (F, 3)
transl: torch.tensor, (F, 3)
T_w2c: torch.tensor, (*, 4, 4)
offset: torch.tensor, (3,)
Returns:
R_c: torch.tensor, (F, 3)
t_c: torch.tensor, (F, 3)
"""
assert global_orient.shape == transl.shape and len(global_orient.shape) == 2
R_w = axis_angle_to_matrix(global_orient) # (F, 3, 3)
t_w = transl # (F, 3)
R_w2c = T_w2c[..., :3, :3] # (*, 3, 3)
t_w2c = T_w2c[..., :3, 3] # (*, 3)
if len(R_w2c.shape) == 2:
R_w2c = R_w2c[None].expand(R_w.size(0), -1, -1) # (F, 3, 3)
t_w2c = t_w2c[None].expand(t_w.size(0), -1)
R_c = matrix_to_axis_angle(R_w2c @ R_w) # (F, 3)
t_c = torch.einsum("fij,fj->fi", R_w2c, t_w + offset) + t_w2c - offset # (F, 3)
return R_c, t_c
def get_T_w2c_from_wcparams(global_orient_w, transl_w, global_orient_c, transl_c, offset):
"""
Args:
global_orient_w: torch.tensor, (F, 3)
transl_w: torch.tensor, (F, 3)
global_orient_c: torch.tensor, (F, 3)
transl_c: torch.tensor, (F, 3)
offset: torch.tensor, (*, 3)
Returns:
T_w2c: torch.tensor, (F, 4, 4)
"""
assert global_orient_w.shape == transl_w.shape and len(global_orient_w.shape) == 2
assert global_orient_c.shape == transl_c.shape and len(global_orient_c.shape) == 2
R_w = axis_angle_to_matrix(global_orient_w) # (F, 3, 3)
t_w = transl_w # (F, 3)
R_c = axis_angle_to_matrix(global_orient_c) # (F, 3, 3)
t_c = transl_c # (F, 3)
R_w2c = R_c @ R_w.transpose(-1, -2) # (F, 3, 3)
t_w2c = t_c + offset - torch.einsum("fij,fj->fi", R_w2c, t_w + offset) # (F, 3)
T_w2c = torch.eye(4, device=global_orient_w.device).repeat(R_w.size(0), 1, 1) # (F, 4, 4)
T_w2c[..., :3, :3] = R_w2c # (F, 3, 3)
T_w2c[..., :3, 3] = t_w2c # (F, 3)
return T_w2c
def get_local_transl_vel(transl, global_orient):
"""
transl velocity is in local coordinate (or, SMPL-coord)
Args:
transl: (*, L, 3)
global_orient: (*, L, 3)
Returns:
transl_vel: (*, L, 3)
"""
assert len(transl.shape) == len(global_orient.shape)
global_orient_R = axis_angle_to_matrix(global_orient) # (B, L, 3, 3)
transl_vel = transl[..., 1:, :] - transl[..., :-1, :] # (B, L-1, 3)
transl_vel = torch.cat([transl_vel, transl_vel[..., [-1], :]], dim=-2) # (B, L, 3) last-padding
# v_local = R^T @ v_global
local_transl_vel = torch.einsum("...lij,...li->...lj", global_orient_R, transl_vel)
return local_transl_vel
def rollout_local_transl_vel(local_transl_vel, global_orient, transl_0=None):
"""
transl velocity is in local coordinate (or, SMPL-coord)
Args:
local_transl_vel: (*, L, 3)
global_orient: (*, L, 3)
transl_0: (*, 1, 3), if not provided, the start point is 0
Returns:
transl: (*, L, 3)
"""
global_orient_R = axis_angle_to_matrix(global_orient)
transl_vel = torch.einsum("...lij,...lj->...li", global_orient_R, local_transl_vel)
# set start point
if transl_0 is None:
transl_0 = transl_vel[..., :1, :].clone().detach().zero_()
transl_ = torch.cat([transl_0, transl_vel[..., :-1, :]], dim=-2)
# rollout from start point
transl = torch.cumsum(transl_, dim=-2)
return transl
def get_local_transl_vel_alignhead(transl, global_orient):
# assume global_orient is ay
global_orient_rot = axis_angle_to_matrix(global_orient) # (*, 3, 3)
global_orient_quat = matrix_to_quaternion(global_orient_rot) # (*, 4)
global_orient_quat_xyzw = matrix.quat_wxyz2xyzw(global_orient_quat) # (*, 4)
head_quat_xyzw = matrix.calc_heading_quat(global_orient_quat_xyzw, head_ind=2, gravity_axis="y") # (*, 4)
head_quat = matrix.quat_xyzw2wxyz(head_quat_xyzw) # (*, 4)
head_rot = quaternion_to_matrix(head_quat)
head_aa = matrix_to_axis_angle(head_rot)
local_transl_vel_alignhead = get_local_transl_vel(transl, head_aa)
return local_transl_vel_alignhead
def rollout_local_transl_vel_alignhead(local_transl_vel_alignhead, global_orient, transl_0=None):
# assume global_orient is ay
global_orient_rot = axis_angle_to_matrix(global_orient) # (*, 3, 3)
global_orient_quat = matrix_to_quaternion(global_orient_rot) # (*, 4)
global_orient_quat_xyzw = matrix.quat_wxyz2xyzw(global_orient_quat) # (*, 4)
head_quat_xyzw = matrix.calc_heading_quat(global_orient_quat_xyzw, head_ind=2, gravity_axis="y") # (*, 4)
head_quat = matrix.quat_xyzw2wxyz(head_quat_xyzw) # (*, 4)
head_rot = quaternion_to_matrix(head_quat)
head_aa = matrix_to_axis_angle(head_rot)
transl = rollout_local_transl_vel(local_transl_vel_alignhead, head_aa, transl_0)
return transl
def get_local_transl_vel_alignhead_absy(transl, global_orient):
# assume global_orient is ay
global_orient_rot = axis_angle_to_matrix(global_orient) # (*, 3, 3)
global_orient_quat = matrix_to_quaternion(global_orient_rot) # (*, 4)
global_orient_quat_xyzw = matrix.quat_wxyz2xyzw(global_orient_quat) # (*, 4)
head_quat_xyzw = matrix.calc_heading_quat(global_orient_quat_xyzw, head_ind=2, gravity_axis="y") # (*, 4)
head_quat = matrix.quat_xyzw2wxyz(head_quat_xyzw) # (*, 4)
head_rot = quaternion_to_matrix(head_quat)
head_aa = matrix_to_axis_angle(head_rot)
local_transl_vel_alignhead = get_local_transl_vel(transl, head_aa)
abs_y = torch.cumsum(local_transl_vel_alignhead[..., [1]], dim=-2) # (*, L, 1)
local_transl_vel_alignhead_absy = torch.cat(
[local_transl_vel_alignhead[..., [0]], abs_y, local_transl_vel_alignhead[..., [2]]], dim=-1
)
return local_transl_vel_alignhead_absy
def rollout_local_transl_vel_alignhead_absy(local_transl_vel_alignhead_absy, global_orient, transl_0=None):
# assume global_orient is ay
global_orient_rot = axis_angle_to_matrix(global_orient) # (*, 3, 3)
global_orient_quat = matrix_to_quaternion(global_orient_rot) # (*, 4)
global_orient_quat_xyzw = matrix.quat_wxyz2xyzw(global_orient_quat) # (*, 4)
head_quat_xyzw = matrix.calc_heading_quat(global_orient_quat_xyzw, head_ind=2, gravity_axis="y") # (*, 4)
head_quat = matrix.quat_xyzw2wxyz(head_quat_xyzw) # (*, 4)
head_rot = quaternion_to_matrix(head_quat)
head_aa = matrix_to_axis_angle(head_rot)
local_transl_vel_alignhead_y = (
local_transl_vel_alignhead_absy[..., 1:, [1]] - local_transl_vel_alignhead_absy[..., :-1, [1]]
)
local_transl_vel_alignhead_y = torch.cat(
[local_transl_vel_alignhead_absy[..., :1, [1]], local_transl_vel_alignhead_y], dim=-2
)
local_transl_vel_alignhead = torch.cat(
[
local_transl_vel_alignhead_absy[..., [0]],
local_transl_vel_alignhead_y,
local_transl_vel_alignhead_absy[..., [2]],
],
dim=-1,
)
transl = rollout_local_transl_vel(local_transl_vel_alignhead, head_aa, transl_0)
return transl
def get_local_transl_vel_alignhead_absgy(transl, global_orient):
# assume global_orient is ay
global_orient_rot = axis_angle_to_matrix(global_orient) # (*, 3, 3)
global_orient_quat = matrix_to_quaternion(global_orient_rot) # (*, 4)
global_orient_quat_xyzw = matrix.quat_wxyz2xyzw(global_orient_quat) # (*, 4)
head_quat_xyzw = matrix.calc_heading_quat(global_orient_quat_xyzw, head_ind=2, gravity_axis="y") # (*, 4)
head_quat = matrix.quat_xyzw2wxyz(head_quat_xyzw) # (*, 4)
head_rot = quaternion_to_matrix(head_quat)
head_aa = matrix_to_axis_angle(head_rot)
local_transl_vel_alignhead = get_local_transl_vel(transl, head_aa)
abs_y = transl[..., [1]] # (*, L, 1)
local_transl_vel_alignhead_absy = torch.cat(
[local_transl_vel_alignhead[..., [0]], abs_y, local_transl_vel_alignhead[..., [2]]], dim=-1
)
return local_transl_vel_alignhead_absy
def rollout_local_transl_vel_alignhead_absgy(local_transl_vel_alignhead_absgy, global_orient, transl_0=None):
# assume global_orient is ay
global_orient_rot = axis_angle_to_matrix(global_orient) # (*, 3, 3)
global_orient_quat = matrix_to_quaternion(global_orient_rot) # (*, 4)
global_orient_quat_xyzw = matrix.quat_wxyz2xyzw(global_orient_quat) # (*, 4)
head_quat_xyzw = matrix.calc_heading_quat(global_orient_quat_xyzw, head_ind=2, gravity_axis="y") # (*, 4)
head_quat = matrix.quat_xyzw2wxyz(head_quat_xyzw) # (*, 4)
head_rot = quaternion_to_matrix(head_quat)
head_aa = matrix_to_axis_angle(head_rot)
local_transl_vel_alignhead_y = (
local_transl_vel_alignhead_absgy[..., 1:, [1]] - local_transl_vel_alignhead_absgy[..., :-1, [1]]
)
local_transl_vel_alignhead_y = torch.cat(
[local_transl_vel_alignhead_y, local_transl_vel_alignhead_y[..., -1:, :]], dim=-2
)
if transl_0 is not None:
transl_0 = transl_0.clone()
transl_0[..., 1] = local_transl_vel_alignhead_absgy[..., :1, 1]
else:
transl_0 = local_transl_vel_alignhead_absgy.clone()[..., :1, :] # (*, 1, 3)
transl_0[..., :1, 0] = 0.0
transl_0[..., :1, 2] = 0.0
local_transl_vel_alignhead = torch.cat(
[
local_transl_vel_alignhead_absgy[..., [0]],
local_transl_vel_alignhead_y,
local_transl_vel_alignhead_absgy[..., [2]],
],
dim=-1,
)
transl = rollout_local_transl_vel(local_transl_vel_alignhead, head_aa, transl_0)
return transl
def rollout_vel(vel, transl_0=None):
"""
Args:
vel: (*, L, 3)
transl_0: (*, 1, 3), if not provided, the start point is 0
Returns:
transl: (*, L, 3)
"""
# set start point
if transl_0 is None:
assert len(vel.shape) == len(transl_0.shape)
transl_0 = vel[..., :1, :].clone().detach().zero_()
transl_ = torch.cat([transl_0, vel[..., :-1, :]], dim=-2)
# rollout from start point
transl = torch.cumsum(transl_, dim=-2)
return transl
def get_static_joint_mask(w_j3d, vel_thr=0.25, smooth=False, repeat_last=False):
"""
w_j3d: (*, L, J, 3)
vel_thr: HuMoR uses 0.15m/s
"""
joint_v_ = (w_j3d[..., 1:, :, :] - w_j3d[..., :-1, :, :]).pow(2).sum(-1).sqrt() / 0.033 # (*, L-1, J)
if smooth:
joint_v_ = gaussian_smooth(joint_v_, 3, -2)
static_joint_mask = joint_v_ < vel_thr # 1 as stable, 0 as moving
if repeat_last: # repeat the last frame, this makes the shape same as w_j3d
static_joint_mask = torch.cat([static_joint_mask, static_joint_mask[..., [-1], :]], dim=-2)
return static_joint_mask
================================================
FILE: eval/GVHMR/hmr4d/utils/geo/quaternion.py
================================================
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import torch
import numpy as np
_EPS4 = np.finfo(float).eps * 4.0
try:
_FLOAT_EPS = np.finfo(np.float).eps
except:
_FLOAT_EPS = np.finfo(float).eps
# PyTorch-backed implementations
def qinv(q):
assert q.shape[-1] == 4, "q must be a tensor of shape (*, 4)"
mask = torch.ones_like(q)
mask[..., 1:] = -mask[..., 1:]
return q * mask
def qinv_np(q):
assert q.shape[-1] == 4, "q must be a tensor of shape (*, 4)"
return qinv(torch.from_numpy(q).float()).numpy()
def qnormalize(q):
assert q.shape[-1] == 4, "q must be a tensor of shape (*, 4)"
return q / torch.clamp(torch.norm(q, dim=-1, keepdim=True), min=1e-8)
def qmul(q, r):
"""
Multiply quaternion(s) q with quaternion(s) r.
Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
Returns q*r as a tensor of shape (*, 4).
"""
assert q.shape[-1] == 4
assert r.shape[-1] == 4
original_shape = q.shape
# Compute outer product
terms = torch.bmm(r.reshape(-1, 4, 1), q.reshape(-1, 1, 4))
w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
return torch.stack((w, x, y, z), dim=1).view(original_shape)
def qrot(q, v):
"""
Rotate vector(s) v about the rotation described by quaternion(s) q.
Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
where * denotes any number of dimensions.
Returns a tensor of shape (*, 3).
"""
assert q.shape[-1] == 4
assert v.shape[-1] == 3
assert q.shape[:-1] == v.shape[:-1]
original_shape = list(v.shape)
# print(q.shape)
q = q.contiguous().view(-1, 4)
v = v.contiguous().view(-1, 3)
qvec = q[:, 1:]
uv = torch.cross(qvec, v, dim=1)
uuv = torch.cross(qvec, uv, dim=1)
return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
def qeuler(q, order, epsilon=0, deg=True):
"""
Convert quaternion(s) q to Euler angles.
Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
Returns a tensor of shape (*, 3).
"""
assert q.shape[-1] == 4
original_shape = list(q.shape)
original_shape[-1] = 3
q = q.view(-1, 4)
q0 = q[:, 0]
q1 = q[:, 1]
q2 = q[:, 2]
q3 = q[:, 3]
if order == "xyz":
x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
elif order == "yzx":
x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
elif order == "zxy":
x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
elif order == "xzy":
x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
elif order == "yxz":
x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
elif order == "zyx":
x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
else:
raise
if deg:
return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi
else:
return torch.stack((x, y, z), dim=1).view(original_shape)
# Numpy-backed implementations
def qmul_np(q, r):
q = torch.from_numpy(q).contiguous().float()
r = torch.from_numpy(r).contiguous().float()
return qmul(q, r).numpy()
def qrot_np(q, v):
q = torch.from_numpy(q).contiguous().float()
v = torch.from_numpy(v).contiguous().float()
return qrot(q, v).numpy()
def qeuler_np(q, order, epsilon=0, use_gpu=False):
if use_gpu:
q = torch.from_numpy(q).cuda().float()
return qeuler(q, order, epsilon).cpu().numpy()
else:
q = torch.from_numpy(q).contiguous().float()
return qeuler(q, order, epsilon).numpy()
def qfix(q):
"""
Enforce quaternion continuity across the time dimension by selecting
the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
between two consecutive frames.
Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
Returns a tensor of the same shape.
"""
assert len(q.shape) == 3
assert q.shape[-1] == 4
result = q.copy()
dot_products = np.sum(q[1:] * q[:-1], axis=2)
mask = dot_products < 0
mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
result[1:][mask] *= -1
return result
def euler2quat(e, order, deg=True):
"""
Convert Euler angles to quaternions.
"""
assert e.shape[-1] == 3
original_shape = list(e.shape)
original_shape[-1] = 4
e = e.view(-1, 3)
## if euler angles in degrees
if deg:
e = e * np.pi / 180.0
x = e[:, 0]
y = e[:, 1]
z = e[:, 2]
rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1)
ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1)
rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1)
result = None
for coord in order:
if coord == "x":
r = rx
elif coord == "y":
r = ry
elif coord == "z":
r = rz
else:
raise
if result is None:
result = r
else:
result = qmul(result, r)
# Reverse antipodal representation to have a non-negative "w"
if order in ["xyz", "yzx", "zxy"]:
result *= -1
return result.view(original_shape)
def expmap_to_quaternion(e):
"""
Convert axis-angle rotations (aka exponential maps) to quaternions.
Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
Returns a tensor of shape (*, 4).
"""
assert e.shape[-1] == 3
original_shape = list(e.shape)
original_shape[-1] = 4
e = e.reshape(-1, 3)
theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
w = np.cos(0.5 * theta).reshape(-1, 1)
xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
return np.concatenate((w, xyz), axis=1).reshape(original_shape)
def euler_to_quaternion(e, order):
"""
Convert Euler angles to quaternions.
"""
assert e.shape[-1] == 3
original_shape = list(e.shape)
original_shape[-1] = 4
e = e.reshape(-1, 3)
x = e[:, 0]
y = e[:, 1]
z = e[:, 2]
rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1)
ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1)
rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1)
result = None
for coord in order:
if coord == "x":
r = rx
elif coord == "y":
r = ry
elif coord == "z":
r = rz
else:
raise
if result is None:
result = r
else:
result = qmul_np(result, r)
# Reverse antipodal representation to have a non-negative "w"
if order in ["xyz", "yzx", "zxy"]:
result *= -1
return result.reshape(original_shape)
def quaternion_to_matrix(quaternions):
"""
Convert rotations given as quaternions to rotation matrices.
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
r, i, j, k = torch.unbind(quaternions, -1)
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
def quaternion_to_matrix_np(quaternions):
q = torch.from_numpy(quaternions).contiguous().float()
return quaternion_to_matrix(q).numpy()
def quaternion_to_cont6d_np(quaternions):
rotation_mat = quaternion_to_matrix_np(quaternions)
cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
return cont_6d
def quaternion_to_cont6d(quaternions):
rotation_mat = quaternion_to_matrix(quaternions)
cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
return cont_6d
def cont6d_to_matrix(cont6d):
assert cont6d.shape[-1] == 6, "The last dimension must be 6"
x_raw = cont6d[..., 0:3]
y_raw = cont6d[..., 3:6]
x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
z = torch.cross(x, y_raw, dim=-1)
z = z / torch.norm(z, dim=-1, keepdim=True)
y = torch.cross(z, x, dim=-1)
x = x[..., None]
y = y[..., None]
z = z[..., None]
mat = torch.cat([x, y, z], dim=-1)
return mat
def cont6d_to_matrix_np(cont6d):
q = torch.from_numpy(cont6d).contiguous().float()
return cont6d_to_matrix(q).numpy()
def qpow(q0, t, dtype=torch.float):
"""q0 : tensor of quaternions
t: tensor of powers
"""
q0 = qnormalize(q0)
theta0 = torch.acos(q0[..., :1])
## if theta0 is close to zero, add epsilon to avoid NaNs
mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
mask = mask.float()
theta0 = (1 - mask) * theta0 + mask * 10e-10
v0 = q0[..., 1:] / torch.sin(theta0)
if isinstance(t, torch.Tensor):
# Do not check here
q = torch.zeros(t.shape + q0.shape, device=q0.device)
theta = t.view(-1, 1) * theta0.view(1, -1)
else: ## if t is a number
q = torch.zeros(q0.shape, device=q0.device)
theta = t * theta0
q[..., :1] = torch.cos(theta)
q[..., 1:] = v0 * torch.sin(theta)
return q.to(dtype)
def qslerp(q0, q1, t):
"""
q0: starting quaternion
q1: ending quaternion
t: array of points along the way
Returns:
Tensor of Slerps: t.shape + q0.shape
"""
q0 = qnormalize(q0)
q1 = qnormalize(q1)
q_ = qpow(qmul(q1, qinv(q0)), t)
return qmul(q_, q0)
def qbetween(v0, v1):
"""
find the quaternion used to rotate v0 to v1
"""
assert v0.shape[-1] == 3, "v0 must be of the shape (*, 3)"
assert v1.shape[-1] == 3, "v1 must be of the shape (*, 3)"
v = torch.cross(v0, v1, dim=-1)
w = torch.sqrt((v0**2).sum(dim=-1, keepdim=True) * (v1**2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(
dim=-1, keepdim=True
)
y_vec = torch.zeros_like(v)
y_vec[..., 1] = 1.0
# if v0 is (0, 0, -1), v1 is (0, 0, 1), v will be 0 and w will also be 0 -> this makes below situation comes v=1 w = 2
mask = v.norm(dim=-1) == 0
# if v0 is (0, 0, 1), v1 is (0, 0, 1), v will be 0 and w will be 2 -> do nothing
mask2 = w.sum(dim=-1).abs() <= 1e-4
mask = torch.logical_and(mask, mask2)
v[mask] = y_vec[mask]
return qnormalize(torch.cat([w, v], dim=-1))
def qbetween_np(v0, v1):
"""
find the quaternion used to rotate v0 to v1
"""
assert v0.shape[-1] == 3, "v0 must be of the shape (*, 3)"
assert v1.shape[-1] == 3, "v1 must be of the shape (*, 3)"
v0 = torch.from_numpy(v0).float()
v1 = torch.from_numpy(v1).float()
return qbetween(v0, v1).numpy()
def lerp(p0, p1, t):
if not isinstance(t, torch.Tensor):
t = torch.Tensor([t])
new_shape = t.shape + p0.shape
new_view_t = t.shape + torch.Size([1] * len(p0.shape))
new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
p0 = p0.view(new_view_p).expand(new_shape)
p1 = p1.view(new_view_p).expand(new_shape)
t = t.view(new_view_t).expand(new_shape)
return p0 + t * (p1 - p0)
================================================
FILE: eval/GVHMR/hmr4d/utils/geo/transforms.py
================================================
import torch
def axis_rotate_to_matrix(angle, axis="x"):
"""Get rotation matrix for rotating around one axis
Args:
angle: (N, 1)
Returns:
R: (N, 3, 3)
"""
if isinstance(angle, float):
angle = torch.tensor([angle], dtype=torch.float)
c = torch.cos(angle)
s = torch.sin(angle)
z = torch.zeros_like(angle)
o = torch.ones_like(angle)
if axis == "x":
R = torch.stack([o, z, z, z, c, -s, z, s, c], dim=1).view(-1, 3, 3)
elif axis == "y":
R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3)
else:
assert axis == "z"
R = torch.stack([c, -s, z, s, c, z, z, z, o], dim=1).view(-1, 3, 3)
return R
================================================
FILE: eval/GVHMR/hmr4d/utils/geo_transform.py
================================================
import numpy as np
import cv2
import torch
import torch.nn.functional as F
from pytorch3d.transforms import so3_exp_map, so3_log_map
from pytorch3d.transforms import matrix_to_quaternion, quaternion_to_axis_angle, matrix_to_rotation_6d
import pytorch3d.ops.knn as knn
from hmr4d.utils.pylogger import Log
from pytorch3d.transforms import euler_angles_to_matrix
import hmr4d.utils.matrix as matrix
from einops import einsum, rearrange, repeat
from hmr4d.utils.geo.quaternion import qbetween
def homo_points(points):
"""
Args:
points: (..., C)
Returns: (..., C+1), with 1 padded
"""
return F.pad(points, [0, 1], value=1.0)
def apply_Ts_on_seq_points(points, Ts):
"""
perform translation matrix on related point
Args:
points: (..., N, 3)
Ts: (..., N, 4, 4)
Returns: (..., N, 3)
"""
points = torch.torch.einsum("...ki,...i->...k", Ts[..., :3, :3], points) + Ts[..., :3, 3]
return points
def apply_T_on_points(points, T):
"""
Args:
points: (..., N, 3)
T: (..., 4, 4)
Returns: (..., N, 3)
"""
points_T = torch.einsum("...ki,...ji->...jk", T[..., :3, :3], points) + T[..., None, :3, 3]
return points_T
def T_transforms_points(T, points, pattern):
"""manual mode of apply_T_on_points
T: (..., 4, 4)
points: (..., 3)
pattern: "... c d, ... d -> ... c"
"""
return einsum(T, homo_points(points), pattern)[..., :3]
def project_p2d(points, K=None, is_pinhole=True):
"""
Args:
points: (..., (N), 3)
K: (..., 3, 3)
Returns: shape is similar to points but without z
"""
points = points.clone()
if is_pinhole:
z = points[..., [-1]]
z.masked_fill_(z.abs() < 1e-6, 1e-6)
points_proj = points / z
else: # orthogonal
points_proj = F.pad(points[..., :2], (0, 1), value=1)
if K is not None:
# Handle N
if len(points_proj.shape) == len(K.shape):
p2d_h = torch.einsum("...ki,...ji->...jk", K, points_proj)
else:
p2d_h = torch.einsum("...ki,...i->...k", K, points_proj)
else:
p2d_h = points_proj[..., :2]
return p2d_h[..., :2]
def gen_uv_from_HW(H, W, device="cpu"):
"""Returns: (H, W, 2), as float. Note: uv not ij"""
grid_v, grid_u = torch.meshgrid(torch.arange(H), torch.arange(W))
return (
torch.stack(
[grid_u, grid_v],
dim=-1,
)
.float()
.to(device)
) # (H, W, 2)
def unproject_p2d(uv, z, K):
"""we assume a pinhole camera for unprojection
uv: (B, N, 2)
z: (B, N, 1)
K: (B, 3, 3)
Returns: (B, N, 3)
"""
xy_atz1 = (uv - K[:, None, :2, 2]) / K[:, None, [0, 1], [0, 1]] # (B, N, 2)
xyz = torch.cat([xy_atz1 * z, z], dim=-1)
return xyz
def cvt_p2d_from_i_to_c(uv, K):
"""
Args:
uv: (..., 2) or (..., N, 2)
K: (..., 3, 3)
Returns: the same shape as input uv
"""
if len(uv.shape) == len(K.shape):
xy = (uv - K[..., None, :2, 2]) / K[..., None, [0, 1], [0, 1]]
else: # without N
xy = (uv - K[..., :2, 2]) / K[..., [0, 1], [0, 1]]
return xy
def cvt_to_bi01_p2d(p2d, bbx_lurb):
"""
p2d: (..., (N), 2)
bbx_lurb: (..., 4)
"""
if len(p2d.shape) == len(bbx_lurb.shape) + 1:
bbx_lurb = bbx_lurb[..., None, :]
bbx_wh = bbx_lurb[..., 2:] - bbx_lurb[..., :2]
bi01_p2d = (p2d - bbx_lurb[..., :2]) / bbx_wh
return bi01_p2d
def cvt_from_bi01_p2d(bi01_p2d, bbx_lurb):
"""Use bbx_lurb to resize bi01_p2d to p2d (image-coordinates)
Args:
p2d: (..., 2) or (..., N, 2)
bbx_lurb: (..., 4)
Returns:
p2d: shape is the same as input
"""
bbx_wh = bbx_lurb[..., 2:] - bbx_lurb[..., :2] # (..., 2)
if len(bi01_p2d.shape) == len(bbx_wh.shape) + 1:
p2d = (bi01_p2d * bbx_wh.unsqueeze(-2)) + bbx_lurb[..., None, :2]
else:
p2d = (bi01_p2d * bbx_wh) + bbx_lurb[..., :2]
return p2d
def cvt_p2d_from_bi01_to_c(bi01, bbxs_lurb, Ks):
"""
Args:
bi01: (..., (N), 2), value in range (0,1), the point in the bbx image
bbxs_lurb: (..., 4)
Ks: (..., 3, 3)
Returns:
c: (..., (N), 2)
"""
i = cvt_from_bi01_p2d(bi01, bbxs_lurb)
c = cvt_p2d_from_i_to_c(i, Ks)
return c
def cvt_p2d_from_pm1_to_i(p2d_pm1, bbx_xys):
"""
Args:
p2d_pm1: (..., (N), 2), value in range (-1,1), the point in the bbx image
bbx_xys: (..., 3)
Returns:
p2d: (..., (N), 2)
"""
return bbx_xys[..., :2] + p2d_pm1 * bbx_xys[..., [2]] / 2
def uv2l_index(uv, W):
return uv[..., 0] + uv[..., 1] * W
def l2uv_index(l, W):
v = torch.div(l, W, rounding_mode="floor")
u = l % W
return torch.stack([u, v], dim=-1)
def transform_mat(R, t):
"""
Args:
R: Bx3x3 array of a batch of rotation matrices
t: Bx3x(1) array of a batch of translation vectors
Returns:
T: Bx4x4 Transformation matrix
"""
# No padding left or right, only add an extra row
if len(R.shape) > len(t.shape):
t = t[..., None]
return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=-1)
def axis_angle_to_matrix_exp_map(aa):
"""use pytorch3d so3_exp_map
Args:
aa: (*, 3)
Returns:
R: (*, 3, 3)
"""
print("Use pytorch3d.transforms.axis_angle_to_matrix instead!!!")
ori_shape = aa.shape[:-1]
return so3_exp_map(aa.reshape(-1, 3)).reshape(*ori_shape, 3, 3)
def matrix_to_axis_angle_log_map(R):
"""use pytorch3d so3_log_map
Args:
aa: (*, 3, 3)
Returns:
R: (*, 3)
"""
print("WARINING! I met singularity problem with this function, use matrix_to_axis_angle instead!")
ori_shape = R.shape[:-2]
return so3_log_map(R.reshape(-1, 3, 3)).reshape(*ori_shape, 3)
def matrix_to_axis_angle(R):
"""use pytorch3d so3_log_map
Args:
aa: (*, 3, 3)
Returns:
R: (*, 3)
"""
return quaternion_to_axis_angle(matrix_to_quaternion(R))
def ransac_PnP(K, pts_2d, pts_3d, err_thr=10):
"""solve pnp"""
dist_coeffs = np.zeros(shape=[8, 1], dtype="float64")
pts_2d = np.ascontiguousarray(pts_2d.astype(np.float64))
pts_3d = np.ascontiguousarray(pts_3d.astype(np.float64))
K = K.astype(np.float64)
try:
_, rvec, tvec, inliers = cv2.solvePnPRansac(
pts_3d, pts_2d, K, dist_coeffs, reprojectionError=err_thr, iterationsCount=10000, flags=cv2.SOLVEPNP_EPNP
)
rotation = cv2.Rodrigues(rvec)[0]
pose = np.concatenate([rotation, tvec], axis=-1)
pose_homo = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
inliers = [] if inliers is None else inliers
return pose, pose_homo, inliers
except cv2.error:
print("CV ERROR")
return np.eye(4)[:3], np.eye(4), []
def ransac_PnP_batch(K_raw, pts_2d, pts_3d, err_thr=10):
fit_R, fit_t = [], []
for b in range(K_raw.shape[0]):
pose, _, inliers = ransac_PnP(K_raw[b], pts_2d[b], pts_3d[b], err_thr=err_thr)
fit_R.append(pose[:3, :3])
fit_t.append(pose[:3, 3])
fit_R = np.stack(fit_R, axis=0)
fit_t = np.stack(fit_t, axis=0)
return fit_R, fit_t
def triangulate_point(Ts_w2c, c_p2d, **kwargs):
from hmr4d.utils.geo.triangulation import triangulate_persp
print("Deprecated, please import from hmr4d.utils.geo.triangulation")
return triangulate_persp(Ts_w2c, c_p2d, **kwargs)
def triangulate_point_ortho(Ts_w2c, c_p2d, **kwargs):
from hmr4d.utils.geo.triangulation import triangulate_ortho
print("Deprecated, please import from hmr4d.utils.geo.triangulation")
return triangulate_ortho(Ts_w2c, c_p2d, **kwargs)
def get_nearby_points(points, query_verts, padding=0.0, p=1):
"""
points: (S, 3)
query_verts: (V, 3)
"""
if p == 1:
max_xyz = query_verts.max(0)[0] + padding
min_xyz = query_verts.min(0)[0] - padding
idx = (((points - min_xyz) > 0).all(dim=-1) * ((points - max_xyz) < 0).all(dim=-1)).nonzero().squeeze(-1)
nearby_points = points[idx]
elif p == 2:
squared_dist, _, _ = knn.knn_points(points[None], query_verts[None], K=1, return_nn=False)
mask = squared_dist[0, :, 0] < padding**2 # (S,)
nearby_points = points[mask]
return nearby_points
def unproj_bbx_to_fst(bbx_lurb, K, near_z=0.5, far_z=12.5):
B = bbx_lurb.size(0)
uv = bbx_lurb[:, [[0, 1], [2, 1], [2, 3], [0, 3], [0, 1], [2, 1], [2, 3], [0, 3]]]
if isinstance(near_z, float):
z = uv.new([near_z] * 4 + [far_z] * 4).reshape(1, 8, 1).repeat(B, 1, 1)
else:
z = torch.cat([near_z[:, None, None].repeat(1, 4, 1), far_z[:, None, None].repeat(1, 4, 1)], dim=1)
c_frustum_points = unproject_p2d(uv, z, K) # (B, 8, 3)
return c_frustum_points
def convert_bbx_xys_to_lurb(bbx_xys):
"""
Args: bbx_xys (..., 3) -> bbx_lurb (..., 4)
"""
size = bbx_xys[..., 2:]
center = bbx_xys[..., :2]
lurb = torch.cat([center - size / 2, center + size / 2], dim=-1)
return lurb
def convert_lurb_to_bbx_xys(bbx_lurb):
"""
Args: bbx_lurb (..., 4) -> bbx_xys (..., 3) be aware that it is squared
"""
size = (bbx_lurb[..., 2:] - bbx_lurb[..., :2]).max(-1, keepdim=True)[0]
center = (bbx_lurb[..., :2] + bbx_lurb[..., 2:]) / 2
return torch.cat([center, size], dim=-1)
# ================== AZ/AY Transformations ================== #
def compute_T_ayf2az(joints, inverse=False):
"""
Args:
joints: (B, J, 3), in the start-frame, az-coordinate
Returns:
if inverse == False:
T_af2az: (B, 4, 4)
else :
T_az2af: (B, 4, 4)
"""
t_ayf2az = joints[:, 0, :].detach().clone()
t_ayf2az[:, 2] = 0 # do not modify z
RL_xy_h = joints[:, 1, [0, 1]] - joints[:, 2, [0, 1]] # (B, 2), hip point to left side
RL_xy_s = joints[:, 16, [0, 1]] - joints[:, 17, [0, 1]] # (B, 2), shoulder point to left side
RL_xy = RL_xy_h + RL_xy_s
I_mask = RL_xy.pow(2).sum(-1) < 1e-4 # do not rotate, when can't decided the face direction
if I_mask.sum() > 0:
Log.warn("{} samples can't decide the face direction".format(I_mask.sum()))
x_dir = F.pad(F.normalize(RL_xy, 2, -1), (0, 1), value=0) # (B, 3)
y_dir = torch.zeros_like(x_dir)
y_dir[..., 2] = 1
z_dir = torch.cross(x_dir, y_dir, dim=-1)
R_ayf2az = torch.stack([x_dir, y_dir, z_dir], dim=-1) # (B, 3, 3)
R_ayf2az[I_mask] = torch.eye(3).to(R_ayf2az)
if inverse:
R_az2ayf = R_ayf2az.transpose(1, 2) # (B, 3, 3)
t_az2ayf = -einsum(R_ayf2az, t_ayf2az, "b i j , b i -> b j") # (B, 3)
return transform_mat(R_az2ayf, t_az2ayf)
else:
return transform_mat(R_ayf2az, t_ayf2az)
def compute_T_ayfz2ay(joints, inverse=False):
"""
Args:
joints: (B, J, 3), in the start-frame, ay-coordinate
Returns:
if inverse == False:
T_ayfz2ay: (B, 4, 4)
else :
T_ay2ayfz: (B, 4, 4)
"""
t_ayfz2ay = joints[:, 0, :].detach().clone()
t_ayfz2ay[:, 1] = 0 # do not modify y
RL_xz_h = joints[:, 1, [0, 2]] - joints[:, 2, [0, 2]] # (B, 2), hip point to left side
RL_xz_s = joints[:, 16, [0, 2]] - joints[:, 17, [0, 2]] # (B, 2), shoulder point to left side
RL_xz = RL_xz_h + RL_xz_s
I_mask = RL_xz.pow(2).sum(-1) < 1e-4 # do not rotate, when can't decided the face direction
if I_mask.sum() > 0:
Log.warn("{} samples can't decide the face direction".format(I_mask.sum()))
x_dir = torch.zeros_like(t_ayfz2ay) # (B, 3)
x_dir[:, [0, 2]] = F.normalize(RL_xz, 2, -1)
y_dir = torch.zeros_like(x_dir)
y_dir[..., 1] = 1 # (B, 3)
z_dir = torch.cross(x_dir, y_dir, dim=-1)
R_ayfz2ay = torch.stack([x_dir, y_dir, z_dir], dim=-1) # (B, 3, 3)
R_ayfz2ay[I_mask] = torch.eye(3).to(R_ayfz2ay)
if inverse:
R_ay2ayfz = R_ayfz2ay.transpose(1, 2)
t_ay2ayfz = -einsum(R_ayfz2ay, t_ayfz2ay, "b i j , b i -> b j")
return transform_mat(R_ay2ayfz, t_ay2ayfz)
else:
return transform_mat(R_ayfz2ay, t_ayfz2ay)
def compute_T_ay2ayrot(joints):
"""
Args:
joints: (B, J, 3), in the start-frame, ay-coordinate
Returns:
T_ay2ayrot: (B, 4, 4)
"""
t_ayrot2ay = joints[:, 0, :].detach().clone()
t_ayrot2ay[:, 1] = 0 # do not modify y
B = joints.shape[0]
euler_angle = torch.zeros((B, 3), device=joints.device)
yrot_angle = torch.rand((B,), device=joints.device) * 2 * torch.pi
euler_angle[:, 0] = yrot_angle
R_ay2ayrot = euler_angles_to_matrix(euler_angle, "YXZ") # (B, 3, 3)
R_ayrot2ay = R_ay2ayrot.transpose(1, 2)
t_ay2ayrot = -einsum(R_ayrot2ay, t_ayrot2ay, "b i j , b i -> b j")
return transform_mat(R_ay2ayrot, t_ay2ayrot)
def compute_root_quaternion_ay(joints):
"""
Args:
joints: (B, J, 3), in the start-frame, ay-coordinate
Returns:
root_quat: (B, 4) from z-axis to fz
"""
joints_shape = joints.shape
joints = joints.reshape((-1,) + joints_shape[-2:])
t_ayfz2ay = joints[:, 0, :].detach().clone()
t_ayfz2ay[:, 1] = 0 # do not modify y
RL_xz_h = joints[:, 1, [0, 2]] - joints[:, 2, [0, 2]] # (B, 2), hip point to left side
RL_xz_s = joints[:, 16, [0, 2]] - joints[:, 17, [0, 2]] # (B, 2), shoulder point to left side
RL_xz = RL_xz_h + RL_xz_s
I_mask = RL_xz.pow(2).sum(-1) < 1e-4 # do not rotate, when can't decided the face direction
if I_mask.sum() > 0:
Log.warn("{} samples can't decide the face direction".format(I_mask.sum()))
x_dir = torch.zeros_like(t_ayfz2ay) # (B, 3)
x_dir[:, [0, 2]] = F.normalize(RL_xz, 2, -1)
y_dir = torch.zeros_like(x_dir)
y_dir[..., 1] = 1 # (B, 3)
z_dir = torch.cross(x_dir, y_dir, dim=-1)
z_dir[..., 2] += 1e-9
pos_z_vec = torch.tensor([0, 0, 1]).to(joints.device).float() # (3,)
root_quat = qbetween(pos_z_vec[None], z_dir) # (B, 4)
root_quat = root_quat.reshape(joints_shape[:-2] + (4,))
return root_quat
# ================== Transformations between two sets of features ================== #
def similarity_transform_batch(S1, S2):
"""
Computes a similarity transform (sR, t) that solves the orthogonal Procrutes problem.
Args:
S1, S2: (*, L, 3)
"""
assert S1.shape == S2.shape
S_shape = S1.shape
S1 = S1.reshape(-1, *S_shape[-2:])
S2 = S2.reshape(-1, *S_shape[-2:])
S1 = S1.transpose(-2, -1)
S2 = S2.transpose(-2, -1)
# --- The code is borrowed from WHAM ---
# 1. Remove mean.
mu1 = S1.mean(axis=-1, keepdims=True) # axis is along N, S1(B, 3, N)
mu2 = S2.mean(axis=-1, keepdims=True)
X1 = S1 - mu1
X2 = S2 - mu2
# 2. Compute variance of X1 used for scale.
var1 = torch.sum(X1**2, dim=1).sum(dim=1)
# 3. The outer product of X1 and X2.
K = X1.bmm(X2.permute(0, 2, 1))
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
# singular vectors of K.
U, s, V = torch.svd(K)
# Construct Z that fixes the orientation of R to get det(R)=1.
Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0)
Z = Z.repeat(U.shape[0], 1, 1)
Z[:, -1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0, 2, 1))))
# Construct R.
R = V.bmm(Z.bmm(U.permute(0, 2, 1)))
# 5. Recover scale.
scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1
# 6. Recover translation.
t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1)))
# -------
# reshape back
# sR = scale[:, None, None] * R
# sR = sR.reshape(*S_shape[:-2], 3, 3)
scale = scale.reshape(*S_shape[:-2], 1, 1)
R = R.reshape(*S_shape[:-2], 3, 3)
t = t.reshape(*S_shape[:-2], 3, 1)
return (scale, R), t
def kabsch_algorithm_batch(X1, X2):
"""
Computes a rigid transform (R, t)
Args:
X1, X2: (*, L, 3)
"""
assert X1.shape == X2.shape
X_shape = X1.shape
X1 = X1.reshape(-1, *X_shape[-2:])
X2 = X2.reshape(-1, *X_shape[-2:])
# 1. 计算质心
centroid_X1 = torch.mean(X1, dim=-2, keepdim=True)
centroid_X2 = torch.mean(X2, dim=-2, keepdim=True)
# 2. 去中心化
X1_centered = X1 - centroid_X1
X2_centered = X2 - centroid_X2
# 3. 计算协方差矩阵
H = torch.matmul(X1_centered.transpose(-2, -1), X2_centered)
# 4. 奇异值分解
U, S, Vt = torch.linalg.svd(H)
# 5. 计算旋转矩阵
R = torch.matmul(Vt.transpose(-2, -1), U.transpose(-2, -1))
# 修正反射矩阵
d = (torch.det(R) < 0).unsqueeze(-1).unsqueeze(-1)
Vt = torch.where(d, -Vt, Vt)
R = torch.matmul(Vt.transpose(-2, -1), U.transpose(-2, -1))
# 6. 计算平移向量
t = centroid_X2.transpose(-2, -1) - torch.matmul(R, centroid_X1.transpose(-2, -1))
# -------
# reshape back
R = R.reshape(*X_shape[:-2], 3, 3)
t = t.reshape(*X_shape[:-2], 3, 1)
return R, t
# ===== WHAM cam_angvel ===== #
def compute_cam_angvel(R_w2c, padding_last=True):
"""
R_w2c : (F, 3, 3)
"""
# R @ R0 = R1, so R = R1 @ R0^T
cam_angvel = matrix_to_rotation_6d(R_w2c[1:] @ R_w2c[:-1].transpose(-1, -2)) # (F-1, 6)
# cam_angvel = (cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]])) * FPS
assert padding_last
cam_angvel = torch.cat([cam_angvel, cam_angvel[-1:]], dim=0) # (F, 6)
return cam_angvel.float()
def ransac_gravity_vec(xyz, num_iterations=100, threshold=0.05, verbose=False):
# xyz: (L, 3)
N = xyz.shape[0]
max_inliers = []
best_model = None
norms = xyz.norm(dim=-1) # (L,)
for _ in range(num_iterations):
# 随机选择一个样本
sample_index = np.random.randint(N)
sample = xyz[sample_index] # (3,)
# 计算所有点与样本点的角度差
dot_product = (xyz * sample).sum(dim=-1) # (L,)
angles = dot_product / norms * norms[sample_index] # (L,)
angles = torch.clamp(angles, -1, 1) # 防止数值误差导致的异常
angles = torch.acos(angles)
# 确定内点
inliers = xyz[angles < threshold]
if len(inliers) > len(max_inliers):
max_inliers = inliers
best_model = sample
if len(max_inliers) == N:
break
if verbose:
print(f"Inliers: {len(max_inliers)} / {N}")
result = max_inliers.mean(dim=0)
return result, max_inliers
def sequence_best_cammat(w_j3d, c_j3d, cam_rot):
# get best camera estimation along the sequence, requires static camera
# w_j3d: (L, J, 3)
# c_j3d: (L, J, 3)
# cam_rot: (L, 3, 3)
L, J, _ = w_j3d.shape
root_in_w = w_j3d[:, 0] # (L, 3)
root_in_c = c_j3d[:, 0] # (L, 3)
cam_mat = matrix.get_TRS(cam_rot, root_in_w) # (L, 4, 4)
cam_pos = matrix.get_position_from(-root_in_c[:, None], cam_mat)[:, 0] # (L, 3)
cam_mat = matrix.set_position(cam_mat, cam_pos) # (L, 4, 4)
w_j3d_expand = w_j3d[None].expand(L, -1, -1, -1) # (L, L, J, 3)
w_j3d_expand = w_j3d_expand.reshape(L, -1, 3) # (L, L*J, 3)
# get reproject error
w_j3d_expand_in_c = matrix.get_relative_position_to(w_j3d_expand, cam_mat) # (L, L*J, 3)
w_j2d_expand_in_c = project_p2d(w_j3d_expand_in_c) # (L, L*J, 2)
w_j2d_expand_in_c = w_j2d_expand_in_c.reshape(L, L, J, 2) # (L, L, J, 2)
c_j2d = project_p2d(c_j3d) # (L, J, 2)
error = w_j2d_expand_in_c - c_j2d[None] # (L, L, J, 2)
error = error.norm(dim=-1).mean(dim=-1) # (L, L)
error = error.mean(dim=-1) # (L,)
ind = error.argmin()
return cam_mat[ind], ind
def get_sequence_cammat(w_j3d, c_j3d, cam_rot):
# w_j3d: (L, J, 3)
# c_j3d: (L, J, 3)
# cam_rot: (L, 3, 3)
L, J, _ = w_j3d.shape
root_in_w = w_j3d[:, 0] # (L, 3)
root_in_c = c_j3d[:, 0] # (L, 3)
cam_mat = matrix.get_TRS(cam_rot, root_in_w) # (L, 4, 4)
cam_pos = matrix.get_position_from(-root_in_c[:, None], cam_mat)[:, 0] # (L, 3)
cam_mat = matrix.set_position(cam_mat, cam_pos) # (L, 4, 4)
return cam_mat
def ransac_vec(vel, min_multiply=20, verbose=False):
# xyz: (L, 3)
# remove outlier velocity
N = vel.shape[0]
vel_1 = vel[None].expand(N, -1, -1) # (L, L, 3)
vel_2 = vel[:, None].expand(-1, N, -1) # (L, L, 3)
dist_mat = (vel_1 - vel_2).norm(dim=-1) # (L, L)
big_identity = torch.eye(N, device=vel.device) * 1e6
dist_mat_ = dist_mat + big_identity
threshold = dist_mat_.min() * min_multiply
inner_mask = dist_mat < threshold # (L, L)
inner_num = inner_mask.sum(dim=-1) # (L, )
ind = inner_num.argmax()
result = vel[inner_mask[ind]].mean(dim=0) # (3,)
if verbose:
print(inner_mask[ind].sum().item())
return result, inner_mask[ind]
================================================
FILE: eval/GVHMR/hmr4d/utils/ik/ccd_ik.py
================================================
# Sebastian IK
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import einsum, rearrange, repeat
from pytorch3d.transforms import (
matrix_to_rotation_6d,
rotation_6d_to_matrix,
axis_angle_to_matrix,
matrix_to_axis_angle,
quaternion_to_matrix,
matrix_to_quaternion,
)
import hmr4d.utils.matrix as matrix
from hmr4d.utils.geo.quaternion import qbetween, qslerp, qinv, qmul, qrot
class CCD_IK:
def __init__(
self,
local_mat,
parent,
target_ind,
target_pos=None,
target_rot=None,
kinematic_chain=None,
max_iter=2, # sebas sets 25 but with converged flag, 2 is enough
threshold=0.001,
pos_weight=1.0,
rot_weight=0.0, # this makes optimization unstable, although sebas uses 1.0
):
if kinematic_chain is None:
kinematic_chain = range(local_mat.shape[-3])
global_mat = matrix.forward_kinematics(local_mat, parent)
# get kinematic chain only local mat and assign root mat (do not modify root during IK)
local_mat = local_mat.clone()
local_mat = local_mat[..., kinematic_chain, :, :]
local_mat[..., 0, :, :] = global_mat[..., kinematic_chain[0], :, :]
parent = [i - 1 for i in range(len(kinematic_chain))]
self.local_mat = local_mat
self.global_mat = matrix.forward_kinematics(local_mat, parent) # (*, J, 4, 4)
self.parent = parent
self.target_ind = target_ind
if target_pos is not None:
self.target_pos = target_pos # (*, O, 3)
else:
self.target_pos = None
if target_rot is not None:
self.target_q = matrix_to_quaternion(target_rot) # (*, O, 4)
else:
self.target_q = None
self.threshold = threshold
self.J_N = self.local_mat.shape[-3]
self.target_N = len(target_ind)
self.max_iter = max_iter
self.pos_weight = pos_weight
self.rot_weight = rot_weight
def is_converged(self):
end_pos = matrix.get_position(self.global_mat)[..., self.target_ind, :] # (*, OJ, 3)
converged_mask = (self.target_pos - end_pos).norm(dim=-1) < self.threshold
self.converged_mask = converged_mask
if self.converged_mask.sum() > 0:
return False
return True
def solve(self):
for _ in range(self.max_iter):
# if self.is_converged():
# return self.local_mat
# do not optimize root, so start from 1
self.optimize(1)
return self.local_mat
def optimize(self, i):
# i: joint_i
if i == self.J_N - 1:
return
pos = matrix.get_position(self.global_mat)[..., i, :] # (*, 3)
rot = matrix.get_rotation(self.global_mat)[..., i, :, :] # (*, 3, 3)
quat = matrix_to_quaternion(rot) # (*, 4)
x_vec = torch.zeros((quat.shape[:-1] + (3,)), device=quat.device)
x_vec[..., 0] = 1.0
x_vec_sum = torch.zeros_like(x_vec)
y_vec = torch.zeros((quat.shape[:-1] + (3,)), device=quat.device)
y_vec[..., 1] = 1.0
y_vec_sum = torch.zeros_like(y_vec)
count = 0
for target_i, j in enumerate(self.target_ind):
if i >= j:
# do not optimise same joint or child joint of targets
continue
end_pos = matrix.get_position(self.global_mat)[..., j, :] # (*, 3)
end_rot = matrix.get_rotation(self.global_mat)[..., j, :, :] # (*, 3, 3)
end_quat = matrix_to_quaternion(end_rot) # (*, 4)
if self.target_pos is not None:
target_pos = self.target_pos[..., target_i, :] # (*, 3)
# Solve objective position
solved_pos_target_quat = qslerp(
quat,
qmul(qbetween(end_pos - pos, target_pos - pos), quat),
self.get_weight(i),
)
x_vec_sum += qrot(solved_pos_target_quat, x_vec)
y_vec_sum += qrot(solved_pos_target_quat, y_vec)
if self.pos_weight > 0:
count += 1
if self.target_q is not None:
if target_i < self.target_N - 1:
# multiple rot target makes more unstable, only keep the last one
continue
# optimize rotation target is not stable
target_q = self.target_q[..., target_i, :] # (*, 4)
# Solve objective rotation
solved_q_target_quat = qslerp(
quat,
qmul(qmul(target_q, qinv(end_quat)), quat),
self.get_weight(i),
)
x_vec_sum += qrot(solved_q_target_quat, x_vec) * self.rot_weight
y_vec_sum += qrot(solved_q_target_quat, y_vec) * self.rot_weight
if self.rot_weight > 0:
count += 1
if count > 0:
x_vec_avg = matrix.normalize(x_vec_sum / count)
y_vec_avg = matrix.normalize(y_vec_sum / count)
z_vec_avg = torch.cross(x_vec_avg, y_vec_avg, dim=-1)
solved_rot = torch.stack([x_vec_avg, y_vec_avg, z_vec_avg], dim=-1) # column
parent_rot = matrix.get_rotation(self.global_mat)[..., self.parent[i], :, :]
solved_local_rot = matrix.get_mat_BtoA(parent_rot, solved_rot)
self.local_mat[..., i, :-1, :-1] = solved_local_rot
self.global_mat = matrix.forward_kinematics(self.local_mat, self.parent)
self.optimize(i + 1)
def get_weight(self, i):
weight = (i + 1) / self.J_N
return weight
================================================
FILE: eval/GVHMR/hmr4d/utils/kpts/kp2d_utils.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import cv2
import numpy as np
# expose _taylor to outside
__all__ = ["keypoints_from_heatmaps"]
def _taylor(heatmap, coord):
"""Distribution aware coordinate decoding method.
Note:
- heatmap height: H
- heatmap width: W
Args:
heatmap (np.ndarray[H, W]): Heatmap of a particular joint type.
coord (np.ndarray[2,]): Coordinates of the predicted keypoints.
Returns:
np.ndarray[2,]: Updated coordinates.
"""
H, W = heatmap.shape[:2]
px, py = int(coord[0]), int(coord[1])
if 1 < px < W - 2 and 1 < py < H - 2:
dx = 0.5 * (heatmap[py][px + 1] - heatmap[py][px - 1])
dy = 0.5 * (heatmap[py + 1][px] - heatmap[py - 1][px])
dxx = 0.25 * (heatmap[py][px + 2] - 2 * heatmap[py][px] + heatmap[py][px - 2])
dxy = 0.25 * (
heatmap[py + 1][px + 1] - heatmap[py - 1][px + 1] - heatmap[py + 1][px - 1] + heatmap[py - 1][px - 1]
)
dyy = 0.25 * (heatmap[py + 2 * 1][px] - 2 * heatmap[py][px] + heatmap[py - 2 * 1][px])
derivative = np.array([[dx], [dy]])
hessian = np.array([[dxx, dxy], [dxy, dyy]])
if dxx * dyy - dxy**2 != 0:
hessianinv = np.linalg.inv(hessian)
offset = -hessianinv @ derivative
offset = np.squeeze(np.array(offset.T), axis=0)
coord += offset
return coord
def _get_max_preds(heatmaps):
"""Get keypoint predictions from score maps.
Note:
batch_size: N
num_keypoints: K
heatmap height: H
heatmap width: W
Args:
heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps.
Returns:
tuple: A tuple containing aggregated results.
- preds (np.ndarray[N, K, 2]): Predicted keypoint location.
- maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.
"""
assert isinstance(heatmaps, np.ndarray), "heatmaps should be numpy.ndarray"
assert heatmaps.ndim == 4, "batch_images should be 4-ndim"
N, K, _, W = heatmaps.shape
heatmaps_reshaped = heatmaps.reshape((N, K, -1))
idx = np.argmax(heatmaps_reshaped, 2).reshape((N, K, 1))
maxvals = np.amax(heatmaps_reshaped, 2).reshape((N, K, 1))
preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
preds[:, :, 0] = preds[:, :, 0] % W
preds[:, :, 1] = preds[:, :, 1] // W
preds = np.where(np.tile(maxvals, (1, 1, 2)) > 0.0, preds, -1)
return preds, maxvals
def post_dark_udp(coords, batch_heatmaps, kernel=3):
"""DARK post-pocessing. Implemented by udp. Paper ref: Huang et al. The
Devil is in the Details: Delving into Unbiased Data Processing for Human
Pose Estimation (CVPR 2020). Zhang et al. Distribution-Aware Coordinate
Representation for Human Pose Estimation (CVPR 2020).
Note:
- batch size: B
- num keypoints: K
- num persons: N
- height of heatmaps: H
- width of heatmaps: W
B=1 for bottom_up paradigm where all persons share the same heatmap.
B=N for top_down paradigm where each person has its own heatmaps.
Args:
coords (np.ndarray[N, K, 2]): Initial coordinates of human pose.
batch_heatmaps (np.ndarray[B, K, H, W]): batch_heatmaps
kernel (int): Gaussian kernel size (K) for modulation.
Returns:
np.ndarray([N, K, 2]): Refined coordinates.
"""
if not isinstance(batch_heatmaps, np.ndarray):
batch_heatmaps = batch_heatmaps.cpu().numpy()
B, K, H, W = batch_heatmaps.shape
N = coords.shape[0]
assert B == 1 or B == N
for heatmaps in batch_heatmaps:
for heatmap in heatmaps:
cv2.GaussianBlur(heatmap, (kernel, kernel), 0, heatmap)
np.clip(batch_heatmaps, 0.001, 50, batch_heatmaps)
np.log(batch_heatmaps, batch_heatmaps)
batch_heatmaps_pad = np.pad(batch_heatmaps, ((0, 0), (0, 0), (1, 1), (1, 1)), mode="edge").flatten()
index = coords[..., 0] + 1 + (coords[..., 1] + 1) * (W + 2)
index += (W + 2) * (H + 2) * np.arange(0, B * K).reshape(-1, K)
index = index.astype(int).reshape(-1, 1)
i_ = batch_heatmaps_pad[index]
ix1 = batch_heatmaps_pad[index + 1]
iy1 = batch_heatmaps_pad[index + W + 2]
ix1y1 = batch_heatmaps_pad[index + W + 3]
ix1_y1_ = batch_heatmaps_pad[index - W - 3]
ix1_ = batch_heatmaps_pad[index - 1]
iy1_ = batch_heatmaps_pad[index - 2 - W]
dx = 0.5 * (ix1 - ix1_)
dy = 0.5 * (iy1 - iy1_)
derivative = np.concatenate([dx, dy], axis=1)
derivative = derivative.reshape(N, K, 2, 1)
dxx = ix1 - 2 * i_ + ix1_
dyy = iy1 - 2 * i_ + iy1_
dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_)
hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1)
hessian = hessian.reshape(N, K, 2, 2)
hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2))
coords -= np.einsum("ijmn,ijnk->ijmk", hessian, derivative).squeeze()
return coords
def _gaussian_blur(heatmaps, kernel=11):
"""Modulate heatmap distribution with Gaussian.
sigma = 0.3*((kernel_size-1)*0.5-1)+0.8
sigma~=3 if k=17
sigma=2 if k=11;
sigma~=1.5 if k=7;
sigma~=1 if k=3;
Note:
- batch_size: N
- num_keypoints: K
- heatmap height: H
- heatmap width: W
Args:
heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps.
kernel (int): Gaussian kernel size (K) for modulation, which should
match the heatmap gaussian sigma when training.
K=17 for sigma=3 and k=11 for sigma=2.
Returns:
np.ndarray ([N, K, H, W]): Modulated heatmap distribution.
"""
assert kernel % 2 == 1
border = (kernel - 1) // 2
batch_size = heatmaps.shape[0]
num_joints = heatmaps.shape[1]
height = heatmaps.shape[2]
width = heatmaps.shape[3]
for i in range(batch_size):
for j in range(num_joints):
origin_max = np.max(heatmaps[i, j])
dr = np.zeros((height + 2 * border, width + 2 * border), dtype=np.float32)
dr[border:-border, border:-border] = heatmaps[i, j].copy()
dr = cv2.GaussianBlur(dr, (kernel, kernel), 0)
heatmaps[i, j] = dr[border:-border, border:-border].copy()
heatmaps[i, j] *= origin_max / np.max(heatmaps[i, j])
return heatmaps
def keypoints_from_heatmaps(
heatmaps,
center,
scale,
unbiased=False,
post_process="default",
kernel=11,
valid_radius_factor=0.0546875,
use_udp=False,
target_type="GaussianHeatmap",
):
"""Get final keypoint predictions from heatmaps and transform them back to
the image.
Note:
- batch size: N
- num keypoints: K
- heatmap height: H
- heatmap width: W
Args:
heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps.
center (np.ndarray[N, 2]): Center of the bounding box (x, y).
scale (np.ndarray[N, 2]): Scale of the bounding box
wrt height/width.
post_process (str/None): Choice of methods to post-process
heatmaps. Currently supported: None, 'default', 'unbiased',
'megvii'.
unbiased (bool): Option to use unbiased decoding. Mutually
exclusive with megvii.
Note: this arg is deprecated and unbiased=True can be replaced
by post_process='unbiased'
Paper ref: Zhang et al. Distribution-Aware Coordinate
Representation for Human Pose Estimation (CVPR 2020).
kernel (int): Gaussian kernel size (K) for modulation, which should
match the heatmap gaussian sigma when training.
K=17 for sigma=3 and k=11 for sigma=2.
valid_radius_factor (float): The radius factor of the positive area
in classification heatmap for UDP.
use_udp (bool): Use unbiased data processing.
target_type (str): 'GaussianHeatmap' or 'CombinedTarget'.
GaussianHeatmap: Classification target with gaussian distribution.
CombinedTarget: The combination of classification target
(response map) and regression target (offset map).
Paper ref: Huang et al. The Devil is in the Details: Delving into
Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
Returns:
tuple: A tuple containing keypoint predictions and scores.
- preds (np.ndarray[N, K, 2]): Predicted keypoint location in images.
- maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.
"""
# Avoid being affected
heatmaps = heatmaps.copy()
# detect conflicts
if unbiased:
assert post_process not in [False, None, "megvii"]
if post_process in ["megvii", "unbiased"]:
assert kernel > 0
if use_udp:
assert not post_process == "megvii"
# normalize configs
if post_process is False:
warnings.warn("post_process=False is deprecated, " "please use post_process=None instead", DeprecationWarning)
post_process = None
elif post_process is True:
if unbiased is True:
warnings.warn(
"post_process=True, unbiased=True is deprecated," " please use post_process='unbiased' instead",
DeprecationWarning,
)
post_process = "unbiased"
else:
warnings.warn(
"post_process=True, unbiased=False is deprecated, " "please use post_process='default' instead",
DeprecationWarning,
)
post_process = "default"
elif post_process == "default":
if unbiased is True:
warnings.warn(
"unbiased=True is deprecated, please use " "post_process='unbiased' instead", DeprecationWarning
)
post_process = "unbiased"
# start processing
if post_process == "megvii":
heatmaps = _gaussian_blur(heatmaps, kernel=kernel)
N, K, H, W = heatmaps.shape
if use_udp:
if target_type.lower() == "GaussianHeatMap".lower():
preds, maxvals = _get_max_preds(heatmaps)
preds = post_dark_udp(preds, heatmaps, kernel=kernel)
elif target_type.lower() == "CombinedTarget".lower():
for person_heatmaps in heatmaps:
for i, heatmap in enumerate(person_heatmaps):
kt = 2 * kernel + 1 if i % 3 == 0 else kernel
cv2.GaussianBlur(heatmap, (kt, kt), 0, heatmap)
# valid radius is in direct proportion to the height of heatmap.
valid_radius = valid_radius_factor * H
offset_x = heatmaps[:, 1::3, :].flatten() * valid_radius
offset_y = heatmaps[:, 2::3, :].flatten() * valid_radius
heatmaps = heatmaps[:, ::3, :]
preds, maxvals = _get_max_preds(heatmaps)
index = preds[..., 0] + preds[..., 1] * W
index += W * H * np.arange(0, N * K / 3)
index = index.astype(int).reshape(N, K // 3, 1)
preds += np.concatenate((offset_x[index], offset_y[index]), axis=2)
else:
raise ValueError("target_type should be either " "'GaussianHeatmap' or 'CombinedTarget'")
else:
preds, maxvals = _get_max_preds(heatmaps)
if post_process == "unbiased": # alleviate biased coordinate
# apply Gaussian distribution modulation.
heatmaps = np.log(np.maximum(_gaussian_blur(heatmaps, kernel), 1e-10))
for n in range(N):
for k in range(K):
preds[n][k] = _taylor(heatmaps[n][k], preds[n][k])
elif post_process is not None:
# add +/-0.25 shift to the predicted locations for higher acc.
for n in range(N):
for k in range(K):
heatmap = heatmaps[n][k]
px = int(preds[n][k][0])
py = int(preds[n][k][1])
if 1 < px < W - 1 and 1 < py < H - 1:
diff = np.array(
[heatmap[py][px + 1] - heatmap[py][px - 1], heatmap[py + 1][px] - heatmap[py - 1][px]]
)
preds[n][k] += np.sign(diff) * 0.25
if post_process == "megvii":
preds[n][k] += 0.5
# Transform back to the image
for i in range(N):
preds[i] = transform_preds(preds[i], center[i], scale[i], [W, H], use_udp=use_udp)
if post_process == "megvii":
maxvals = maxvals / 255.0 + 0.5
return preds, maxvals
def transform_preds(coords, center, scale, output_size, use_udp=False):
"""Get final keypoint predictions from heatmaps and apply scaling and
translation to map them back to the image.
Note:
num_keypoints: K
Args:
coords (np.ndarray[K, ndims]):
* If ndims=2, corrds are predicted keypoint location.
* If ndims=4, corrds are composed of (x, y, scores, tags)
* If ndims=5, corrds are composed of (x, y, scores, tags,
flipped_tags)
center (np.ndarray[2, ]): Center of the bounding box (x, y).
scale (np.ndarray[2, ]): Scale of the bounding box
wrt [width, height].
output_size (np.ndarray[2, ] | list(2,)): Size of the
destination heatmaps.
use_udp (bool): Use unbiased data processing
Returns:
np.ndarray: Predicted coordinates in the images.
"""
assert coords.shape[1] in (2, 4, 5)
assert len(center) == 2
assert len(scale) == 2
assert len(output_size) == 2
# Recover the scale which is normalized by a factor of 200.
scale = scale * 200.0
if use_udp:
scale_x = scale[0] / (output_size[0] - 1.0)
scale_y = scale[1] / (output_size[1] - 1.0)
else:
scale_x = scale[0] / output_size[0]
scale_y = scale[1] / output_size[1]
target_coords = np.ones_like(coords)
target_coords[:, 0] = coords[:, 0] * scale_x + center[0] - scale[0] * 0.5
target_coords[:, 1] = coords[:, 1] * scale_y + center[1] - scale[1] * 0.5
return target_coords
================================================
FILE: eval/GVHMR/hmr4d/utils/matrix.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from typing import List, Optional
import numpy as np
import math
def identity_mat(x=None, device="cpu", is_numpy=False):
if x is not None:
if isinstance(x, torch.Tensor):
mat = torch.eye(4, device=device)
mat = mat.repeat(x.shape[:-2] + (1, 1))
elif isinstance(x, np.ndarray):
mat = np.eye(4, dtype=np.float32)
if x is not None:
for _ in range(len(x.shape) - 2):
mat = mat[None]
mat = np.tile(mat, x.shape[:-2] + (1, 1))
else:
raise ValueError
else:
# (4, 4)
if is_numpy:
mat = np.eye(4, dtype=np.float32)
else:
mat = torch.eye(4, device=device)
return mat
def vec2mat(vec):
"""_summary_
Args:
vec (tensor): [12], pos, forward, up and right
Returns:
mat_world(tensor): [4, 4]
"""
# Assume bs = 1
v = np.tile(np.array([[0, 0, 0, 1]]), (1, 1))
if isinstance(vec, torch.Tensor):
v = torch.tensor(
v,
device=vec.device,
dtype=vec.dtype,
)
pos = vec[:3]
forward = vec[3:6]
up = vec[6:9]
right = vec[9:12]
if isinstance(vec, torch.Tensor):
mat_world = torch.stack([right, up, forward, pos], dim=-1)
mat_world = torch.cat([mat_world, v], dim=-2)
elif isinstance(vec, np.ndarray):
mat_world = np.stack([right, up, forward, pos], axis=-1)
mat_world = np.concatenate([mat_world, v], axis=-2)
else:
raise ValueError
mat_world = normalized_matrix(mat_world)
return mat_world
def mat2vec(mat):
"""_summary_
Args:
mat(tensor): [4, 4]
Returns:
vec (tensor): [12], pos, forward, up and right
"""
# Assume bs = 1
pos = mat[:-1, 3]
forward = normalized(mat[:-1, 2])
up = normalized(mat[:-1, 1])
right = normalized(mat[:-1, 0])
if isinstance(mat, torch.Tensor):
vec = torch.cat((pos, forward, up, right))
elif isinstance(mat, np.ndarray):
vec = np.concatenate((pos, forward, up, right))
else:
raise ValueError
return vec
def vec2mat_batch(vec):
"""_summary_
Args:
vec (tensor): [B, 12], pos, forward, up and right
Returns:
mat_world(tensor): [B, 4, 4]
"""
# Assume bs = 1
v = np.tile(np.array([[0, 0, 0, 1]], dtype=np.float32), (vec.shape[0], 1, 1))
if isinstance(vec, torch.Tensor):
v = torch.tensor(
v,
device=vec.device,
dtype=vec.dtype,
)
pos = vec[..., :3]
forward = vec[..., 3:6]
up = vec[..., 6:9]
right = vec[..., 9:12]
if isinstance(vec, torch.Tensor):
mat_world = torch.stack([right, up, forward, pos], dim=-1)
mat_world = torch.cat([mat_world, v], dim=-2)
elif isinstance(vec, np.ndarray):
mat_world = np.stack([right, up, forward, pos], axis=-1)
mat_world = np.concatenate([mat_world, v], axis=-2)
else:
raise ValueError
mat_world = normalized_matrix(mat_world)
return mat_world
def rotmat2tan_norm(mat):
"""_summary_
Args:
mat(tensor): [B, 3, 3]
Returns:
vec (tensor): [B, 6], tan norm
"""
if isinstance(mat, np.ndarray):
tan = np.zeros_like(mat[..., 2])
norm = np.zeros_like(mat[..., 0])
elif isinstance(mat, torch.Tensor):
tan = torch.zeros_like(mat[..., 2])
norm = torch.zeros_like(mat[..., 0])
else:
raise ValueError
tan[...] = mat[..., 2, ::-1]
tan[..., -1] *= -1
norm[...] = mat[..., 0, ::-1]
norm[..., -1] *= -1
if isinstance(mat, np.ndarray):
tan_norm = np.concatenate((tan, norm), axis=-1)
elif isinstance(mat, torch.Tensor):
tan_norm = torch.cat((tan, norm), dim=-1)
else:
raise ValueError
return tan_norm
def mat2tan_norm(mat):
"""_summary_
Args:
mat(tensor): [B, 4, 4]
Returns:
vec (tensor): [B, 6], tan norm
"""
rot_mat = mat[..., :-1, :-1]
return rotmat2tan_norm(rot_mat)
def rotmat2tan_norm(mat):
"""_summary_
Args:
mat(tensor): [B, 3, 3]
Returns:
vec (tensor): [B, 6], tan norm
"""
if isinstance(mat, np.ndarray):
tan = np.zeros_like(mat[..., 2])
norm = np.zeros_like(mat[..., 0])
tan[...] = mat[..., 2, ::-1]
norm[...] = mat[..., 0, ::-1]
elif isinstance(mat, torch.Tensor):
tan = torch.zeros_like(mat[..., 2])
norm = torch.zeros_like(mat[..., 0])
tan[...] = torch.flip(mat[..., 2], dims=[-1])
norm[...] = torch.flip(mat[..., 0], dims=[-1])
else:
raise ValueError
tan[..., -1] *= -1
norm[..., -1] *= -1
if isinstance(mat, np.ndarray):
tan_norm = np.concatenate((tan, norm), axis=-1)
elif isinstance(mat, torch.Tensor):
tan_norm = torch.cat((tan, norm), dim=-1)
else:
raise ValueError
return tan_norm
def tan_norm2rotmat(tan_norm):
"""_summary_
Args:
mat(tensor): [B, 6]
Returns:
vec (tensor): [B, 3]
"""
tan = copy.deepcopy(tan_norm[..., :3])
norm = copy.deepcopy(tan_norm[..., 3:])
tan[..., -1] *= -1
norm[..., -1] *= -1
if isinstance(tan_norm, np.ndarray):
rotmat = np.zeros(tan_norm.shape[:-1] + (3, 3))
tan = tan[..., ::-1]
norm = norm[..., ::-1]
other = np.cross(tan, norm)
elif isinstance(tan_norm, torch.Tensor):
rotmat = torch.zeros(tan_norm.shape[:-1] + (3, 3), device=tan_norm.device)
tan = torch.flip(tan, dims=[-1])
norm = torch.flip(norm, dims=[-1])
other = torch.cross(tan, norm)
else:
raise ValueError
rotmat[..., 2, :] = tan
rotmat[..., 0, :] = norm
rotmat[..., 1, :] = other
return rotmat
def rotmat332vec_batch(mat):
"""_summary_
Args:
mat(tensor): [B, 3, 3]
Returns:
vec (tensor): [B, 6], forward, up, right
"""
# Assume bs = 1
mat = normalized_matrix(mat)
forward = mat[..., :, 2]
up = mat[..., :, 1]
right = mat[..., :, 0]
if isinstance(mat, torch.Tensor):
vec = torch.cat((forward, up, right), dim=-1)
elif isinstance(mat, np.ndarray):
vec = np.concatenate((forward, up, right), axis=-1)
else:
raise ValueError
return vec
def rotmat2vec_batch(mat):
"""_summary_
Args:
mat(tensor): [B, 4, 4]
Returns:
vec (tensor): [B, 9], forward, up, right
"""
# Assume bs = 1
mat = normalized_matrix(mat)
forward = mat[..., :-1, 2]
up = mat[..., :-1, 1]
right = mat[..., :-1, 0]
if isinstance(mat, torch.Tensor):
vec = torch.cat((forward, up, right), dim=-1)
elif isinstance(mat, np.ndarray):
vec = np.concatenate((forward, up, right), axis=-1)
else:
raise ValueError
return vec
def mat2vec_batch(mat):
"""_summary_
Args:
mat(tensor): [B, 4, 4]
Returns:
vec (tensor): [B, 12], pos, forward, up and right
"""
# Assume bs = 1
mat = normalized_matrix(mat)
pos = mat[..., :-1, 3]
forward = mat[..., :-1, 2]
up = mat[..., :-1, 1]
right = mat[..., :-1, 0]
if isinstance(mat, torch.Tensor):
vec = torch.cat((pos, forward, up, right), dim=-1)
elif isinstance(mat, np.ndarray):
vec = np.concatenate((pos, forward, up, right), axis=-1)
else:
raise ValueError
return vec
def mat2pose_batch(mat, returnvel=True):
"""_summary_
Args:
mat(tensor): [B, 4, 4]
Returns:
vec (tensor): [B, 12], pos, forward, up, zeros
"""
# Assume bs = 1
mat = normalized_matrix(mat)
pos = mat[..., :-1, 3]
forward = mat[..., :-1, 2]
up = mat[..., :-1, 1]
if isinstance(mat, torch.Tensor):
if returnvel:
vel = torch.zeros_like(up)
vec = torch.cat((pos, forward, up, vel), dim=-1)
else:
vec = torch.cat((pos, forward, up), dim=-1)
elif isinstance(mat, np.ndarray):
if returnvel:
vel = np.zeros_like(up)
vec = np.concatenate((pos, forward, up, vel), axis=-1)
else:
vec = np.concatenate((pos, forward, up), axis=-1)
else:
raise ValueError
return vec
def get_mat_BinA(matCtoA, matCtoB):
"""
given matrix of the same object in two coordinate A and B,
return matrix B in the coordinate of A
Args:
matCtoA (tensor): [4, 4] world matrix
matCtoB (tensor): [4, 4] world matrix
"""
if isinstance(matCtoA, torch.Tensor):
matCtoB_inv = torch.inverse(matCtoB)
elif isinstance(matCtoA, np.ndarray):
matCtoB_inv = np.linalg.inv(matCtoB)
else:
raise ValueError
matCtoB_inv = normalized_matrix(matCtoB_inv)
if isinstance(matCtoA, torch.Tensor):
mat_BtoA = torch.matmul(matCtoA, matCtoB_inv)
elif isinstance(matCtoA, np.ndarray):
mat_BtoA = np.matmul(matCtoA, matCtoB_inv)
mat_BtoA = normalized_matrix(mat_BtoA)
return mat_BtoA
def get_mat_BtoA(matA, matB):
"""
return matrix B in the coordinate of A
Args:
matA (tensor): [4, 4] world matrix
matB (tensor): [4, 4] world matrix
"""
if isinstance(matA, torch.Tensor):
matA_inv = torch.inverse(matA)
elif isinstance(matA, np.ndarray):
matA_inv = np.linalg.inv(matA)
else:
raise ValueError
matA_inv = normalized_matrix(matA_inv)
if isinstance(matA, torch.Tensor):
mat_BtoA = torch.matmul(matA_inv, matB)
elif isinstance(matA, np.ndarray):
mat_BtoA = np.matmul(matA_inv, matB)
mat_BtoA = normalized_matrix(mat_BtoA)
return mat_BtoA
def get_mat_BfromA(matA, matBtoA):
"""
return world matrix B given matrix A and mat B realtive to A
Args:
matA (_type_): [4, 4] world matrix
matBtoA (_type_): [4, 4] matrix B relative to A
"""
if isinstance(matA, torch.Tensor):
matB = torch.matmul(matA, matBtoA)
if isinstance(matA, np.ndarray):
matB = np.matmul(matA, matBtoA)
matB = normalized_matrix(matB)
return matB
def get_relative_position_to(pos, mat):
"""_summary_
Args:
pos (_type_): [N, M, 3] or [N, 3]
mat (_type_): [N, 4, 4] or [4, 4]
Returns:
_type_: _description_
"""
if isinstance(mat, torch.Tensor):
mat_inv = torch.inverse(mat)
elif isinstance(mat, np.ndarray):
mat_inv = np.linalg.inv(mat)
else:
raise ValueError
mat_inv = normalized_matrix(mat_inv)
if isinstance(mat, torch.Tensor):
rot_pos = torch.matmul(mat_inv[..., :-1, :-1], pos.transpose(-1, -2)).transpose(-1, -2)
elif isinstance(mat, np.ndarray):
rot_pos = np.matmul(mat_inv[..., :-1, :-1], pos.swapaxes(-1, -2)).swapaxes(-1, -2)
world_pos = rot_pos + mat_inv[..., None, :-1, 3]
return world_pos
def get_rotation(mat):
"""_summary_
Args:
mat (_type_): [..., 4, 4]
Returns:
_type_: _description_
"""
return mat[..., :-1, :-1]
def set_rotation(mat, rotmat):
"""_summary_
Args:
mat (_type_): [..., 4, 4]
Returns:
_type_: _description_
"""
mat[..., :-1, :-1] = rotmat
return mat
def set_position(mat, pos):
"""_summary_
Args:
mat (_type_): [..., 4, 4]
Returns:
_type_: _description_
"""
mat[..., :-1, 3] = pos
return mat
def get_position(mat):
"""_summary_
Args:
mat (_type_): [..., 4, 4]
Returns:
_type_: _description_
"""
return mat[..., :-1, 3]
def get_position_from(pos, mat):
"""_summary_
Args:
pos (_type_): [N, M, 3] or [N, 3]
mat (_type_): [N, 4, 4] or [4, 4]
Returns:
_type_: _description_
"""
if isinstance(mat, torch.Tensor):
rot_pos = torch.matmul(mat[..., :-1, :-1], pos.transpose(-1, -2)).transpose(-1, -2)
elif isinstance(mat, np.ndarray):
rot_pos = np.matmul(mat[..., :-1, :-1], pos.swapaxes(-1, -2)).swapaxes(-1, -2)
else:
raise ValueError
world_pos = rot_pos + mat[..., None, :-1, 3]
return world_pos
def get_position_from_rotmat(pos, mat):
"""_summary_
Args:
pos (_type_): [N, M, 3] or [N, 3]
mat (_type_): [N, 4, 4] or [4, 4]
Returns:
_type_: _description_
"""
if isinstance(mat, torch.Tensor):
rot_pos = torch.matmul(mat, pos.transpose(-1, -2)).transpose(-1, -2)
elif isinstance(mat, np.ndarray):
rot_pos = np.matmul(mat, pos.swapaxes(-1, -2)).swapaxes(-1, -2)
else:
raise ValueError
return rot_pos
def get_relative_direction_to(dir, mat):
"""_summary_
Args:
dir (_type_): [N, M, 3] or [N, 3]
mat (_type_): [N, 4, 4] or [4, 4]
Returns:
_type_: _description_
"""
if isinstance(mat, torch.Tensor):
mat_inv = torch.inverse(mat)
elif isinstance(mat, np.ndarray):
mat_inv = np.linalg.inv(mat)
else:
raise ValueError
mat_inv = normalized_matrix(mat_inv)
rot_mat_inv = mat_inv[..., :3, :3]
if isinstance(mat, torch.Tensor):
rel_dir = torch.matmul(rot_mat_inv, dir.transpose(-1, -2))
return rel_dir.transpose(-1, -2)
elif isinstance(mat, np.ndarray):
rel_dir = np.matmul(rot_mat_inv, dir.swapaxes(-1, -2))
return rel_dir.swapaxes(-1, -2)
else:
raise ValueError
return
def get_direction_from(dir, mat):
"""_summary_
Args:
dir (_type_): [N, M, 3] or [N, 3]
mat (_type_): [N, 4, 4] or [4, 4]
Returns:
tensor: [N, M, 3] or [N, 3]
"""
rot_mat = mat[..., :3, :3]
if isinstance(mat, torch.Tensor):
world_dir = torch.matmul(rot_mat, dir.transpose(-1, -2))
return world_dir.transpose(-1, -2)
elif isinstance(mat, np.ndarray):
world_dir = np.matmul(rot_mat, dir.swapaxes(-1, -2))
return world_dir.swapaxes(-1, -2)
else:
raise ValueError
return
def get_coord_vis(pos, rot_mat, scale=1.0):
forward = rot_mat[..., :, 2]
up = rot_mat[..., :, 1]
right = rot_mat[..., :, 0]
return pos + right * scale, pos + up * scale, pos + forward * scale
def project_vec(vec):
"""_summary_
Args:
vec (tensor): [*, 12], pos, forward, up and right
Returns:
proj_vec (tensor): [*, 4], posx, posz, forwardx, forwardz
"""
posx = vec[..., 0:1]
posz = vec[..., 2:3]
forwardx = vec[..., 3:4]
forwardz = vec[..., 5:6]
if isinstance(vec, torch.Tensor):
proj_vec = torch.cat((posx, posz, forwardx, forwardz), dim=-1)
elif isinstance(vec, np.ndarray):
proj_vec = np.concatenate((posx, posz, forwardx, forwardz), axis=-1)
else:
raise ValueError
return proj_vec
def xz2xyz(vec):
x = vec[..., 0:1]
z = vec[..., 1:2]
if isinstance(vec, torch.Tensor):
y = torch.zeros(vec.shape[:-1] + (1,), device=vec.device)
xyz_vec = torch.cat((x, y, z), dim=-1)
elif isinstance(vec, np.ndarray):
y = np.zeros(vec.shape[:-1] + (1,))
xyz_vec = np.concatenate((x, y, z), axis=-1)
else:
raise ValueError
return xyz_vec
def normalized(vec):
if isinstance(vec, torch.Tensor):
norm_vec = vec / (vec.norm(2, dim=-1, keepdim=True) + 1e-9)
elif isinstance(vec, np.ndarray):
norm_vec = vec / (np.linalg.norm(vec, ord=2, axis=-1, keepdims=True) + 1e-9)
else:
raise ValueError
return norm_vec
def normalized_matrix(mat):
if mat.shape[-1] == 4:
rot_mat = mat[..., :-1, :-1]
else:
rot_mat = mat
if isinstance(mat, torch.Tensor):
rot_mat_norm = rot_mat / (rot_mat.norm(2, dim=-2, keepdim=True) + 1e-9)
norm_mat = torch.zeros_like(mat)
elif isinstance(mat, np.ndarray):
rot_mat_norm = rot_mat / (np.linalg.norm(rot_mat, ord=2, axis=-2, keepdims=True) + 1e-9)
norm_mat = np.zeros_like(mat)
else:
raise ValueError
if mat.shape[-1] == 4:
norm_mat[..., :-1, :-1] = rot_mat_norm
norm_mat[..., :-1, -1] = mat[..., :-1, -1]
norm_mat[..., -1, -1] = 1.0
else:
norm_mat = rot_mat_norm
return norm_mat
def get_rot_mat_from_forward(forward):
"""_summary_
Args:
forward (tensor): [N, M, 3]
Returns:
mat (tensor): [N, M, 3, 3]
"""
if isinstance(forward, torch.Tensor):
mat = torch.eye(3, device=forward.device).repeat(forward.shape[:-1] + (1, 1))
right = torch.zeros_like(forward)
elif isinstance(forward, np.ndarray):
mat = np.eye(3, dtype=np.float32)
for _ in range(len(forward.shape) - 1):
mat = mat[None]
mat = np.tile(mat, forward.shape[:-1] + (1, 1))
right = np.zeros_like(forward)
else:
raise ValueError
right[..., 0] = forward[..., 2]
right[..., 1] = 0.0
right[..., 2] = -forward[..., 0]
# right = torch.cross(mat[..., 1], forward) # cannot backward
mat[..., 2] = normalized(forward)
right = normalized(right)
mat[..., 0] = right
return mat
def get_rot_mat_from_forward_up(forward, up):
"""_summary_
Args:
forward (tensor): [N, M, 3]
up (tensor): [N, M, 3]
Returns:
mat (tensor): [N, M, 3, 3]
"""
if isinstance(forward, torch.Tensor):
mat = torch.eye(3, device=forward.device).repeat(forward.shape[:-1] + (1, 1))
right = torch.cross(up, forward)
elif isinstance(forward, np.ndarray):
mat = np.eye(3, dtype=np.float32)
for _ in range(len(forward.shape) - 1):
mat = mat[None]
mat = np.tile(mat, forward.shape[:-1] + (1, 1))
right = np.cross(up, forward)
else:
raise ValueError
right = normalized(right)
mat[..., 2] = normalized(forward)
mat[..., 1] = normalized(up)
mat[..., 0] = right
return mat
def get_rot_mat_from_pose_vec(vec):
"""_summary_
Args:
vec (tensor): [N, M, 6]
Returns:
mat (tensor): [N, M, 3, 3]
"""
forward = vec[..., :3]
up = vec[..., 3:6]
return get_rot_mat_from_forward_up(forward, up)
def get_TRS(rot_mat, pos):
"""_summary_
Args:
rot_mat (tensor): [N, 3, 3]
pos (tensor): [N, 3]
Returns:
mat (tensor): [N, 4, 4]
"""
if isinstance(rot_mat, torch.Tensor):
mat = torch.eye(4, device=pos.device).repeat(pos.shape[:-1] + (1, 1))
elif isinstance(rot_mat, np.ndarray):
mat = np.eye(4, dtype=np.float32)
for _ in range(len(pos.shape) - 1):
mat = mat[None]
mat = np.tile(mat, pos.shape[:-1] + (1, 1))
else:
raise ValueError
mat[..., :3, :3] = rot_mat
mat[..., :3, 3] = pos
mat = normalized_matrix(mat)
return mat
def xzvec2mat(vec):
"""_summary_
Args:
vec (tensor): [N, 4]
Returns:
mat (tensor): [N, 4, 4]
"""
vec_shape = vec.shape[:-1]
if isinstance(vec, torch.Tensor):
pos = torch.zeros(vec_shape + (3,))
forward = torch.zeros(vec_shape + (3,))
elif isinstance(vec, np.ndarray):
pos = np.zeros(vec_shape + (3,))
forward = np.zeros(vec_shape + (3,))
else:
raise ValueError
pos[..., 0] = vec[..., 0]
pos[..., 2] = vec[..., 1]
forward[..., 0] = vec[..., 2]
forward[..., 2] = vec[..., 3]
rot_mat = get_rot_mat_from_forward(forward)
mat = get_TRS(rot_mat, pos)
return mat
def distance(vec1, vec2):
return ((vec1 - vec2) ** 2).sum() ** 0.5
def get_relative_pose_from_vec(pose, root, N):
root_p_mat = xzvec2mat(root)
pose = pose.reshape(-1, N, 12)
pose[..., :3] = get_position_from(pose[..., :3], root_p_mat)
pose[..., 3:6] = get_direction_from(pose[..., 3:6], root_p_mat)
pose[..., 6:9] = get_direction_from(pose[..., 6:9], root_p_mat)
pose[..., 9:] = get_direction_from(pose[..., 9:], root_p_mat)
pos = pose[..., 0, :3]
rot = pose[..., 3:9].reshape(-1, N * 6)
pose = np.concatenate((pos, rot), axis=-1)
return pose
def get_forward_from_pos(pos):
"""_summary_
Args:
pos (N, J, 3): joints positions of each frame
Returns:
_type_: _description_
"""
pos_y_vec = torch.tensor([0, 1, 0], dtype=torch.float32).to(pos.device)
face_joint_indx = [2, 1, 17, 16]
r_hip, l_hip, r_sdr, l_sdr = face_joint_indx # use hip and shoulder to get the cross vector
cross_hip = pos[..., 0, r_hip, :] - pos[..., 0, l_hip, :]
cross_sdr = pos[..., 0, r_sdr, :] - pos[..., 0, l_sdr, :]
cross_vec = cross_hip + cross_sdr # (3, )
forward_vec = torch.cross(pos_y_vec, cross_vec, dim=-1)
forward_vec = normalized(forward_vec)
return forward_vec
def project_point_along_ray(p, ray, keepnorm=False):
"""_summary_
Args:
p (*, 3): point positions
ray (*, 3): ray direction
keepnorm: False -> project point on the ray,
True -> project point on the ray and keep the point length
Returns:
_type_: _description_
"""
ray = normalized(ray)
if keepnorm:
new_p = ray * p.norm(dim=-1, keepdim=True)
else:
dot_product = torch.sum(p * ray, dim=-1, keepdim=True)
new_p = dot_product * ray
return new_p
def solve_point_along_ray_with_constraint(c, ray, p, constraint="x"):
"""_summary_
Args:
c (*,): constraint value
ray (*, 3): ray direction
p (*, 3): start point of the ray
Returns:
_type_: _description_
"""
ray = normalized(ray)
if constraint == "x":
ind = 0
elif constraint == "y":
ind = 1
elif constraint == "z":
ind = 2
else:
raise ValueError
t = (c - p[..., ind]) / ray[..., ind]
out_p = ray * t[..., None] + p
return out_p
def calc_cosine(vec1, vec2, return_angle=False):
"""_summary_
Args:
vec1 (*, 3): vector
vec2 (*, 3): vector
return_angle: True -> return angle, False -> return cosine
Returns:
_type_: _description_
"""
vec1 = normalized(vec1)
vec2 = normalized(vec2)
cosine = torch.sum(vec1 * vec2, dim=-1)
if return_angle:
return torch.acos(cosine)
return cosine
############################################
#
# quaternion assumes xyzw
#
############################################
def quat_xyzw2wxyz(quat):
new_quat = torch.cat([quat[..., 3:4], quat[..., :3]], dim=-1)
return new_quat
def quat_wxyz2xyzw(quat):
new_quat = torch.cat([quat[..., 1:4], quat[..., :1]], dim=-1)
return new_quat
def quat_mul(a, b):
"""
quaternion multiplication
"""
x1, y1, z1, w1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
x2, y2, z2, w2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2
z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2
return torch.stack([x, y, z, w], dim=-1)
def quat_pos(x):
"""
make all the real part of the quaternion positive
"""
q = x
z = (q[..., 3:] < 0).float()
q = (1 - 2 * z) * q
return q
def quat_abs(x):
"""
quaternion norm (unit quaternion represents a 3D rotation, which has norm of 1)
"""
x = x.norm(p=2, dim=-1)
return x
def quat_unit(x):
"""
normalized quaternion with norm of 1
"""
norm = quat_abs(x).unsqueeze(-1)
return x / (norm.clamp(min=1e-4))
def quat_conjugate(x):
"""
quaternion with its imaginary part negated
"""
return torch.cat([-x[..., :3], x[..., 3:]], dim=-1)
def quat_real(x):
"""
real component of the quaternion
"""
return x[..., 3]
def quat_imaginary(x):
"""
imaginary components of the quaternion
"""
return x[..., :3]
def quat_norm_check(x):
"""
verify that a quaternion has norm 1
"""
assert bool((abs(x.norm(p=2, dim=-1) - 1) < 1e-3).all()), "the quaternion is has non-1 norm: {}".format(
abs(x.norm(p=2, dim=-1) - 1)
)
assert bool((x[..., 3] >= 0).all()), "the quaternion has negative real part"
def quat_normalize(q):
"""
Construct 3D rotation from quaternion (the quaternion needs not to be normalized).
"""
q = quat_unit(quat_pos(q)) # normalized to positive and unit quaternion
return q
def quat_from_xyz(xyz):
"""
Construct 3D rotation from the imaginary component
"""
w = (1.0 - xyz.norm()).unsqueeze(-1)
assert bool((w >= 0).all()), "xyz has its norm greater than 1"
return torch.cat([xyz, w], dim=-1)
def quat_identity(shape: List[int]):
"""
Construct 3D identity rotation given shape
"""
w = torch.ones(shape + (1,))
xyz = torch.zeros(shape + (3,))
q = torch.cat([xyz, w], dim=-1)
return quat_normalize(q)
def tgm_quat_from_angle_axis(angle, axis, degree: bool = False):
"""Create a 3D rotation from angle and axis of rotation. The rotation is counter-clockwise
along the axis.
The rotation can be interpreted as a_R_b where frame "b" is the new frame that
gets rotated counter-clockwise along the axis from frame "a"
:param angle: angle of rotation
:type angle: Tensor
:param axis: axis of rotation
:type axis: Tensor
:param degree: put True here if the angle is given by degree
:type degree: bool, optional, default=False
"""
if degree:
angle = angle / 180.0 * math.pi
theta = (angle / 2).unsqueeze(-1)
axis = axis / (axis.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-4))
xyz = axis * theta.sin()
w = theta.cos()
return quat_normalize(torch.cat([w, xyz], dim=-1))
def quat_from_rotation_matrix(m):
"""
Construct a 3D rotation from a valid 3x3 rotation matrices.
Reference can be found here:
http://www.cg.info.hiroshima-cu.ac.jp/~miyazaki/knowledge/teche52.html
:param m: 3x3 orthogonal rotation matrices.
:type m: Tensor
:rtype: Tensor
"""
m = m.unsqueeze(0)
diag0 = m[..., 0, 0]
diag1 = m[..., 1, 1]
diag2 = m[..., 2, 2]
# Math stuff.
w = (((diag0 + diag1 + diag2 + 1.0) / 4.0).clamp(0.0, None)) ** 0.5
x = (((diag0 - diag1 - diag2 + 1.0) / 4.0).clamp(0.0, None)) ** 0.5
y = (((-diag0 + diag1 - diag2 + 1.0) / 4.0).clamp(0.0, None)) ** 0.5
z = (((-diag0 - diag1 + diag2 + 1.0) / 4.0).clamp(0.0, None)) ** 0.5
# Only modify quaternions where w > x, y, z.
c0 = (w >= x) & (w >= y) & (w >= z)
x[c0] *= (m[..., 2, 1][c0] - m[..., 1, 2][c0]).sign()
y[c0] *= (m[..., 0, 2][c0] - m[..., 2, 0][c0]).sign()
z[c0] *= (m[..., 1, 0][c0] - m[..., 0, 1][c0]).sign()
# Only modify quaternions where x > w, y, z
c1 = (x >= w) & (x >= y) & (x >= z)
w[c1] *= (m[..., 2, 1][c1] - m[..., 1, 2][c1]).sign()
y[c1] *= (m[..., 1, 0][c1] + m[..., 0, 1][c1]).sign()
z[c1] *= (m[..., 0, 2][c1] + m[..., 2, 0][c1]).sign()
# Only modify quaternions where y > w, x, z.
c2 = (y >= w) & (y >= x) & (y >= z)
w[c2] *= (m[..., 0, 2][c2] - m[..., 2, 0][c2]).sign()
x[c2] *= (m[..., 1, 0][c2] + m[..., 0, 1][c2]).sign()
z[c2] *= (m[..., 2, 1][c2] + m[..., 1, 2][c2]).sign()
# Only modify quaternions where z > w, x, y.
c3 = (z >= w) & (z >= x) & (z >= y)
w[c3] *= (m[..., 1, 0][c3] - m[..., 0, 1][c3]).sign()
x[c3] *= (m[..., 2, 0][c3] + m[..., 0, 2][c3]).sign()
y[c3] *= (m[..., 2, 1][c3] + m[..., 1, 2][c3]).sign()
return quat_normalize(torch.stack([x, y, z, w], dim=-1)).squeeze(0)
def quat_mul_norm(x, y):
"""
Combine two set of 3D rotations together using \**\* operator. The shape needs to be
broadcastable
"""
return quat_normalize(quat_mul(x, y))
def quat_rotate(rot, vec):
"""
Rotate a 3D vector with the 3D rotation
"""
other_q = torch.cat([vec, torch.zeros_like(vec[..., :1])], dim=-1)
return quat_imaginary(quat_mul(quat_mul(rot, other_q), quat_conjugate(rot)))
def quat_inverse(x):
"""
The inverse of the rotation
"""
return quat_conjugate(x)
def quat_identity_like(x):
"""
Construct identity 3D rotation with the same shape
"""
return quat_identity(x.shape[:-1])
def quat_angle_axis(x):
"""
The (angle, axis) representation of the rotation. The axis is normalized to unit length.
The angle is guaranteed to be between [0, pi].
"""
s = 2 * (x[..., 3] ** 2) - 1
angle = s.clamp(-1, 1).arccos() # just to be safe
axis = x[..., :3]
axis /= axis.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-4)
return angle, axis
def quat_yaw_rotation(x, z_up: bool = True):
"""
Yaw rotation (rotation along z-axis)
"""
q = x
if z_up:
q = torch.cat([torch.zeros_like(q[..., 0:2]), q[..., 2:3], q[..., 3:]], dim=-1)
else:
q = torch.cat(
[
torch.zeros_like(q[..., 0:1]),
q[..., 1:2],
torch.zeros_like(q[..., 2:3]),
q[..., 3:4],
],
dim=-1,
)
return quat_normalize(q)
def transform_from_rotation_translation(r: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None):
"""
Construct a transform from a quaternion and 3D translation. Only one of them can be None.
"""
assert r is not None or t is not None, "rotation and translation can't be all None"
if r is None:
assert t is not None
r = quat_identity(list(t.shape))
if t is None:
t = torch.zeros(list(r.shape) + [3])
return torch.cat([r, t], dim=-1)
def transform_identity(shape: List[int]):
"""
Identity transformation with given shape
"""
r = quat_identity(shape)
t = torch.zeros(shape + [3])
return transform_from_rotation_translation(r, t)
def transform_rotation(x):
"""Get rotation from transform"""
return x[..., :4]
def transform_translation(x):
"""Get translation from transform"""
return x[..., 4:]
def transform_inverse(x):
"""
Inverse transformation
"""
inv_so3 = quat_inverse(transform_rotation(x))
return transform_from_rotation_translation(r=inv_so3, t=quat_rotate(inv_so3, -transform_translation(x)))
def transform_identity_like(x):
"""
identity transformation with the same shape
"""
return transform_identity(x.shape)
def transform_mul(x, y):
"""
Combine two transformation together
"""
z = transform_from_rotation_translation(
r=quat_mul_norm(transform_rotation(x), transform_rotation(y)),
t=quat_rotate(transform_rotation(x), transform_translation(y)) + transform_translation(x),
)
return z
def transform_apply(rot, vec):
"""
Transform a 3D vector
"""
assert isinstance(vec, torch.Tensor)
return quat_rotate(transform_rotation(rot), vec) + transform_translation(rot)
def rot_matrix_det(x):
"""
Return the determinant of the 3x3 matrix. The shape of the tensor will be as same as the
shape of the matrix
"""
a, b, c = x[..., 0, 0], x[..., 0, 1], x[..., 0, 2]
d, e, f = x[..., 1, 0], x[..., 1, 1], x[..., 1, 2]
g, h, i = x[..., 2, 0], x[..., 2, 1], x[..., 2, 2]
t1 = a * (e * i - f * h)
t2 = b * (d * i - f * g)
t3 = c * (d * h - e * g)
return t1 - t2 + t3
def rot_matrix_integrity_check(x):
"""
Verify that a rotation matrix has a determinant of one and is orthogonal
"""
det = rot_matrix_det(x)
assert bool((abs(det - 1) < 1e-3).all()), "the matrix has non-one determinant"
rtr = x @ x.permute(torch.arange(x.dim() - 2), -1, -2)
rtr_gt = rtr.zeros_like()
rtr_gt[..., 0, 0] = 1
rtr_gt[..., 1, 1] = 1
rtr_gt[..., 2, 2] = 1
assert bool(((rtr - rtr_gt) < 1e-3).all()), "the matrix is not orthogonal"
def rot_matrix_from_quaternion(q):
"""
Construct rotation matrix from quaternion
"""
# Shortcuts for individual elements (using wikipedia's convention)
qi, qj, qk, qr = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
# Set individual elements
R00 = 1.0 - 2.0 * (qj**2 + qk**2)
R01 = 2 * (qi * qj - qk * qr)
R02 = 2 * (qi * qk + qj * qr)
R10 = 2 * (qi * qj + qk * qr)
R11 = 1.0 - 2.0 * (qi**2 + qk**2)
R12 = 2 * (qj * qk - qi * qr)
R20 = 2 * (qi * qk - qj * qr)
R21 = 2 * (qj * qk + qi * qr)
R22 = 1.0 - 2.0 * (qi**2 + qj**2)
R0 = torch.stack([R00, R01, R02], dim=-1)
R1 = torch.stack([R10, R11, R12], dim=-1)
R2 = torch.stack([R20, R21, R22], dim=-1)
R = torch.stack([R0, R1, R2], dim=-2)
return R
def euclidean_to_rotation_matrix(x):
"""
Get the rotation matrix on the top-left corner of a Euclidean transformation matrix
"""
return x[..., :3, :3]
def euclidean_integrity_check(x):
euclidean_to_rotation_matrix(x) # check 3d-rotation matrix
assert bool((x[..., 3, :3] == 0).all()), "the last row is illegal"
assert bool((x[..., 3, 3] == 1).all()), "the last row is illegal"
def euclidean_translation(x):
"""
Get the translation vector located at the last column of the matrix
"""
return x[..., :3, 3]
def euclidean_inverse(x):
"""
Compute the matrix that represents the inverse rotation
"""
s = x.zeros_like()
irot = quat_inverse(quat_from_rotation_matrix(x))
s[..., :3, :3] = irot
s[..., :3, 4] = quat_rotate(irot, -euclidean_translation(x))
return s
def euclidean_to_transform(transformation_matrix):
"""
Construct a transform from a Euclidean transformation matrix
"""
return transform_from_rotation_translation(
r=quat_from_rotation_matrix(m=euclidean_to_rotation_matrix(transformation_matrix)),
t=euclidean_translation(transformation_matrix),
)
def to_torch(x, dtype=torch.float, device="cuda:0", requires_grad=False):
return torch.tensor(x, dtype=dtype, device=device, requires_grad=requires_grad)
def quat_mul(a, b):
assert a.shape == b.shape
shape = a.shape
a = a.reshape(-1, 4)
b = b.reshape(-1, 4)
x1, y1, z1, w1 = a[:, 0], a[:, 1], a[:, 2], a[:, 3]
x2, y2, z2, w2 = b[:, 0], b[:, 1], b[:, 2], b[:, 3]
ww = (z1 + x1) * (x2 + y2)
yy = (w1 - y1) * (w2 + z2)
zz = (w1 + y1) * (w2 - z2)
xx = ww + yy + zz
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
w = qq - ww + (z1 - y1) * (y2 - z2)
x = qq - xx + (x1 + w1) * (x2 + w2)
y = qq - yy + (w1 - x1) * (y2 + z2)
z = qq - zz + (z1 + y1) * (w2 - x2)
quat = torch.stack([x, y, z, w], dim=-1).view(shape)
return quat
def normalize(x, eps: float = 1e-9):
return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1)
def quat_apply(a, b):
shape = b.shape
a = a.reshape(-1, 4)
b = b.reshape(-1, 3)
xyz = a[:, :3]
t = xyz.cross(b, dim=-1) * 2
return (b + a[:, 3:] * t + xyz.cross(t, dim=-1)).view(shape)
def quat_rotate(q, v):
shape = q.shape
q_w = q[:, -1]
q_vec = q[:, :3]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0
return a + b + c
def quat_rotate_inverse(q, v):
shape = q.shape
q_w = q[:, -1]
q_vec = q[:, :3]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0
return a - b + c
def quat_conjugate(a):
shape = a.shape
a = a.reshape(-1, 4)
return torch.cat((-a[:, :3], a[:, -1:]), dim=-1).view(shape)
def quat_unit(a):
return normalize(a)
def quat_from_angle_axis(angle, axis):
theta = (angle / 2).unsqueeze(-1)
xyz = normalize(axis) * torch.sin(theta.clone())
w = torch.cos(theta.clone())
return quat_unit(torch.cat([xyz, w], dim=-1))
def normalize_angle(x):
return torch.atan2(torch.sin(x.clone()), torch.cos(x.clone()))
def tf_inverse(q, t):
q_inv = quat_conjugate(q)
return q_inv, -quat_apply(q_inv, t)
def tf_apply(q, t, v):
return quat_apply(q, v) + t
def tf_vector(q, v):
return quat_apply(q, v)
def tf_combine(q1, t1, q2, t2):
return quat_mul(q1, q2), quat_apply(q1, t2) + t1
def get_basis_vector(q, v):
return quat_rotate(q, v)
def get_axis_params(value, axis_idx, x_value=0.0, dtype=float, n_dims=3):
"""construct arguments to `Vec` according to axis index."""
zs = np.zeros((n_dims,))
assert axis_idx < n_dims, "the axis dim should be within the vector dimensions"
zs[axis_idx] = 1.0
params = np.where(zs == 1.0, value, zs)
params[0] = x_value
return list(params.astype(dtype))
def copysign(a, b):
# type: (float, Tensor) -> Tensor
a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0])
return torch.abs(a) * torch.sign(b)
def get_euler_xyz(q):
qx, qy, qz, qw = 0, 1, 2, 3
# roll (x-axis rotation)
sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz])
cosr_cosp = q[:, qw] * q[:, qw] - q[:, qx] * q[:, qx] - q[:, qy] * q[:, qy] + q[:, qz] * q[:, qz]
roll = torch.atan2(sinr_cosp, cosr_cosp)
# pitch (y-axis rotation)
sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx])
pitch = torch.where(torch.abs(sinp) >= 1, copysign(np.pi / 2.0, sinp), torch.asin(sinp))
# yaw (z-axis rotation)
siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy])
cosy_cosp = q[:, qw] * q[:, qw] + q[:, qx] * q[:, qx] - q[:, qy] * q[:, qy] - q[:, qz] * q[:, qz]
yaw = torch.atan2(siny_cosp, cosy_cosp)
return roll % (2 * np.pi), pitch % (2 * np.pi), yaw % (2 * np.pi)
def quat_from_euler_xyz(roll, pitch, yaw):
cy = torch.cos(yaw * 0.5)
sy = torch.sin(yaw * 0.5)
cr = torch.cos(roll * 0.5)
sr = torch.sin(roll * 0.5)
cp = torch.cos(pitch * 0.5)
sp = torch.sin(pitch * 0.5)
qw = cy * cr * cp + sy * sr * sp
qx = cy * sr * cp - sy * cr * sp
qy = cy * cr * sp + sy * sr * cp
qz = sy * cr * cp - cy * sr * sp
return torch.stack([qx, qy, qz, qw], dim=-1)
def torch_rand_float(lower, upper, shape, device):
# type: (float, float, Tuple[int, int], str) -> Tensor
return (upper - lower) * torch.rand(*shape, device=device) + lower
def torch_random_dir_2(shape, device):
# type: (Tuple[int, int], str) -> Tensor
angle = torch_rand_float(-np.pi, np.pi, shape, device).squeeze(-1)
return torch.stack([torch.cos(angle), torch.sin(angle)], dim=-1)
def tensor_clamp(t, min_t, max_t):
return torch.max(torch.min(t, max_t), min_t)
def scale(x, lower, upper):
return 0.5 * (x + 1.0) * (upper - lower) + lower
def unscale(x, lower, upper):
return (2.0 * x - upper - lower) / (upper - lower)
def unscale_np(x, lower, upper):
return (2.0 * x - upper - lower) / (upper - lower)
def quat_to_angle_axis(q):
# type: (Tensor) -> Tuple[Tensor, Tensor]
# computes axis-angle representation from quaternion q
# q must be normalized
min_theta = 1e-5
qx, qy, qz, qw = 0, 1, 2, 3
sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw])
angle = 2 * torch.acos(q[..., qw])
angle = normalize_angle(angle)
sin_theta_expand = sin_theta.unsqueeze(-1)
axis = q[..., qx:qw] / sin_theta_expand
mask = torch.abs(sin_theta) > min_theta
default_axis = torch.zeros_like(axis)
default_axis[..., -1] = 1
angle = torch.where(mask, angle, torch.zeros_like(angle))
mask_expand = mask.unsqueeze(-1)
axis = torch.where(mask_expand, axis, default_axis)
return angle, axis
def angle_axis_to_exp_map(angle, axis):
# type: (Tensor, Tensor) -> Tensor
# compute exponential map from axis-angle
angle_expand = angle.unsqueeze(-1)
exp_map = angle_expand * axis
return exp_map
def quat_to_exp_map(q):
# type: (Tensor) -> Tensor
# compute exponential map from quaternion
# q must be normalized
angle, axis = quat_to_angle_axis(q)
exp_map = angle_axis_to_exp_map(angle, axis)
return exp_map
def quat_to_tan_norm(q):
# type: (Tensor) -> Tensor
# represents a rotation using the tangent and normal vectors
ref_tan = torch.zeros_like(q[..., 0:3])
ref_tan[..., 0] = 1
tan = quat_rotate(q, ref_tan)
ref_norm = torch.zeros_like(q[..., 0:3])
ref_norm[..., -1] = 1
norm = quat_rotate(q, ref_norm)
norm_tan = torch.cat([tan, norm], dim=len(tan.shape) - 1)
return norm_tan
def euler_xyz_to_exp_map(roll, pitch, yaw):
# type: (Tensor, Tensor, Tensor) -> Tensor
q = quat_from_euler_xyz(roll, pitch, yaw)
exp_map = quat_to_exp_map(q)
return exp_map
def exp_map_to_angle_axis(exp_map):
min_theta = 1e-5
angle = torch.norm(exp_map.clone(), dim=-1) + 1e-6
angle_exp = torch.unsqueeze(angle, dim=-1)
axis = exp_map.clone() / angle_exp.clone()
angle = normalize_angle(angle)
default_axis = torch.zeros_like(exp_map)
default_axis[..., -1] = 1
mask = torch.abs(angle) > min_theta
angle = torch.where(mask, angle, torch.zeros_like(angle))
mask_expand = mask.unsqueeze(-1)
axis = torch.where(mask_expand, axis, default_axis)
return angle, axis
def exp_map_to_quat(exp_map):
angle, axis = exp_map_to_angle_axis(exp_map)
q = quat_from_angle_axis(angle, axis)
return q
def slerp(q0, q1, t):
# type: (Tensor, Tensor, Tensor) -> Tensor
cos_half_theta = torch.sum(q0 * q1, dim=-1)
neg_mask = cos_half_theta < 0
q1 = q1.clone()
q1[neg_mask] = -q1[neg_mask]
cos_half_theta = torch.abs(cos_half_theta)
cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1)
half_theta = torch.acos(cos_half_theta)
sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta)
ratioA = torch.sin((1 - t) * half_theta) / sin_half_theta
ratioB = torch.sin(t * half_theta) / sin_half_theta
new_q = ratioA * q0 + ratioB * q1
new_q = torch.where(torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q)
new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q)
return new_q
def calc_heading_vec(q, head_ind=0):
# type: (Tensor, int) -> Tensor
# calculate heading direction from quaternion
# the heading is the direction vector
# q must be normalized
ref_dir = torch.zeros_like(q[..., 0:3])
ref_dir[..., head_ind] = 1
rot_dir = quat_rotate(q, ref_dir)
return rot_dir
def calc_heading(q, head_ind=0, gravity_axis="z"):
# type: (Tensor, int, str) -> Tensor
# calculate heading direction from quaternion
# the heading is the direction on the xy plane
# q must be normalized
ref_dir = torch.zeros_like(q[..., 0:3])
ref_dir[..., head_ind] = 1
# ref_dir[..., 0] = 1
shape = ref_dir.shape[:-1]
q = q.reshape((-1, 4))
ref_dir = ref_dir.reshape(-1, 3)
rot_dir = quat_rotate(q, ref_dir)
rot_dir = rot_dir.reshape(shape + (3,))
if gravity_axis == "z":
heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0])
elif gravity_axis == "y":
heading = torch.atan2(rot_dir[..., 0], rot_dir[..., 2])
elif gravity_axis == "x":
heading = torch.atan2(rot_dir[..., 2], rot_dir[..., 1])
return heading
def calc_heading_quat(q, head_ind=0, gravity_axis="z"):
# type: (Tensor, int, str) -> Tensor
# calculate heading rotation from quaternion
# the heading is the direction on the xy plane
# q must be normalized
heading = calc_heading(q, head_ind, gravity_axis=gravity_axis)
axis = torch.zeros_like(q[..., 0:3])
if gravity_axis == "z":
g_axis = 2
elif gravity_axis == "y":
g_axis = 1
elif gravity_axis == "x":
g_axis = 0
axis[..., g_axis] = 1
heading_q = quat_from_angle_axis(heading, axis)
return heading_q
def calc_heading_quat_inv(q, head_ind=0):
# type: (Tensor, int) -> Tensor
# calculate heading rotation from quaternion
# the heading is the direction on the xy plane
# q must be normalized
heading = calc_heading(q, head_ind)
axis = torch.zeros_like(q[..., 0:3])
axis[..., 2] = 1
heading_q = quat_from_angle_axis(-heading, axis)
return heading_q
def forward_kinematics(mat, parent):
"""_summary_
Args:
mat ([..., N, 3, 3]): _description_
parent (): _description_
"""
if isinstance(mat, torch.Tensor):
rotations = torch.eye(mat.shape[-1], device=mat.device)
rotations = rotations.repeat(mat.shape[:-2] + (1, 1))
else:
rotations = np.eye(mat.shape[-1], dtype=np.float32)
rotations = np.tile(rotations, mat.shape[:-2] + (1, 1))
for i in range(mat.shape[-3]):
if parent[i] != -1:
if isinstance(mat, torch.Tensor):
# this way make gradient flow
new_mat = get_mat_BfromA(rotations[..., parent[i], :, :], mat[..., i, :, :])
rotations = torch.cat(
(
rotations[..., :i, :, :],
new_mat[..., None, :, :],
rotations[..., i + 1 :, :, :],
),
dim=-3,
)
else:
rotations[..., i, :, :] = get_mat_BfromA(rotations[..., parent[i], :, :], mat[..., i, :, :])
else:
if isinstance(mat, torch.Tensor):
# this way make gradient flow
rotations = torch.cat((mat[..., : i + 1, :, :], rotations[..., i + 1 :, :, :]), dim=-3)
else:
rotations[..., i, :, :] = mat[..., i, :, :]
return rotations
================================================
FILE: eval/GVHMR/hmr4d/utils/net_utils.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from hmr4d.utils.pylogger import Log
from pytorch_lightning.utilities.memory import recursive_detach
from einops import repeat, rearrange
from scipy.ndimage._filters import _gaussian_kernel1d
def load_pretrained_model(model, ckpt_path):
"""
Load ckpt to model with strategy
"""
assert Path(ckpt_path).exists()
# use model's own load_pretrained_model method
if hasattr(model, "load_pretrained_model"):
model.load_pretrained_model(ckpt_path)
else:
Log.info(f"Loading ckpt: {ckpt_path}")
ckpt = torch.load(ckpt_path, "cpu")
model.load_state_dict(ckpt, strict=True)
def find_last_ckpt_path(dirpath):
"""
Assume ckpt is named as e{}* or last*, following the convention of pytorch-lightning.
"""
assert dirpath is not None
dirpath = Path(dirpath)
assert dirpath.exists()
# Priority 1: last.ckpt
auto_last_ckpt_path = dirpath / "last.ckpt"
if auto_last_ckpt_path.exists():
return auto_last_ckpt_path
# Priority 2
model_paths = []
for p in sorted(list(dirpath.glob("*.ckpt"))):
if "last" in p.name:
continue
model_paths.append(p)
if len(model_paths) > 0:
return model_paths[-1]
else:
Log.info("No checkpoint found, set model_path to None")
return None
def get_resume_ckpt_path(resume_mode, ckpt_dir=None):
if Path(resume_mode).exists(): # This is a path
return resume_mode
assert resume_mode == "last"
return find_last_ckpt_path(ckpt_dir)
def select_state_dict_by_prefix(state_dict, prefix, new_prefix=""):
"""
For each weight that start with {old_prefix}, remove the {old_prefic} and form a new state_dict.
Args:
state_dict: dict
prefix: str
new_prefix: str, if exists, the new key will be {new_prefix} + {old_key[len(prefix):]}
Returns:
state_dict_new: dict
"""
state_dict_new = {}
for k in list(state_dict.keys()):
if k.startswith(prefix):
new_key = new_prefix + k[len(prefix) :]
state_dict_new[new_key] = state_dict[k]
return state_dict_new
def detach_to_cpu(in_dict):
return recursive_detach(in_dict, to_cpu=True)
def to_cuda(data):
"""Move data in the batch to cuda(), carefully handle data that is not tensor"""
if isinstance(data, torch.Tensor):
return data.cuda()
elif isinstance(data, dict):
return {k: to_cuda(v) for k, v in data.items()}
elif isinstance(data, list):
return [to_cuda(v) for v in data]
else:
return data
def get_valid_mask(max_len, valid_len, device="cpu"):
mask = torch.zeros(max_len, dtype=torch.bool).to(device)
mask[:valid_len] = True
return mask
def length_to_mask(lengths, max_len):
"""
Returns: (B, max_len)
"""
mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)
return mask
def repeat_to_max_len(x, max_len, dim=0):
"""Repeat last frame to max_len along dim"""
assert isinstance(x, torch.Tensor)
if x.shape[dim] == max_len:
return x
elif x.shape[dim] < max_len:
x = x.clone()
x = x.transpose(0, dim)
x = torch.cat([x, repeat(x[-1:], "b ... -> (b r) ...", r=max_len - x.shape[0])])
x = x.transpose(0, dim)
return x
else:
raise ValueError(f"Unexpected length v.s. max_len: {x.shape[0]} v.s. {max_len}")
def repeat_to_max_len_dict(x_dict, max_len, dim=0):
for k, v in x_dict.items():
x_dict[k] = repeat_to_max_len(v, max_len, dim=dim)
return x_dict
class Transpose(nn.Module):
def __init__(self, dim1, dim2):
super(Transpose, self).__init__()
self.dim1 = dim1
self.dim2 = dim2
def forward(self, x):
return x.transpose(self.dim1, self.dim2)
class GaussianSmooth(nn.Module):
def __init__(self, sigma=3, dim=-1):
super(GaussianSmooth, self).__init__()
kernel_smooth = _gaussian_kernel1d(sigma=sigma, order=0, radius=int(4 * sigma + 0.5))
kernel_smooth = torch.from_numpy(kernel_smooth).float()[None, None] # (1, 1, K)
self.register_buffer("kernel_smooth", kernel_smooth, persistent=False)
self.dim = dim
def forward(self, x):
"""x (..., f, ...) f at dim"""
rad = self.kernel_smooth.size(-1) // 2
x = x.transpose(self.dim, -1)
x_shape = x.shape[:-1]
x = rearrange(x, "... f -> (...) 1 f") # (NB, 1, f)
x = F.pad(x[None], (rad, rad, 0, 0), mode="replicate")[0]
x = F.conv1d(x, self.kernel_smooth)
x = x.squeeze(1).reshape(*x_shape, -1) # (..., f)
x = x.transpose(-1, self.dim)
return x
def gaussian_smooth(x, sigma=3, dim=-1):
kernel_smooth = _gaussian_kernel1d(sigma=sigma, order=0, radius=int(4 * sigma + 0.5))
kernel_smooth = torch.from_numpy(kernel_smooth).float()[None, None].to(x) # (1, 1, K)
rad = kernel_smooth.size(-1) // 2
x = x.transpose(dim, -1)
x_shape = x.shape[:-1]
x = rearrange(x, "... f -> (...) 1 f") # (NB, 1, f)
x = F.pad(x[None], (rad, rad, 0, 0), mode="replicate")[0]
x = F.conv1d(x, kernel_smooth)
x = x.squeeze(1).reshape(*x_shape, -1) # (..., f)
x = x.transpose(-1, dim)
return x
def moving_average_smooth(x, window_size=5, dim=-1):
kernel_smooth = torch.ones(window_size).float() / window_size
kernel_smooth = kernel_smooth[None, None].to(x) # (1, 1, window_size)
rad = kernel_smooth.size(-1) // 2
x = x.transpose(dim, -1)
x_shape = x.shape[:-1]
x = rearrange(x, "... f -> (...) 1 f") # (NB, 1, f)
x = F.pad(x[None], (rad, rad, 0, 0), mode="replicate")[0]
x = F.conv1d(x, kernel_smooth)
x = x.squeeze(1).reshape(*x_shape, -1) # (..., f)
x = x.transpose(-1, dim)
return x
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/__init__.py
================================================
try:
from hmr4d.utils.preproc.tracker import Tracker
from hmr4d.utils.preproc.vitfeat_extractor import Extractor
from hmr4d.utils.preproc.vitpose import VitPoseExtractor
from hmr4d.utils.preproc.slam import SLAMModel
except:
pass
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/slam.py
================================================
import cv2
import time
import torch
from multiprocessing import Process, Queue
try:
from dpvo.utils import Timer
from dpvo.dpvo import DPVO
from dpvo.config import cfg
except:
pass
from hmr4d import PROJ_ROOT
from hmr4d.utils.geo.hmr_cam import estimate_focal_length
class SLAMModel(object):
def __init__(self, video_path, width, height, intrinsics=None, stride=1, skip=0, buffer=2048, resize=0.5):
"""
Args:
intrinsics: [fx, fy, cx, cy]
"""
if intrinsics is None:
print("Estimating focal length")
focal_length = estimate_focal_length(width, height)
intrinsics = torch.tensor([focal_length, focal_length, width / 2.0, height / 2.0])
else:
intrinsics = intrinsics.clone()
self.dpvo_cfg = str(PROJ_ROOT / "third-party/DPVO/config/default.yaml")
self.dpvo_ckpt = "inputs/checkpoints/dpvo/dpvo.pth"
self.buffer = buffer
self.times = []
self.slam = None
self.queue = Queue(maxsize=8)
self.reader = Process(target=video_stream, args=(self.queue, video_path, intrinsics, stride, skip, resize))
self.reader.start()
def track(self):
(t, image, intrinsics) = self.queue.get()
if t < 0:
return False
image = torch.from_numpy(image).permute(2, 0, 1).cuda()
intrinsics = intrinsics.cuda() # [fx, fy, cx, cy]
if self.slam is None:
cfg.merge_from_file(self.dpvo_cfg)
cfg.BUFFER_SIZE = self.buffer
self.slam = DPVO(cfg, self.dpvo_ckpt, ht=image.shape[1], wd=image.shape[2], viz=False)
with Timer("SLAM", enabled=False):
t = time.time()
self.slam(t, image, intrinsics)
self.times.append(time.time() - t)
return True
def process(self):
for _ in range(12):
self.slam.update()
self.reader.join()
return self.slam.terminate()[0]
def video_stream(queue, imagedir, intrinsics, stride, skip=0, resize=0.5):
"""video generator"""
assert len(intrinsics) == 4, "intrinsics should be [fx, fy, cx, cy]"
cap = cv2.VideoCapture(imagedir)
t = 0
for _ in range(skip):
ret, image = cap.read()
while True:
# Capture frame-by-frame
for _ in range(stride):
ret, image = cap.read()
# if frame is read correctly ret is True
if not ret:
break
if not ret:
break
image = cv2.resize(image, None, fx=resize, fy=resize, interpolation=cv2.INTER_AREA)
h, w, _ = image.shape
image = image[: h - h % 16, : w - w % 16]
intrinsics_ = intrinsics.clone() * resize
queue.put((t, image, intrinsics_))
t += 1
queue.put((-1, image, intrinsics)) # -1 will terminate the process
cap.release()
# wait for the queue to be empty, otherwise the process will end immediately
while not queue.empty():
time.sleep(1)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/tracker.py
================================================
from ultralytics import YOLO
from hmr4d import PROJ_ROOT
import torch
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from hmr4d.utils.seq_utils import (
get_frame_id_list_from_mask,
linear_interpolate_frame_ids,
frame_id_to_mask,
rearrange_by_mask,
)
from hmr4d.utils.video_io_utils import get_video_lwh
from hmr4d.utils.net_utils import moving_average_smooth
class Tracker:
def __init__(self) -> None:
# https://docs.ultralytics.com/modes/predict/
self.yolo = YOLO(PROJ_ROOT / "inputs/checkpoints/yolo/yolov8x.pt")
def track(self, video_path):
track_history = []
cfg = {
"device": "cuda",
"conf": 0.5, # default 0.25, wham 0.5
"classes": 0, # human
"verbose": False,
"stream": True,
}
results = self.yolo.track(video_path, **cfg)
# frame-by-frame tracking
track_history = []
for result in tqdm(results, total=get_video_lwh(video_path)[0], desc="YoloV8 Tracking"):
if result.boxes.id is not None:
track_ids = result.boxes.id.int().cpu().tolist() # (N)
bbx_xyxy = result.boxes.xyxy.cpu().numpy() # (N, 4)
result_frame = [{"id": track_ids[i], "bbx_xyxy": bbx_xyxy[i]} for i in range(len(track_ids))]
else:
result_frame = []
track_history.append(result_frame)
return track_history
@staticmethod
def sort_track_length(track_history, video_path):
"""This handles the track history from YOLO tracker."""
id_to_frame_ids = defaultdict(list)
id_to_bbx_xyxys = defaultdict(list)
# parse to {det_id : [frame_id]}
for frame_id, frame in enumerate(track_history):
for det in frame:
id_to_frame_ids[det["id"]].append(frame_id)
id_to_bbx_xyxys[det["id"]].append(det["bbx_xyxy"])
for k, v in id_to_bbx_xyxys.items():
id_to_bbx_xyxys[k] = np.array(v)
# Sort by length of each track (max to min)
id_length = {k: len(v) for k, v in id_to_frame_ids.items()}
id2length = dict(sorted(id_length.items(), key=lambda item: item[1], reverse=True))
# Sort by area sum (max to min)
id_area_sum = {}
l, w, h = get_video_lwh(video_path)
for k, v in id_to_bbx_xyxys.items():
bbx_wh = v[:, 2:] - v[:, :2]
id_area_sum[k] = (bbx_wh[:, 0] * bbx_wh[:, 1] / w / h).sum()
id2area_sum = dict(sorted(id_area_sum.items(), key=lambda item: item[1], reverse=True))
id_sorted = list(id2area_sum.keys())
return id_to_frame_ids, id_to_bbx_xyxys, id_sorted
def get_one_track(self, video_path):
# track
track_history = self.track(video_path)
# parse track_history & use top1 track
id_to_frame_ids, id_to_bbx_xyxys, id_sorted = self.sort_track_length(track_history, video_path)
track_id = id_sorted[0]
frame_ids = torch.tensor(id_to_frame_ids[track_id]) # (N,)
bbx_xyxys = torch.tensor(id_to_bbx_xyxys[track_id]) # (N, 4)
# interpolate missing frames
mask = frame_id_to_mask(frame_ids, get_video_lwh(video_path)[0])
bbx_xyxy_one_track = rearrange_by_mask(bbx_xyxys, mask) # (F, 4), missing filled with 0
missing_frame_id_list = get_frame_id_list_from_mask(~mask) # list of list
bbx_xyxy_one_track = linear_interpolate_frame_ids(bbx_xyxy_one_track, missing_frame_id_list)
assert (bbx_xyxy_one_track.sum(1) != 0).all()
bbx_xyxy_one_track = moving_average_smooth(bbx_xyxy_one_track, window_size=5, dim=0)
bbx_xyxy_one_track = moving_average_smooth(bbx_xyxy_one_track, window_size=5, dim=0)
return bbx_xyxy_one_track
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitfeat_extractor.py
================================================
import torch
from hmr4d.network.hmr2 import load_hmr2, HMR2
from hmr4d.utils.video_io_utils import read_video_np
import cv2
import numpy as np
from hmr4d.network.hmr2.utils.preproc import crop_and_resize, IMAGE_MEAN, IMAGE_STD
from tqdm import tqdm
def get_batch(input_path, bbx_xys, img_ds=0.5, img_dst_size=256, path_type="video"):
if path_type == "video":
imgs = read_video_np(input_path, scale=img_ds)
elif path_type == "image":
imgs = cv2.imread(str(input_path))[..., ::-1]
imgs = cv2.resize(imgs, (0, 0), fx=img_ds, fy=img_ds)
imgs = imgs[None]
elif path_type == "np":
assert isinstance(input_path, np.ndarray)
assert img_ds == 1.0 # this is safe
imgs = input_path
gt_center = bbx_xys[:, :2]
gt_bbx_size = bbx_xys[:, 2]
# Blur image to avoid aliasing artifacts
if True:
gt_bbx_size_ds = gt_bbx_size * img_ds
ds_factors = ((gt_bbx_size_ds * 1.0) / img_dst_size / 2.0).numpy()
imgs = np.stack(
[
# gaussian(v, sigma=(d - 1) / 2, channel_axis=2, preserve_range=True) if d > 1.1 else v
cv2.GaussianBlur(v, (5, 5), (d - 1) / 2) if d > 1.1 else v
for v, d in zip(imgs, ds_factors)
]
)
# Output
imgs_list = []
bbx_xys_ds_list = []
for i in range(len(imgs)):
img, bbx_xys_ds = crop_and_resize(
imgs[i],
gt_center[i] * img_ds,
gt_bbx_size[i] * img_ds,
img_dst_size,
enlarge_ratio=1.0,
)
imgs_list.append(img)
bbx_xys_ds_list.append(bbx_xys_ds)
imgs = torch.from_numpy(np.stack(imgs_list)) # (F, 256, 256, 3), RGB
bbx_xys = torch.from_numpy(np.stack(bbx_xys_ds_list)) / img_ds # (F, 3)
imgs = ((imgs / 255.0 - IMAGE_MEAN) / IMAGE_STD).permute(0, 3, 1, 2) # (F, 3, 256, 256
return imgs, bbx_xys
class Extractor:
def __init__(self, tqdm_leave=True):
self.extractor: HMR2 = load_hmr2().cuda().eval()
self.tqdm_leave = tqdm_leave
def extract_video_features(self, video_path, bbx_xys, img_ds=0.5):
"""
img_ds makes the image smaller, which is useful for faster processing
"""
# Get the batch
if isinstance(video_path, str):
imgs, bbx_xys = get_batch(video_path, bbx_xys, img_ds=img_ds)
else:
assert isinstance(video_path, torch.Tensor)
imgs = video_path
# Inference
F, _, H, W = imgs.shape # (F, 3, H, W)
imgs = imgs.cuda()
batch_size = 16 # 5GB GPU memory, occupies all CUDA cores of 3090
features = []
for j in tqdm(range(0, F, batch_size), desc="HMR2 Feature", leave=self.tqdm_leave):
imgs_batch = imgs[j : j + batch_size]
with torch.no_grad():
feature = self.extractor({"img": imgs_batch})
features.append(feature.detach().cpu())
features = torch.cat(features, dim=0).clone() # (F, 1024)
return features
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose.py
================================================
import torch
import torch.nn.functional as F
import numpy as np
from .vitpose_pytorch import build_model
from .vitfeat_extractor import get_batch
from tqdm import tqdm
from hmr4d.utils.kpts.kp2d_utils import keypoints_from_heatmaps
from hmr4d.utils.geo_transform import cvt_p2d_from_pm1_to_i
from hmr4d.utils.geo.flip_utils import flip_heatmap_coco17
class VitPoseExtractor:
def __init__(self, tqdm_leave=True):
ckpt_path = "inputs/checkpoints/vitpose/vitpose-h-multi-coco.pth"
self.pose = build_model("ViTPose_huge_coco_256x192", ckpt_path)
self.pose.cuda().eval()
self.flip_test = True
self.tqdm_leave = tqdm_leave
@torch.no_grad()
def extract(self, video_path, bbx_xys, img_ds=0.5):
# Get the batch
if isinstance(video_path, str):
imgs, bbx_xys = get_batch(video_path, bbx_xys, img_ds=img_ds)
else:
assert isinstance(video_path, torch.Tensor)
imgs = video_path
# Inference
L, _, H, W = imgs.shape # (L, 3, H, W)
batch_size = 16
vitpose = []
for j in tqdm(range(0, L, batch_size), desc="ViTPose", leave=self.tqdm_leave):
# Heat map
imgs_batch = imgs[j : j + batch_size, :, :, 32:224].cuda()
if self.flip_test:
heatmap, heatmap_flipped = self.pose(torch.cat([imgs_batch, imgs_batch.flip(3)], dim=0)).chunk(2)
heatmap_flipped = flip_heatmap_coco17(heatmap_flipped)
heatmap = (heatmap + heatmap_flipped) * 0.5
del heatmap_flipped
else:
heatmap = self.pose(imgs_batch.clone()) # (B, J, 64, 48)
if False:
# Get joint
bbx_xys_batch = bbx_xys[j : j + batch_size].cuda()
method = "hard"
if method == "hard":
kp2d_pm1, conf = get_heatmap_preds(heatmap)
elif method == "soft":
kp2d_pm1, conf = get_heatmap_preds(heatmap, soft=True)
# Convert 64, 48 to 64, 64
kp2d_pm1[:, :, 0] *= 24 / 32
kp2d = cvt_p2d_from_pm1_to_i(kp2d_pm1, bbx_xys_batch[:, None])
kp2d = torch.cat([kp2d, conf], dim=-1)
else: # postprocess from mmpose
bbx_xys_batch = bbx_xys[j : j + batch_size]
heatmap = heatmap.clone().cpu().numpy()
center = bbx_xys_batch[:, :2].numpy()
scale = (torch.cat((bbx_xys_batch[:, [2]] * 24 / 32, bbx_xys_batch[:, [2]]), dim=1) / 200).numpy()
preds, maxvals = keypoints_from_heatmaps(heatmaps=heatmap, center=center, scale=scale, use_udp=True)
kp2d = np.concatenate((preds, maxvals), axis=-1)
kp2d = torch.from_numpy(kp2d)
vitpose.append(kp2d.detach().cpu().clone())
vitpose = torch.cat(vitpose, dim=0).clone() # (F, 17, 3)
return vitpose
def get_heatmap_preds(heatmap, normalize_keypoints=True, thr=0.0, soft=False):
"""
heatmap: (B, J, H, W)
"""
assert heatmap.ndim == 4, "batch_images should be 4-ndim"
B, J, H, W = heatmap.shape
heatmaps_reshaped = heatmap.reshape((B, J, -1))
maxvals, idx = torch.max(heatmaps_reshaped, 2)
maxvals = maxvals.reshape((B, J, 1))
idx = idx.reshape((B, J, 1))
preds = idx.repeat(1, 1, 2).float()
preds[:, :, 0] = (preds[:, :, 0]) % W
preds[:, :, 1] = torch.floor((preds[:, :, 1]) / W)
pred_mask = torch.gt(maxvals, thr).repeat(1, 1, 2)
pred_mask = pred_mask.float()
preds *= pred_mask
# soft peak
if soft:
patch_size = 5
patch_half = patch_size // 2
patches = torch.zeros((B, J, patch_size, patch_size)).to(heatmap)
default_patch = torch.zeros(patch_size, patch_size).to(heatmap)
default_patch[patch_half, patch_half] = 1
for b in range(B):
for j in range(17):
x, y = preds[b, j].int()
if x >= patch_half and x <= W - patch_half and y >= patch_half and y <= H - patch_half:
patches[b, j] = heatmap[
b, j, y - patch_half : y + patch_half + 1, x - patch_half : x + patch_half + 1
]
else:
patches[b, j] = default_patch
dx, dy = soft_patch_dx_dy(patches)
preds[:, :, 0] += dx
preds[:, :, 1] += dy
if normalize_keypoints: # to [-1, 1]
preds[:, :, 0] = preds[:, :, 0] / (W - 1) * 2 - 1
preds[:, :, 1] = preds[:, :, 1] / (H - 1) * 2 - 1
return preds, maxvals
def soft_patch_dx_dy(p):
"""p (B,J,P,P)"""
p_batch_shape = p.shape[:-2]
patch_size = p.size(-1)
temperature = 1.0
score = F.softmax(p.view(-1, patch_size**2) * temperature, dim=-1)
# get a offset_grid (BN, P, P, 2) for dx, dy
offset_grid = torch.meshgrid(torch.arange(patch_size), torch.arange(patch_size))[::-1]
offset_grid = torch.stack(offset_grid, dim=-1).float() - (patch_size - 1) / 2
offset_grid = offset_grid.view(1, 1, patch_size, patch_size, 2).to(p.device)
score = score.view(*p_batch_shape, patch_size, patch_size)
dx = torch.sum(score * offset_grid[..., 0], dim=(-2, -1))
dy = torch.sum(score * offset_grid[..., 1], dim=(-2, -1))
if False:
b, j = 0, 0
print(torch.stack([dx[b, j], dy[b, j]]))
print(p[b, j])
return dx, dy
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/__init__.py
================================================
from .src.vitpose_infer.model_builder import build_model
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/__init__.py
================================================
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/__init__.py
================================================
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# from .alexnet import AlexNet
# from .cpm import CPM
# from .hourglass import HourglassNet
# from .hourglass_ae import HourglassAENet
# from .hrformer import HRFormer
# from .hrnet import HRNet
# from .litehrnet import LiteHRNet
# from .mobilenet_v2 import MobileNetV2
# from .mobilenet_v3 import MobileNetV3
# from .mspn import MSPN
# from .regnet import RegNet
# from .resnest import ResNeSt
# from .resnet import ResNet, ResNetV1d
# from .resnext import ResNeXt
# from .rsn import RSN
# from .scnet import SCNet
# from .seresnet import SEResNet
# from .seresnext import SEResNeXt
# from .shufflenet_v1 import ShuffleNetV1
# from .shufflenet_v2 import ShuffleNetV2
# from .tcn import TCN
# from .v2v_net import V2VNet
# from .vgg import VGG
# from .vipnas_mbv3 import ViPNAS_MobileNetV3
# from .vipnas_resnet import ViPNAS_ResNet
from .vit import ViT
# __all__ = [
# 'AlexNet', 'HourglassNet', 'HourglassAENet', 'HRNet', 'MobileNetV2',
# 'MobileNetV3', 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SCNet',
# 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2', 'CPM', 'RSN',
# 'MSPN', 'ResNeSt', 'VGG', 'TCN', 'ViPNAS_ResNet', 'ViPNAS_MobileNetV3',
# 'LiteHRNet', 'V2VNet', 'HRFormer', 'ViT'
# ]
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/alexnet.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from ..builder import BACKBONES
from .base_backbone import BaseBackbone
@BACKBONES.register_module()
class AlexNet(BaseBackbone):
"""`AlexNet `__ backbone.
The input for AlexNet is a 224x224 RGB image.
Args:
num_classes (int): number of classes for classification.
The default value is -1, which uses the backbone as
a feature extractor without the top classifier.
"""
def __init__(self, num_classes=-1):
super().__init__()
self.num_classes = num_classes
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
if self.num_classes > 0:
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
def forward(self, x):
x = self.features(x)
if self.num_classes > 0:
x = x.view(x.size(0), 256 * 6 * 6)
x = self.classifier(x)
return x
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/cpm.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, normal_init
from torch.nn.modules.batchnorm import _BatchNorm
from mmpose.utils import get_root_logger
from ..builder import BACKBONES
from .base_backbone import BaseBackbone
from .utils import load_checkpoint
class CpmBlock(nn.Module):
"""CpmBlock for Convolutional Pose Machine.
Args:
in_channels (int): Input channels of this block.
channels (list): Output channels of each conv module.
kernels (list): Kernel sizes of each conv module.
"""
def __init__(self,
in_channels,
channels=(128, 128, 128),
kernels=(11, 11, 11),
norm_cfg=None):
super().__init__()
assert len(channels) == len(kernels)
layers = []
for i in range(len(channels)):
if i == 0:
input_channels = in_channels
else:
input_channels = channels[i - 1]
layers.append(
ConvModule(
input_channels,
channels[i],
kernels[i],
padding=(kernels[i] - 1) // 2,
norm_cfg=norm_cfg))
self.model = nn.Sequential(*layers)
def forward(self, x):
"""Model forward function."""
out = self.model(x)
return out
@BACKBONES.register_module()
class CPM(BaseBackbone):
"""CPM backbone.
Convolutional Pose Machines.
More details can be found in the `paper
`__ .
Args:
in_channels (int): The input channels of the CPM.
out_channels (int): The output channels of the CPM.
feat_channels (int): Feature channel of each CPM stage.
middle_channels (int): Feature channel of conv after the middle stage.
num_stages (int): Number of stages.
norm_cfg (dict): Dictionary to construct and config norm layer.
Example:
>>> from mmpose.models import CPM
>>> import torch
>>> self = CPM(3, 17)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 368, 368)
>>> level_outputs = self.forward(inputs)
>>> for level_output in level_outputs:
... print(tuple(level_output.shape))
(1, 17, 46, 46)
(1, 17, 46, 46)
(1, 17, 46, 46)
(1, 17, 46, 46)
(1, 17, 46, 46)
(1, 17, 46, 46)
"""
def __init__(self,
in_channels,
out_channels,
feat_channels=128,
middle_channels=32,
num_stages=6,
norm_cfg=dict(type='BN', requires_grad=True)):
# Protect mutable default arguments
norm_cfg = copy.deepcopy(norm_cfg)
super().__init__()
assert in_channels == 3
self.num_stages = num_stages
assert self.num_stages >= 1
self.stem = nn.Sequential(
ConvModule(in_channels, 128, 9, padding=4, norm_cfg=norm_cfg),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
ConvModule(128, 32, 5, padding=2, norm_cfg=norm_cfg),
ConvModule(32, 512, 9, padding=4, norm_cfg=norm_cfg),
ConvModule(512, 512, 1, padding=0, norm_cfg=norm_cfg),
ConvModule(512, out_channels, 1, padding=0, act_cfg=None))
self.middle = nn.Sequential(
ConvModule(in_channels, 128, 9, padding=4, norm_cfg=norm_cfg),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
self.cpm_stages = nn.ModuleList([
CpmBlock(
middle_channels + out_channels,
channels=[feat_channels, feat_channels, feat_channels],
kernels=[11, 11, 11],
norm_cfg=norm_cfg) for _ in range(num_stages - 1)
])
self.middle_conv = nn.ModuleList([
nn.Sequential(
ConvModule(
128, middle_channels, 5, padding=2, norm_cfg=norm_cfg))
for _ in range(num_stages - 1)
])
self.out_convs = nn.ModuleList([
nn.Sequential(
ConvModule(
feat_channels,
feat_channels,
1,
padding=0,
norm_cfg=norm_cfg),
ConvModule(feat_channels, out_channels, 1, act_cfg=None))
for _ in range(num_stages - 1)
])
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, std=0.001)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
"""Model forward function."""
stage1_out = self.stem(x)
middle_out = self.middle(x)
out_feats = []
out_feats.append(stage1_out)
for ind in range(self.num_stages - 1):
single_stage = self.cpm_stages[ind]
out_conv = self.out_convs[ind]
inp_feat = torch.cat(
[out_feats[-1], self.middle_conv[ind](middle_out)], 1)
cpm_feat = single_stage(inp_feat)
out_feat = out_conv(cpm_feat)
out_feats.append(out_feat)
return out_feats
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/hourglass.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, normal_init
from torch.nn.modules.batchnorm import _BatchNorm
from mmpose.utils import get_root_logger
from ..builder import BACKBONES
from .base_backbone import BaseBackbone
from .resnet import BasicBlock, ResLayer
from .utils import load_checkpoint
class HourglassModule(nn.Module):
"""Hourglass Module for HourglassNet backbone.
Generate module recursively and use BasicBlock as the base unit.
Args:
depth (int): Depth of current HourglassModule.
stage_channels (list[int]): Feature channels of sub-modules in current
and follow-up HourglassModule.
stage_blocks (list[int]): Number of sub-modules stacked in current and
follow-up HourglassModule.
norm_cfg (dict): Dictionary to construct and config norm layer.
"""
def __init__(self,
depth,
stage_channels,
stage_blocks,
norm_cfg=dict(type='BN', requires_grad=True)):
# Protect mutable default arguments
norm_cfg = copy.deepcopy(norm_cfg)
super().__init__()
self.depth = depth
cur_block = stage_blocks[0]
next_block = stage_blocks[1]
cur_channel = stage_channels[0]
next_channel = stage_channels[1]
self.up1 = ResLayer(
BasicBlock, cur_block, cur_channel, cur_channel, norm_cfg=norm_cfg)
self.low1 = ResLayer(
BasicBlock,
cur_block,
cur_channel,
next_channel,
stride=2,
norm_cfg=norm_cfg)
if self.depth > 1:
self.low2 = HourglassModule(depth - 1, stage_channels[1:],
stage_blocks[1:])
else:
self.low2 = ResLayer(
BasicBlock,
next_block,
next_channel,
next_channel,
norm_cfg=norm_cfg)
self.low3 = ResLayer(
BasicBlock,
cur_block,
next_channel,
cur_channel,
norm_cfg=norm_cfg,
downsample_first=False)
self.up2 = nn.Upsample(scale_factor=2)
def forward(self, x):
"""Model forward function."""
up1 = self.up1(x)
low1 = self.low1(x)
low2 = self.low2(low1)
low3 = self.low3(low2)
up2 = self.up2(low3)
return up1 + up2
@BACKBONES.register_module()
class HourglassNet(BaseBackbone):
"""HourglassNet backbone.
Stacked Hourglass Networks for Human Pose Estimation.
More details can be found in the `paper
`__ .
Args:
downsample_times (int): Downsample times in a HourglassModule.
num_stacks (int): Number of HourglassModule modules stacked,
1 for Hourglass-52, 2 for Hourglass-104.
stage_channels (list[int]): Feature channel of each sub-module in a
HourglassModule.
stage_blocks (list[int]): Number of sub-modules stacked in a
HourglassModule.
feat_channel (int): Feature channel of conv after a HourglassModule.
norm_cfg (dict): Dictionary to construct and config norm layer.
Example:
>>> from mmpose.models import HourglassNet
>>> import torch
>>> self = HourglassNet()
>>> self.eval()
>>> inputs = torch.rand(1, 3, 511, 511)
>>> level_outputs = self.forward(inputs)
>>> for level_output in level_outputs:
... print(tuple(level_output.shape))
(1, 256, 128, 128)
(1, 256, 128, 128)
"""
def __init__(self,
downsample_times=5,
num_stacks=2,
stage_channels=(256, 256, 384, 384, 384, 512),
stage_blocks=(2, 2, 2, 2, 2, 4),
feat_channel=256,
norm_cfg=dict(type='BN', requires_grad=True)):
# Protect mutable default arguments
norm_cfg = copy.deepcopy(norm_cfg)
super().__init__()
self.num_stacks = num_stacks
assert self.num_stacks >= 1
assert len(stage_channels) == len(stage_blocks)
assert len(stage_channels) > downsample_times
cur_channel = stage_channels[0]
self.stem = nn.Sequential(
ConvModule(3, 128, 7, padding=3, stride=2, norm_cfg=norm_cfg),
ResLayer(BasicBlock, 1, 128, 256, stride=2, norm_cfg=norm_cfg))
self.hourglass_modules = nn.ModuleList([
HourglassModule(downsample_times, stage_channels, stage_blocks)
for _ in range(num_stacks)
])
self.inters = ResLayer(
BasicBlock,
num_stacks - 1,
cur_channel,
cur_channel,
norm_cfg=norm_cfg)
self.conv1x1s = nn.ModuleList([
ConvModule(
cur_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
for _ in range(num_stacks - 1)
])
self.out_convs = nn.ModuleList([
ConvModule(
cur_channel, feat_channel, 3, padding=1, norm_cfg=norm_cfg)
for _ in range(num_stacks)
])
self.remap_convs = nn.ModuleList([
ConvModule(
feat_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
for _ in range(num_stacks - 1)
])
self.relu = nn.ReLU(inplace=True)
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, std=0.001)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
"""Model forward function."""
inter_feat = self.stem(x)
out_feats = []
for ind in range(self.num_stacks):
single_hourglass = self.hourglass_modules[ind]
out_conv = self.out_convs[ind]
hourglass_feat = single_hourglass(inter_feat)
out_feat = out_conv(hourglass_feat)
out_feats.append(out_feat)
if ind < self.num_stacks - 1:
inter_feat = self.conv1x1s[ind](
inter_feat) + self.remap_convs[ind](
out_feat)
inter_feat = self.inters[ind](self.relu(inter_feat))
return out_feats
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/hourglass_ae.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import torch.nn as nn
from mmcv.cnn import ConvModule, MaxPool2d, constant_init, normal_init
from torch.nn.modules.batchnorm import _BatchNorm
from mmpose.utils import get_root_logger
from ..builder import BACKBONES
from .base_backbone import BaseBackbone
from .utils import load_checkpoint
class HourglassAEModule(nn.Module):
"""Modified Hourglass Module for HourglassNet_AE backbone.
Generate module recursively and use BasicBlock as the base unit.
Args:
depth (int): Depth of current HourglassModule.
stage_channels (list[int]): Feature channels of sub-modules in current
and follow-up HourglassModule.
norm_cfg (dict): Dictionary to construct and config norm layer.
"""
def __init__(self,
depth,
stage_channels,
norm_cfg=dict(type='BN', requires_grad=True)):
# Protect mutable default arguments
norm_cfg = copy.deepcopy(norm_cfg)
super().__init__()
self.depth = depth
cur_channel = stage_channels[0]
next_channel = stage_channels[1]
self.up1 = ConvModule(
cur_channel, cur_channel, 3, padding=1, norm_cfg=norm_cfg)
self.pool1 = MaxPool2d(2, 2)
self.low1 = ConvModule(
cur_channel, next_channel, 3, padding=1, norm_cfg=norm_cfg)
if self.depth > 1:
self.low2 = HourglassAEModule(depth - 1, stage_channels[1:])
else:
self.low2 = ConvModule(
next_channel, next_channel, 3, padding=1, norm_cfg=norm_cfg)
self.low3 = ConvModule(
next_channel, cur_channel, 3, padding=1, norm_cfg=norm_cfg)
self.up2 = nn.UpsamplingNearest2d(scale_factor=2)
def forward(self, x):
"""Model forward function."""
up1 = self.up1(x)
pool1 = self.pool1(x)
low1 = self.low1(pool1)
low2 = self.low2(low1)
low3 = self.low3(low2)
up2 = self.up2(low3)
return up1 + up2
@BACKBONES.register_module()
class HourglassAENet(BaseBackbone):
"""Hourglass-AE Network proposed by Newell et al.
Associative Embedding: End-to-End Learning for Joint
Detection and Grouping.
More details can be found in the `paper
`__ .
Args:
downsample_times (int): Downsample times in a HourglassModule.
num_stacks (int): Number of HourglassModule modules stacked,
1 for Hourglass-52, 2 for Hourglass-104.
stage_channels (list[int]): Feature channel of each sub-module in a
HourglassModule.
stage_blocks (list[int]): Number of sub-modules stacked in a
HourglassModule.
feat_channels (int): Feature channel of conv after a HourglassModule.
norm_cfg (dict): Dictionary to construct and config norm layer.
Example:
>>> from mmpose.models import HourglassAENet
>>> import torch
>>> self = HourglassAENet()
>>> self.eval()
>>> inputs = torch.rand(1, 3, 512, 512)
>>> level_outputs = self.forward(inputs)
>>> for level_output in level_outputs:
... print(tuple(level_output.shape))
(1, 34, 128, 128)
"""
def __init__(self,
downsample_times=4,
num_stacks=1,
out_channels=34,
stage_channels=(256, 384, 512, 640, 768),
feat_channels=256,
norm_cfg=dict(type='BN', requires_grad=True)):
# Protect mutable default arguments
norm_cfg = copy.deepcopy(norm_cfg)
super().__init__()
self.num_stacks = num_stacks
assert self.num_stacks >= 1
assert len(stage_channels) > downsample_times
cur_channels = stage_channels[0]
self.stem = nn.Sequential(
ConvModule(3, 64, 7, padding=3, stride=2, norm_cfg=norm_cfg),
ConvModule(64, 128, 3, padding=1, norm_cfg=norm_cfg),
MaxPool2d(2, 2),
ConvModule(128, 128, 3, padding=1, norm_cfg=norm_cfg),
ConvModule(128, feat_channels, 3, padding=1, norm_cfg=norm_cfg),
)
self.hourglass_modules = nn.ModuleList([
nn.Sequential(
HourglassAEModule(
downsample_times, stage_channels, norm_cfg=norm_cfg),
ConvModule(
feat_channels,
feat_channels,
3,
padding=1,
norm_cfg=norm_cfg),
ConvModule(
feat_channels,
feat_channels,
3,
padding=1,
norm_cfg=norm_cfg)) for _ in range(num_stacks)
])
self.out_convs = nn.ModuleList([
ConvModule(
cur_channels,
out_channels,
1,
padding=0,
norm_cfg=None,
act_cfg=None) for _ in range(num_stacks)
])
self.remap_out_convs = nn.ModuleList([
ConvModule(
out_channels,
feat_channels,
1,
norm_cfg=norm_cfg,
act_cfg=None) for _ in range(num_stacks - 1)
])
self.remap_feature_convs = nn.ModuleList([
ConvModule(
feat_channels,
feat_channels,
1,
norm_cfg=norm_cfg,
act_cfg=None) for _ in range(num_stacks - 1)
])
self.relu = nn.ReLU(inplace=True)
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, std=0.001)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
"""Model forward function."""
inter_feat = self.stem(x)
out_feats = []
for ind in range(self.num_stacks):
single_hourglass = self.hourglass_modules[ind]
out_conv = self.out_convs[ind]
hourglass_feat = single_hourglass(inter_feat)
out_feat = out_conv(hourglass_feat)
out_feats.append(out_feat)
if ind < self.num_stacks - 1:
inter_feat = inter_feat + self.remap_out_convs[ind](
out_feat) + self.remap_feature_convs[ind](
hourglass_feat)
return out_feats
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/hrformer.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
# from timm.models.layers import to_2tuple, trunc_normal_
from mmcv.cnn import (build_activation_layer, build_conv_layer,
build_norm_layer, trunc_normal_init)
from mmcv.cnn.bricks.transformer import build_dropout
from mmcv.runner import BaseModule
from torch.nn.functional import pad
from ..builder import BACKBONES
from .hrnet import Bottleneck, HRModule, HRNet
def nlc_to_nchw(x, hw_shape):
"""Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
Args:
x (Tensor): The input tensor of shape [N, L, C] before conversion.
hw_shape (Sequence[int]): The height and width of output feature map.
Returns:
Tensor: The output tensor of shape [N, C, H, W] after conversion.
"""
H, W = hw_shape
assert len(x.shape) == 3
B, L, C = x.shape
assert L == H * W, 'The seq_len doesn\'t match H, W'
return x.transpose(1, 2).reshape(B, C, H, W)
def nchw_to_nlc(x):
"""Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
Args:
x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
Returns:
Tensor: The output tensor of shape [N, L, C] after conversion.
"""
assert len(x.shape) == 4
return x.flatten(2).transpose(1, 2).contiguous()
def build_drop_path(drop_path_rate):
"""Build drop path layer."""
return build_dropout(dict(type='DropPath', drop_prob=drop_path_rate))
class WindowMSA(BaseModule):
"""Window based multi-head self-attention (W-MSA) module with relative
position bias.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (tuple[int]): The height and width of the window.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.0
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
with_rpe (bool, optional): If True, use relative position bias.
Default: True.
init_cfg (dict | None, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size,
qkv_bias=True,
qk_scale=None,
attn_drop_rate=0.,
proj_drop_rate=0.,
with_rpe=True,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_embed_dims = embed_dims // num_heads
self.scale = qk_scale or head_embed_dims**-0.5
self.with_rpe = with_rpe
if self.with_rpe:
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros(
(2 * window_size[0] - 1) * (2 * window_size[1] - 1),
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
Wh, Ww = self.window_size
rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
rel_position_index = rel_index_coords + rel_index_coords.T
rel_position_index = rel_position_index.flip(1).contiguous()
self.register_buffer('relative_position_index', rel_position_index)
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_rate)
self.proj = nn.Linear(embed_dims, embed_dims)
self.proj_drop = nn.Dropout(proj_drop_rate)
self.softmax = nn.Softmax(dim=-1)
def init_weights(self):
trunc_normal_init(self.relative_position_bias_table, std=0.02)
def forward(self, x, mask=None):
"""
Args:
x (tensor): input features with shape of (B*num_windows, N, C)
mask (tensor | None, Optional): mask with shape of (num_windows,
Wh*Ww, Wh*Ww), value should be between (-inf, 0].
"""
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.with_rpe:
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B // nW, nW, self.num_heads, N,
N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@staticmethod
def double_step_seq(step1, len1, step2, len2):
seq1 = torch.arange(0, step1 * len1, step1)
seq2 = torch.arange(0, step2 * len2, step2)
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
class LocalWindowSelfAttention(BaseModule):
r""" Local-window Self Attention (LSA) module with relative position bias.
This module is the short-range self-attention module in the
Interlaced Sparse Self-Attention `_.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (tuple[int] | int): The height and width of the window.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.0
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
with_rpe (bool, optional): If True, use relative position bias.
Default: True.
with_pad_mask (bool, optional): If True, mask out the padded tokens in
the attention process. Default: False.
init_cfg (dict | None, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size,
qkv_bias=True,
qk_scale=None,
attn_drop_rate=0.,
proj_drop_rate=0.,
with_rpe=True,
with_pad_mask=False,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
if isinstance(window_size, int):
window_size = (window_size, window_size)
self.window_size = window_size
self.with_pad_mask = with_pad_mask
self.attn = WindowMSA(
embed_dims=embed_dims,
num_heads=num_heads,
window_size=window_size,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop_rate=attn_drop_rate,
proj_drop_rate=proj_drop_rate,
with_rpe=with_rpe,
init_cfg=init_cfg)
def forward(self, x, H, W, **kwargs):
"""Forward function."""
B, N, C = x.shape
x = x.view(B, H, W, C)
Wh, Ww = self.window_size
# center-pad the feature on H and W axes
pad_h = math.ceil(H / Wh) * Wh - H
pad_w = math.ceil(W / Ww) * Ww - W
x = pad(x, (0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
pad_h - pad_h // 2))
# permute
x = x.view(B, math.ceil(H / Wh), Wh, math.ceil(W / Ww), Ww, C)
x = x.permute(0, 1, 3, 2, 4, 5)
x = x.reshape(-1, Wh * Ww, C) # (B*num_window, Wh*Ww, C)
# attention
if self.with_pad_mask and pad_h > 0 and pad_w > 0:
pad_mask = x.new_zeros(1, H, W, 1)
pad_mask = pad(
pad_mask, [
0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
pad_h - pad_h // 2
],
value=-float('inf'))
pad_mask = pad_mask.view(1, math.ceil(H / Wh), Wh,
math.ceil(W / Ww), Ww, 1)
pad_mask = pad_mask.permute(1, 3, 0, 2, 4, 5)
pad_mask = pad_mask.reshape(-1, Wh * Ww)
pad_mask = pad_mask[:, None, :].expand([-1, Wh * Ww, -1])
out = self.attn(x, pad_mask, **kwargs)
else:
out = self.attn(x, **kwargs)
# reverse permutation
out = out.reshape(B, math.ceil(H / Wh), math.ceil(W / Ww), Wh, Ww, C)
out = out.permute(0, 1, 3, 2, 4, 5)
out = out.reshape(B, H + pad_h, W + pad_w, C)
# de-pad
out = out[:, pad_h // 2:H + pad_h // 2, pad_w // 2:W + pad_w // 2]
return out.reshape(B, N, C)
class CrossFFN(BaseModule):
r"""FFN with Depthwise Conv of HRFormer.
Args:
in_features (int): The feature dimension.
hidden_features (int, optional): The hidden dimension of FFNs.
Defaults: The same as in_features.
act_cfg (dict, optional): Config of activation layer.
Default: dict(type='GELU').
dw_act_cfg (dict, optional): Config of activation layer appended
right after DW Conv. Default: dict(type='GELU').
norm_cfg (dict, optional): Config of norm layer.
Default: dict(type='SyncBN').
init_cfg (dict | list | None, optional): The init config.
Default: None.
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_cfg=dict(type='GELU'),
dw_act_cfg=dict(type='GELU'),
norm_cfg=dict(type='SyncBN'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1)
self.act1 = build_activation_layer(act_cfg)
self.norm1 = build_norm_layer(norm_cfg, hidden_features)[1]
self.dw3x3 = nn.Conv2d(
hidden_features,
hidden_features,
kernel_size=3,
stride=1,
groups=hidden_features,
padding=1)
self.act2 = build_activation_layer(dw_act_cfg)
self.norm2 = build_norm_layer(norm_cfg, hidden_features)[1]
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1)
self.act3 = build_activation_layer(act_cfg)
self.norm3 = build_norm_layer(norm_cfg, out_features)[1]
# put the modules togather
self.layers = [
self.fc1, self.norm1, self.act1, self.dw3x3, self.norm2, self.act2,
self.fc2, self.norm3, self.act3
]
def forward(self, x, H, W):
"""Forward function."""
x = nlc_to_nchw(x, (H, W))
for layer in self.layers:
x = layer(x)
x = nchw_to_nlc(x)
return x
class HRFormerBlock(BaseModule):
"""High-Resolution Block for HRFormer.
Args:
in_features (int): The input dimension.
out_features (int): The output dimension.
num_heads (int): The number of head within each LSA.
window_size (int, optional): The window size for the LSA.
Default: 7
mlp_ratio (int, optional): The expansion ration of FFN.
Default: 4
act_cfg (dict, optional): Config of activation layer.
Default: dict(type='GELU').
norm_cfg (dict, optional): Config of norm layer.
Default: dict(type='SyncBN').
transformer_norm_cfg (dict, optional): Config of transformer norm
layer. Default: dict(type='LN', eps=1e-6).
init_cfg (dict | list | None, optional): The init config.
Default: None.
"""
expansion = 1
def __init__(self,
in_features,
out_features,
num_heads,
window_size=7,
mlp_ratio=4.0,
drop_path=0.0,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='SyncBN'),
transformer_norm_cfg=dict(type='LN', eps=1e-6),
init_cfg=None,
**kwargs):
super(HRFormerBlock, self).__init__(init_cfg=init_cfg)
self.num_heads = num_heads
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.norm1 = build_norm_layer(transformer_norm_cfg, in_features)[1]
self.attn = LocalWindowSelfAttention(
in_features,
num_heads=num_heads,
window_size=window_size,
init_cfg=None,
**kwargs)
self.norm2 = build_norm_layer(transformer_norm_cfg, out_features)[1]
self.ffn = CrossFFN(
in_features=in_features,
hidden_features=int(in_features * mlp_ratio),
out_features=out_features,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
dw_act_cfg=act_cfg,
init_cfg=None)
self.drop_path = build_drop_path(
drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x):
"""Forward function."""
B, C, H, W = x.size()
# Attention
x = x.view(B, C, -1).permute(0, 2, 1)
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
# FFN
x = x + self.drop_path(self.ffn(self.norm2(x), H, W))
x = x.permute(0, 2, 1).view(B, C, H, W)
return x
def extra_repr(self):
"""(Optional) Set the extra information about this module."""
return 'num_heads={}, window_size={}, mlp_ratio={}'.format(
self.num_heads, self.window_size, self.mlp_ratio)
class HRFomerModule(HRModule):
"""High-Resolution Module for HRFormer.
Args:
num_branches (int): The number of branches in the HRFormerModule.
block (nn.Module): The building block of HRFormer.
The block should be the HRFormerBlock.
num_blocks (tuple): The number of blocks in each branch.
The length must be equal to num_branches.
num_inchannels (tuple): The number of input channels in each branch.
The length must be equal to num_branches.
num_channels (tuple): The number of channels in each branch.
The length must be equal to num_branches.
num_heads (tuple): The number of heads within the LSAs.
num_window_sizes (tuple): The window size for the LSAs.
num_mlp_ratios (tuple): The expansion ratio for the FFNs.
drop_path (int, optional): The drop path rate of HRFomer.
Default: 0.0
multiscale_output (bool, optional): Whether to output multi-level
features produced by multiple branches. If False, only the first
level feature will be output. Default: True.
conv_cfg (dict, optional): Config of the conv layers.
Default: None.
norm_cfg (dict, optional): Config of the norm layers appended
right after conv. Default: dict(type='SyncBN', requires_grad=True)
transformer_norm_cfg (dict, optional): Config of the norm layers.
Default: dict(type='LN', eps=1e-6)
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False
upsample_cfg(dict, optional): The config of upsample layers in fuse
layers. Default: dict(mode='bilinear', align_corners=False)
"""
def __init__(self,
num_branches,
block,
num_blocks,
num_inchannels,
num_channels,
num_heads,
num_window_sizes,
num_mlp_ratios,
multiscale_output=True,
drop_paths=0.0,
with_rpe=True,
with_pad_mask=False,
conv_cfg=None,
norm_cfg=dict(type='SyncBN', requires_grad=True),
transformer_norm_cfg=dict(type='LN', eps=1e-6),
with_cp=False,
upsample_cfg=dict(mode='bilinear', align_corners=False)):
self.transformer_norm_cfg = transformer_norm_cfg
self.drop_paths = drop_paths
self.num_heads = num_heads
self.num_window_sizes = num_window_sizes
self.num_mlp_ratios = num_mlp_ratios
self.with_rpe = with_rpe
self.with_pad_mask = with_pad_mask
super().__init__(num_branches, block, num_blocks, num_inchannels,
num_channels, multiscale_output, with_cp, conv_cfg,
norm_cfg, upsample_cfg)
def _make_one_branch(self,
branch_index,
block,
num_blocks,
num_channels,
stride=1):
"""Build one branch."""
# HRFormerBlock does not support down sample layer yet.
assert stride == 1 and self.in_channels[branch_index] == num_channels[
branch_index]
layers = []
layers.append(
block(
self.in_channels[branch_index],
num_channels[branch_index],
num_heads=self.num_heads[branch_index],
window_size=self.num_window_sizes[branch_index],
mlp_ratio=self.num_mlp_ratios[branch_index],
drop_path=self.drop_paths[0],
norm_cfg=self.norm_cfg,
transformer_norm_cfg=self.transformer_norm_cfg,
init_cfg=None,
with_rpe=self.with_rpe,
with_pad_mask=self.with_pad_mask))
self.in_channels[
branch_index] = self.in_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
layers.append(
block(
self.in_channels[branch_index],
num_channels[branch_index],
num_heads=self.num_heads[branch_index],
window_size=self.num_window_sizes[branch_index],
mlp_ratio=self.num_mlp_ratios[branch_index],
drop_path=self.drop_paths[i],
norm_cfg=self.norm_cfg,
transformer_norm_cfg=self.transformer_norm_cfg,
init_cfg=None,
with_rpe=self.with_rpe,
with_pad_mask=self.with_pad_mask))
return nn.Sequential(*layers)
def _make_fuse_layers(self):
"""Build fuse layers."""
if self.num_branches == 1:
return None
num_branches = self.num_branches
num_inchannels = self.in_channels
fuse_layers = []
for i in range(num_branches if self.multiscale_output else 1):
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
num_inchannels[j],
num_inchannels[i],
kernel_size=1,
stride=1,
bias=False),
build_norm_layer(self.norm_cfg,
num_inchannels[i])[1],
nn.Upsample(
scale_factor=2**(j - i),
mode=self.upsample_cfg['mode'],
align_corners=self.
upsample_cfg['align_corners'])))
elif j == i:
fuse_layer.append(None)
else:
conv3x3s = []
for k in range(i - j):
if k == i - j - 1:
num_outchannels_conv3x3 = num_inchannels[i]
with_out_act = False
else:
num_outchannels_conv3x3 = num_inchannels[j]
with_out_act = True
sub_modules = [
build_conv_layer(
self.conv_cfg,
num_inchannels[j],
num_inchannels[j],
kernel_size=3,
stride=2,
padding=1,
groups=num_inchannels[j],
bias=False,
),
build_norm_layer(self.norm_cfg,
num_inchannels[j])[1],
build_conv_layer(
self.conv_cfg,
num_inchannels[j],
num_outchannels_conv3x3,
kernel_size=1,
stride=1,
bias=False,
),
build_norm_layer(self.norm_cfg,
num_outchannels_conv3x3)[1]
]
if with_out_act:
sub_modules.append(nn.ReLU(False))
conv3x3s.append(nn.Sequential(*sub_modules))
fuse_layer.append(nn.Sequential(*conv3x3s))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def get_num_inchannels(self):
"""Return the number of input channels."""
return self.in_channels
@BACKBONES.register_module()
class HRFormer(HRNet):
"""HRFormer backbone.
This backbone is the implementation of `HRFormer: High-Resolution
Transformer for Dense Prediction `_.
Args:
extra (dict): Detailed configuration for each stage of HRNet.
There must be 4 stages, the configuration for each stage must have
5 keys:
- num_modules (int): The number of HRModule in this stage.
- num_branches (int): The number of branches in the HRModule.
- block (str): The type of block.
- num_blocks (tuple): The number of blocks in each branch.
The length must be equal to num_branches.
- num_channels (tuple): The number of channels in each branch.
The length must be equal to num_branches.
in_channels (int): Number of input image channels. Normally 3.
conv_cfg (dict): Dictionary to construct and config conv layer.
Default: None.
norm_cfg (dict): Config of norm layer.
Use `SyncBN` by default.
transformer_norm_cfg (dict): Config of transformer norm layer.
Use `LN` by default.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
Example:
>>> from mmpose.models import HRFormer
>>> import torch
>>> extra = dict(
>>> stage1=dict(
>>> num_modules=1,
>>> num_branches=1,
>>> block='BOTTLENECK',
>>> num_blocks=(2, ),
>>> num_channels=(64, )),
>>> stage2=dict(
>>> num_modules=1,
>>> num_branches=2,
>>> block='HRFORMER',
>>> window_sizes=(7, 7),
>>> num_heads=(1, 2),
>>> mlp_ratios=(4, 4),
>>> num_blocks=(2, 2),
>>> num_channels=(32, 64)),
>>> stage3=dict(
>>> num_modules=4,
>>> num_branches=3,
>>> block='HRFORMER',
>>> window_sizes=(7, 7, 7),
>>> num_heads=(1, 2, 4),
>>> mlp_ratios=(4, 4, 4),
>>> num_blocks=(2, 2, 2),
>>> num_channels=(32, 64, 128)),
>>> stage4=dict(
>>> num_modules=2,
>>> num_branches=4,
>>> block='HRFORMER',
>>> window_sizes=(7, 7, 7, 7),
>>> num_heads=(1, 2, 4, 8),
>>> mlp_ratios=(4, 4, 4, 4),
>>> num_blocks=(2, 2, 2, 2),
>>> num_channels=(32, 64, 128, 256)))
>>> self = HRFormer(extra, in_channels=1)
>>> self.eval()
>>> inputs = torch.rand(1, 1, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 32, 8, 8)
(1, 64, 4, 4)
(1, 128, 2, 2)
(1, 256, 1, 1)
"""
blocks_dict = {'BOTTLENECK': Bottleneck, 'HRFORMERBLOCK': HRFormerBlock}
def __init__(self,
extra,
in_channels=3,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
transformer_norm_cfg=dict(type='LN', eps=1e-6),
norm_eval=False,
with_cp=False,
zero_init_residual=False,
frozen_stages=-1):
# stochastic depth
depths = [
extra[stage]['num_blocks'][0] * extra[stage]['num_modules']
for stage in ['stage2', 'stage3', 'stage4']
]
depth_s2, depth_s3, _ = depths
drop_path_rate = extra['drop_path_rate']
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
]
extra['stage2']['drop_path_rates'] = dpr[0:depth_s2]
extra['stage3']['drop_path_rates'] = dpr[depth_s2:depth_s2 + depth_s3]
extra['stage4']['drop_path_rates'] = dpr[depth_s2 + depth_s3:]
# HRFormer use bilinear upsample as default
upsample_cfg = extra.get('upsample', {
'mode': 'bilinear',
'align_corners': False
})
extra['upsample'] = upsample_cfg
self.transformer_norm_cfg = transformer_norm_cfg
self.with_rpe = extra.get('with_rpe', True)
self.with_pad_mask = extra.get('with_pad_mask', False)
super().__init__(extra, in_channels, conv_cfg, norm_cfg, norm_eval,
with_cp, zero_init_residual, frozen_stages)
def _make_stage(self,
layer_config,
num_inchannels,
multiscale_output=True):
"""Make each stage."""
num_modules = layer_config['num_modules']
num_branches = layer_config['num_branches']
num_blocks = layer_config['num_blocks']
num_channels = layer_config['num_channels']
block = self.blocks_dict[layer_config['block']]
num_heads = layer_config['num_heads']
num_window_sizes = layer_config['window_sizes']
num_mlp_ratios = layer_config['mlp_ratios']
drop_path_rates = layer_config['drop_path_rates']
modules = []
for i in range(num_modules):
# multiscale_output is only used at the last module
if not multiscale_output and i == num_modules - 1:
reset_multiscale_output = False
else:
reset_multiscale_output = True
modules.append(
HRFomerModule(
num_branches,
block,
num_blocks,
num_inchannels,
num_channels,
num_heads,
num_window_sizes,
num_mlp_ratios,
reset_multiscale_output,
drop_paths=drop_path_rates[num_blocks[0] *
i:num_blocks[0] * (i + 1)],
with_rpe=self.with_rpe,
with_pad_mask=self.with_pad_mask,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
transformer_norm_cfg=self.transformer_norm_cfg,
with_cp=self.with_cp,
upsample_cfg=self.upsample_cfg))
num_inchannels = modules[-1].get_num_inchannels()
return nn.Sequential(*modules), num_inchannels
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/litehrnet.py
================================================
# ------------------------------------------------------------------------------
# Adapted from https://github.com/HRNet/Lite-HRNet
# Original licence: Apache License 2.0.
# ------------------------------------------------------------------------------
import mmcv
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
build_conv_layer, build_norm_layer, constant_init,
normal_init)
from torch.nn.modules.batchnorm import _BatchNorm
from mmpose.utils import get_root_logger
from ..builder import BACKBONES
from .utils import channel_shuffle, load_checkpoint
class SpatialWeighting(nn.Module):
"""Spatial weighting module.
Args:
channels (int): The channels of the module.
ratio (int): channel reduction ratio.
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
act_cfg (dict): Config dict for activation layer.
Default: (dict(type='ReLU'), dict(type='Sigmoid')).
The last ConvModule uses Sigmoid by default.
"""
def __init__(self,
channels,
ratio=16,
conv_cfg=None,
norm_cfg=None,
act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))):
super().__init__()
if isinstance(act_cfg, dict):
act_cfg = (act_cfg, act_cfg)
assert len(act_cfg) == 2
assert mmcv.is_tuple_of(act_cfg, dict)
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.conv1 = ConvModule(
in_channels=channels,
out_channels=int(channels / ratio),
kernel_size=1,
stride=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg[0])
self.conv2 = ConvModule(
in_channels=int(channels / ratio),
out_channels=channels,
kernel_size=1,
stride=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg[1])
def forward(self, x):
out = self.global_avgpool(x)
out = self.conv1(out)
out = self.conv2(out)
return x * out
class CrossResolutionWeighting(nn.Module):
"""Cross-resolution channel weighting module.
Args:
channels (int): The channels of the module.
ratio (int): channel reduction ratio.
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
act_cfg (dict): Config dict for activation layer.
Default: (dict(type='ReLU'), dict(type='Sigmoid')).
The last ConvModule uses Sigmoid by default.
"""
def __init__(self,
channels,
ratio=16,
conv_cfg=None,
norm_cfg=None,
act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))):
super().__init__()
if isinstance(act_cfg, dict):
act_cfg = (act_cfg, act_cfg)
assert len(act_cfg) == 2
assert mmcv.is_tuple_of(act_cfg, dict)
self.channels = channels
total_channel = sum(channels)
self.conv1 = ConvModule(
in_channels=total_channel,
out_channels=int(total_channel / ratio),
kernel_size=1,
stride=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg[0])
self.conv2 = ConvModule(
in_channels=int(total_channel / ratio),
out_channels=total_channel,
kernel_size=1,
stride=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg[1])
def forward(self, x):
mini_size = x[-1].size()[-2:]
out = [F.adaptive_avg_pool2d(s, mini_size) for s in x[:-1]] + [x[-1]]
out = torch.cat(out, dim=1)
out = self.conv1(out)
out = self.conv2(out)
out = torch.split(out, self.channels, dim=1)
out = [
s * F.interpolate(a, size=s.size()[-2:], mode='nearest')
for s, a in zip(x, out)
]
return out
class ConditionalChannelWeighting(nn.Module):
"""Conditional channel weighting block.
Args:
in_channels (int): The input channels of the block.
stride (int): Stride of the 3x3 convolution layer.
reduce_ratio (int): channel reduction ratio.
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def __init__(self,
in_channels,
stride,
reduce_ratio,
conv_cfg=None,
norm_cfg=dict(type='BN'),
with_cp=False):
super().__init__()
self.with_cp = with_cp
self.stride = stride
assert stride in [1, 2]
branch_channels = [channel // 2 for channel in in_channels]
self.cross_resolution_weighting = CrossResolutionWeighting(
branch_channels,
ratio=reduce_ratio,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg)
self.depthwise_convs = nn.ModuleList([
ConvModule(
channel,
channel,
kernel_size=3,
stride=self.stride,
padding=1,
groups=channel,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None) for channel in branch_channels
])
self.spatial_weighting = nn.ModuleList([
SpatialWeighting(channels=channel, ratio=4)
for channel in branch_channels
])
def forward(self, x):
def _inner_forward(x):
x = [s.chunk(2, dim=1) for s in x]
x1 = [s[0] for s in x]
x2 = [s[1] for s in x]
x2 = self.cross_resolution_weighting(x2)
x2 = [dw(s) for s, dw in zip(x2, self.depthwise_convs)]
x2 = [sw(s) for s, sw in zip(x2, self.spatial_weighting)]
out = [torch.cat([s1, s2], dim=1) for s1, s2 in zip(x1, x2)]
out = [channel_shuffle(s, 2) for s in out]
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
class Stem(nn.Module):
"""Stem network block.
Args:
in_channels (int): The input channels of the block.
stem_channels (int): Output channels of the stem layer.
out_channels (int): The output channels of the block.
expand_ratio (int): adjusts number of channels of the hidden layer
in InvertedResidual by this amount.
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def __init__(self,
in_channels,
stem_channels,
out_channels,
expand_ratio,
conv_cfg=None,
norm_cfg=dict(type='BN'),
with_cp=False):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.with_cp = with_cp
self.conv1 = ConvModule(
in_channels=in_channels,
out_channels=stem_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=dict(type='ReLU'))
mid_channels = int(round(stem_channels * expand_ratio))
branch_channels = stem_channels // 2
if stem_channels == self.out_channels:
inc_channels = self.out_channels - branch_channels
else:
inc_channels = self.out_channels - stem_channels
self.branch1 = nn.Sequential(
ConvModule(
branch_channels,
branch_channels,
kernel_size=3,
stride=2,
padding=1,
groups=branch_channels,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None),
ConvModule(
branch_channels,
inc_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU')),
)
self.expand_conv = ConvModule(
branch_channels,
mid_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'))
self.depthwise_conv = ConvModule(
mid_channels,
mid_channels,
kernel_size=3,
stride=2,
padding=1,
groups=mid_channels,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
self.linear_conv = ConvModule(
mid_channels,
branch_channels
if stem_channels == self.out_channels else stem_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'))
def forward(self, x):
def _inner_forward(x):
x = self.conv1(x)
x1, x2 = x.chunk(2, dim=1)
x2 = self.expand_conv(x2)
x2 = self.depthwise_conv(x2)
x2 = self.linear_conv(x2)
out = torch.cat((self.branch1(x1), x2), dim=1)
out = channel_shuffle(out, 2)
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
class IterativeHead(nn.Module):
"""Extra iterative head for feature learning.
Args:
in_channels (int): The input channels of the block.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
"""
def __init__(self, in_channels, norm_cfg=dict(type='BN')):
super().__init__()
projects = []
num_branchs = len(in_channels)
self.in_channels = in_channels[::-1]
for i in range(num_branchs):
if i != num_branchs - 1:
projects.append(
DepthwiseSeparableConvModule(
in_channels=self.in_channels[i],
out_channels=self.in_channels[i + 1],
kernel_size=3,
stride=1,
padding=1,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
dw_act_cfg=None,
pw_act_cfg=dict(type='ReLU')))
else:
projects.append(
DepthwiseSeparableConvModule(
in_channels=self.in_channels[i],
out_channels=self.in_channels[i],
kernel_size=3,
stride=1,
padding=1,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
dw_act_cfg=None,
pw_act_cfg=dict(type='ReLU')))
self.projects = nn.ModuleList(projects)
def forward(self, x):
x = x[::-1]
y = []
last_x = None
for i, s in enumerate(x):
if last_x is not None:
last_x = F.interpolate(
last_x,
size=s.size()[-2:],
mode='bilinear',
align_corners=True)
s = s + last_x
s = self.projects[i](s)
y.append(s)
last_x = s
return y[::-1]
class ShuffleUnit(nn.Module):
"""InvertedResidual block for ShuffleNetV2 backbone.
Args:
in_channels (int): The input channels of the block.
out_channels (int): The output channels of the block.
stride (int): Stride of the 3x3 convolution layer. Default: 1
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def __init__(self,
in_channels,
out_channels,
stride=1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
with_cp=False):
super().__init__()
self.stride = stride
self.with_cp = with_cp
branch_features = out_channels // 2
if self.stride == 1:
assert in_channels == branch_features * 2, (
f'in_channels ({in_channels}) should equal to '
f'branch_features * 2 ({branch_features * 2}) '
'when stride is 1')
if in_channels != branch_features * 2:
assert self.stride != 1, (
f'stride ({self.stride}) should not equal 1 when '
f'in_channels != branch_features * 2')
if self.stride > 1:
self.branch1 = nn.Sequential(
ConvModule(
in_channels,
in_channels,
kernel_size=3,
stride=self.stride,
padding=1,
groups=in_channels,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None),
ConvModule(
in_channels,
branch_features,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
)
self.branch2 = nn.Sequential(
ConvModule(
in_channels if (self.stride > 1) else branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
branch_features,
branch_features,
kernel_size=3,
stride=self.stride,
padding=1,
groups=branch_features,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None),
ConvModule(
branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
def forward(self, x):
def _inner_forward(x):
if self.stride > 1:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
else:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
out = channel_shuffle(out, 2)
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
class LiteHRModule(nn.Module):
"""High-Resolution Module for LiteHRNet.
It contains conditional channel weighting blocks and
shuffle blocks.
Args:
num_branches (int): Number of branches in the module.
num_blocks (int): Number of blocks in the module.
in_channels (list(int)): Number of input image channels.
reduce_ratio (int): Channel reduction ratio.
module_type (str): 'LITE' or 'NAIVE'
multiscale_output (bool): Whether to output multi-scale features.
with_fuse (bool): Whether to use fuse layers.
conv_cfg (dict): dictionary to construct and config conv layer.
norm_cfg (dict): dictionary to construct and config norm layer.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
def __init__(
self,
num_branches,
num_blocks,
in_channels,
reduce_ratio,
module_type,
multiscale_output=False,
with_fuse=True,
conv_cfg=None,
norm_cfg=dict(type='BN'),
with_cp=False,
):
super().__init__()
self._check_branches(num_branches, in_channels)
self.in_channels = in_channels
self.num_branches = num_branches
self.module_type = module_type
self.multiscale_output = multiscale_output
self.with_fuse = with_fuse
self.norm_cfg = norm_cfg
self.conv_cfg = conv_cfg
self.with_cp = with_cp
if self.module_type.upper() == 'LITE':
self.layers = self._make_weighting_blocks(num_blocks, reduce_ratio)
elif self.module_type.upper() == 'NAIVE':
self.layers = self._make_naive_branches(num_branches, num_blocks)
else:
raise ValueError("module_type should be either 'LITE' or 'NAIVE'.")
if self.with_fuse:
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU()
def _check_branches(self, num_branches, in_channels):
"""Check input to avoid ValueError."""
if num_branches != len(in_channels):
error_msg = f'NUM_BRANCHES({num_branches}) ' \
f'!= NUM_INCHANNELS({len(in_channels)})'
raise ValueError(error_msg)
def _make_weighting_blocks(self, num_blocks, reduce_ratio, stride=1):
"""Make channel weighting blocks."""
layers = []
for i in range(num_blocks):
layers.append(
ConditionalChannelWeighting(
self.in_channels,
stride=stride,
reduce_ratio=reduce_ratio,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
with_cp=self.with_cp))
return nn.Sequential(*layers)
def _make_one_branch(self, branch_index, num_blocks, stride=1):
"""Make one branch."""
layers = []
layers.append(
ShuffleUnit(
self.in_channels[branch_index],
self.in_channels[branch_index],
stride=stride,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=dict(type='ReLU'),
with_cp=self.with_cp))
for i in range(1, num_blocks):
layers.append(
ShuffleUnit(
self.in_channels[branch_index],
self.in_channels[branch_index],
stride=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=dict(type='ReLU'),
with_cp=self.with_cp))
return nn.Sequential(*layers)
def _make_naive_branches(self, num_branches, num_blocks):
"""Make branches."""
branches = []
for i in range(num_branches):
branches.append(self._make_one_branch(i, num_blocks))
return nn.ModuleList(branches)
def _make_fuse_layers(self):
"""Make fuse layer."""
if self.num_branches == 1:
return None
num_branches = self.num_branches
in_channels = self.in_channels
fuse_layers = []
num_out_branches = num_branches if self.multiscale_output else 1
for i in range(num_out_branches):
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[i],
kernel_size=1,
stride=1,
padding=0,
bias=False),
build_norm_layer(self.norm_cfg, in_channels[i])[1],
nn.Upsample(
scale_factor=2**(j - i), mode='nearest')))
elif j == i:
fuse_layer.append(None)
else:
conv_downsamples = []
for k in range(i - j):
if k == i - j - 1:
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[j],
kernel_size=3,
stride=2,
padding=1,
groups=in_channels[j],
bias=False),
build_norm_layer(self.norm_cfg,
in_channels[j])[1],
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[i],
kernel_size=1,
stride=1,
padding=0,
bias=False),
build_norm_layer(self.norm_cfg,
in_channels[i])[1]))
else:
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[j],
kernel_size=3,
stride=2,
padding=1,
groups=in_channels[j],
bias=False),
build_norm_layer(self.norm_cfg,
in_channels[j])[1],
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[j],
kernel_size=1,
stride=1,
padding=0,
bias=False),
build_norm_layer(self.norm_cfg,
in_channels[j])[1],
nn.ReLU(inplace=True)))
fuse_layer.append(nn.Sequential(*conv_downsamples))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def forward(self, x):
"""Forward function."""
if self.num_branches == 1:
return [self.layers[0](x[0])]
if self.module_type.upper() == 'LITE':
out = self.layers(x)
elif self.module_type.upper() == 'NAIVE':
for i in range(self.num_branches):
x[i] = self.layers[i](x[i])
out = x
if self.with_fuse:
out_fuse = []
for i in range(len(self.fuse_layers)):
# `y = 0` will lead to decreased accuracy (0.5~1 mAP)
y = out[0] if i == 0 else self.fuse_layers[i][0](out[0])
for j in range(self.num_branches):
if i == j:
y += out[j]
else:
y += self.fuse_layers[i][j](out[j])
out_fuse.append(self.relu(y))
out = out_fuse
if not self.multiscale_output:
out = [out[0]]
return out
@BACKBONES.register_module()
class LiteHRNet(nn.Module):
"""Lite-HRNet backbone.
`Lite-HRNet: A Lightweight High-Resolution Network
`_.
Code adapted from 'https://github.com/HRNet/Lite-HRNet'.
Args:
extra (dict): detailed configuration for each stage of HRNet.
in_channels (int): Number of input image channels. Default: 3.
conv_cfg (dict): dictionary to construct and config conv layer.
norm_cfg (dict): dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
Example:
>>> from mmpose.models import LiteHRNet
>>> import torch
>>> extra=dict(
>>> stem=dict(stem_channels=32, out_channels=32, expand_ratio=1),
>>> num_stages=3,
>>> stages_spec=dict(
>>> num_modules=(2, 4, 2),
>>> num_branches=(2, 3, 4),
>>> num_blocks=(2, 2, 2),
>>> module_type=('LITE', 'LITE', 'LITE'),
>>> with_fuse=(True, True, True),
>>> reduce_ratios=(8, 8, 8),
>>> num_channels=(
>>> (40, 80),
>>> (40, 80, 160),
>>> (40, 80, 160, 320),
>>> )),
>>> with_head=False)
>>> self = LiteHRNet(extra, in_channels=1)
>>> self.eval()
>>> inputs = torch.rand(1, 1, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 40, 8, 8)
"""
def __init__(self,
extra,
in_channels=3,
conv_cfg=None,
norm_cfg=dict(type='BN'),
norm_eval=False,
with_cp=False):
super().__init__()
self.extra = extra
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.stem = Stem(
in_channels,
stem_channels=self.extra['stem']['stem_channels'],
out_channels=self.extra['stem']['out_channels'],
expand_ratio=self.extra['stem']['expand_ratio'],
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg)
self.num_stages = self.extra['num_stages']
self.stages_spec = self.extra['stages_spec']
num_channels_last = [
self.stem.out_channels,
]
for i in range(self.num_stages):
num_channels = self.stages_spec['num_channels'][i]
num_channels = [num_channels[i] for i in range(len(num_channels))]
setattr(
self, f'transition{i}',
self._make_transition_layer(num_channels_last, num_channels))
stage, num_channels_last = self._make_stage(
self.stages_spec, i, num_channels, multiscale_output=True)
setattr(self, f'stage{i}', stage)
self.with_head = self.extra['with_head']
if self.with_head:
self.head_layer = IterativeHead(
in_channels=num_channels_last,
norm_cfg=self.norm_cfg,
)
def _make_transition_layer(self, num_channels_pre_layer,
num_channels_cur_layer):
"""Make transition layer."""
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
num_channels_pre_layer[i],
num_channels_pre_layer[i],
kernel_size=3,
stride=1,
padding=1,
groups=num_channels_pre_layer[i],
bias=False),
build_norm_layer(self.norm_cfg,
num_channels_pre_layer[i])[1],
build_conv_layer(
self.conv_cfg,
num_channels_pre_layer[i],
num_channels_cur_layer[i],
kernel_size=1,
stride=1,
padding=0,
bias=False),
build_norm_layer(self.norm_cfg,
num_channels_cur_layer[i])[1],
nn.ReLU()))
else:
transition_layers.append(None)
else:
conv_downsamples = []
for j in range(i + 1 - num_branches_pre):
in_channels = num_channels_pre_layer[-1]
out_channels = num_channels_cur_layer[i] \
if j == i - num_branches_pre else in_channels
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=1,
groups=in_channels,
bias=False),
build_norm_layer(self.norm_cfg, in_channels)[1],
build_conv_layer(
self.conv_cfg,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False),
build_norm_layer(self.norm_cfg, out_channels)[1],
nn.ReLU()))
transition_layers.append(nn.Sequential(*conv_downsamples))
return nn.ModuleList(transition_layers)
def _make_stage(self,
stages_spec,
stage_index,
in_channels,
multiscale_output=True):
num_modules = stages_spec['num_modules'][stage_index]
num_branches = stages_spec['num_branches'][stage_index]
num_blocks = stages_spec['num_blocks'][stage_index]
reduce_ratio = stages_spec['reduce_ratios'][stage_index]
with_fuse = stages_spec['with_fuse'][stage_index]
module_type = stages_spec['module_type'][stage_index]
modules = []
for i in range(num_modules):
# multi_scale_output is only used last module
if not multiscale_output and i == num_modules - 1:
reset_multiscale_output = False
else:
reset_multiscale_output = True
modules.append(
LiteHRModule(
num_branches,
num_blocks,
in_channels,
reduce_ratio,
module_type,
multiscale_output=reset_multiscale_output,
with_fuse=with_fuse,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
with_cp=self.with_cp))
in_channels = modules[-1].in_channels
return nn.Sequential(*modules), in_channels
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, std=0.001)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
"""Forward function."""
x = self.stem(x)
y_list = [x]
for i in range(self.num_stages):
x_list = []
transition = getattr(self, f'transition{i}')
for j in range(self.stages_spec['num_branches'][i]):
if transition[j]:
if j >= len(y_list):
x_list.append(transition[j](y_list[-1]))
else:
x_list.append(transition[j](y_list[j]))
else:
x_list.append(y_list[j])
y_list = getattr(self, f'stage{i}')(x_list)
x = y_list
if self.with_head:
x = self.head_layer(x)
return [x[0]]
def train(self, mode=True):
"""Convert the model into training mode."""
super().train(mode)
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/mobilenet_v2.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import logging
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule, constant_init, kaiming_init
from torch.nn.modules.batchnorm import _BatchNorm
from ..builder import BACKBONES
from .base_backbone import BaseBackbone
from .utils import load_checkpoint, make_divisible
class InvertedResidual(nn.Module):
"""InvertedResidual block for MobileNetV2.
Args:
in_channels (int): The input channels of the InvertedResidual block.
out_channels (int): The output channels of the InvertedResidual block.
stride (int): Stride of the middle (first) 3x3 convolution.
expand_ratio (int): adjusts number of channels of the hidden layer
in InvertedResidual by this amount.
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU6').
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def __init__(self,
in_channels,
out_channels,
stride,
expand_ratio,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6'),
with_cp=False):
# Protect mutable default arguments
norm_cfg = copy.deepcopy(norm_cfg)
act_cfg = copy.deepcopy(act_cfg)
super().__init__()
self.stride = stride
assert stride in [1, 2], f'stride must in [1, 2]. ' \
f'But received {stride}.'
self.with_cp = with_cp
self.use_res_connect = self.stride == 1 and in_channels == out_channels
hidden_dim = int(round(in_channels * expand_ratio))
layers = []
if expand_ratio != 1:
layers.append(
ConvModule(
in_channels=in_channels,
out_channels=hidden_dim,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
layers.extend([
ConvModule(
in_channels=hidden_dim,
out_channels=hidden_dim,
kernel_size=3,
stride=stride,
padding=1,
groups=hidden_dim,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
in_channels=hidden_dim,
out_channels=out_channels,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
])
self.conv = nn.Sequential(*layers)
def forward(self, x):
def _inner_forward(x):
if self.use_res_connect:
return x + self.conv(x)
return self.conv(x)
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
@BACKBONES.register_module()
class MobileNetV2(BaseBackbone):
"""MobileNetV2 backbone.
Args:
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Default: 1.0.
out_indices (None or Sequence[int]): Output from which stages.
Default: (7, ).
frozen_stages (int): Stages to be frozen (all param fixed).
Default: -1, which means not freezing any parameters.
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU6').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
# Parameters to build layers. 4 parameters are needed to construct a
# layer, from left to right: expand_ratio, channel, num_blocks, stride.
arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2],
[6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2],
[6, 320, 1, 1]]
def __init__(self,
widen_factor=1.,
out_indices=(7, ),
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6'),
norm_eval=False,
with_cp=False):
# Protect mutable default arguments
norm_cfg = copy.deepcopy(norm_cfg)
act_cfg = copy.deepcopy(act_cfg)
super().__init__()
self.widen_factor = widen_factor
self.out_indices = out_indices
for index in out_indices:
if index not in range(0, 8):
raise ValueError('the item in out_indices must in '
f'range(0, 8). But received {index}')
if frozen_stages not in range(-1, 8):
raise ValueError('frozen_stages must be in range(-1, 8). '
f'But received {frozen_stages}')
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.in_channels = make_divisible(32 * widen_factor, 8)
self.conv1 = ConvModule(
in_channels=3,
out_channels=self.in_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.layers = []
for i, layer_cfg in enumerate(self.arch_settings):
expand_ratio, channel, num_blocks, stride = layer_cfg
out_channels = make_divisible(channel * widen_factor, 8)
inverted_res_layer = self.make_layer(
out_channels=out_channels,
num_blocks=num_blocks,
stride=stride,
expand_ratio=expand_ratio)
layer_name = f'layer{i + 1}'
self.add_module(layer_name, inverted_res_layer)
self.layers.append(layer_name)
if widen_factor > 1.0:
self.out_channel = int(1280 * widen_factor)
else:
self.out_channel = 1280
layer = ConvModule(
in_channels=self.in_channels,
out_channels=self.out_channel,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.add_module('conv2', layer)
self.layers.append('conv2')
def make_layer(self, out_channels, num_blocks, stride, expand_ratio):
"""Stack InvertedResidual blocks to build a layer for MobileNetV2.
Args:
out_channels (int): out_channels of block.
num_blocks (int): number of blocks.
stride (int): stride of the first block. Default: 1
expand_ratio (int): Expand the number of channels of the
hidden layer in InvertedResidual by this ratio. Default: 6.
"""
layers = []
for i in range(num_blocks):
if i >= 1:
stride = 1
layers.append(
InvertedResidual(
self.in_channels,
out_channels,
stride,
expand_ratio=expand_ratio,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
with_cp=self.with_cp))
self.in_channels = out_channels
return nn.Sequential(*layers)
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
x = self.conv1(x)
outs = []
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
return tuple(outs)
def _freeze_stages(self):
if self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
layer = getattr(self, f'layer{i}')
layer.eval()
for param in layer.parameters():
param.requires_grad = False
def train(self, mode=True):
super().train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/mobilenet_v3.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import logging
import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, kaiming_init
from torch.nn.modules.batchnorm import _BatchNorm
from ..builder import BACKBONES
from .base_backbone import BaseBackbone
from .utils import InvertedResidual, load_checkpoint
@BACKBONES.register_module()
class MobileNetV3(BaseBackbone):
"""MobileNetV3 backbone.
Args:
arch (str): Architecture of mobilnetv3, from {small, big}.
Default: small.
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
out_indices (None or Sequence[int]): Output from which stages.
Default: (-1, ), which means output tensors from final stage.
frozen_stages (int): Stages to be frozen (all param fixed).
Default: -1, which means not freezing any parameters.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed.
Default: False.
"""
# Parameters to build each block:
# [kernel size, mid channels, out channels, with_se, act type, stride]
arch_settings = {
'small': [[3, 16, 16, True, 'ReLU', 2],
[3, 72, 24, False, 'ReLU', 2],
[3, 88, 24, False, 'ReLU', 1],
[5, 96, 40, True, 'HSwish', 2],
[5, 240, 40, True, 'HSwish', 1],
[5, 240, 40, True, 'HSwish', 1],
[5, 120, 48, True, 'HSwish', 1],
[5, 144, 48, True, 'HSwish', 1],
[5, 288, 96, True, 'HSwish', 2],
[5, 576, 96, True, 'HSwish', 1],
[5, 576, 96, True, 'HSwish', 1]],
'big': [[3, 16, 16, False, 'ReLU', 1],
[3, 64, 24, False, 'ReLU', 2],
[3, 72, 24, False, 'ReLU', 1],
[5, 72, 40, True, 'ReLU', 2],
[5, 120, 40, True, 'ReLU', 1],
[5, 120, 40, True, 'ReLU', 1],
[3, 240, 80, False, 'HSwish', 2],
[3, 200, 80, False, 'HSwish', 1],
[3, 184, 80, False, 'HSwish', 1],
[3, 184, 80, False, 'HSwish', 1],
[3, 480, 112, True, 'HSwish', 1],
[3, 672, 112, True, 'HSwish', 1],
[5, 672, 160, True, 'HSwish', 1],
[5, 672, 160, True, 'HSwish', 2],
[5, 960, 160, True, 'HSwish', 1]]
} # yapf: disable
def __init__(self,
arch='small',
conv_cfg=None,
norm_cfg=dict(type='BN'),
out_indices=(-1, ),
frozen_stages=-1,
norm_eval=False,
with_cp=False):
# Protect mutable default arguments
norm_cfg = copy.deepcopy(norm_cfg)
super().__init__()
assert arch in self.arch_settings
for index in out_indices:
if index not in range(-len(self.arch_settings[arch]),
len(self.arch_settings[arch])):
raise ValueError('the item in out_indices must in '
f'range(0, {len(self.arch_settings[arch])}). '
f'But received {index}')
if frozen_stages not in range(-1, len(self.arch_settings[arch])):
raise ValueError('frozen_stages must be in range(-1, '
f'{len(self.arch_settings[arch])}). '
f'But received {frozen_stages}')
self.arch = arch
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.norm_eval = norm_eval
self.with_cp = with_cp
self.in_channels = 16
self.conv1 = ConvModule(
in_channels=3,
out_channels=self.in_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=dict(type='HSwish'))
self.layers = self._make_layer()
self.feat_dim = self.arch_settings[arch][-1][2]
def _make_layer(self):
layers = []
layer_setting = self.arch_settings[self.arch]
for i, params in enumerate(layer_setting):
(kernel_size, mid_channels, out_channels, with_se, act,
stride) = params
if with_se:
se_cfg = dict(
channels=mid_channels,
ratio=4,
act_cfg=(dict(type='ReLU'), dict(type='HSigmoid')))
else:
se_cfg = None
layer = InvertedResidual(
in_channels=self.in_channels,
out_channels=out_channels,
mid_channels=mid_channels,
kernel_size=kernel_size,
stride=stride,
se_cfg=se_cfg,
with_expand_conv=True,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=dict(type=act),
with_cp=self.with_cp)
self.in_channels = out_channels
layer_name = f'layer{i + 1}'
self.add_module(layer_name, layer)
layers.append(layer_name)
return layers
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
x = self.conv1(x)
outs = []
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.out_indices or \
i - len(self.layers) in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
return tuple(outs)
def _freeze_stages(self):
if self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
layer = getattr(self, f'layer{i}')
layer.eval()
for param in layer.parameters():
param.requires_grad = False
def train(self, mode=True):
super().train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/mspn.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import copy as cp
from collections import OrderedDict
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import (ConvModule, MaxPool2d, constant_init, kaiming_init,
normal_init)
from mmcv.runner.checkpoint import load_state_dict
from mmpose.utils import get_root_logger
from ..builder import BACKBONES
from .base_backbone import BaseBackbone
from .resnet import Bottleneck as _Bottleneck
from .utils.utils import get_state_dict
class Bottleneck(_Bottleneck):
expansion = 4
"""Bottleneck block for MSPN.
Args:
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
stride (int): stride of the block. Default: 1
downsample (nn.Module): downsample operation on identity branch.
Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
"""
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__(in_channels, out_channels * 4, **kwargs)
class DownsampleModule(nn.Module):
"""Downsample module for MSPN.
Args:
block (nn.Module): Downsample block.
num_blocks (list): Number of blocks in each downsample unit.
num_units (int): Numbers of downsample units. Default: 4
has_skip (bool): Have skip connections from prior upsample
module or not. Default:False
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
in_channels (int): Number of channels of the input feature to
downsample module. Default: 64
"""
def __init__(self,
block,
num_blocks,
num_units=4,
has_skip=False,
norm_cfg=dict(type='BN'),
in_channels=64):
# Protect mutable default arguments
norm_cfg = cp.deepcopy(norm_cfg)
super().__init__()
self.has_skip = has_skip
self.in_channels = in_channels
assert len(num_blocks) == num_units
self.num_blocks = num_blocks
self.num_units = num_units
self.norm_cfg = norm_cfg
self.layer1 = self._make_layer(block, in_channels, num_blocks[0])
for i in range(1, num_units):
module_name = f'layer{i + 1}'
self.add_module(
module_name,
self._make_layer(
block, in_channels * pow(2, i), num_blocks[i], stride=2))
def _make_layer(self, block, out_channels, blocks, stride=1):
downsample = None
if stride != 1 or self.in_channels != out_channels * block.expansion:
downsample = ConvModule(
self.in_channels,
out_channels * block.expansion,
kernel_size=1,
stride=stride,
padding=0,
norm_cfg=self.norm_cfg,
act_cfg=None,
inplace=True)
units = list()
units.append(
block(
self.in_channels,
out_channels,
stride=stride,
downsample=downsample,
norm_cfg=self.norm_cfg))
self.in_channels = out_channels * block.expansion
for _ in range(1, blocks):
units.append(block(self.in_channels, out_channels))
return nn.Sequential(*units)
def forward(self, x, skip1, skip2):
out = list()
for i in range(self.num_units):
module_name = f'layer{i + 1}'
module_i = getattr(self, module_name)
x = module_i(x)
if self.has_skip:
x = x + skip1[i] + skip2[i]
out.append(x)
out.reverse()
return tuple(out)
class UpsampleUnit(nn.Module):
"""Upsample unit for upsample module.
Args:
ind (int): Indicates whether to interpolate (>0) and whether to
generate feature map for the next hourglass-like module.
num_units (int): Number of units that form a upsample module. Along
with ind and gen_cross_conv, nm_units is used to decide whether
to generate feature map for the next hourglass-like module.
in_channels (int): Channel number of the skip-in feature maps from
the corresponding downsample unit.
unit_channels (int): Channel number in this unit. Default:256.
gen_skip: (bool): Whether or not to generate skips for the posterior
downsample module. Default:False
gen_cross_conv (bool): Whether to generate feature map for the next
hourglass-like module. Default:False
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
out_channels (int): Number of channels of feature output by upsample
module. Must equal to in_channels of downsample module. Default:64
"""
def __init__(self,
ind,
num_units,
in_channels,
unit_channels=256,
gen_skip=False,
gen_cross_conv=False,
norm_cfg=dict(type='BN'),
out_channels=64):
# Protect mutable default arguments
norm_cfg = cp.deepcopy(norm_cfg)
super().__init__()
self.num_units = num_units
self.norm_cfg = norm_cfg
self.in_skip = ConvModule(
in_channels,
unit_channels,
kernel_size=1,
stride=1,
padding=0,
norm_cfg=self.norm_cfg,
act_cfg=None,
inplace=True)
self.relu = nn.ReLU(inplace=True)
self.ind = ind
if self.ind > 0:
self.up_conv = ConvModule(
unit_channels,
unit_channels,
kernel_size=1,
stride=1,
padding=0,
norm_cfg=self.norm_cfg,
act_cfg=None,
inplace=True)
self.gen_skip = gen_skip
if self.gen_skip:
self.out_skip1 = ConvModule(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0,
norm_cfg=self.norm_cfg,
inplace=True)
self.out_skip2 = ConvModule(
unit_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0,
norm_cfg=self.norm_cfg,
inplace=True)
self.gen_cross_conv = gen_cross_conv
if self.ind == num_units - 1 and self.gen_cross_conv:
self.cross_conv = ConvModule(
unit_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
norm_cfg=self.norm_cfg,
inplace=True)
def forward(self, x, up_x):
out = self.in_skip(x)
if self.ind > 0:
up_x = F.interpolate(
up_x,
size=(x.size(2), x.size(3)),
mode='bilinear',
align_corners=True)
up_x = self.up_conv(up_x)
out = out + up_x
out = self.relu(out)
skip1 = None
skip2 = None
if self.gen_skip:
skip1 = self.out_skip1(x)
skip2 = self.out_skip2(out)
cross_conv = None
if self.ind == self.num_units - 1 and self.gen_cross_conv:
cross_conv = self.cross_conv(out)
return out, skip1, skip2, cross_conv
class UpsampleModule(nn.Module):
"""Upsample module for MSPN.
Args:
unit_channels (int): Channel number in the upsample units.
Default:256.
num_units (int): Numbers of upsample units. Default: 4
gen_skip (bool): Whether to generate skip for posterior downsample
module or not. Default:False
gen_cross_conv (bool): Whether to generate feature map for the next
hourglass-like module. Default:False
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
out_channels (int): Number of channels of feature output by upsample
module. Must equal to in_channels of downsample module. Default:64
"""
def __init__(self,
unit_channels=256,
num_units=4,
gen_skip=False,
gen_cross_conv=False,
norm_cfg=dict(type='BN'),
out_channels=64):
# Protect mutable default arguments
norm_cfg = cp.deepcopy(norm_cfg)
super().__init__()
self.in_channels = list()
for i in range(num_units):
self.in_channels.append(Bottleneck.expansion * out_channels *
pow(2, i))
self.in_channels.reverse()
self.num_units = num_units
self.gen_skip = gen_skip
self.gen_cross_conv = gen_cross_conv
self.norm_cfg = norm_cfg
for i in range(num_units):
module_name = f'up{i + 1}'
self.add_module(
module_name,
UpsampleUnit(
i,
self.num_units,
self.in_channels[i],
unit_channels,
self.gen_skip,
self.gen_cross_conv,
norm_cfg=self.norm_cfg,
out_channels=64))
def forward(self, x):
out = list()
skip1 = list()
skip2 = list()
cross_conv = None
for i in range(self.num_units):
module_i = getattr(self, f'up{i + 1}')
if i == 0:
outi, skip1_i, skip2_i, _ = module_i(x[i], None)
elif i == self.num_units - 1:
outi, skip1_i, skip2_i, cross_conv = module_i(x[i], out[i - 1])
else:
outi, skip1_i, skip2_i, _ = module_i(x[i], out[i - 1])
out.append(outi)
skip1.append(skip1_i)
skip2.append(skip2_i)
skip1.reverse()
skip2.reverse()
return out, skip1, skip2, cross_conv
class SingleStageNetwork(nn.Module):
"""Single_stage Network.
Args:
unit_channels (int): Channel number in the upsample units. Default:256.
num_units (int): Numbers of downsample/upsample units. Default: 4
gen_skip (bool): Whether to generate skip for posterior downsample
module or not. Default:False
gen_cross_conv (bool): Whether to generate feature map for the next
hourglass-like module. Default:False
has_skip (bool): Have skip connections from prior upsample
module or not. Default:False
num_blocks (list): Number of blocks in each downsample unit.
Default: [2, 2, 2, 2] Note: Make sure num_units==len(num_blocks)
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
in_channels (int): Number of channels of the feature from ResNetTop.
Default: 64.
"""
def __init__(self,
has_skip=False,
gen_skip=False,
gen_cross_conv=False,
unit_channels=256,
num_units=4,
num_blocks=[2, 2, 2, 2],
norm_cfg=dict(type='BN'),
in_channels=64):
# Protect mutable default arguments
norm_cfg = cp.deepcopy(norm_cfg)
num_blocks = cp.deepcopy(num_blocks)
super().__init__()
assert len(num_blocks) == num_units
self.has_skip = has_skip
self.gen_skip = gen_skip
self.gen_cross_conv = gen_cross_conv
self.num_units = num_units
self.unit_channels = unit_channels
self.num_blocks = num_blocks
self.norm_cfg = norm_cfg
self.downsample = DownsampleModule(Bottleneck, num_blocks, num_units,
has_skip, norm_cfg, in_channels)
self.upsample = UpsampleModule(unit_channels, num_units, gen_skip,
gen_cross_conv, norm_cfg, in_channels)
def forward(self, x, skip1, skip2):
mid = self.downsample(x, skip1, skip2)
out, skip1, skip2, cross_conv = self.upsample(mid)
return out, skip1, skip2, cross_conv
class ResNetTop(nn.Module):
"""ResNet top for MSPN.
Args:
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
channels (int): Number of channels of the feature output by ResNetTop.
"""
def __init__(self, norm_cfg=dict(type='BN'), channels=64):
# Protect mutable default arguments
norm_cfg = cp.deepcopy(norm_cfg)
super().__init__()
self.top = nn.Sequential(
ConvModule(
3,
channels,
kernel_size=7,
stride=2,
padding=3,
norm_cfg=norm_cfg,
inplace=True), MaxPool2d(kernel_size=3, stride=2, padding=1))
def forward(self, img):
return self.top(img)
@BACKBONES.register_module()
class MSPN(BaseBackbone):
"""MSPN backbone. Paper ref: Li et al. "Rethinking on Multi-Stage Networks
for Human Pose Estimation" (CVPR 2020).
Args:
unit_channels (int): Number of Channels in an upsample unit.
Default: 256
num_stages (int): Number of stages in a multi-stage MSPN. Default: 4
num_units (int): Number of downsample/upsample units in a single-stage
network. Default: 4
Note: Make sure num_units == len(self.num_blocks)
num_blocks (list): Number of bottlenecks in each
downsample unit. Default: [2, 2, 2, 2]
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
res_top_channels (int): Number of channels of feature from ResNetTop.
Default: 64.
Example:
>>> from mmpose.models import MSPN
>>> import torch
>>> self = MSPN(num_stages=2,num_units=2,num_blocks=[2,2])
>>> self.eval()
>>> inputs = torch.rand(1, 3, 511, 511)
>>> level_outputs = self.forward(inputs)
>>> for level_output in level_outputs:
... for feature in level_output:
... print(tuple(feature.shape))
...
(1, 256, 64, 64)
(1, 256, 128, 128)
(1, 256, 64, 64)
(1, 256, 128, 128)
"""
def __init__(self,
unit_channels=256,
num_stages=4,
num_units=4,
num_blocks=[2, 2, 2, 2],
norm_cfg=dict(type='BN'),
res_top_channels=64):
# Protect mutable default arguments
norm_cfg = cp.deepcopy(norm_cfg)
num_blocks = cp.deepcopy(num_blocks)
super().__init__()
self.unit_channels = unit_channels
self.num_stages = num_stages
self.num_units = num_units
self.num_blocks = num_blocks
self.norm_cfg = norm_cfg
assert self.num_stages > 0
assert self.num_units > 1
assert self.num_units == len(self.num_blocks)
self.top = ResNetTop(norm_cfg=norm_cfg)
self.multi_stage_mspn = nn.ModuleList([])
for i in range(self.num_stages):
if i == 0:
has_skip = False
else:
has_skip = True
if i != self.num_stages - 1:
gen_skip = True
gen_cross_conv = True
else:
gen_skip = False
gen_cross_conv = False
self.multi_stage_mspn.append(
SingleStageNetwork(has_skip, gen_skip, gen_cross_conv,
unit_channels, num_units, num_blocks,
norm_cfg, res_top_channels))
def forward(self, x):
"""Model forward function."""
out_feats = []
skip1 = None
skip2 = None
x = self.top(x)
for i in range(self.num_stages):
out, skip1, skip2, x = self.multi_stage_mspn[i](x, skip1, skip2)
out_feats.append(out)
return out_feats
def init_weights(self, pretrained=None):
"""Initialize model weights."""
if isinstance(pretrained, str):
logger = get_root_logger()
state_dict_tmp = get_state_dict(pretrained)
state_dict = OrderedDict()
state_dict['top'] = OrderedDict()
state_dict['bottlenecks'] = OrderedDict()
for k, v in state_dict_tmp.items():
if k.startswith('layer'):
if 'downsample.0' in k:
state_dict['bottlenecks'][k.replace(
'downsample.0', 'downsample.conv')] = v
elif 'downsample.1' in k:
state_dict['bottlenecks'][k.replace(
'downsample.1', 'downsample.bn')] = v
else:
state_dict['bottlenecks'][k] = v
elif k.startswith('conv1'):
state_dict['top'][k.replace('conv1', 'top.0.conv')] = v
elif k.startswith('bn1'):
state_dict['top'][k.replace('bn1', 'top.0.bn')] = v
load_state_dict(
self.top, state_dict['top'], strict=False, logger=logger)
for i in range(self.num_stages):
load_state_dict(
self.multi_stage_mspn[i].downsample,
state_dict['bottlenecks'],
strict=False,
logger=logger)
else:
for m in self.multi_stage_mspn.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
elif isinstance(m, nn.Linear):
normal_init(m, std=0.01)
for m in self.top.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/regnet.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import numpy as np
import torch.nn as nn
from mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES
from .resnet import ResNet
from .resnext import Bottleneck
@BACKBONES.register_module()
class RegNet(ResNet):
"""RegNet backbone.
More details can be found in `paper `__ .
Args:
arch (dict): The parameter of RegNets.
- w0 (int): initial width
- wa (float): slope of width
- wm (float): quantization parameter to quantize the width
- depth (int): depth of the backbone
- group_w (int): width of group
- bot_mul (float): bottleneck ratio, i.e. expansion of bottleneck.
strides (Sequence[int]): Strides of the first block of each stage.
base_channels (int): Base channels after stem layer.
in_channels (int): Number of input image channels. Default: 3.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer. Default: "pytorch".
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters. Default: -1.
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True).
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
Example:
>>> from mmpose.models import RegNet
>>> import torch
>>> self = RegNet(
arch=dict(
w0=88,
wa=26.31,
wm=2.25,
group_w=48,
depth=25,
bot_mul=1.0),
out_indices=(0, 1, 2, 3))
>>> self.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 96, 8, 8)
(1, 192, 4, 4)
(1, 432, 2, 2)
(1, 1008, 1, 1)
"""
arch_settings = {
'regnetx_400mf':
dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22, bot_mul=1.0),
'regnetx_800mf':
dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16, bot_mul=1.0),
'regnetx_1.6gf':
dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18, bot_mul=1.0),
'regnetx_3.2gf':
dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25, bot_mul=1.0),
'regnetx_4.0gf':
dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23, bot_mul=1.0),
'regnetx_6.4gf':
dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17, bot_mul=1.0),
'regnetx_8.0gf':
dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23, bot_mul=1.0),
'regnetx_12gf':
dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, bot_mul=1.0),
}
def __init__(self,
arch,
in_channels=3,
stem_channels=32,
base_channels=32,
strides=(2, 2, 2, 2),
dilations=(1, 1, 1, 1),
out_indices=(3, ),
style='pytorch',
deep_stem=False,
avg_down=False,
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=False,
with_cp=False,
zero_init_residual=True):
# Protect mutable default arguments
norm_cfg = copy.deepcopy(norm_cfg)
super(ResNet, self).__init__()
# Generate RegNet parameters first
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'"arch": "{arch}" is not one of the' \
' arch_settings'
arch = self.arch_settings[arch]
elif not isinstance(arch, dict):
raise TypeError('Expect "arch" to be either a string '
f'or a dict, got {type(arch)}')
widths, num_stages = self.generate_regnet(
arch['w0'],
arch['wa'],
arch['wm'],
arch['depth'],
)
# Convert to per stage format
stage_widths, stage_blocks = self.get_stages_from_blocks(widths)
# Generate group widths and bot muls
group_widths = [arch['group_w'] for _ in range(num_stages)]
self.bottleneck_ratio = [arch['bot_mul'] for _ in range(num_stages)]
# Adjust the compatibility of stage_widths and group_widths
stage_widths, group_widths = self.adjust_width_group(
stage_widths, self.bottleneck_ratio, group_widths)
# Group params by stage
self.stage_widths = stage_widths
self.group_widths = group_widths
self.depth = sum(stage_blocks)
self.stem_channels = stem_channels
self.base_channels = base_channels
self.num_stages = num_stages
assert 1 <= num_stages <= 4
self.strides = strides
self.dilations = dilations
assert len(strides) == len(dilations) == num_stages
self.out_indices = out_indices
assert max(out_indices) < num_stages
self.style = style
self.deep_stem = deep_stem
if self.deep_stem:
raise NotImplementedError(
'deep_stem has not been implemented for RegNet')
self.avg_down = avg_down
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.with_cp = with_cp
self.norm_eval = norm_eval
self.zero_init_residual = zero_init_residual
self.stage_blocks = stage_blocks[:num_stages]
self._make_stem_layer(in_channels, stem_channels)
_in_channels = stem_channels
self.res_layers = []
for i, num_blocks in enumerate(self.stage_blocks):
stride = self.strides[i]
dilation = self.dilations[i]
group_width = self.group_widths[i]
width = int(round(self.stage_widths[i] * self.bottleneck_ratio[i]))
stage_groups = width // group_width
res_layer = self.make_res_layer(
block=Bottleneck,
num_blocks=num_blocks,
in_channels=_in_channels,
out_channels=self.stage_widths[i],
expansion=1,
stride=stride,
dilation=dilation,
style=self.style,
avg_down=self.avg_down,
with_cp=self.with_cp,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
base_channels=self.stage_widths[i],
groups=stage_groups,
width_per_group=group_width)
_in_channels = self.stage_widths[i]
layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self._freeze_stages()
self.feat_dim = stage_widths[-1]
def _make_stem_layer(self, in_channels, base_channels):
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels,
base_channels,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, base_channels, postfix=1)
self.add_module(self.norm1_name, norm1)
self.relu = nn.ReLU(inplace=True)
@staticmethod
def generate_regnet(initial_width,
width_slope,
width_parameter,
depth,
divisor=8):
"""Generates per block width from RegNet parameters.
Args:
initial_width ([int]): Initial width of the backbone
width_slope ([float]): Slope of the quantized linear function
width_parameter ([int]): Parameter used to quantize the width.
depth ([int]): Depth of the backbone.
divisor (int, optional): The divisor of channels. Defaults to 8.
Returns:
list, int: return a list of widths of each stage and the number of
stages
"""
assert width_slope >= 0
assert initial_width > 0
assert width_parameter > 1
assert initial_width % divisor == 0
widths_cont = np.arange(depth) * width_slope + initial_width
ks = np.round(
np.log(widths_cont / initial_width) / np.log(width_parameter))
widths = initial_width * np.power(width_parameter, ks)
widths = np.round(np.divide(widths, divisor)) * divisor
num_stages = len(np.unique(widths))
widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist()
return widths, num_stages
@staticmethod
def quantize_float(number, divisor):
"""Converts a float to closest non-zero int divisible by divior.
Args:
number (int): Original number to be quantized.
divisor (int): Divisor used to quantize the number.
Returns:
int: quantized number that is divisible by devisor.
"""
return int(round(number / divisor) * divisor)
def adjust_width_group(self, widths, bottleneck_ratio, groups):
"""Adjusts the compatibility of widths and groups.
Args:
widths (list[int]): Width of each stage.
bottleneck_ratio (float): Bottleneck ratio.
groups (int): number of groups in each stage
Returns:
tuple(list): The adjusted widths and groups of each stage.
"""
bottleneck_width = [
int(w * b) for w, b in zip(widths, bottleneck_ratio)
]
groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_width)]
bottleneck_width = [
self.quantize_float(w_bot, g)
for w_bot, g in zip(bottleneck_width, groups)
]
widths = [
int(w_bot / b)
for w_bot, b in zip(bottleneck_width, bottleneck_ratio)
]
return widths, groups
def get_stages_from_blocks(self, widths):
"""Gets widths/stage_blocks of network at each stage.
Args:
widths (list[int]): Width in each stage.
Returns:
tuple(list): width and depth of each stage
"""
width_diff = [
width != width_prev
for width, width_prev in zip(widths + [0], [0] + widths)
]
stage_widths = [
width for width, diff in zip(widths, width_diff[:-1]) if diff
]
stage_blocks = np.diff([
depth for depth, diff in zip(range(len(width_diff)), width_diff)
if diff
]).tolist()
return stage_widths, stage_blocks
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
return tuple(outs)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/resnest.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResLayer, ResNetV1d
class RSoftmax(nn.Module):
"""Radix Softmax module in ``SplitAttentionConv2d``.
Args:
radix (int): Radix of input.
groups (int): Groups of input.
"""
def __init__(self, radix, groups):
super().__init__()
self.radix = radix
self.groups = groups
def forward(self, x):
batch = x.size(0)
if self.radix > 1:
x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
x = F.softmax(x, dim=1)
x = x.reshape(batch, -1)
else:
x = torch.sigmoid(x)
return x
class SplitAttentionConv2d(nn.Module):
"""Split-Attention Conv2d.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int | tuple[int]): Same as nn.Conv2d.
stride (int | tuple[int]): Same as nn.Conv2d.
padding (int | tuple[int]): Same as nn.Conv2d.
dilation (int | tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
radix (int): Radix of SpltAtConv2d. Default: 2
reduction_factor (int): Reduction factor of SplitAttentionConv2d.
Default: 4.
conv_cfg (dict): Config dict for convolution layer. Default: None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer. Default: None.
"""
def __init__(self,
in_channels,
channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
radix=2,
reduction_factor=4,
conv_cfg=None,
norm_cfg=dict(type='BN')):
super().__init__()
inter_channels = max(in_channels * radix // reduction_factor, 32)
self.radix = radix
self.groups = groups
self.channels = channels
self.conv = build_conv_layer(
conv_cfg,
in_channels,
channels * radix,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups * radix,
bias=False)
self.norm0_name, norm0 = build_norm_layer(
norm_cfg, channels * radix, postfix=0)
self.add_module(self.norm0_name, norm0)
self.relu = nn.ReLU(inplace=True)
self.fc1 = build_conv_layer(
None, channels, inter_channels, 1, groups=self.groups)
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, inter_channels, postfix=1)
self.add_module(self.norm1_name, norm1)
self.fc2 = build_conv_layer(
None, inter_channels, channels * radix, 1, groups=self.groups)
self.rsoftmax = RSoftmax(radix, groups)
@property
def norm0(self):
return getattr(self, self.norm0_name)
@property
def norm1(self):
return getattr(self, self.norm1_name)
def forward(self, x):
x = self.conv(x)
x = self.norm0(x)
x = self.relu(x)
batch, rchannel = x.shape[:2]
if self.radix > 1:
splits = x.view(batch, self.radix, -1, *x.shape[2:])
gap = splits.sum(dim=1)
else:
gap = x
gap = F.adaptive_avg_pool2d(gap, 1)
gap = self.fc1(gap)
gap = self.norm1(gap)
gap = self.relu(gap)
atten = self.fc2(gap)
atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
if self.radix > 1:
attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
out = torch.sum(attens * splits, dim=1)
else:
out = atten * x
return out.contiguous()
class Bottleneck(_Bottleneck):
"""Bottleneck block for ResNeSt.
Args:
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
groups (int): Groups of conv2.
width_per_group (int): Width per group of conv2. 64x4d indicates
``groups=64, width_per_group=4`` and 32x8d indicates
``groups=32, width_per_group=8``.
radix (int): Radix of SpltAtConv2d. Default: 2
reduction_factor (int): Reduction factor of SplitAttentionConv2d.
Default: 4.
avg_down_stride (bool): Whether to use average pool for stride in
Bottleneck. Default: True.
stride (int): stride of the block. Default: 1
dilation (int): dilation of convolution. Default: 1
downsample (nn.Module): downsample operation on identity branch.
Default: None
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
conv_cfg (dict): dictionary to construct and config conv layer.
Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
def __init__(self,
in_channels,
out_channels,
groups=1,
width_per_group=4,
base_channels=64,
radix=2,
reduction_factor=4,
avg_down_stride=True,
**kwargs):
super().__init__(in_channels, out_channels, **kwargs)
self.groups = groups
self.width_per_group = width_per_group
# For ResNet bottleneck, middle channels are determined by expansion
# and out_channels, but for ResNeXt bottleneck, it is determined by
# groups and width_per_group and the stage it is located in.
if groups != 1:
assert self.mid_channels % base_channels == 0
self.mid_channels = (
groups * width_per_group * self.mid_channels // base_channels)
self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, self.mid_channels, postfix=1)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.out_channels, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.in_channels,
self.mid_channels,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = SplitAttentionConv2d(
self.mid_channels,
self.mid_channels,
kernel_size=3,
stride=1 if self.avg_down_stride else self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
radix=radix,
reduction_factor=reduction_factor,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg)
delattr(self, self.norm2_name)
if self.avg_down_stride:
self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
self.conv3 = build_conv_layer(
self.conv_cfg,
self.mid_channels,
self.out_channels,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
def forward(self, x):
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
if self.avg_down_stride:
out = self.avd_layer(out)
out = self.conv3(out)
out = self.norm3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
@BACKBONES.register_module()
class ResNeSt(ResNetV1d):
"""ResNeSt backbone.
Please refer to the `paper `__
for details.
Args:
depth (int): Network depth, from {50, 101, 152, 200}.
groups (int): Groups of conv2 in Bottleneck. Default: 32.
width_per_group (int): Width per group of conv2 in Bottleneck.
Default: 4.
radix (int): Radix of SpltAtConv2d. Default: 2
reduction_factor (int): Reduction factor of SplitAttentionConv2d.
Default: 4.
avg_down_stride (bool): Whether to use average pool for stride in
Bottleneck. Default: True.
in_channels (int): Number of input image channels. Default: 3.
stem_channels (int): Output channels of the stem layer. Default: 64.
num_stages (int): Stages of the network. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
"""
arch_settings = {
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3)),
200: (Bottleneck, (3, 24, 36, 3)),
269: (Bottleneck, (3, 30, 48, 8))
}
def __init__(self,
depth,
groups=1,
width_per_group=4,
radix=2,
reduction_factor=4,
avg_down_stride=True,
**kwargs):
self.groups = groups
self.width_per_group = width_per_group
self.radix = radix
self.reduction_factor = reduction_factor
self.avg_down_stride = avg_down_stride
super().__init__(depth=depth, **kwargs)
def make_res_layer(self, **kwargs):
return ResLayer(
groups=self.groups,
width_per_group=self.width_per_group,
base_channels=self.base_channels,
radix=self.radix,
reduction_factor=self.reduction_factor,
avg_down_stride=self.avg_down_stride,
**kwargs)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/resnext.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResLayer, ResNet
class Bottleneck(_Bottleneck):
"""Bottleneck block for ResNeXt.
Args:
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
groups (int): Groups of conv2.
width_per_group (int): Width per group of conv2. 64x4d indicates
``groups=64, width_per_group=4`` and 32x8d indicates
``groups=32, width_per_group=8``.
stride (int): stride of the block. Default: 1
dilation (int): dilation of convolution. Default: 1
downsample (nn.Module): downsample operation on identity branch.
Default: None
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
conv_cfg (dict): dictionary to construct and config conv layer.
Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
def __init__(self,
in_channels,
out_channels,
base_channels=64,
groups=32,
width_per_group=4,
**kwargs):
super().__init__(in_channels, out_channels, **kwargs)
self.groups = groups
self.width_per_group = width_per_group
# For ResNet bottleneck, middle channels are determined by expansion
# and out_channels, but for ResNeXt bottleneck, it is determined by
# groups and width_per_group and the stage it is located in.
if groups != 1:
assert self.mid_channels % base_channels == 0
self.mid_channels = (
groups * width_per_group * self.mid_channels // base_channels)
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, self.mid_channels, postfix=1)
self.norm2_name, norm2 = build_norm_layer(
self.norm_cfg, self.mid_channels, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.out_channels, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.in_channels,
self.mid_channels,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
self.conv_cfg,
self.mid_channels,
self.mid_channels,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
self.conv_cfg,
self.mid_channels,
self.out_channels,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
@BACKBONES.register_module()
class ResNeXt(ResNet):
"""ResNeXt backbone.
Please refer to the `paper `__ for
details.
Args:
depth (int): Network depth, from {50, 101, 152}.
groups (int): Groups of conv2 in Bottleneck. Default: 32.
width_per_group (int): Width per group of conv2 in Bottleneck.
Default: 4.
in_channels (int): Number of input image channels. Default: 3.
stem_channels (int): Output channels of the stem layer. Default: 64.
num_stages (int): Stages of the network. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
Example:
>>> from mmpose.models import ResNeXt
>>> import torch
>>> self = ResNeXt(depth=50, out_indices=(0, 1, 2, 3))
>>> self.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 256, 8, 8)
(1, 512, 4, 4)
(1, 1024, 2, 2)
(1, 2048, 1, 1)
"""
arch_settings = {
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self, depth, groups=32, width_per_group=4, **kwargs):
self.groups = groups
self.width_per_group = width_per_group
super().__init__(depth, **kwargs)
def make_res_layer(self, **kwargs):
return ResLayer(
groups=self.groups,
width_per_group=self.width_per_group,
base_channels=self.base_channels,
**kwargs)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/rsn.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import copy as cp
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import (ConvModule, MaxPool2d, constant_init, kaiming_init,
normal_init)
from ..builder import BACKBONES
from .base_backbone import BaseBackbone
class RSB(nn.Module):
"""Residual Steps block for RSN. Paper ref: Cai et al. "Learning Delicate
Local Representations for Multi-Person Pose Estimation" (ECCV 2020).
Args:
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
num_steps (int): Numbers of steps in RSB
stride (int): stride of the block. Default: 1
downsample (nn.Module): downsample operation on identity branch.
Default: None.
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
expand_times (int): Times by which the in_channels are expanded.
Default:26.
res_top_channels (int): Number of channels of feature output by
ResNet_top. Default:64.
"""
expansion = 1
def __init__(self,
in_channels,
out_channels,
num_steps=4,
stride=1,
downsample=None,
with_cp=False,
norm_cfg=dict(type='BN'),
expand_times=26,
res_top_channels=64):
# Protect mutable default arguments
norm_cfg = cp.deepcopy(norm_cfg)
super().__init__()
assert num_steps > 1
self.in_channels = in_channels
self.branch_channels = self.in_channels * expand_times
self.branch_channels //= res_top_channels
self.out_channels = out_channels
self.stride = stride
self.downsample = downsample
self.with_cp = with_cp
self.norm_cfg = norm_cfg
self.num_steps = num_steps
self.conv_bn_relu1 = ConvModule(
self.in_channels,
self.num_steps * self.branch_channels,
kernel_size=1,
stride=self.stride,
padding=0,
norm_cfg=self.norm_cfg,
inplace=False)
for i in range(self.num_steps):
for j in range(i + 1):
module_name = f'conv_bn_relu2_{i + 1}_{j + 1}'
self.add_module(
module_name,
ConvModule(
self.branch_channels,
self.branch_channels,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg,
inplace=False))
self.conv_bn3 = ConvModule(
self.num_steps * self.branch_channels,
self.out_channels * self.expansion,
kernel_size=1,
stride=1,
padding=0,
act_cfg=None,
norm_cfg=self.norm_cfg,
inplace=False)
self.relu = nn.ReLU(inplace=False)
def forward(self, x):
"""Forward function."""
identity = x
x = self.conv_bn_relu1(x)
spx = torch.split(x, self.branch_channels, 1)
outputs = list()
outs = list()
for i in range(self.num_steps):
outputs_i = list()
outputs.append(outputs_i)
for j in range(i + 1):
if j == 0:
inputs = spx[i]
else:
inputs = outputs[i][j - 1]
if i > j:
inputs = inputs + outputs[i - 1][j]
module_name = f'conv_bn_relu2_{i + 1}_{j + 1}'
module_i_j = getattr(self, module_name)
outputs[i].append(module_i_j(inputs))
outs.append(outputs[i][i])
out = torch.cat(tuple(outs), 1)
out = self.conv_bn3(out)
if self.downsample is not None:
identity = self.downsample(identity)
out = out + identity
out = self.relu(out)
return out
class Downsample_module(nn.Module):
"""Downsample module for RSN.
Args:
block (nn.Module): Downsample block.
num_blocks (list): Number of blocks in each downsample unit.
num_units (int): Numbers of downsample units. Default: 4
has_skip (bool): Have skip connections from prior upsample
module or not. Default:False
num_steps (int): Number of steps in a block. Default:4
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
in_channels (int): Number of channels of the input feature to
downsample module. Default: 64
expand_times (int): Times by which the in_channels are expanded.
Default:26.
"""
def __init__(self,
block,
num_blocks,
num_steps=4,
num_units=4,
has_skip=False,
norm_cfg=dict(type='BN'),
in_channels=64,
expand_times=26):
# Protect mutable default arguments
norm_cfg = cp.deepcopy(norm_cfg)
super().__init__()
self.has_skip = has_skip
self.in_channels = in_channels
assert len(num_blocks) == num_units
self.num_blocks = num_blocks
self.num_units = num_units
self.num_steps = num_steps
self.norm_cfg = norm_cfg
self.layer1 = self._make_layer(
block,
in_channels,
num_blocks[0],
expand_times=expand_times,
res_top_channels=in_channels)
for i in range(1, num_units):
module_name = f'layer{i + 1}'
self.add_module(
module_name,
self._make_layer(
block,
in_channels * pow(2, i),
num_blocks[i],
stride=2,
expand_times=expand_times,
res_top_channels=in_channels))
def _make_layer(self,
block,
out_channels,
blocks,
stride=1,
expand_times=26,
res_top_channels=64):
downsample = None
if stride != 1 or self.in_channels != out_channels * block.expansion:
downsample = ConvModule(
self.in_channels,
out_channels * block.expansion,
kernel_size=1,
stride=stride,
padding=0,
norm_cfg=self.norm_cfg,
act_cfg=None,
inplace=True)
units = list()
units.append(
block(
self.in_channels,
out_channels,
num_steps=self.num_steps,
stride=stride,
downsample=downsample,
norm_cfg=self.norm_cfg,
expand_times=expand_times,
res_top_channels=res_top_channels))
self.in_channels = out_channels * block.expansion
for _ in range(1, blocks):
units.append(
block(
self.in_channels,
out_channels,
num_steps=self.num_steps,
expand_times=expand_times,
res_top_channels=res_top_channels))
return nn.Sequential(*units)
def forward(self, x, skip1, skip2):
out = list()
for i in range(self.num_units):
module_name = f'layer{i + 1}'
module_i = getattr(self, module_name)
x = module_i(x)
if self.has_skip:
x = x + skip1[i] + skip2[i]
out.append(x)
out.reverse()
return tuple(out)
class Upsample_unit(nn.Module):
"""Upsample unit for upsample module.
Args:
ind (int): Indicates whether to interpolate (>0) and whether to
generate feature map for the next hourglass-like module.
num_units (int): Number of units that form a upsample module. Along
with ind and gen_cross_conv, nm_units is used to decide whether
to generate feature map for the next hourglass-like module.
in_channels (int): Channel number of the skip-in feature maps from
the corresponding downsample unit.
unit_channels (int): Channel number in this unit. Default:256.
gen_skip: (bool): Whether or not to generate skips for the posterior
downsample module. Default:False
gen_cross_conv (bool): Whether to generate feature map for the next
hourglass-like module. Default:False
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
out_channels (in): Number of channels of feature output by upsample
module. Must equal to in_channels of downsample module. Default:64
"""
def __init__(self,
ind,
num_units,
in_channels,
unit_channels=256,
gen_skip=False,
gen_cross_conv=False,
norm_cfg=dict(type='BN'),
out_channels=64):
# Protect mutable default arguments
norm_cfg = cp.deepcopy(norm_cfg)
super().__init__()
self.num_units = num_units
self.norm_cfg = norm_cfg
self.in_skip = ConvModule(
in_channels,
unit_channels,
kernel_size=1,
stride=1,
padding=0,
norm_cfg=self.norm_cfg,
act_cfg=None,
inplace=True)
self.relu = nn.ReLU(inplace=True)
self.ind = ind
if self.ind > 0:
self.up_conv = ConvModule(
unit_channels,
unit_channels,
kernel_size=1,
stride=1,
padding=0,
norm_cfg=self.norm_cfg,
act_cfg=None,
inplace=True)
self.gen_skip = gen_skip
if self.gen_skip:
self.out_skip1 = ConvModule(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0,
norm_cfg=self.norm_cfg,
inplace=True)
self.out_skip2 = ConvModule(
unit_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0,
norm_cfg=self.norm_cfg,
inplace=True)
self.gen_cross_conv = gen_cross_conv
if self.ind == num_units - 1 and self.gen_cross_conv:
self.cross_conv = ConvModule(
unit_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
norm_cfg=self.norm_cfg,
inplace=True)
def forward(self, x, up_x):
out = self.in_skip(x)
if self.ind > 0:
up_x = F.interpolate(
up_x,
size=(x.size(2), x.size(3)),
mode='bilinear',
align_corners=True)
up_x = self.up_conv(up_x)
out = out + up_x
out = self.relu(out)
skip1 = None
skip2 = None
if self.gen_skip:
skip1 = self.out_skip1(x)
skip2 = self.out_skip2(out)
cross_conv = None
if self.ind == self.num_units - 1 and self.gen_cross_conv:
cross_conv = self.cross_conv(out)
return out, skip1, skip2, cross_conv
class Upsample_module(nn.Module):
"""Upsample module for RSN.
Args:
unit_channels (int): Channel number in the upsample units.
Default:256.
num_units (int): Numbers of upsample units. Default: 4
gen_skip (bool): Whether to generate skip for posterior downsample
module or not. Default:False
gen_cross_conv (bool): Whether to generate feature map for the next
hourglass-like module. Default:False
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
out_channels (int): Number of channels of feature output by upsample
module. Must equal to in_channels of downsample module. Default:64
"""
def __init__(self,
unit_channels=256,
num_units=4,
gen_skip=False,
gen_cross_conv=False,
norm_cfg=dict(type='BN'),
out_channels=64):
# Protect mutable default arguments
norm_cfg = cp.deepcopy(norm_cfg)
super().__init__()
self.in_channels = list()
for i in range(num_units):
self.in_channels.append(RSB.expansion * out_channels * pow(2, i))
self.in_channels.reverse()
self.num_units = num_units
self.gen_skip = gen_skip
self.gen_cross_conv = gen_cross_conv
self.norm_cfg = norm_cfg
for i in range(num_units):
module_name = f'up{i + 1}'
self.add_module(
module_name,
Upsample_unit(
i,
self.num_units,
self.in_channels[i],
unit_channels,
self.gen_skip,
self.gen_cross_conv,
norm_cfg=self.norm_cfg,
out_channels=64))
def forward(self, x):
out = list()
skip1 = list()
skip2 = list()
cross_conv = None
for i in range(self.num_units):
module_i = getattr(self, f'up{i + 1}')
if i == 0:
outi, skip1_i, skip2_i, _ = module_i(x[i], None)
elif i == self.num_units - 1:
outi, skip1_i, skip2_i, cross_conv = module_i(x[i], out[i - 1])
else:
outi, skip1_i, skip2_i, _ = module_i(x[i], out[i - 1])
out.append(outi)
skip1.append(skip1_i)
skip2.append(skip2_i)
skip1.reverse()
skip2.reverse()
return out, skip1, skip2, cross_conv
class Single_stage_RSN(nn.Module):
"""Single_stage Residual Steps Network.
Args:
unit_channels (int): Channel number in the upsample units. Default:256.
num_units (int): Numbers of downsample/upsample units. Default: 4
gen_skip (bool): Whether to generate skip for posterior downsample
module or not. Default:False
gen_cross_conv (bool): Whether to generate feature map for the next
hourglass-like module. Default:False
has_skip (bool): Have skip connections from prior upsample
module or not. Default:False
num_steps (int): Number of steps in RSB. Default: 4
num_blocks (list): Number of blocks in each downsample unit.
Default: [2, 2, 2, 2] Note: Make sure num_units==len(num_blocks)
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
in_channels (int): Number of channels of the feature from ResNet_Top.
Default: 64.
expand_times (int): Times by which the in_channels are expanded in RSB.
Default:26.
"""
def __init__(self,
has_skip=False,
gen_skip=False,
gen_cross_conv=False,
unit_channels=256,
num_units=4,
num_steps=4,
num_blocks=[2, 2, 2, 2],
norm_cfg=dict(type='BN'),
in_channels=64,
expand_times=26):
# Protect mutable default arguments
norm_cfg = cp.deepcopy(norm_cfg)
num_blocks = cp.deepcopy(num_blocks)
super().__init__()
assert len(num_blocks) == num_units
self.has_skip = has_skip
self.gen_skip = gen_skip
self.gen_cross_conv = gen_cross_conv
self.num_units = num_units
self.num_steps = num_steps
self.unit_channels = unit_channels
self.num_blocks = num_blocks
self.norm_cfg = norm_cfg
self.downsample = Downsample_module(RSB, num_blocks, num_steps,
num_units, has_skip, norm_cfg,
in_channels, expand_times)
self.upsample = Upsample_module(unit_channels, num_units, gen_skip,
gen_cross_conv, norm_cfg, in_channels)
def forward(self, x, skip1, skip2):
mid = self.downsample(x, skip1, skip2)
out, skip1, skip2, cross_conv = self.upsample(mid)
return out, skip1, skip2, cross_conv
class ResNet_top(nn.Module):
"""ResNet top for RSN.
Args:
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
channels (int): Number of channels of the feature output by ResNet_top.
"""
def __init__(self, norm_cfg=dict(type='BN'), channels=64):
# Protect mutable default arguments
norm_cfg = cp.deepcopy(norm_cfg)
super().__init__()
self.top = nn.Sequential(
ConvModule(
3,
channels,
kernel_size=7,
stride=2,
padding=3,
norm_cfg=norm_cfg,
inplace=True), MaxPool2d(kernel_size=3, stride=2, padding=1))
def forward(self, img):
return self.top(img)
@BACKBONES.register_module()
class RSN(BaseBackbone):
"""Residual Steps Network backbone. Paper ref: Cai et al. "Learning
Delicate Local Representations for Multi-Person Pose Estimation" (ECCV
2020).
Args:
unit_channels (int): Number of Channels in an upsample unit.
Default: 256
num_stages (int): Number of stages in a multi-stage RSN. Default: 4
num_units (int): NUmber of downsample/upsample units in a single-stage
RSN. Default: 4 Note: Make sure num_units == len(self.num_blocks)
num_blocks (list): Number of RSBs (Residual Steps Block) in each
downsample unit. Default: [2, 2, 2, 2]
num_steps (int): Number of steps in a RSB. Default:4
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
res_top_channels (int): Number of channels of feature from ResNet_top.
Default: 64.
expand_times (int): Times by which the in_channels are expanded in RSB.
Default:26.
Example:
>>> from mmpose.models import RSN
>>> import torch
>>> self = RSN(num_stages=2,num_units=2,num_blocks=[2,2])
>>> self.eval()
>>> inputs = torch.rand(1, 3, 511, 511)
>>> level_outputs = self.forward(inputs)
>>> for level_output in level_outputs:
... for feature in level_output:
... print(tuple(feature.shape))
...
(1, 256, 64, 64)
(1, 256, 128, 128)
(1, 256, 64, 64)
(1, 256, 128, 128)
"""
def __init__(self,
unit_channels=256,
num_stages=4,
num_units=4,
num_blocks=[2, 2, 2, 2],
num_steps=4,
norm_cfg=dict(type='BN'),
res_top_channels=64,
expand_times=26):
# Protect mutable default arguments
norm_cfg = cp.deepcopy(norm_cfg)
num_blocks = cp.deepcopy(num_blocks)
super().__init__()
self.unit_channels = unit_channels
self.num_stages = num_stages
self.num_units = num_units
self.num_blocks = num_blocks
self.num_steps = num_steps
self.norm_cfg = norm_cfg
assert self.num_stages > 0
assert self.num_steps > 1
assert self.num_units > 1
assert self.num_units == len(self.num_blocks)
self.top = ResNet_top(norm_cfg=norm_cfg)
self.multi_stage_rsn = nn.ModuleList([])
for i in range(self.num_stages):
if i == 0:
has_skip = False
else:
has_skip = True
if i != self.num_stages - 1:
gen_skip = True
gen_cross_conv = True
else:
gen_skip = False
gen_cross_conv = False
self.multi_stage_rsn.append(
Single_stage_RSN(has_skip, gen_skip, gen_cross_conv,
unit_channels, num_units, num_steps,
num_blocks, norm_cfg, res_top_channels,
expand_times))
def forward(self, x):
"""Model forward function."""
out_feats = []
skip1 = None
skip2 = None
x = self.top(x)
for i in range(self.num_stages):
out, skip1, skip2, x = self.multi_stage_rsn[i](x, skip1, skip2)
out_feats.append(out)
return out_feats
def init_weights(self, pretrained=None):
"""Initialize model weights."""
for m in self.multi_stage_rsn.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
elif isinstance(m, nn.Linear):
normal_init(m, std=0.01)
for m in self.top.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/scnet.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES
from .resnet import Bottleneck, ResNet
class SCConv(nn.Module):
"""SCConv (Self-calibrated Convolution)
Args:
in_channels (int): The input channels of the SCConv.
out_channels (int): The output channel of the SCConv.
stride (int): stride of SCConv.
pooling_r (int): size of pooling for scconv.
conv_cfg (dict): dictionary to construct and config conv layer.
Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
"""
def __init__(self,
in_channels,
out_channels,
stride,
pooling_r,
conv_cfg=None,
norm_cfg=dict(type='BN', momentum=0.1)):
# Protect mutable default arguments
norm_cfg = copy.deepcopy(norm_cfg)
super().__init__()
assert in_channels == out_channels
self.k2 = nn.Sequential(
nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r),
build_conv_layer(
conv_cfg,
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(norm_cfg, in_channels)[1],
)
self.k3 = nn.Sequential(
build_conv_layer(
conv_cfg,
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(norm_cfg, in_channels)[1],
)
self.k4 = nn.Sequential(
build_conv_layer(
conv_cfg,
in_channels,
in_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False),
build_norm_layer(norm_cfg, out_channels)[1],
nn.ReLU(inplace=True),
)
def forward(self, x):
"""Forward function."""
identity = x
out = torch.sigmoid(
torch.add(identity, F.interpolate(self.k2(x),
identity.size()[2:])))
out = torch.mul(self.k3(x), out)
out = self.k4(out)
return out
class SCBottleneck(Bottleneck):
"""SC(Self-calibrated) Bottleneck.
Args:
in_channels (int): The input channels of the SCBottleneck block.
out_channels (int): The output channel of the SCBottleneck block.
"""
pooling_r = 4
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__(in_channels, out_channels, **kwargs)
self.mid_channels = out_channels // self.expansion // 2
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, self.mid_channels, postfix=1)
self.norm2_name, norm2 = build_norm_layer(
self.norm_cfg, self.mid_channels, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, out_channels, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels,
self.mid_channels,
kernel_size=1,
stride=1,
bias=False)
self.add_module(self.norm1_name, norm1)
self.k1 = nn.Sequential(
build_conv_layer(
self.conv_cfg,
self.mid_channels,
self.mid_channels,
kernel_size=3,
stride=self.stride,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, self.mid_channels)[1],
nn.ReLU(inplace=True))
self.conv2 = build_conv_layer(
self.conv_cfg,
in_channels,
self.mid_channels,
kernel_size=1,
stride=1,
bias=False)
self.add_module(self.norm2_name, norm2)
self.scconv = SCConv(self.mid_channels, self.mid_channels, self.stride,
self.pooling_r, self.conv_cfg, self.norm_cfg)
self.conv3 = build_conv_layer(
self.conv_cfg,
self.mid_channels * 2,
out_channels,
kernel_size=1,
stride=1,
bias=False)
self.add_module(self.norm3_name, norm3)
def forward(self, x):
"""Forward function."""
def _inner_forward(x):
identity = x
out_a = self.conv1(x)
out_a = self.norm1(out_a)
out_a = self.relu(out_a)
out_a = self.k1(out_a)
out_b = self.conv2(x)
out_b = self.norm2(out_b)
out_b = self.relu(out_b)
out_b = self.scconv(out_b)
out = self.conv3(torch.cat([out_a, out_b], dim=1))
out = self.norm3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
@BACKBONES.register_module()
class SCNet(ResNet):
"""SCNet backbone.
Improving Convolutional Networks with Self-Calibrated Convolutions,
Jiang-Jiang Liu, Qibin Hou, Ming-Ming Cheng, Changhu Wang, Jiashi Feng,
IEEE CVPR, 2020.
http://mftp.mmcheng.net/Papers/20cvprSCNet.pdf
Args:
depth (int): Depth of scnet, from {50, 101}.
in_channels (int): Number of input image channels. Normally 3.
base_channels (int): Number of base channels of hidden layer.
num_stages (int): SCNet stages, normally 4.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
norm_cfg (dict): Dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity.
Example:
>>> from mmpose.models import SCNet
>>> import torch
>>> self = SCNet(depth=50, out_indices=(0, 1, 2, 3))
>>> self.eval()
>>> inputs = torch.rand(1, 3, 224, 224)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 256, 56, 56)
(1, 512, 28, 28)
(1, 1024, 14, 14)
(1, 2048, 7, 7)
"""
arch_settings = {
50: (SCBottleneck, [3, 4, 6, 3]),
101: (SCBottleneck, [3, 4, 23, 3])
}
def __init__(self, depth, **kwargs):
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for SCNet')
super().__init__(depth, **kwargs)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/seresnet.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch.utils.checkpoint as cp
from ..builder import BACKBONES
from .resnet import Bottleneck, ResLayer, ResNet
from .utils.se_layer import SELayer
class SEBottleneck(Bottleneck):
"""SEBottleneck block for SEResNet.
Args:
in_channels (int): The input channels of the SEBottleneck block.
out_channels (int): The output channel of the SEBottleneck block.
se_ratio (int): Squeeze ratio in SELayer. Default: 16
"""
def __init__(self, in_channels, out_channels, se_ratio=16, **kwargs):
super().__init__(in_channels, out_channels, **kwargs)
self.se_layer = SELayer(out_channels, ratio=se_ratio)
def forward(self, x):
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.norm3(out)
out = self.se_layer(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
@BACKBONES.register_module()
class SEResNet(ResNet):
"""SEResNet backbone.
Please refer to the `paper `__ for
details.
Args:
depth (int): Network depth, from {50, 101, 152}.
se_ratio (int): Squeeze ratio in SELayer. Default: 16.
in_channels (int): Number of input image channels. Default: 3.
stem_channels (int): Output channels of the stem layer. Default: 64.
num_stages (int): Stages of the network. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
Example:
>>> from mmpose.models import SEResNet
>>> import torch
>>> self = SEResNet(depth=50, out_indices=(0, 1, 2, 3))
>>> self.eval()
>>> inputs = torch.rand(1, 3, 224, 224)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 256, 56, 56)
(1, 512, 28, 28)
(1, 1024, 14, 14)
(1, 2048, 7, 7)
"""
arch_settings = {
50: (SEBottleneck, (3, 4, 6, 3)),
101: (SEBottleneck, (3, 4, 23, 3)),
152: (SEBottleneck, (3, 8, 36, 3))
}
def __init__(self, depth, se_ratio=16, **kwargs):
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for SEResNet')
self.se_ratio = se_ratio
super().__init__(depth, **kwargs)
def make_res_layer(self, **kwargs):
return ResLayer(se_ratio=self.se_ratio, **kwargs)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/seresnext.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES
from .resnet import ResLayer
from .seresnet import SEBottleneck as _SEBottleneck
from .seresnet import SEResNet
class SEBottleneck(_SEBottleneck):
"""SEBottleneck block for SEResNeXt.
Args:
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
base_channels (int): Middle channels of the first stage. Default: 64.
groups (int): Groups of conv2.
width_per_group (int): Width per group of conv2. 64x4d indicates
``groups=64, width_per_group=4`` and 32x8d indicates
``groups=32, width_per_group=8``.
stride (int): stride of the block. Default: 1
dilation (int): dilation of convolution. Default: 1
downsample (nn.Module): downsample operation on identity branch.
Default: None
se_ratio (int): Squeeze ratio in SELayer. Default: 16
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
conv_cfg (dict): dictionary to construct and config conv layer.
Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
def __init__(self,
in_channels,
out_channels,
base_channels=64,
groups=32,
width_per_group=4,
se_ratio=16,
**kwargs):
super().__init__(in_channels, out_channels, se_ratio, **kwargs)
self.groups = groups
self.width_per_group = width_per_group
# We follow the same rational of ResNext to compute mid_channels.
# For SEResNet bottleneck, middle channels are determined by expansion
# and out_channels, but for SEResNeXt bottleneck, it is determined by
# groups and width_per_group and the stage it is located in.
if groups != 1:
assert self.mid_channels % base_channels == 0
self.mid_channels = (
groups * width_per_group * self.mid_channels // base_channels)
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, self.mid_channels, postfix=1)
self.norm2_name, norm2 = build_norm_layer(
self.norm_cfg, self.mid_channels, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.out_channels, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.in_channels,
self.mid_channels,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
self.conv_cfg,
self.mid_channels,
self.mid_channels,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
self.conv_cfg,
self.mid_channels,
self.out_channels,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
@BACKBONES.register_module()
class SEResNeXt(SEResNet):
"""SEResNeXt backbone.
Please refer to the `paper `__ for
details.
Args:
depth (int): Network depth, from {50, 101, 152}.
groups (int): Groups of conv2 in Bottleneck. Default: 32.
width_per_group (int): Width per group of conv2 in Bottleneck.
Default: 4.
se_ratio (int): Squeeze ratio in SELayer. Default: 16.
in_channels (int): Number of input image channels. Default: 3.
stem_channels (int): Output channels of the stem layer. Default: 64.
num_stages (int): Stages of the network. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
Example:
>>> from mmpose.models import SEResNeXt
>>> import torch
>>> self = SEResNet(depth=50, out_indices=(0, 1, 2, 3))
>>> self.eval()
>>> inputs = torch.rand(1, 3, 224, 224)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 256, 56, 56)
(1, 512, 28, 28)
(1, 1024, 14, 14)
(1, 2048, 7, 7)
"""
arch_settings = {
50: (SEBottleneck, (3, 4, 6, 3)),
101: (SEBottleneck, (3, 4, 23, 3)),
152: (SEBottleneck, (3, 8, 36, 3))
}
def __init__(self, depth, groups=32, width_per_group=4, **kwargs):
self.groups = groups
self.width_per_group = width_per_group
super().__init__(depth, **kwargs)
def make_res_layer(self, **kwargs):
return ResLayer(
groups=self.groups,
width_per_group=self.width_per_group,
base_channels=self.base_channels,
**kwargs)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/shufflenet_v1.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import logging
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, build_activation_layer, constant_init,
normal_init)
from torch.nn.modules.batchnorm import _BatchNorm
from ..builder import BACKBONES
from .base_backbone import BaseBackbone
from .utils import channel_shuffle, load_checkpoint, make_divisible
class ShuffleUnit(nn.Module):
"""ShuffleUnit block.
ShuffleNet unit with pointwise group convolution (GConv) and channel
shuffle.
Args:
in_channels (int): The input channels of the ShuffleUnit.
out_channels (int): The output channels of the ShuffleUnit.
groups (int, optional): The number of groups to be used in grouped 1x1
convolutions in each ShuffleUnit. Default: 3
first_block (bool, optional): Whether it is the first ShuffleUnit of a
sequential ShuffleUnits. Default: True, which means not using the
grouped 1x1 convolution.
combine (str, optional): The ways to combine the input and output
branches. Default: 'add'.
conv_cfg (dict): Config dict for convolution layer. Default: None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Default: False.
Returns:
Tensor: The output tensor.
"""
def __init__(self,
in_channels,
out_channels,
groups=3,
first_block=True,
combine='add',
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
with_cp=False):
# Protect mutable default arguments
norm_cfg = copy.deepcopy(norm_cfg)
act_cfg = copy.deepcopy(act_cfg)
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.first_block = first_block
self.combine = combine
self.groups = groups
self.bottleneck_channels = self.out_channels // 4
self.with_cp = with_cp
if self.combine == 'add':
self.depthwise_stride = 1
self._combine_func = self._add
assert in_channels == out_channels, (
'in_channels must be equal to out_channels when combine '
'is add')
elif self.combine == 'concat':
self.depthwise_stride = 2
self._combine_func = self._concat
self.out_channels -= self.in_channels
self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
else:
raise ValueError(f'Cannot combine tensors with {self.combine}. '
'Only "add" and "concat" are supported')
self.first_1x1_groups = 1 if first_block else self.groups
self.g_conv_1x1_compress = ConvModule(
in_channels=self.in_channels,
out_channels=self.bottleneck_channels,
kernel_size=1,
groups=self.first_1x1_groups,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.depthwise_conv3x3_bn = ConvModule(
in_channels=self.bottleneck_channels,
out_channels=self.bottleneck_channels,
kernel_size=3,
stride=self.depthwise_stride,
padding=1,
groups=self.bottleneck_channels,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
self.g_conv_1x1_expand = ConvModule(
in_channels=self.bottleneck_channels,
out_channels=self.out_channels,
kernel_size=1,
groups=self.groups,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
self.act = build_activation_layer(act_cfg)
@staticmethod
def _add(x, out):
# residual connection
return x + out
@staticmethod
def _concat(x, out):
# concatenate along channel axis
return torch.cat((x, out), 1)
def forward(self, x):
def _inner_forward(x):
residual = x
out = self.g_conv_1x1_compress(x)
out = self.depthwise_conv3x3_bn(out)
if self.groups > 1:
out = channel_shuffle(out, self.groups)
out = self.g_conv_1x1_expand(out)
if self.combine == 'concat':
residual = self.avgpool(residual)
out = self.act(out)
out = self._combine_func(residual, out)
else:
out = self._combine_func(residual, out)
out = self.act(out)
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
@BACKBONES.register_module()
class ShuffleNetV1(BaseBackbone):
"""ShuffleNetV1 backbone.
Args:
groups (int, optional): The number of groups to be used in grouped 1x1
convolutions in each ShuffleUnit. Default: 3.
widen_factor (float, optional): Width multiplier - adjusts the number
of channels in each layer by this amount. Default: 1.0.
out_indices (Sequence[int]): Output from which stages.
Default: (2, )
frozen_stages (int): Stages to be frozen (all param fixed).
Default: -1, which means not freezing any parameters.
conv_cfg (dict): Config dict for convolution layer. Default: None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def __init__(self,
groups=3,
widen_factor=1.0,
out_indices=(2, ),
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
norm_eval=False,
with_cp=False):
# Protect mutable default arguments
norm_cfg = copy.deepcopy(norm_cfg)
act_cfg = copy.deepcopy(act_cfg)
super().__init__()
self.stage_blocks = [4, 8, 4]
self.groups = groups
for index in out_indices:
if index not in range(0, 3):
raise ValueError('the item in out_indices must in '
f'range(0, 3). But received {index}')
if frozen_stages not in range(-1, 3):
raise ValueError('frozen_stages must be in range(-1, 3). '
f'But received {frozen_stages}')
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
if groups == 1:
channels = (144, 288, 576)
elif groups == 2:
channels = (200, 400, 800)
elif groups == 3:
channels = (240, 480, 960)
elif groups == 4:
channels = (272, 544, 1088)
elif groups == 8:
channels = (384, 768, 1536)
else:
raise ValueError(f'{groups} groups is not supported for 1x1 '
'Grouped Convolutions')
channels = [make_divisible(ch * widen_factor, 8) for ch in channels]
self.in_channels = int(24 * widen_factor)
self.conv1 = ConvModule(
in_channels=3,
out_channels=self.in_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layers = nn.ModuleList()
for i, num_blocks in enumerate(self.stage_blocks):
first_block = (i == 0)
layer = self.make_layer(channels[i], num_blocks, first_block)
self.layers.append(layer)
def _freeze_stages(self):
if self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
for i in range(self.frozen_stages):
layer = self.layers[i]
layer.eval()
for param in layer.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'conv1' in name:
normal_init(m, mean=0, std=0.01)
else:
normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, val=1, bias=0.0001)
if isinstance(m, _BatchNorm):
if m.running_mean is not None:
nn.init.constant_(m.running_mean, 0)
else:
raise TypeError('pretrained must be a str or None. But received '
f'{type(pretrained)}')
def make_layer(self, out_channels, num_blocks, first_block=False):
"""Stack ShuffleUnit blocks to make a layer.
Args:
out_channels (int): out_channels of the block.
num_blocks (int): Number of blocks.
first_block (bool, optional): Whether is the first ShuffleUnit of a
sequential ShuffleUnits. Default: False, which means using
the grouped 1x1 convolution.
"""
layers = []
for i in range(num_blocks):
first_block = first_block if i == 0 else False
combine_mode = 'concat' if i == 0 else 'add'
layers.append(
ShuffleUnit(
self.in_channels,
out_channels,
groups=self.groups,
first_block=first_block,
combine=combine_mode,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
with_cp=self.with_cp))
self.in_channels = out_channels
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
return tuple(outs)
def train(self, mode=True):
super().train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/shufflenet_v2.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import logging
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule, constant_init, normal_init
from torch.nn.modules.batchnorm import _BatchNorm
from ..builder import BACKBONES
from .base_backbone import BaseBackbone
from .utils import channel_shuffle, load_checkpoint
class InvertedResidual(nn.Module):
"""InvertedResidual block for ShuffleNetV2 backbone.
Args:
in_channels (int): The input channels of the block.
out_channels (int): The output channels of the block.
stride (int): Stride of the 3x3 convolution layer. Default: 1
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def __init__(self,
in_channels,
out_channels,
stride=1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
with_cp=False):
# Protect mutable default arguments
norm_cfg = copy.deepcopy(norm_cfg)
act_cfg = copy.deepcopy(act_cfg)
super().__init__()
self.stride = stride
self.with_cp = with_cp
branch_features = out_channels // 2
if self.stride == 1:
assert in_channels == branch_features * 2, (
f'in_channels ({in_channels}) should equal to '
f'branch_features * 2 ({branch_features * 2}) '
'when stride is 1')
if in_channels != branch_features * 2:
assert self.stride != 1, (
f'stride ({self.stride}) should not equal 1 when '
f'in_channels != branch_features * 2')
if self.stride > 1:
self.branch1 = nn.Sequential(
ConvModule(
in_channels,
in_channels,
kernel_size=3,
stride=self.stride,
padding=1,
groups=in_channels,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None),
ConvModule(
in_channels,
branch_features,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
)
self.branch2 = nn.Sequential(
ConvModule(
in_channels if (self.stride > 1) else branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
branch_features,
branch_features,
kernel_size=3,
stride=self.stride,
padding=1,
groups=branch_features,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None),
ConvModule(
branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
def forward(self, x):
def _inner_forward(x):
if self.stride > 1:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
else:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
out = channel_shuffle(out, 2)
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
@BACKBONES.register_module()
class ShuffleNetV2(BaseBackbone):
"""ShuffleNetV2 backbone.
Args:
widen_factor (float): Width multiplier - adjusts the number of
channels in each layer by this amount. Default: 1.0.
out_indices (Sequence[int]): Output from which stages.
Default: (0, 1, 2, 3).
frozen_stages (int): Stages to be frozen (all param fixed).
Default: -1, which means not freezing any parameters.
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def __init__(self,
widen_factor=1.0,
out_indices=(3, ),
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
norm_eval=False,
with_cp=False):
# Protect mutable default arguments
norm_cfg = copy.deepcopy(norm_cfg)
act_cfg = copy.deepcopy(act_cfg)
super().__init__()
self.stage_blocks = [4, 8, 4]
for index in out_indices:
if index not in range(0, 4):
raise ValueError('the item in out_indices must in '
f'range(0, 4). But received {index}')
if frozen_stages not in range(-1, 4):
raise ValueError('frozen_stages must be in range(-1, 4). '
f'But received {frozen_stages}')
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
if widen_factor == 0.5:
channels = [48, 96, 192, 1024]
elif widen_factor == 1.0:
channels = [116, 232, 464, 1024]
elif widen_factor == 1.5:
channels = [176, 352, 704, 1024]
elif widen_factor == 2.0:
channels = [244, 488, 976, 2048]
else:
raise ValueError('widen_factor must be in [0.5, 1.0, 1.5, 2.0]. '
f'But received {widen_factor}')
self.in_channels = 24
self.conv1 = ConvModule(
in_channels=3,
out_channels=self.in_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layers = nn.ModuleList()
for i, num_blocks in enumerate(self.stage_blocks):
layer = self._make_layer(channels[i], num_blocks)
self.layers.append(layer)
output_channels = channels[-1]
self.layers.append(
ConvModule(
in_channels=self.in_channels,
out_channels=output_channels,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
def _make_layer(self, out_channels, num_blocks):
"""Stack blocks to make a layer.
Args:
out_channels (int): out_channels of the block.
num_blocks (int): number of blocks.
"""
layers = []
for i in range(num_blocks):
stride = 2 if i == 0 else 1
layers.append(
InvertedResidual(
in_channels=self.in_channels,
out_channels=out_channels,
stride=stride,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
with_cp=self.with_cp))
self.in_channels = out_channels
return nn.Sequential(*layers)
def _freeze_stages(self):
if self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
for i in range(self.frozen_stages):
m = self.layers[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'conv1' in name:
normal_init(m, mean=0, std=0.01)
else:
normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m.weight, val=1, bias=0.0001)
if isinstance(m, _BatchNorm):
if m.running_mean is not None:
nn.init.constant_(m.running_mean, 0)
else:
raise TypeError('pretrained must be a str or None. But received '
f'{type(pretrained)}')
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
return tuple(outs)
def train(self, mode=True):
super().train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/tcn.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import torch.nn as nn
from mmcv.cnn import ConvModule, build_conv_layer, constant_init, kaiming_init
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmpose.core import WeightNormClipHook
from ..builder import BACKBONES
from .base_backbone import BaseBackbone
class BasicTemporalBlock(nn.Module):
"""Basic block for VideoPose3D.
Args:
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
mid_channels (int): The output channels of conv1. Default: 1024.
kernel_size (int): Size of the convolving kernel. Default: 3.
dilation (int): Spacing between kernel elements. Default: 3.
dropout (float): Dropout rate. Default: 0.25.
causal (bool): Use causal convolutions instead of symmetric
convolutions (for real-time applications). Default: False.
residual (bool): Use residual connection. Default: True.
use_stride_conv (bool): Use optimized TCN that designed
specifically for single-frame batching, i.e. where batches have
input length = receptive field, and output length = 1. This
implementation replaces dilated convolutions with strided
convolutions to avoid generating unused intermediate results.
Default: False.
conv_cfg (dict): dictionary to construct and config conv layer.
Default: dict(type='Conv1d').
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN1d').
"""
def __init__(self,
in_channels,
out_channels,
mid_channels=1024,
kernel_size=3,
dilation=3,
dropout=0.25,
causal=False,
residual=True,
use_stride_conv=False,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d')):
# Protect mutable default arguments
conv_cfg = copy.deepcopy(conv_cfg)
norm_cfg = copy.deepcopy(norm_cfg)
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.mid_channels = mid_channels
self.kernel_size = kernel_size
self.dilation = dilation
self.dropout = dropout
self.causal = causal
self.residual = residual
self.use_stride_conv = use_stride_conv
self.pad = (kernel_size - 1) * dilation // 2
if use_stride_conv:
self.stride = kernel_size
self.causal_shift = kernel_size // 2 if causal else 0
self.dilation = 1
else:
self.stride = 1
self.causal_shift = kernel_size // 2 * dilation if causal else 0
self.conv1 = nn.Sequential(
ConvModule(
in_channels,
mid_channels,
kernel_size=kernel_size,
stride=self.stride,
dilation=self.dilation,
bias='auto',
conv_cfg=conv_cfg,
norm_cfg=norm_cfg))
self.conv2 = nn.Sequential(
ConvModule(
mid_channels,
out_channels,
kernel_size=1,
bias='auto',
conv_cfg=conv_cfg,
norm_cfg=norm_cfg))
if residual and in_channels != out_channels:
self.short_cut = build_conv_layer(conv_cfg, in_channels,
out_channels, 1)
else:
self.short_cut = None
self.dropout = nn.Dropout(dropout) if dropout > 0 else None
def forward(self, x):
"""Forward function."""
if self.use_stride_conv:
assert self.causal_shift + self.kernel_size // 2 < x.shape[2]
else:
assert 0 <= self.pad + self.causal_shift < x.shape[2] - \
self.pad + self.causal_shift <= x.shape[2]
out = self.conv1(x)
if self.dropout is not None:
out = self.dropout(out)
out = self.conv2(out)
if self.dropout is not None:
out = self.dropout(out)
if self.residual:
if self.use_stride_conv:
res = x[:, :, self.causal_shift +
self.kernel_size // 2::self.kernel_size]
else:
res = x[:, :,
(self.pad + self.causal_shift):(x.shape[2] - self.pad +
self.causal_shift)]
if self.short_cut is not None:
res = self.short_cut(res)
out = out + res
return out
@BACKBONES.register_module()
class TCN(BaseBackbone):
"""TCN backbone.
Temporal Convolutional Networks.
More details can be found in the
`paper `__ .
Args:
in_channels (int): Number of input channels, which equals to
num_keypoints * num_features.
stem_channels (int): Number of feature channels. Default: 1024.
num_blocks (int): NUmber of basic temporal convolutional blocks.
Default: 2.
kernel_sizes (Sequence[int]): Sizes of the convolving kernel of
each basic block. Default: ``(3, 3, 3)``.
dropout (float): Dropout rate. Default: 0.25.
causal (bool): Use causal convolutions instead of symmetric
convolutions (for real-time applications).
Default: False.
residual (bool): Use residual connection. Default: True.
use_stride_conv (bool): Use TCN backbone optimized for
single-frame batching, i.e. where batches have input length =
receptive field, and output length = 1. This implementation
replaces dilated convolutions with strided convolutions to avoid
generating unused intermediate results. The weights are
interchangeable with the reference implementation. Default: False
conv_cfg (dict): dictionary to construct and config conv layer.
Default: dict(type='Conv1d').
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN1d').
max_norm (float|None): if not None, the weight of convolution layers
will be clipped to have a maximum norm of max_norm.
Example:
>>> from mmpose.models import TCN
>>> import torch
>>> self = TCN(in_channels=34)
>>> self.eval()
>>> inputs = torch.rand(1, 34, 243)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 1024, 235)
(1, 1024, 217)
"""
def __init__(self,
in_channels,
stem_channels=1024,
num_blocks=2,
kernel_sizes=(3, 3, 3),
dropout=0.25,
causal=False,
residual=True,
use_stride_conv=False,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
max_norm=None):
# Protect mutable default arguments
conv_cfg = copy.deepcopy(conv_cfg)
norm_cfg = copy.deepcopy(norm_cfg)
super().__init__()
self.in_channels = in_channels
self.stem_channels = stem_channels
self.num_blocks = num_blocks
self.kernel_sizes = kernel_sizes
self.dropout = dropout
self.causal = causal
self.residual = residual
self.use_stride_conv = use_stride_conv
self.max_norm = max_norm
assert num_blocks == len(kernel_sizes) - 1
for ks in kernel_sizes:
assert ks % 2 == 1, 'Only odd filter widths are supported.'
self.expand_conv = ConvModule(
in_channels,
stem_channels,
kernel_size=kernel_sizes[0],
stride=kernel_sizes[0] if use_stride_conv else 1,
bias='auto',
conv_cfg=conv_cfg,
norm_cfg=norm_cfg)
dilation = kernel_sizes[0]
self.tcn_blocks = nn.ModuleList()
for i in range(1, num_blocks + 1):
self.tcn_blocks.append(
BasicTemporalBlock(
in_channels=stem_channels,
out_channels=stem_channels,
mid_channels=stem_channels,
kernel_size=kernel_sizes[i],
dilation=dilation,
dropout=dropout,
causal=causal,
residual=residual,
use_stride_conv=use_stride_conv,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg))
dilation *= kernel_sizes[i]
if self.max_norm is not None:
# Apply weight norm clip to conv layers
weight_clip = WeightNormClipHook(self.max_norm)
for module in self.modules():
if isinstance(module, nn.modules.conv._ConvNd):
weight_clip.register(module)
self.dropout = nn.Dropout(dropout) if dropout > 0 else None
def forward(self, x):
"""Forward function."""
x = self.expand_conv(x)
if self.dropout is not None:
x = self.dropout(x)
outs = []
for i in range(self.num_blocks):
x = self.tcn_blocks[i](x)
outs.append(x)
return tuple(outs)
def init_weights(self, pretrained=None):
"""Initialize the weights."""
super().init_weights(pretrained)
if pretrained is None:
for m in self.modules():
if isinstance(m, nn.modules.conv._ConvNd):
kaiming_init(m, mode='fan_in', nonlinearity='relu')
elif isinstance(m, _BatchNorm):
constant_init(m, 1)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/test_torch.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# 1 input image channel, 6 output channels, 5x5 square convolution
# kernel
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5*5 from image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
# If the size is a square, you can specify with a single number
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
# print(net)
net.train()
input = torch.randn(1, 1, 32, 32)
# out = net(input)
# print(out)
output = net(input)
target = torch.randn(10) # a dummy target, for example
target = target.view(1, -1) # make it the same shape as output
criterion = nn.MSELoss()
# loss = criterion(output.cuda(), target.cuda())
import torch.optim as optim
# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01)
# in your training loop:
optimizer.zero_grad() # zero the gradient buffers
output = net(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# print(loss)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from .channel_shuffle import channel_shuffle
from .inverted_residual import InvertedResidual
from .make_divisible import make_divisible
from .se_layer import SELayer
from .utils import load_checkpoint
__all__ = [
'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer',
'load_checkpoint'
]
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/channel_shuffle.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
def channel_shuffle(x, groups):
"""Channel Shuffle operation.
This function enables cross-group information flow for multiple groups
convolution layers.
Args:
x (Tensor): The input tensor.
groups (int): The number of groups to divide the input tensor
in the channel dimension.
Returns:
Tensor: The output tensor after channel shuffle operation.
"""
batch_size, num_channels, height, width = x.size()
assert (num_channels % groups == 0), ('num_channels should be '
'divisible by groups')
channels_per_group = num_channels // groups
x = x.view(batch_size, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(batch_size, -1, height, width)
return x
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/inverted_residual.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule
from .se_layer import SELayer
class InvertedResidual(nn.Module):
"""Inverted Residual Block.
Args:
in_channels (int): The input channels of this Module.
out_channels (int): The output channels of this Module.
mid_channels (int): The input channels of the depthwise convolution.
kernel_size (int): The kernel size of the depthwise convolution.
Default: 3.
groups (None or int): The group number of the depthwise convolution.
Default: None, which means group number = mid_channels.
stride (int): The stride of the depthwise convolution. Default: 1.
se_cfg (dict): Config dict for se layer. Default: None, which means no
se layer.
with_expand_conv (bool): Use expand conv or not. If set False,
mid_channels must be the same with in_channels.
Default: True.
conv_cfg (dict): Config dict for convolution layer. Default: None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
Returns:
Tensor: The output tensor.
"""
def __init__(self,
in_channels,
out_channels,
mid_channels,
kernel_size=3,
groups=None,
stride=1,
se_cfg=None,
with_expand_conv=True,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
with_cp=False):
# Protect mutable default arguments
norm_cfg = copy.deepcopy(norm_cfg)
act_cfg = copy.deepcopy(act_cfg)
super().__init__()
self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
assert stride in [1, 2]
self.with_cp = with_cp
self.with_se = se_cfg is not None
self.with_expand_conv = with_expand_conv
if groups is None:
groups = mid_channels
if self.with_se:
assert isinstance(se_cfg, dict)
if not self.with_expand_conv:
assert mid_channels == in_channels
if self.with_expand_conv:
self.expand_conv = ConvModule(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.depthwise_conv = ConvModule(
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2,
groups=groups,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
if self.with_se:
self.se = SELayer(**se_cfg)
self.linear_conv = ConvModule(
in_channels=mid_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
def forward(self, x):
def _inner_forward(x):
out = x
if self.with_expand_conv:
out = self.expand_conv(out)
out = self.depthwise_conv(out)
if self.with_se:
out = self.se(out)
out = self.linear_conv(out)
if self.with_res_shortcut:
return x + out
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/make_divisible.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
def make_divisible(value, divisor, min_value=None, min_ratio=0.9):
"""Make divisible function.
This function rounds the channel number down to the nearest value that can
be divisible by the divisor.
Args:
value (int): The original channel number.
divisor (int): The divisor to fully divide the channel number.
min_value (int, optional): The minimum value of the output channel.
Default: None, means that the minimum value equal to the divisor.
min_ratio (float, optional): The minimum ratio of the rounded channel
number to the original channel number. Default: 0.9.
Returns:
int: The modified output channel number
"""
if min_value is None:
min_value = divisor
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than (1-min_ratio).
if new_value < min_ratio * value:
new_value += divisor
return new_value
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/se_layer.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch.nn as nn
from mmcv.cnn import ConvModule
class SELayer(nn.Module):
"""Squeeze-and-Excitation Module.
Args:
channels (int): The input (and output) channels of the SE layer.
ratio (int): Squeeze ratio in SELayer, the intermediate channel will be
``int(channels/ratio)``. Default: 16.
conv_cfg (None or dict): Config dict for convolution layer.
Default: None, which means using conv2d.
act_cfg (dict or Sequence[dict]): Config dict for activation layer.
If act_cfg is a dict, two activation layers will be configurated
by this dict. If act_cfg is a sequence of dicts, the first
activation layer will be configurated by the first dict and the
second activation layer will be configurated by the second dict.
Default: (dict(type='ReLU'), dict(type='Sigmoid'))
"""
def __init__(self,
channels,
ratio=16,
conv_cfg=None,
act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))):
super().__init__()
if isinstance(act_cfg, dict):
act_cfg = (act_cfg, act_cfg)
assert len(act_cfg) == 2
assert mmcv.is_tuple_of(act_cfg, dict)
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.conv1 = ConvModule(
in_channels=channels,
out_channels=int(channels / ratio),
kernel_size=1,
stride=1,
conv_cfg=conv_cfg,
act_cfg=act_cfg[0])
self.conv2 = ConvModule(
in_channels=int(channels / ratio),
out_channels=channels,
kernel_size=1,
stride=1,
conv_cfg=conv_cfg,
act_cfg=act_cfg[1])
def forward(self, x):
out = self.global_avgpool(x)
out = self.conv1(out)
out = self.conv2(out)
return x * out
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/utils/utils.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from mmcv.runner.checkpoint import _load_checkpoint, load_state_dict
def load_checkpoint(model,
filename,
map_location='cpu',
strict=False,
logger=None):
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = _load_checkpoint(filename, map_location)
# OrderedDict is a subclass of dict
if not isinstance(checkpoint, dict):
raise RuntimeError(
f'No state_dict found in checkpoint file {filename}')
# get state_dict from checkpoint
if 'state_dict' in checkpoint:
state_dict_tmp = checkpoint['state_dict']
else:
state_dict_tmp = checkpoint
state_dict = OrderedDict()
# strip prefix of state_dict
for k, v in state_dict_tmp.items():
if k.startswith('module.backbone.'):
state_dict[k[16:]] = v
elif k.startswith('module.'):
state_dict[k[7:]] = v
elif k.startswith('backbone.'):
state_dict[k[9:]] = v
else:
state_dict[k] = v
# load state_dict
load_state_dict(model, state_dict, strict, logger)
return checkpoint
def get_state_dict(filename, map_location='cpu'):
"""Get state_dict from a file or URI.
Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``.
map_location (str): Same as :func:`torch.load`.
Returns:
OrderedDict: The state_dict.
"""
checkpoint = _load_checkpoint(filename, map_location)
# OrderedDict is a subclass of dict
if not isinstance(checkpoint, dict):
raise RuntimeError(
f'No state_dict found in checkpoint file {filename}')
# get state_dict from checkpoint
if 'state_dict' in checkpoint:
state_dict_tmp = checkpoint['state_dict']
else:
state_dict_tmp = checkpoint
state_dict = OrderedDict()
# strip prefix of state_dict
for k, v in state_dict_tmp.items():
if k.startswith('module.backbone.'):
state_dict[k[16:]] = v
elif k.startswith('module.'):
state_dict[k[7:]] = v
elif k.startswith('backbone.'):
state_dict[k[9:]] = v
else:
state_dict[k] = v
return state_dict
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/vgg.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, kaiming_init, normal_init
from mmcv.utils.parrots_wrapper import _BatchNorm
from ..builder import BACKBONES
from .base_backbone import BaseBackbone
def make_vgg_layer(in_channels,
out_channels,
num_blocks,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
dilation=1,
with_norm=False,
ceil_mode=False):
layers = []
for _ in range(num_blocks):
layer = ConvModule(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
dilation=dilation,
padding=dilation,
bias=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
layers.append(layer)
in_channels = out_channels
layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode))
return layers
@BACKBONES.register_module()
class VGG(BaseBackbone):
"""VGG backbone.
Args:
depth (int): Depth of vgg, from {11, 13, 16, 19}.
with_norm (bool): Use BatchNorm or not.
num_classes (int): number of classes for classification.
num_stages (int): VGG stages, normally 5.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. When it is None, the default behavior depends on
whether num_classes is specified. If num_classes <= 0, the default
value is (4, ), outputting the last feature map before classifier.
If num_classes > 0, the default value is (5, ), outputting the
classification score. Default: None.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
ceil_mode (bool): Whether to use ceil_mode of MaxPool. Default: False.
with_last_pool (bool): Whether to keep the last pooling before
classifier. Default: True.
"""
# Parameters to build layers. Each element specifies the number of conv in
# each stage. For example, VGG11 contains 11 layers with learnable
# parameters. 11 is computed as 11 = (1 + 1 + 2 + 2 + 2) + 3,
# where 3 indicates the last three fully-connected layers.
arch_settings = {
11: (1, 1, 2, 2, 2),
13: (2, 2, 2, 2, 2),
16: (2, 2, 3, 3, 3),
19: (2, 2, 4, 4, 4)
}
def __init__(self,
depth,
num_classes=-1,
num_stages=5,
dilations=(1, 1, 1, 1, 1),
out_indices=None,
frozen_stages=-1,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
norm_eval=False,
ceil_mode=False,
with_last_pool=True):
super().__init__()
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for vgg')
assert num_stages >= 1 and num_stages <= 5
stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages]
assert len(dilations) == num_stages
self.num_classes = num_classes
self.frozen_stages = frozen_stages
self.norm_eval = norm_eval
with_norm = norm_cfg is not None
if out_indices is None:
out_indices = (5, ) if num_classes > 0 else (4, )
assert max(out_indices) <= num_stages
self.out_indices = out_indices
self.in_channels = 3
start_idx = 0
vgg_layers = []
self.range_sub_modules = []
for i, num_blocks in enumerate(self.stage_blocks):
num_modules = num_blocks + 1
end_idx = start_idx + num_modules
dilation = dilations[i]
out_channels = 64 * 2**i if i < 4 else 512
vgg_layer = make_vgg_layer(
self.in_channels,
out_channels,
num_blocks,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
dilation=dilation,
with_norm=with_norm,
ceil_mode=ceil_mode)
vgg_layers.extend(vgg_layer)
self.in_channels = out_channels
self.range_sub_modules.append([start_idx, end_idx])
start_idx = end_idx
if not with_last_pool:
vgg_layers.pop(-1)
self.range_sub_modules[-1][1] -= 1
self.module_name = 'features'
self.add_module(self.module_name, nn.Sequential(*vgg_layers))
if self.num_classes > 0:
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
def init_weights(self, pretrained=None):
super().init_weights(pretrained)
if pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, _BatchNorm):
constant_init(m, 1)
elif isinstance(m, nn.Linear):
normal_init(m, std=0.01)
def forward(self, x):
outs = []
vgg_layers = getattr(self, self.module_name)
for i in range(len(self.stage_blocks)):
for j in range(*self.range_sub_modules[i]):
vgg_layer = vgg_layers[j]
x = vgg_layer(x)
if i in self.out_indices:
outs.append(x)
if self.num_classes > 0:
x = x.view(x.size(0), -1)
x = self.classifier(x)
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
def _freeze_stages(self):
vgg_layers = getattr(self, self.module_name)
for i in range(self.frozen_stages):
for j in range(*self.range_sub_modules[i]):
m = vgg_layers[j]
m.eval()
for param in m.parameters():
param.requires_grad = False
def train(self, mode=True):
super().train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/vipnas_mbv3.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import logging
import torch.nn as nn
from mmcv.cnn import ConvModule
from torch.nn.modules.batchnorm import _BatchNorm
from ..builder import BACKBONES
from .base_backbone import BaseBackbone
from .utils import InvertedResidual, load_checkpoint
@BACKBONES.register_module()
class ViPNAS_MobileNetV3(BaseBackbone):
"""ViPNAS_MobileNetV3 backbone.
"ViPNAS: Efficient Video Pose Estimation via Neural Architecture Search"
More details can be found in the `paper
`__ .
Args:
wid (list(int)): Searched width config for each stage.
expan (list(int)): Searched expansion ratio config for each stage.
dep (list(int)): Searched depth config for each stage.
ks (list(int)): Searched kernel size config for each stage.
group (list(int)): Searched group number config for each stage.
att (list(bool)): Searched attention config for each stage.
stride (list(int)): Stride config for each stage.
act (list(dict)): Activation config for each stage.
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
frozen_stages (int): Stages to be frozen (all param fixed).
Default: -1, which means not freezing any parameters.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed.
Default: False.
"""
def __init__(self,
wid=[16, 16, 24, 40, 80, 112, 160],
expan=[None, 1, 5, 4, 5, 5, 6],
dep=[None, 1, 4, 4, 4, 4, 4],
ks=[3, 3, 7, 7, 5, 7, 5],
group=[None, 8, 120, 20, 100, 280, 240],
att=[None, True, True, False, True, True, True],
stride=[2, 1, 2, 2, 2, 1, 2],
act=[
'HSwish', 'ReLU', 'ReLU', 'ReLU', 'HSwish', 'HSwish',
'HSwish'
],
conv_cfg=None,
norm_cfg=dict(type='BN'),
frozen_stages=-1,
norm_eval=False,
with_cp=False):
# Protect mutable default arguments
norm_cfg = copy.deepcopy(norm_cfg)
super().__init__()
self.wid = wid
self.expan = expan
self.dep = dep
self.ks = ks
self.group = group
self.att = att
self.stride = stride
self.act = act
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.frozen_stages = frozen_stages
self.norm_eval = norm_eval
self.with_cp = with_cp
self.conv1 = ConvModule(
in_channels=3,
out_channels=self.wid[0],
kernel_size=self.ks[0],
stride=self.stride[0],
padding=self.ks[0] // 2,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=dict(type=self.act[0]))
self.layers = self._make_layer()
def _make_layer(self):
layers = []
layer_index = 0
for i, dep in enumerate(self.dep[1:]):
mid_channels = self.wid[i + 1] * self.expan[i + 1]
if self.att[i + 1]:
se_cfg = dict(
channels=mid_channels,
ratio=4,
act_cfg=(dict(type='ReLU'), dict(type='HSigmoid')))
else:
se_cfg = None
if self.expan[i + 1] == 1:
with_expand_conv = False
else:
with_expand_conv = True
for j in range(dep):
if j == 0:
stride = self.stride[i + 1]
in_channels = self.wid[i]
else:
stride = 1
in_channels = self.wid[i + 1]
layer = InvertedResidual(
in_channels=in_channels,
out_channels=self.wid[i + 1],
mid_channels=mid_channels,
kernel_size=self.ks[i + 1],
groups=self.group[i + 1],
stride=stride,
se_cfg=se_cfg,
with_expand_conv=with_expand_conv,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=dict(type=self.act[i + 1]),
with_cp=self.with_cp)
layer_index += 1
layer_name = f'layer{layer_index}'
self.add_module(layer_name, layer)
layers.append(layer_name)
return layers
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, std=0.001)
for name, _ in m.named_parameters():
if name in ['bias']:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
x = self.conv1(x)
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
return x
def _freeze_stages(self):
if self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
layer = getattr(self, f'layer{i}')
layer.eval()
for param in layer.parameters():
param.requires_grad = False
def train(self, mode=True):
super().train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/vipnas_resnet.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer
from mmcv.cnn.bricks import ContextBlock
from mmcv.utils.parrots_wrapper import _BatchNorm
from ..builder import BACKBONES
from .base_backbone import BaseBackbone
class ViPNAS_Bottleneck(nn.Module):
"""Bottleneck block for ViPNAS_ResNet.
Args:
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
expansion (int): The ratio of ``out_channels/mid_channels`` where
``mid_channels`` is the input/output channels of conv2. Default: 4.
stride (int): stride of the block. Default: 1
dilation (int): dilation of convolution. Default: 1
downsample (nn.Module): downsample operation on identity branch.
Default: None.
style (str): ``"pytorch"`` or ``"caffe"``. If set to "pytorch", the
stride-two layer is the 3x3 conv layer, otherwise the stride-two
layer is the first 1x1 conv layer. Default: "pytorch".
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
conv_cfg (dict): dictionary to construct and config conv layer.
Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
kernel_size (int): kernel size of conv2 searched in ViPANS.
groups (int): group number of conv2 searched in ViPNAS.
attention (bool): whether to use attention module in the end of
the block.
"""
def __init__(self,
in_channels,
out_channels,
expansion=4,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
kernel_size=3,
groups=1,
attention=False):
# Protect mutable default arguments
norm_cfg = copy.deepcopy(norm_cfg)
super().__init__()
assert style in ['pytorch', 'caffe']
self.in_channels = in_channels
self.out_channels = out_channels
self.expansion = expansion
assert out_channels % expansion == 0
self.mid_channels = out_channels // expansion
self.stride = stride
self.dilation = dilation
self.style = style
self.with_cp = with_cp
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
if self.style == 'pytorch':
self.conv1_stride = 1
self.conv2_stride = stride
else:
self.conv1_stride = stride
self.conv2_stride = 1
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, self.mid_channels, postfix=1)
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, self.mid_channels, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
norm_cfg, out_channels, postfix=3)
self.conv1 = build_conv_layer(
conv_cfg,
in_channels,
self.mid_channels,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
conv_cfg,
self.mid_channels,
self.mid_channels,
kernel_size=kernel_size,
stride=self.conv2_stride,
padding=kernel_size // 2,
groups=groups,
dilation=dilation,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
conv_cfg,
self.mid_channels,
out_channels,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
if attention:
self.attention = ContextBlock(out_channels,
max(1.0 / 16, 16.0 / out_channels))
else:
self.attention = None
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
return getattr(self, self.norm1_name)
@property
def norm2(self):
"""nn.Module: the normalization layer named "norm2" """
return getattr(self, self.norm2_name)
@property
def norm3(self):
"""nn.Module: the normalization layer named "norm3" """
return getattr(self, self.norm3_name)
def forward(self, x):
"""Forward function."""
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.norm3(out)
if self.attention is not None:
out = self.attention(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
def get_expansion(block, expansion=None):
"""Get the expansion of a residual block.
The block expansion will be obtained by the following order:
1. If ``expansion`` is given, just return it.
2. If ``block`` has the attribute ``expansion``, then return
``block.expansion``.
3. Return the default value according the the block type:
4 for ``ViPNAS_Bottleneck``.
Args:
block (class): The block class.
expansion (int | None): The given expansion ratio.
Returns:
int: The expansion of the block.
"""
if isinstance(expansion, int):
assert expansion > 0
elif expansion is None:
if hasattr(block, 'expansion'):
expansion = block.expansion
elif issubclass(block, ViPNAS_Bottleneck):
expansion = 1
else:
raise TypeError(f'expansion is not specified for {block.__name__}')
else:
raise TypeError('expansion must be an integer or None')
return expansion
class ViPNAS_ResLayer(nn.Sequential):
"""ViPNAS_ResLayer to build ResNet style backbone.
Args:
block (nn.Module): Residual block used to build ViPNAS ResLayer.
num_blocks (int): Number of blocks.
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
expansion (int, optional): The expansion for BasicBlock/Bottleneck.
If not specified, it will firstly be obtained via
``block.expansion``. If the block has no attribute "expansion",
the following default values will be used: 1 for BasicBlock and
4 for Bottleneck. Default: None.
stride (int): stride of the first block. Default: 1.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False
conv_cfg (dict): dictionary to construct and config conv layer.
Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
downsample_first (bool): Downsample at the first block or last block.
False for Hourglass, True for ResNet. Default: True
kernel_size (int): Kernel Size of the corresponding convolution layer
searched in the block.
groups (int): Group number of the corresponding convolution layer
searched in the block.
attention (bool): Whether to use attention module in the end of the
block.
"""
def __init__(self,
block,
num_blocks,
in_channels,
out_channels,
expansion=None,
stride=1,
avg_down=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
downsample_first=True,
kernel_size=3,
groups=1,
attention=False,
**kwargs):
# Protect mutable default arguments
norm_cfg = copy.deepcopy(norm_cfg)
self.block = block
self.expansion = get_expansion(block, expansion)
downsample = None
if stride != 1 or in_channels != out_channels:
downsample = []
conv_stride = stride
if avg_down and stride != 1:
conv_stride = 1
downsample.append(
nn.AvgPool2d(
kernel_size=stride,
stride=stride,
ceil_mode=True,
count_include_pad=False))
downsample.extend([
build_conv_layer(
conv_cfg,
in_channels,
out_channels,
kernel_size=1,
stride=conv_stride,
bias=False),
build_norm_layer(norm_cfg, out_channels)[1]
])
downsample = nn.Sequential(*downsample)
layers = []
if downsample_first:
layers.append(
block(
in_channels=in_channels,
out_channels=out_channels,
expansion=self.expansion,
stride=stride,
downsample=downsample,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
kernel_size=kernel_size,
groups=groups,
attention=attention,
**kwargs))
in_channels = out_channels
for _ in range(1, num_blocks):
layers.append(
block(
in_channels=in_channels,
out_channels=out_channels,
expansion=self.expansion,
stride=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
kernel_size=kernel_size,
groups=groups,
attention=attention,
**kwargs))
else: # downsample_first=False is for HourglassModule
for i in range(0, num_blocks - 1):
layers.append(
block(
in_channels=in_channels,
out_channels=in_channels,
expansion=self.expansion,
stride=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
kernel_size=kernel_size,
groups=groups,
attention=attention,
**kwargs))
layers.append(
block(
in_channels=in_channels,
out_channels=out_channels,
expansion=self.expansion,
stride=stride,
downsample=downsample,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
kernel_size=kernel_size,
groups=groups,
attention=attention,
**kwargs))
super().__init__(*layers)
@BACKBONES.register_module()
class ViPNAS_ResNet(BaseBackbone):
"""ViPNAS_ResNet backbone.
"ViPNAS: Efficient Video Pose Estimation via Neural Architecture Search"
More details can be found in the `paper
`__ .
Args:
depth (int): Network depth, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Default: 3.
num_stages (int): Stages of the network. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
wid (list(int)): Searched width config for each stage.
expan (list(int)): Searched expansion ratio config for each stage.
dep (list(int)): Searched depth config for each stage.
ks (list(int)): Searched kernel size config for each stage.
group (list(int)): Searched group number config for each stage.
att (list(bool)): Searched attention config for each stage.
"""
arch_settings = {
50: ViPNAS_Bottleneck,
}
def __init__(self,
depth,
in_channels=3,
num_stages=4,
strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1),
out_indices=(3, ),
style='pytorch',
deep_stem=False,
avg_down=False,
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=False,
with_cp=False,
zero_init_residual=True,
wid=[48, 80, 160, 304, 608],
expan=[None, 1, 1, 1, 1],
dep=[None, 4, 6, 7, 3],
ks=[7, 3, 5, 5, 5],
group=[None, 16, 16, 16, 16],
att=[None, True, False, True, True]):
# Protect mutable default arguments
norm_cfg = copy.deepcopy(norm_cfg)
super().__init__()
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
self.depth = depth
self.stem_channels = dep[0]
self.num_stages = num_stages
assert 1 <= num_stages <= 4
self.strides = strides
self.dilations = dilations
assert len(strides) == len(dilations) == num_stages
self.out_indices = out_indices
assert max(out_indices) < num_stages
self.style = style
self.deep_stem = deep_stem
self.avg_down = avg_down
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.with_cp = with_cp
self.norm_eval = norm_eval
self.zero_init_residual = zero_init_residual
self.block = self.arch_settings[depth]
self.stage_blocks = dep[1:1 + num_stages]
self._make_stem_layer(in_channels, wid[0], ks[0])
self.res_layers = []
_in_channels = wid[0]
for i, num_blocks in enumerate(self.stage_blocks):
expansion = get_expansion(self.block, expan[i + 1])
_out_channels = wid[i + 1] * expansion
stride = strides[i]
dilation = dilations[i]
res_layer = self.make_res_layer(
block=self.block,
num_blocks=num_blocks,
in_channels=_in_channels,
out_channels=_out_channels,
expansion=expansion,
stride=stride,
dilation=dilation,
style=self.style,
avg_down=self.avg_down,
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
kernel_size=ks[i + 1],
groups=group[i + 1],
attention=att[i + 1])
_in_channels = _out_channels
layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self._freeze_stages()
self.feat_dim = res_layer[-1].out_channels
def make_res_layer(self, **kwargs):
"""Make a ViPNAS ResLayer."""
return ViPNAS_ResLayer(**kwargs)
@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
return getattr(self, self.norm1_name)
def _make_stem_layer(self, in_channels, stem_channels, kernel_size):
"""Make stem layer."""
if self.deep_stem:
self.stem = nn.Sequential(
ConvModule(
in_channels,
stem_channels // 2,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
inplace=True),
ConvModule(
stem_channels // 2,
stem_channels // 2,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
inplace=True),
ConvModule(
stem_channels // 2,
stem_channels,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
inplace=True))
else:
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels,
stem_channels,
kernel_size=kernel_size,
stride=2,
padding=kernel_size // 2,
bias=False)
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, stem_channels, postfix=1)
self.add_module(self.norm1_name, norm1)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
def _freeze_stages(self):
"""Freeze parameters."""
if self.frozen_stages >= 0:
if self.deep_stem:
self.stem.eval()
for param in self.stem.parameters():
param.requires_grad = False
else:
self.norm1.eval()
for m in [self.conv1, self.norm1]:
for param in m.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
m = getattr(self, f'layer{i}')
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
"""Initialize model weights."""
super().init_weights(pretrained)
if pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, std=0.001)
for name, _ in m.named_parameters():
if name in ['bias']:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
"""Forward function."""
if self.deep_stem:
x = self.stem(x)
else:
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.maxpool(x)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
return tuple(outs)
def train(self, mode=True):
"""Convert the model into training mode."""
super().train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/backbones/vit.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
from functools import partial
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
# from ..builder import BACKBONES
# from .base_backbone import BaseBackbone
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self):
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., attn_head_dim=None,):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.dim = dim
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
norm_layer=nn.LayerNorm, attn_head_dim=None
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1))
def forward(self, x, **kwargs):
B, C, H, W = x.shape
x = self.proj(x)
Hp, Wp = x.shape[2], x.shape[3]
x = x.flatten(2).transpose(1, 2)
return x, (Hp, Wp)
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
feature_dim = self.backbone.feature_info.channels()[-1]
self.num_patches = feature_size[0] * feature_size[1]
self.proj = nn.Linear(feature_dim, embed_dim)
def forward(self, x):
x = self.backbone(x)[-1]
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
# @BACKBONES.register_module()
class ViT(nn.Module):
def __init__(self,
img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
frozen_stages=-1, ratio=1, last_norm=True,
patch_padding='pad', freeze_attn=False, freeze_ffn=False,
):
# Protect mutable default arguments
super(ViT, self).__init__()
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.frozen_stages = frozen_stages
self.use_checkpoint = use_checkpoint
self.patch_padding = patch_padding
self.freeze_attn = freeze_attn
self.freeze_ffn = freeze_ffn
self.depth = depth
if hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
num_patches = self.patch_embed.num_patches
# since the pretraining model has class token
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
)
for i in range(depth)])
self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
self._freeze_stages()
def _freeze_stages(self):
"""Freeze parameters."""
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
m = self.blocks[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
if self.freeze_attn:
for i in range(0, self.depth):
m = self.blocks[i]
m.attn.eval()
m.norm1.eval()
for param in m.attn.parameters():
param.requires_grad = False
for param in m.norm1.parameters():
param.requires_grad = False
if self.freeze_ffn:
self.pos_embed.requires_grad = False
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
for i in range(0, self.depth):
m = self.blocks[i]
m.mlp.eval()
m.norm2.eval()
for param in m.mlp.parameters():
param.requires_grad = False
for param in m.norm2.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
super().init_weights(pretrained, patch_padding=self.patch_padding)
if pretrained is None:
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
self.apply(_init_weights)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward_features(self, x):
B, C, H, W = x.shape
x, (Hp, Wp) = self.patch_embed(x)
if self.pos_embed is not None:
# fit for multiple GPU training
# since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
x = self.last_norm(x)
xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
return xp
def forward(self, x):
x = self.forward_features(x)
return x
def train(self, mode=True):
"""Convert the model into training mode."""
super().train(mode)
self._freeze_stages()
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_base_coco_256x192.py
================================================
_base_ = [
'../../../../_base_/default_runtime.py',
'../../../../_base_/datasets/coco.py'
]
evaluation = dict(interval=10, metric='mAP', save_best='AP')
optimizer = dict(type='AdamW', lr=5e-4, betas=(0.9, 0.999), weight_decay=0.1,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(
num_layers=12,
layer_decay_rate=0.75,
custom_keys={
'bias': dict(decay_multi=0.),
'pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}
)
)
optimizer_config = dict(grad_clip=dict(max_norm=1., norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[170, 200])
total_epochs = 210
target_type = 'GaussianHeatmap'
channel_cfg = dict(
num_output_channels=17,
dataset_joints=17,
dataset_channel=[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
],
inference_channel=[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
])
# model settings
model = dict(
type='TopDown',
pretrained=None,
backbone=dict(
type='ViT',
img_size=(256, 192),
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
ratio=1,
use_checkpoint=False,
mlp_ratio=4,
qkv_bias=True,
drop_path_rate=0.3,
),
keypoint_head=dict(
type='TopdownHeatmapSimpleHead',
in_channels=768,
num_deconv_layers=2,
num_deconv_filters=(256, 256),
num_deconv_kernels=(4, 4),
extra=dict(final_conv_kernel=1, ),
out_channels=channel_cfg['num_output_channels'],
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
train_cfg=dict(),
test_cfg=dict())
data_cfg = dict(
image_size=[192, 256],
heatmap_size=[48, 64],
num_output_channels=channel_cfg['num_output_channels'],
num_joints=channel_cfg['dataset_joints'],
dataset_channel=channel_cfg['dataset_channel'],
inference_channel=channel_cfg['inference_channel'],
soft_nms=False,
nms_thr=1.0,
oks_thr=0.9,
vis_thr=0.2,
use_gt_bbox=False,
det_bbox_thr=0.0,
bbox_file='data/coco/person_detection_results/'
'COCO_val2017_detections_AP_H_56_person.json',
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='TopDownRandomFlip', flip_prob=0.5),
dict(
type='TopDownHalfBodyTransform',
num_joints_half_body=8,
prob_half_body=0.3),
dict(
type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
dict(type='TopDownAffine', use_udp=True),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(
type='TopDownGenerateTarget',
sigma=2,
encoding='UDP',
target_type=target_type),
dict(
type='Collect',
keys=['img', 'target', 'target_weight'],
meta_keys=[
'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
'rotation', 'bbox_score', 'flip_pairs'
]),
]
val_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='TopDownAffine', use_udp=True),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'image_file', 'center', 'scale', 'rotation', 'bbox_score',
'flip_pairs'
]),
]
test_pipeline = val_pipeline
data_root = 'data/coco'
# data = dict(
# samples_per_gpu=64,
# workers_per_gpu=4,
# val_dataloader=dict(samples_per_gpu=32),
# test_dataloader=dict(samples_per_gpu=32),
# train=dict(
# type='TopDownCocoDataset',
# ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
# img_prefix=f'{data_root}/train2017/',
# data_cfg=data_cfg,
# pipeline=train_pipeline,
# dataset_info={{_base_.dataset_info}}),
# val=dict(
# type='TopDownCocoDataset',
# ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
# img_prefix=f'{data_root}/val2017/',
# data_cfg=data_cfg,
# pipeline=val_pipeline,
# dataset_info={{_base_.dataset_info}}),
# test=dict(
# type='TopDownCocoDataset',
# ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
# img_prefix=f'{data_root}/val2017/',
# data_cfg=data_cfg,
# pipeline=test_pipeline,
# dataset_info={{_base_.dataset_info}}),
# )
def make_cfg(model=model,data_cfg=data_cfg):
cfg={}
cfg['model'] = model
cfg['data_cfg'] = data_cfg
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_base_simple_coco_256x192.py
================================================
_base_ = [
'../../../../_base_/default_runtime.py',
'../../../../_base_/datasets/coco.py'
]
evaluation = dict(interval=10, metric='mAP', save_best='AP')
optimizer = dict(type='AdamW', lr=5e-4, betas=(0.9, 0.999), weight_decay=0.1,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(
num_layers=12,
layer_decay_rate=0.75,
custom_keys={
'bias': dict(decay_multi=0.),
'pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}
)
)
optimizer_config = dict(grad_clip=dict(max_norm=1., norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[170, 200])
total_epochs = 210
target_type = 'GaussianHeatmap'
channel_cfg = dict(
num_output_channels=17,
dataset_joints=17,
dataset_channel=[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
],
inference_channel=[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
])
# model settings
model = dict(
type='TopDown',
pretrained=None,
backbone=dict(
type='ViT',
img_size=(256, 192),
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
ratio=1,
use_checkpoint=False,
mlp_ratio=4,
qkv_bias=True,
drop_path_rate=0.3,
),
keypoint_head=dict(
type='TopdownHeatmapSimpleHead',
in_channels=768,
num_deconv_layers=0,
num_deconv_filters=[],
num_deconv_kernels=[],
upsample=4,
extra=dict(final_conv_kernel=3, ),
out_channels=channel_cfg['num_output_channels'],
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
train_cfg=dict(),
test_cfg=dict(
flip_test=True,
post_process='default',
shift_heatmap=False,
target_type=target_type,
modulate_kernel=11,
use_udp=True))
data_cfg = dict(
image_size=[192, 256],
heatmap_size=[48, 64],
num_output_channels=channel_cfg['num_output_channels'],
num_joints=channel_cfg['dataset_joints'],
dataset_channel=channel_cfg['dataset_channel'],
inference_channel=channel_cfg['inference_channel'],
soft_nms=False,
nms_thr=1.0,
oks_thr=0.9,
vis_thr=0.2,
use_gt_bbox=False,
det_bbox_thr=0.0,
bbox_file='data/coco/person_detection_results/'
'COCO_val2017_detections_AP_H_56_person.json',
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='TopDownRandomFlip', flip_prob=0.5),
dict(
type='TopDownHalfBodyTransform',
num_joints_half_body=8,
prob_half_body=0.3),
dict(
type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
dict(type='TopDownAffine', use_udp=True),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(
type='TopDownGenerateTarget',
sigma=2,
encoding='UDP',
target_type=target_type),
dict(
type='Collect',
keys=['img', 'target', 'target_weight'],
meta_keys=[
'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
'rotation', 'bbox_score', 'flip_pairs'
]),
]
val_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='TopDownAffine', use_udp=True),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'image_file', 'center', 'scale', 'rotation', 'bbox_score',
'flip_pairs'
]),
]
test_pipeline = val_pipeline
data_root = 'data/coco'
data = dict(
samples_per_gpu=64,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=32),
test_dataloader=dict(samples_per_gpu=32),
train=dict(
type='TopDownCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
img_prefix=f'{data_root}/train2017/',
data_cfg=data_cfg,
pipeline=train_pipeline,
dataset_info={{_base_.dataset_info}}),
val=dict(
type='TopDownCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
img_prefix=f'{data_root}/val2017/',
data_cfg=data_cfg,
pipeline=val_pipeline,
dataset_info={{_base_.dataset_info}}),
test=dict(
type='TopDownCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
img_prefix=f'{data_root}/val2017/',
data_cfg=data_cfg,
pipeline=test_pipeline,
dataset_info={{_base_.dataset_info}}),
)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_huge_coco_256x192.py
================================================
_base_ = [
'../../../../_base_/default_runtime.py',
'../../../../_base_/datasets/coco.py'
]
evaluation = dict(interval=10, metric='mAP', save_best='AP')
optimizer = dict(type='AdamW', lr=5e-4, betas=(0.9, 0.999), weight_decay=0.1,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(
num_layers=32,
layer_decay_rate=0.85,
custom_keys={
'bias': dict(decay_multi=0.),
'pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}
)
)
optimizer_config = dict(grad_clip=dict(max_norm=1., norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[170, 200])
total_epochs = 210
target_type = 'GaussianHeatmap'
channel_cfg = dict(
num_output_channels=17,
dataset_joints=17,
dataset_channel=[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
],
inference_channel=[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
])
# model settings
model = dict(
type='TopDown',
pretrained=None,
backbone=dict(
type='ViT',
img_size=(256, 192),
patch_size=16,
embed_dim=1280,
depth=32,
num_heads=16,
ratio=1,
use_checkpoint=False,
mlp_ratio=4,
qkv_bias=True,
drop_path_rate=0.55,
),
keypoint_head=dict(
type='TopdownHeatmapSimpleHead',
in_channels=1280,
num_deconv_layers=2,
num_deconv_filters=(256, 256),
num_deconv_kernels=(4, 4),
extra=dict(final_conv_kernel=1, ),
out_channels=channel_cfg['num_output_channels'],
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
train_cfg=dict(),
test_cfg=dict(
flip_test=True,
post_process='default',
shift_heatmap=False,
target_type=target_type,
modulate_kernel=11,
use_udp=True))
data_cfg = dict(
image_size=[192, 256],
heatmap_size=[48, 64],
num_output_channels=channel_cfg['num_output_channels'],
num_joints=channel_cfg['dataset_joints'],
dataset_channel=channel_cfg['dataset_channel'],
inference_channel=channel_cfg['inference_channel'],
soft_nms=False,
nms_thr=1.0,
oks_thr=0.9,
vis_thr=0.2,
use_gt_bbox=False,
det_bbox_thr=0.0,
bbox_file='data/coco/person_detection_results/'
'COCO_val2017_detections_AP_H_56_person.json',
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='TopDownRandomFlip', flip_prob=0.5),
dict(
type='TopDownHalfBodyTransform',
num_joints_half_body=8,
prob_half_body=0.3),
dict(
type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
dict(type='TopDownAffine', use_udp=True),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(
type='TopDownGenerateTarget',
sigma=2,
encoding='UDP',
target_type=target_type),
dict(
type='Collect',
keys=['img', 'target', 'target_weight'],
meta_keys=[
'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
'rotation', 'bbox_score', 'flip_pairs'
]),
]
val_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='TopDownAffine', use_udp=True),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'image_file', 'center', 'scale', 'rotation', 'bbox_score',
'flip_pairs'
]),
]
test_pipeline = val_pipeline
data_root = 'data/coco'
data = dict(
samples_per_gpu=64,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=32),
test_dataloader=dict(samples_per_gpu=32),
train=dict(
type='TopDownCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
img_prefix=f'{data_root}/train2017/',
data_cfg=data_cfg,
pipeline=train_pipeline,
dataset_info={{_base_.dataset_info}}),
val=dict(
type='TopDownCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
img_prefix=f'{data_root}/val2017/',
data_cfg=data_cfg,
pipeline=val_pipeline,
dataset_info={{_base_.dataset_info}}),
test=dict(
type='TopDownCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
img_prefix=f'{data_root}/val2017/',
data_cfg=data_cfg,
pipeline=test_pipeline,
dataset_info={{_base_.dataset_info}}),
)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_huge_simple_coco_256x192.py
================================================
_base_ = [
'../../../../_base_/default_runtime.py',
'../../../../_base_/datasets/coco.py'
]
evaluation = dict(interval=10, metric='mAP', save_best='AP')
optimizer = dict(type='AdamW', lr=5e-4, betas=(0.9, 0.999), weight_decay=0.1,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(
num_layers=32,
layer_decay_rate=0.85,
custom_keys={
'bias': dict(decay_multi=0.),
'pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}
)
)
optimizer_config = dict(grad_clip=dict(max_norm=1., norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[170, 200])
total_epochs = 210
target_type = 'GaussianHeatmap'
channel_cfg = dict(
num_output_channels=17,
dataset_joints=17,
dataset_channel=[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
],
inference_channel=[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
])
# model settings
model = dict(
type='TopDown',
pretrained=None,
backbone=dict(
type='ViT',
img_size=(256, 192),
patch_size=16,
embed_dim=1280,
depth=32,
num_heads=16,
ratio=1,
use_checkpoint=False,
mlp_ratio=4,
qkv_bias=True,
drop_path_rate=0.55,
),
keypoint_head=dict(
type='TopdownHeatmapSimpleHead',
in_channels=1280,
num_deconv_layers=0,
num_deconv_filters=[],
num_deconv_kernels=[],
upsample=4,
extra=dict(final_conv_kernel=3, ),
out_channels=channel_cfg['num_output_channels'],
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
train_cfg=dict(),
test_cfg=dict(
flip_test=True,
post_process='default',
shift_heatmap=False,
target_type=target_type,
modulate_kernel=11,
use_udp=True))
data_cfg = dict(
image_size=[192, 256],
heatmap_size=[48, 64],
num_output_channels=channel_cfg['num_output_channels'],
num_joints=channel_cfg['dataset_joints'],
dataset_channel=channel_cfg['dataset_channel'],
inference_channel=channel_cfg['inference_channel'],
soft_nms=False,
nms_thr=1.0,
oks_thr=0.9,
vis_thr=0.2,
use_gt_bbox=False,
det_bbox_thr=0.0,
bbox_file='data/coco/person_detection_results/'
'COCO_val2017_detections_AP_H_56_person.json',
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='TopDownRandomFlip', flip_prob=0.5),
dict(
type='TopDownHalfBodyTransform',
num_joints_half_body=8,
prob_half_body=0.3),
dict(
type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
dict(type='TopDownAffine', use_udp=True),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(
type='TopDownGenerateTarget',
sigma=2,
encoding='UDP',
target_type=target_type),
dict(
type='Collect',
keys=['img', 'target', 'target_weight'],
meta_keys=[
'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
'rotation', 'bbox_score', 'flip_pairs'
]),
]
val_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='TopDownAffine', use_udp=True),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'image_file', 'center', 'scale', 'rotation', 'bbox_score',
'flip_pairs'
]),
]
test_pipeline = val_pipeline
data_root = 'data/coco'
data = dict(
samples_per_gpu=64,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=32),
test_dataloader=dict(samples_per_gpu=32),
train=dict(
type='TopDownCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
img_prefix=f'{data_root}/train2017/',
data_cfg=data_cfg,
pipeline=train_pipeline,
dataset_info={{_base_.dataset_info}}),
val=dict(
type='TopDownCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
img_prefix=f'{data_root}/val2017/',
data_cfg=data_cfg,
pipeline=val_pipeline,
dataset_info={{_base_.dataset_info}}),
test=dict(
type='TopDownCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
img_prefix=f'{data_root}/val2017/',
data_cfg=data_cfg,
pipeline=test_pipeline,
dataset_info={{_base_.dataset_info}}),
)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_large_coco_256x192.py
================================================
_base_ = [
'../../../../_base_/default_runtime.py',
'../../../../_base_/datasets/coco.py'
]
evaluation = dict(interval=10, metric='mAP', save_best='AP')
optimizer = dict(type='AdamW', lr=5e-4, betas=(0.9, 0.999), weight_decay=0.1,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(
num_layers=16,
layer_decay_rate=0.8,
custom_keys={
'bias': dict(decay_multi=0.),
'pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}
)
)
optimizer_config = dict(grad_clip=dict(max_norm=1., norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[170, 200])
total_epochs = 210
target_type = 'GaussianHeatmap'
channel_cfg = dict(
num_output_channels=17,
dataset_joints=17,
dataset_channel=[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
],
inference_channel=[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
])
# model settings
model = dict(
type='TopDown',
pretrained=None,
backbone=dict(
type='ViT',
img_size=(256, 192),
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
ratio=1,
use_checkpoint=False,
mlp_ratio=4,
qkv_bias=True,
drop_path_rate=0.5,
),
keypoint_head=dict(
type='TopdownHeatmapSimpleHead',
in_channels=1024,
num_deconv_layers=2,
num_deconv_filters=(256, 256),
num_deconv_kernels=(4, 4),
extra=dict(final_conv_kernel=1, ),
out_channels=channel_cfg['num_output_channels'],
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
train_cfg=dict(),
test_cfg=dict(
flip_test=True,
post_process='default',
shift_heatmap=False,
target_type=target_type,
modulate_kernel=11,
use_udp=True))
data_cfg = dict(
image_size=[192, 256],
heatmap_size=[48, 64],
num_output_channels=channel_cfg['num_output_channels'],
num_joints=channel_cfg['dataset_joints'],
dataset_channel=channel_cfg['dataset_channel'],
inference_channel=channel_cfg['inference_channel'],
soft_nms=False,
nms_thr=1.0,
oks_thr=0.9,
vis_thr=0.2,
use_gt_bbox=False,
det_bbox_thr=0.0,
bbox_file='data/coco/person_detection_results/'
'COCO_val2017_detections_AP_H_56_person.json',
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='TopDownRandomFlip', flip_prob=0.5),
dict(
type='TopDownHalfBodyTransform',
num_joints_half_body=8,
prob_half_body=0.3),
dict(
type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
dict(type='TopDownAffine', use_udp=True),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(
type='TopDownGenerateTarget',
sigma=2,
encoding='UDP',
target_type=target_type),
dict(
type='Collect',
keys=['img', 'target', 'target_weight'],
meta_keys=[
'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
'rotation', 'bbox_score', 'flip_pairs'
]),
]
val_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='TopDownAffine', use_udp=True),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'image_file', 'center', 'scale', 'rotation', 'bbox_score',
'flip_pairs'
]),
]
test_pipeline = val_pipeline
data_root = 'data/coco'
data = dict(
samples_per_gpu=64,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=32),
test_dataloader=dict(samples_per_gpu=32),
train=dict(
type='TopDownCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
img_prefix=f'{data_root}/train2017/',
data_cfg=data_cfg,
pipeline=train_pipeline,
dataset_info={{_base_.dataset_info}}),
val=dict(
type='TopDownCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
img_prefix=f'{data_root}/val2017/',
data_cfg=data_cfg,
pipeline=val_pipeline,
dataset_info={{_base_.dataset_info}}),
test=dict(
type='TopDownCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
img_prefix=f'{data_root}/val2017/',
data_cfg=data_cfg,
pipeline=test_pipeline,
dataset_info={{_base_.dataset_info}}),
)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/ViTPose_large_simple_coco_256x192.py
================================================
_base_ = [
'../../../../_base_/default_runtime.py',
'../../../../_base_/datasets/coco.py'
]
evaluation = dict(interval=10, metric='mAP', save_best='AP')
optimizer = dict(type='AdamW', lr=5e-4, betas=(0.9, 0.999), weight_decay=0.1,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(
num_layers=24,
layer_decay_rate=0.8,
custom_keys={
'bias': dict(decay_multi=0.),
'pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}
)
)
optimizer_config = dict(grad_clip=dict(max_norm=1., norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[170, 200])
total_epochs = 210
target_type = 'GaussianHeatmap'
channel_cfg = dict(
num_output_channels=17,
dataset_joints=17,
dataset_channel=[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
],
inference_channel=[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
])
# model settings
model = dict(
type='TopDown',
pretrained=None,
backbone=dict(
type='ViT',
img_size=(256, 192),
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
ratio=1,
use_checkpoint=False,
mlp_ratio=4,
qkv_bias=True,
drop_path_rate=0.5,
),
keypoint_head=dict(
type='TopdownHeatmapSimpleHead',
in_channels=1024,
num_deconv_layers=0,
num_deconv_filters=[],
num_deconv_kernels=[],
upsample=4,
extra=dict(final_conv_kernel=3, ),
out_channels=channel_cfg['num_output_channels'],
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
train_cfg=dict(),
test_cfg=dict(
flip_test=True,
post_process='default',
shift_heatmap=False,
target_type=target_type,
modulate_kernel=11,
use_udp=True))
data_cfg = dict(
image_size=[192, 256],
heatmap_size=[48, 64],
num_output_channels=channel_cfg['num_output_channels'],
num_joints=channel_cfg['dataset_joints'],
dataset_channel=channel_cfg['dataset_channel'],
inference_channel=channel_cfg['inference_channel'],
soft_nms=False,
nms_thr=1.0,
oks_thr=0.9,
vis_thr=0.2,
use_gt_bbox=False,
det_bbox_thr=0.0,
bbox_file='data/coco/person_detection_results/'
'COCO_val2017_detections_AP_H_56_person.json',
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='TopDownRandomFlip', flip_prob=0.5),
dict(
type='TopDownHalfBodyTransform',
num_joints_half_body=8,
prob_half_body=0.3),
dict(
type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
dict(type='TopDownAffine', use_udp=True),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(
type='TopDownGenerateTarget',
sigma=2,
encoding='UDP',
target_type=target_type),
dict(
type='Collect',
keys=['img', 'target', 'target_weight'],
meta_keys=[
'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
'rotation', 'bbox_score', 'flip_pairs'
]),
]
val_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='TopDownAffine', use_udp=True),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'image_file', 'center', 'scale', 'rotation', 'bbox_score',
'flip_pairs'
]),
]
test_pipeline = val_pipeline
data_root = 'data/coco'
data = dict(
samples_per_gpu=64,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=32),
test_dataloader=dict(samples_per_gpu=32),
train=dict(
type='TopDownCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
img_prefix=f'{data_root}/train2017/',
data_cfg=data_cfg,
pipeline=train_pipeline,
dataset_info={{_base_.dataset_info}}),
val=dict(
type='TopDownCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
img_prefix=f'{data_root}/val2017/',
data_cfg=data_cfg,
pipeline=val_pipeline,
dataset_info={{_base_.dataset_info}}),
test=dict(
type='TopDownCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
img_prefix=f'{data_root}/val2017/',
data_cfg=data_cfg,
pipeline=test_pipeline,
dataset_info={{_base_.dataset_info}}),
)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/configs/coco/__init__.py
================================================
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/__init__.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# from .ae_higher_resolution_head import AEHigherResolutionHead
# from .ae_multi_stage_head import AEMultiStageHead
# from .ae_simple_head import AESimpleHead
# from .deconv_head import DeconvHead
# from .deeppose_regression_head import DeepposeRegressionHead
# from .hmr_head import HMRMeshHead
# from .interhand_3d_head import Interhand3DHead
# from .temporal_regression_head import TemporalRegressionHead
from .topdown_heatmap_base_head import TopdownHeatmapBaseHead
# from .topdown_heatmap_multi_stage_head import (TopdownHeatmapMSMUHead,
# TopdownHeatmapMultiStageHead)
from .topdown_heatmap_simple_head import TopdownHeatmapSimpleHead
# from .vipnas_heatmap_simple_head import ViPNASHeatmapSimpleHead
# from .voxelpose_head import CuboidCenterHead, CuboidPoseHead
# __all__ = [
# 'TopdownHeatmapSimpleHead', 'TopdownHeatmapMultiStageHead',
# 'TopdownHeatmapMSMUHead', 'TopdownHeatmapBaseHead',
# 'AEHigherResolutionHead', 'AESimpleHead', 'AEMultiStageHead',
# 'DeepposeRegressionHead', 'TemporalRegressionHead', 'Interhand3DHead',
# 'HMRMeshHead', 'DeconvHead', 'ViPNASHeatmapSimpleHead', 'CuboidCenterHead',
# 'CuboidPoseHead'
# ]
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/deconv_head.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer,
constant_init, normal_init)
from mmpose.models.builder import HEADS, build_loss
from mmpose.models.utils.ops import resize
@HEADS.register_module()
class DeconvHead(nn.Module):
"""Simple deconv head.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
num_deconv_layers (int): Number of deconv layers.
num_deconv_layers should >= 0. Note that 0 means
no deconv layers.
num_deconv_filters (list|tuple): Number of filters.
If num_deconv_layers > 0, the length of
num_deconv_kernels (list|tuple): Kernel sizes.
in_index (int|Sequence[int]): Input feature index. Default: 0
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
Default: None.
- 'resize_concat': Multiple feature maps will be resized to the
same size as the first one and then concat together.
Usually used in FCN head of HRNet.
- 'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
- None: Only one select feature map is allowed.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
loss_keypoint (dict): Config for loss. Default: None.
"""
def __init__(self,
in_channels=3,
out_channels=17,
num_deconv_layers=3,
num_deconv_filters=(256, 256, 256),
num_deconv_kernels=(4, 4, 4),
extra=None,
in_index=0,
input_transform=None,
align_corners=False,
loss_keypoint=None):
super().__init__()
self.in_channels = in_channels
self.loss = build_loss(loss_keypoint)
self._init_inputs(in_channels, in_index, input_transform)
self.in_index = in_index
self.align_corners = align_corners
if extra is not None and not isinstance(extra, dict):
raise TypeError('extra should be dict or None.')
if num_deconv_layers > 0:
self.deconv_layers = self._make_deconv_layer(
num_deconv_layers,
num_deconv_filters,
num_deconv_kernels,
)
elif num_deconv_layers == 0:
self.deconv_layers = nn.Identity()
else:
raise ValueError(
f'num_deconv_layers ({num_deconv_layers}) should >= 0.')
identity_final_layer = False
if extra is not None and 'final_conv_kernel' in extra:
assert extra['final_conv_kernel'] in [0, 1, 3]
if extra['final_conv_kernel'] == 3:
padding = 1
elif extra['final_conv_kernel'] == 1:
padding = 0
else:
# 0 for Identity mapping.
identity_final_layer = True
kernel_size = extra['final_conv_kernel']
else:
kernel_size = 1
padding = 0
if identity_final_layer:
self.final_layer = nn.Identity()
else:
conv_channels = num_deconv_filters[
-1] if num_deconv_layers > 0 else self.in_channels
layers = []
if extra is not None:
num_conv_layers = extra.get('num_conv_layers', 0)
num_conv_kernels = extra.get('num_conv_kernels',
[1] * num_conv_layers)
for i in range(num_conv_layers):
layers.append(
build_conv_layer(
dict(type='Conv2d'),
in_channels=conv_channels,
out_channels=conv_channels,
kernel_size=num_conv_kernels[i],
stride=1,
padding=(num_conv_kernels[i] - 1) // 2))
layers.append(
build_norm_layer(dict(type='BN'), conv_channels)[1])
layers.append(nn.ReLU(inplace=True))
layers.append(
build_conv_layer(
cfg=dict(type='Conv2d'),
in_channels=conv_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding))
if len(layers) > 1:
self.final_layer = nn.Sequential(*layers)
else:
self.final_layer = layers[0]
def _init_inputs(self, in_channels, in_index, input_transform):
"""Check and initialize input transforms.
The in_channels, in_index and input_transform must match.
Specifically, when input_transform is None, only single feature map
will be selected. So in_channels and in_index must be of type int.
When input_transform is not None, in_channels and in_index must be
list or tuple, with the same length.
Args:
in_channels (int|Sequence[int]): Input channels.
in_index (int|Sequence[int]): Input feature index.
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
- 'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
- 'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
- None: Only one select feature map is allowed.
"""
if input_transform is not None:
assert input_transform in ['resize_concat', 'multiple_select']
self.input_transform = input_transform
self.in_index = in_index
if input_transform is not None:
assert isinstance(in_channels, (list, tuple))
assert isinstance(in_index, (list, tuple))
assert len(in_channels) == len(in_index)
if input_transform == 'resize_concat':
self.in_channels = sum(in_channels)
else:
self.in_channels = in_channels
else:
assert isinstance(in_channels, int)
assert isinstance(in_index, int)
self.in_channels = in_channels
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor] | Tensor): multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if not isinstance(inputs, list):
return inputs
if self.input_transform == 'resize_concat':
inputs = [inputs[i] for i in self.in_index]
upsampled_inputs = [
resize(
input=x,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == 'multiple_select':
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
"""Make deconv layers."""
if num_layers != len(num_filters):
error_msg = f'num_layers({num_layers}) ' \
f'!= length of num_filters({len(num_filters)})'
raise ValueError(error_msg)
if num_layers != len(num_kernels):
error_msg = f'num_layers({num_layers}) ' \
f'!= length of num_kernels({len(num_kernels)})'
raise ValueError(error_msg)
layers = []
for i in range(num_layers):
kernel, padding, output_padding = \
self._get_deconv_cfg(num_kernels[i])
planes = num_filters[i]
layers.append(
build_upsample_layer(
dict(type='deconv'),
in_channels=self.in_channels,
out_channels=planes,
kernel_size=kernel,
stride=2,
padding=padding,
output_padding=output_padding,
bias=False))
layers.append(nn.BatchNorm2d(planes))
layers.append(nn.ReLU(inplace=True))
self.in_channels = planes
return nn.Sequential(*layers)
@staticmethod
def _get_deconv_cfg(deconv_kernel):
"""Get configurations for deconv layers."""
if deconv_kernel == 4:
padding = 1
output_padding = 0
elif deconv_kernel == 3:
padding = 1
output_padding = 1
elif deconv_kernel == 2:
padding = 0
output_padding = 0
else:
raise ValueError(f'Not supported num_kernels ({deconv_kernel}).')
return deconv_kernel, padding, output_padding
def get_loss(self, outputs, targets, masks):
"""Calculate bottom-up masked mse loss.
Note:
- batch_size: N
- num_channels: C
- heatmaps height: H
- heatmaps weight: W
Args:
outputs (List(torch.Tensor[N,C,H,W])): Multi-scale outputs.
targets (List(torch.Tensor[N,C,H,W])): Multi-scale targets.
masks (List(torch.Tensor[N,H,W])): Masks of multi-scale targets.
"""
losses = dict()
for idx in range(len(targets)):
if 'loss' not in losses:
losses['loss'] = self.loss(outputs[idx], targets[idx],
masks[idx])
else:
losses['loss'] += self.loss(outputs[idx], targets[idx],
masks[idx])
return losses
def forward(self, x):
"""Forward function."""
x = self._transform_inputs(x)
final_outputs = []
x = self.deconv_layers(x)
y = self.final_layer(x)
final_outputs.append(y)
return final_outputs
def init_weights(self):
"""Initialize model weights."""
for _, m in self.deconv_layers.named_modules():
if isinstance(m, nn.ConvTranspose2d):
normal_init(m, std=0.001)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
for m in self.final_layer.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, std=0.001, bias=0)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/deeppose_regression_head.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch.nn as nn
from mmcv.cnn import normal_init
from mmpose.core.evaluation import (keypoint_pck_accuracy,
keypoints_from_regression)
from mmpose.core.post_processing import fliplr_regression
from mmpose.models.builder import HEADS, build_loss
@HEADS.register_module()
class DeepposeRegressionHead(nn.Module):
"""Deeppose regression head with fully connected layers.
"DeepPose: Human Pose Estimation via Deep Neural Networks".
Args:
in_channels (int): Number of input channels
num_joints (int): Number of joints
loss_keypoint (dict): Config for keypoint loss. Default: None.
"""
def __init__(self,
in_channels,
num_joints,
loss_keypoint=None,
train_cfg=None,
test_cfg=None):
super().__init__()
self.in_channels = in_channels
self.num_joints = num_joints
self.loss = build_loss(loss_keypoint)
self.train_cfg = {} if train_cfg is None else train_cfg
self.test_cfg = {} if test_cfg is None else test_cfg
self.fc = nn.Linear(self.in_channels, self.num_joints * 2)
def forward(self, x):
"""Forward function."""
output = self.fc(x)
N, C = output.shape
return output.reshape([N, C // 2, 2])
def get_loss(self, output, target, target_weight):
"""Calculate top-down keypoint loss.
Note:
- batch_size: N
- num_keypoints: K
Args:
output (torch.Tensor[N, K, 2]): Output keypoints.
target (torch.Tensor[N, K, 2]): Target keypoints.
target_weight (torch.Tensor[N, K, 2]):
Weights across different joint types.
"""
losses = dict()
assert not isinstance(self.loss, nn.Sequential)
assert target.dim() == 3 and target_weight.dim() == 3
losses['reg_loss'] = self.loss(output, target, target_weight)
return losses
def get_accuracy(self, output, target, target_weight):
"""Calculate accuracy for top-down keypoint loss.
Note:
- batch_size: N
- num_keypoints: K
Args:
output (torch.Tensor[N, K, 2]): Output keypoints.
target (torch.Tensor[N, K, 2]): Target keypoints.
target_weight (torch.Tensor[N, K, 2]):
Weights across different joint types.
"""
accuracy = dict()
N = output.shape[0]
_, avg_acc, cnt = keypoint_pck_accuracy(
output.detach().cpu().numpy(),
target.detach().cpu().numpy(),
target_weight[:, :, 0].detach().cpu().numpy() > 0,
thr=0.05,
normalize=np.ones((N, 2), dtype=np.float32))
accuracy['acc_pose'] = avg_acc
return accuracy
def inference_model(self, x, flip_pairs=None):
"""Inference function.
Returns:
output_regression (np.ndarray): Output regression.
Args:
x (torch.Tensor[N, K, 2]): Input features.
flip_pairs (None | list[tuple()):
Pairs of keypoints which are mirrored.
"""
output = self.forward(x)
if flip_pairs is not None:
output_regression = fliplr_regression(
output.detach().cpu().numpy(), flip_pairs)
else:
output_regression = output.detach().cpu().numpy()
return output_regression
def decode(self, img_metas, output, **kwargs):
"""Decode the keypoints from output regression.
Args:
img_metas (list(dict)): Information about data augmentation
By default this includes:
- "image_file: path to the image file
- "center": center of the bbox
- "scale": scale of the bbox
- "rotation": rotation of the bbox
- "bbox_score": score of bbox
output (np.ndarray[N, K, 2]): predicted regression vector.
kwargs: dict contains 'img_size'.
img_size (tuple(img_width, img_height)): input image size.
"""
batch_size = len(img_metas)
if 'bbox_id' in img_metas[0]:
bbox_ids = []
else:
bbox_ids = None
c = np.zeros((batch_size, 2), dtype=np.float32)
s = np.zeros((batch_size, 2), dtype=np.float32)
image_paths = []
score = np.ones(batch_size)
for i in range(batch_size):
c[i, :] = img_metas[i]['center']
s[i, :] = img_metas[i]['scale']
image_paths.append(img_metas[i]['image_file'])
if 'bbox_score' in img_metas[i]:
score[i] = np.array(img_metas[i]['bbox_score']).reshape(-1)
if bbox_ids is not None:
bbox_ids.append(img_metas[i]['bbox_id'])
preds, maxvals = keypoints_from_regression(output, c, s,
kwargs['img_size'])
all_preds = np.zeros((batch_size, preds.shape[1], 3), dtype=np.float32)
all_boxes = np.zeros((batch_size, 6), dtype=np.float32)
all_preds[:, :, 0:2] = preds[:, :, 0:2]
all_preds[:, :, 2:3] = maxvals
all_boxes[:, 0:2] = c[:, 0:2]
all_boxes[:, 2:4] = s[:, 0:2]
all_boxes[:, 4] = np.prod(s * 200.0, axis=1)
all_boxes[:, 5] = score
result = {}
result['preds'] = all_preds
result['boxes'] = all_boxes
result['image_paths'] = image_paths
result['bbox_ids'] = bbox_ids
return result
def init_weights(self):
normal_init(self.fc, mean=0, std=0.01, bias=0)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/hmr_head.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import xavier_init
from ..builder import HEADS
from ..utils.geometry import rot6d_to_rotmat
@HEADS.register_module()
class HMRMeshHead(nn.Module):
"""SMPL parameters regressor head of simple baseline. "End-to-end Recovery
of Human Shape and Pose", CVPR'2018.
Args:
in_channels (int): Number of input channels
smpl_mean_params (str): The file name of the mean SMPL parameters
n_iter (int): The iterations of estimating delta parameters
"""
def __init__(self, in_channels, smpl_mean_params=None, n_iter=3):
super().__init__()
self.in_channels = in_channels
self.n_iter = n_iter
npose = 24 * 6
nbeta = 10
ncam = 3
hidden_dim = 1024
self.fc1 = nn.Linear(in_channels + npose + nbeta + ncam, hidden_dim)
self.drop1 = nn.Dropout()
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.drop2 = nn.Dropout()
self.decpose = nn.Linear(hidden_dim, npose)
self.decshape = nn.Linear(hidden_dim, nbeta)
self.deccam = nn.Linear(hidden_dim, ncam)
# Load mean SMPL parameters
if smpl_mean_params is None:
init_pose = torch.zeros([1, npose])
init_shape = torch.zeros([1, nbeta])
init_cam = torch.FloatTensor([[1, 0, 0]])
else:
mean_params = np.load(smpl_mean_params)
init_pose = torch.from_numpy(
mean_params['pose'][:]).unsqueeze(0).float()
init_shape = torch.from_numpy(
mean_params['shape'][:]).unsqueeze(0).float()
init_cam = torch.from_numpy(
mean_params['cam']).unsqueeze(0).float()
self.register_buffer('init_pose', init_pose)
self.register_buffer('init_shape', init_shape)
self.register_buffer('init_cam', init_cam)
def forward(self, x):
"""Forward function.
x is the image feature map and is expected to be in shape (batch size x
channel number x height x width)
"""
batch_size = x.shape[0]
# extract the global feature vector by average along
# spatial dimension.
x = x.mean(dim=-1).mean(dim=-1)
init_pose = self.init_pose.expand(batch_size, -1)
init_shape = self.init_shape.expand(batch_size, -1)
init_cam = self.init_cam.expand(batch_size, -1)
pred_pose = init_pose
pred_shape = init_shape
pred_cam = init_cam
for _ in range(self.n_iter):
xc = torch.cat([x, pred_pose, pred_shape, pred_cam], 1)
xc = self.fc1(xc)
xc = self.drop1(xc)
xc = self.fc2(xc)
xc = self.drop2(xc)
pred_pose = self.decpose(xc) + pred_pose
pred_shape = self.decshape(xc) + pred_shape
pred_cam = self.deccam(xc) + pred_cam
pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)
out = (pred_rotmat, pred_shape, pred_cam)
return out
def init_weights(self):
"""Initialize model weights."""
xavier_init(self.decpose, gain=0.01)
xavier_init(self.decshape, gain=0.01)
xavier_init(self.deccam, gain=0.01)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/interhand_3d_head.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer,
constant_init, normal_init)
from mmpose.core.evaluation.top_down_eval import (
keypoints_from_heatmaps3d, multilabel_classification_accuracy)
from mmpose.core.post_processing import flip_back
from mmpose.models.builder import build_loss
from mmpose.models.necks import GlobalAveragePooling
from ..builder import HEADS
class Heatmap3DHead(nn.Module):
"""Heatmap3DHead is a sub-module of Interhand3DHead, and outputs 3D
heatmaps. Heatmap3DHead is composed of (>=0) number of deconv layers and a
simple conv2d layer.
Args:
in_channels (int): Number of input channels
out_channels (int): Number of output channels
depth_size (int): Number of depth discretization size
num_deconv_layers (int): Number of deconv layers.
num_deconv_layers should >= 0. Note that 0 means no deconv layers.
num_deconv_filters (list|tuple): Number of filters.
num_deconv_kernels (list|tuple): Kernel sizes.
extra (dict): Configs for extra conv layers. Default: None
"""
def __init__(self,
in_channels,
out_channels,
depth_size=64,
num_deconv_layers=3,
num_deconv_filters=(256, 256, 256),
num_deconv_kernels=(4, 4, 4),
extra=None):
super().__init__()
assert out_channels % depth_size == 0
self.depth_size = depth_size
self.in_channels = in_channels
if extra is not None and not isinstance(extra, dict):
raise TypeError('extra should be dict or None.')
if num_deconv_layers > 0:
self.deconv_layers = self._make_deconv_layer(
num_deconv_layers,
num_deconv_filters,
num_deconv_kernels,
)
elif num_deconv_layers == 0:
self.deconv_layers = nn.Identity()
else:
raise ValueError(
f'num_deconv_layers ({num_deconv_layers}) should >= 0.')
identity_final_layer = False
if extra is not None and 'final_conv_kernel' in extra:
assert extra['final_conv_kernel'] in [0, 1, 3]
if extra['final_conv_kernel'] == 3:
padding = 1
elif extra['final_conv_kernel'] == 1:
padding = 0
else:
# 0 for Identity mapping.
identity_final_layer = True
kernel_size = extra['final_conv_kernel']
else:
kernel_size = 1
padding = 0
if identity_final_layer:
self.final_layer = nn.Identity()
else:
conv_channels = num_deconv_filters[
-1] if num_deconv_layers > 0 else self.in_channels
layers = []
if extra is not None:
num_conv_layers = extra.get('num_conv_layers', 0)
num_conv_kernels = extra.get('num_conv_kernels',
[1] * num_conv_layers)
for i in range(num_conv_layers):
layers.append(
build_conv_layer(
dict(type='Conv2d'),
in_channels=conv_channels,
out_channels=conv_channels,
kernel_size=num_conv_kernels[i],
stride=1,
padding=(num_conv_kernels[i] - 1) // 2))
layers.append(
build_norm_layer(dict(type='BN'), conv_channels)[1])
layers.append(nn.ReLU(inplace=True))
layers.append(
build_conv_layer(
cfg=dict(type='Conv2d'),
in_channels=conv_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding))
if len(layers) > 1:
self.final_layer = nn.Sequential(*layers)
else:
self.final_layer = layers[0]
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
"""Make deconv layers."""
if num_layers != len(num_filters):
error_msg = f'num_layers({num_layers}) ' \
f'!= length of num_filters({len(num_filters)})'
raise ValueError(error_msg)
if num_layers != len(num_kernels):
error_msg = f'num_layers({num_layers}) ' \
f'!= length of num_kernels({len(num_kernels)})'
raise ValueError(error_msg)
layers = []
for i in range(num_layers):
kernel, padding, output_padding = \
self._get_deconv_cfg(num_kernels[i])
planes = num_filters[i]
layers.append(
build_upsample_layer(
dict(type='deconv'),
in_channels=self.in_channels,
out_channels=planes,
kernel_size=kernel,
stride=2,
padding=padding,
output_padding=output_padding,
bias=False))
layers.append(nn.BatchNorm2d(planes))
layers.append(nn.ReLU(inplace=True))
self.in_channels = planes
return nn.Sequential(*layers)
@staticmethod
def _get_deconv_cfg(deconv_kernel):
"""Get configurations for deconv layers."""
if deconv_kernel == 4:
padding = 1
output_padding = 0
elif deconv_kernel == 3:
padding = 1
output_padding = 1
elif deconv_kernel == 2:
padding = 0
output_padding = 0
else:
raise ValueError(f'Not supported num_kernels ({deconv_kernel}).')
return deconv_kernel, padding, output_padding
def forward(self, x):
"""Forward function."""
x = self.deconv_layers(x)
x = self.final_layer(x)
N, C, H, W = x.shape
# reshape the 2D heatmap to 3D heatmap
x = x.reshape(N, C // self.depth_size, self.depth_size, H, W)
return x
def init_weights(self):
"""Initialize model weights."""
for _, m in self.deconv_layers.named_modules():
if isinstance(m, nn.ConvTranspose2d):
normal_init(m, std=0.001)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
for m in self.final_layer.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, std=0.001, bias=0)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
class Heatmap1DHead(nn.Module):
"""Heatmap1DHead is a sub-module of Interhand3DHead, and outputs 1D
heatmaps.
Args:
in_channels (int): Number of input channels
heatmap_size (int): Heatmap size
hidden_dims (list|tuple): Number of feature dimension of FC layers.
"""
def __init__(self, in_channels=2048, heatmap_size=64, hidden_dims=(512, )):
super().__init__()
self.in_channels = in_channels
self.heatmap_size = heatmap_size
feature_dims = [in_channels, *hidden_dims, heatmap_size]
self.fc = self._make_linear_layers(feature_dims, relu_final=False)
def soft_argmax_1d(self, heatmap1d):
heatmap1d = F.softmax(heatmap1d, 1)
accu = heatmap1d * torch.arange(
self.heatmap_size, dtype=heatmap1d.dtype,
device=heatmap1d.device)[None, :]
coord = accu.sum(dim=1)
return coord
def _make_linear_layers(self, feat_dims, relu_final=False):
"""Make linear layers."""
layers = []
for i in range(len(feat_dims) - 1):
layers.append(nn.Linear(feat_dims[i], feat_dims[i + 1]))
if i < len(feat_dims) - 2 or \
(i == len(feat_dims) - 2 and relu_final):
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def forward(self, x):
"""Forward function."""
heatmap1d = self.fc(x)
value = self.soft_argmax_1d(heatmap1d).view(-1, 1)
return value
def init_weights(self):
"""Initialize model weights."""
for m in self.fc.modules():
if isinstance(m, nn.Linear):
normal_init(m, mean=0, std=0.01, bias=0)
class MultilabelClassificationHead(nn.Module):
"""MultilabelClassificationHead is a sub-module of Interhand3DHead, and
outputs hand type classification.
Args:
in_channels (int): Number of input channels
num_labels (int): Number of labels
hidden_dims (list|tuple): Number of hidden dimension of FC layers.
"""
def __init__(self, in_channels=2048, num_labels=2, hidden_dims=(512, )):
super().__init__()
self.in_channels = in_channels
self.num_labesl = num_labels
feature_dims = [in_channels, *hidden_dims, num_labels]
self.fc = self._make_linear_layers(feature_dims, relu_final=False)
def _make_linear_layers(self, feat_dims, relu_final=False):
"""Make linear layers."""
layers = []
for i in range(len(feat_dims) - 1):
layers.append(nn.Linear(feat_dims[i], feat_dims[i + 1]))
if i < len(feat_dims) - 2 or \
(i == len(feat_dims) - 2 and relu_final):
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def forward(self, x):
"""Forward function."""
labels = torch.sigmoid(self.fc(x))
return labels
def init_weights(self):
for m in self.fc.modules():
if isinstance(m, nn.Linear):
normal_init(m, mean=0, std=0.01, bias=0)
@HEADS.register_module()
class Interhand3DHead(nn.Module):
"""Interhand 3D head of paper ref: Gyeongsik Moon. "InterHand2.6M: A
Dataset and Baseline for 3D Interacting Hand Pose Estimation from a Single
RGB Image".
Args:
keypoint_head_cfg (dict): Configs of Heatmap3DHead for hand
keypoint estimation.
root_head_cfg (dict): Configs of Heatmap1DHead for relative
hand root depth estimation.
hand_type_head_cfg (dict): Configs of MultilabelClassificationHead
for hand type classification.
loss_keypoint (dict): Config for keypoint loss. Default: None.
loss_root_depth (dict): Config for relative root depth loss.
Default: None.
loss_hand_type (dict): Config for hand type classification
loss. Default: None.
"""
def __init__(self,
keypoint_head_cfg,
root_head_cfg,
hand_type_head_cfg,
loss_keypoint=None,
loss_root_depth=None,
loss_hand_type=None,
train_cfg=None,
test_cfg=None):
super().__init__()
# build sub-module heads
self.right_hand_head = Heatmap3DHead(**keypoint_head_cfg)
self.left_hand_head = Heatmap3DHead(**keypoint_head_cfg)
self.root_head = Heatmap1DHead(**root_head_cfg)
self.hand_type_head = MultilabelClassificationHead(
**hand_type_head_cfg)
self.neck = GlobalAveragePooling()
# build losses
self.keypoint_loss = build_loss(loss_keypoint)
self.root_depth_loss = build_loss(loss_root_depth)
self.hand_type_loss = build_loss(loss_hand_type)
self.train_cfg = {} if train_cfg is None else train_cfg
self.test_cfg = {} if test_cfg is None else test_cfg
self.target_type = self.test_cfg.get('target_type', 'GaussianHeatmap')
def init_weights(self):
self.left_hand_head.init_weights()
self.right_hand_head.init_weights()
self.root_head.init_weights()
self.hand_type_head.init_weights()
def get_loss(self, output, target, target_weight):
"""Calculate loss for hand keypoint heatmaps, relative root depth and
hand type.
Args:
output (list[Tensor]): a list of outputs from multiple heads.
target (list[Tensor]): a list of targets for multiple heads.
target_weight (list[Tensor]): a list of targets weight for
multiple heads.
"""
losses = dict()
# hand keypoint loss
assert not isinstance(self.keypoint_loss, nn.Sequential)
out, tar, tar_weight = output[0], target[0], target_weight[0]
assert tar.dim() == 5 and tar_weight.dim() == 3
losses['hand_loss'] = self.keypoint_loss(out, tar, tar_weight)
# relative root depth loss
assert not isinstance(self.root_depth_loss, nn.Sequential)
out, tar, tar_weight = output[1], target[1], target_weight[1]
assert tar.dim() == 2 and tar_weight.dim() == 2
losses['rel_root_loss'] = self.root_depth_loss(out, tar, tar_weight)
# hand type loss
assert not isinstance(self.hand_type_loss, nn.Sequential)
out, tar, tar_weight = output[2], target[2], target_weight[2]
assert tar.dim() == 2 and tar_weight.dim() in [1, 2]
losses['hand_type_loss'] = self.hand_type_loss(out, tar, tar_weight)
return losses
def get_accuracy(self, output, target, target_weight):
"""Calculate accuracy for hand type.
Args:
output (list[Tensor]): a list of outputs from multiple heads.
target (list[Tensor]): a list of targets for multiple heads.
target_weight (list[Tensor]): a list of targets weight for
multiple heads.
"""
accuracy = dict()
avg_acc = multilabel_classification_accuracy(
output[2].detach().cpu().numpy(),
target[2].detach().cpu().numpy(),
target_weight[2].detach().cpu().numpy(),
)
accuracy['acc_classification'] = float(avg_acc)
return accuracy
def forward(self, x):
"""Forward function."""
outputs = []
outputs.append(
torch.cat([self.right_hand_head(x),
self.left_hand_head(x)], dim=1))
x = self.neck(x)
outputs.append(self.root_head(x))
outputs.append(self.hand_type_head(x))
return outputs
def inference_model(self, x, flip_pairs=None):
"""Inference function.
Returns:
output (list[np.ndarray]): list of output hand keypoint
heatmaps, relative root depth and hand type.
Args:
x (torch.Tensor[N,K,H,W]): Input features.
flip_pairs (None | list[tuple()):
Pairs of keypoints which are mirrored.
"""
output = self.forward(x)
if flip_pairs is not None:
# flip 3D heatmap
heatmap_3d = output[0]
N, K, D, H, W = heatmap_3d.shape
# reshape 3D heatmap to 2D heatmap
heatmap_3d = heatmap_3d.reshape(N, K * D, H, W)
# 2D heatmap flip
heatmap_3d_flipped_back = flip_back(
heatmap_3d.detach().cpu().numpy(),
flip_pairs,
target_type=self.target_type)
# reshape back to 3D heatmap
heatmap_3d_flipped_back = heatmap_3d_flipped_back.reshape(
N, K, D, H, W)
# feature is not aligned, shift flipped heatmap for higher accuracy
if self.test_cfg.get('shift_heatmap', False):
heatmap_3d_flipped_back[...,
1:] = heatmap_3d_flipped_back[..., :-1]
output[0] = heatmap_3d_flipped_back
# flip relative hand root depth
output[1] = -output[1].detach().cpu().numpy()
# flip hand type
hand_type = output[2].detach().cpu().numpy()
hand_type_flipped_back = hand_type.copy()
hand_type_flipped_back[:, 0] = hand_type[:, 1]
hand_type_flipped_back[:, 1] = hand_type[:, 0]
output[2] = hand_type_flipped_back
else:
output = [out.detach().cpu().numpy() for out in output]
return output
def decode(self, img_metas, output, **kwargs):
"""Decode hand keypoint, relative root depth and hand type.
Args:
img_metas (list(dict)): Information about data augmentation
By default this includes:
- "image_file: path to the image file
- "center": center of the bbox
- "scale": scale of the bbox
- "rotation": rotation of the bbox
- "bbox_score": score of bbox
- "heatmap3d_depth_bound": depth bound of hand keypoint
3D heatmap
- "root_depth_bound": depth bound of relative root depth
1D heatmap
output (list[np.ndarray]): model predicted 3D heatmaps, relative
root depth and hand type.
"""
batch_size = len(img_metas)
result = {}
heatmap3d_depth_bound = np.ones(batch_size, dtype=np.float32)
root_depth_bound = np.ones(batch_size, dtype=np.float32)
center = np.zeros((batch_size, 2), dtype=np.float32)
scale = np.zeros((batch_size, 2), dtype=np.float32)
image_paths = []
score = np.ones(batch_size, dtype=np.float32)
if 'bbox_id' in img_metas[0]:
bbox_ids = []
else:
bbox_ids = None
for i in range(batch_size):
heatmap3d_depth_bound[i] = img_metas[i]['heatmap3d_depth_bound']
root_depth_bound[i] = img_metas[i]['root_depth_bound']
center[i, :] = img_metas[i]['center']
scale[i, :] = img_metas[i]['scale']
image_paths.append(img_metas[i]['image_file'])
if 'bbox_score' in img_metas[i]:
score[i] = np.array(img_metas[i]['bbox_score']).reshape(-1)
if bbox_ids is not None:
bbox_ids.append(img_metas[i]['bbox_id'])
all_boxes = np.zeros((batch_size, 6), dtype=np.float32)
all_boxes[:, 0:2] = center[:, 0:2]
all_boxes[:, 2:4] = scale[:, 0:2]
# scale is defined as: bbox_size / 200.0, so we
# need multiply 200.0 to get bbox size
all_boxes[:, 4] = np.prod(scale * 200.0, axis=1)
all_boxes[:, 5] = score
result['boxes'] = all_boxes
result['image_paths'] = image_paths
result['bbox_ids'] = bbox_ids
# decode 3D heatmaps of hand keypoints
heatmap3d = output[0]
preds, maxvals = keypoints_from_heatmaps3d(heatmap3d, center, scale)
keypoints_3d = np.zeros((batch_size, preds.shape[1], 4),
dtype=np.float32)
keypoints_3d[:, :, 0:3] = preds[:, :, 0:3]
keypoints_3d[:, :, 3:4] = maxvals
# transform keypoint depth to camera space
keypoints_3d[:, :, 2] = \
(keypoints_3d[:, :, 2] / self.right_hand_head.depth_size - 0.5) \
* heatmap3d_depth_bound[:, np.newaxis]
result['preds'] = keypoints_3d
# decode relative hand root depth
# transform relative root depth to camera space
result['rel_root_depth'] = (output[1] / self.root_head.heatmap_size -
0.5) * root_depth_bound
# decode hand type
result['hand_type'] = output[2] > 0.5
return result
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/temporal_regression_head.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch.nn as nn
from mmcv.cnn import build_conv_layer, constant_init, kaiming_init
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmpose.core import (WeightNormClipHook, compute_similarity_transform,
fliplr_regression)
from mmpose.models.builder import HEADS, build_loss
@HEADS.register_module()
class TemporalRegressionHead(nn.Module):
"""Regression head of VideoPose3D.
"3D human pose estimation in video with temporal convolutions and
semi-supervised training", CVPR'2019.
Args:
in_channels (int): Number of input channels
num_joints (int): Number of joints
loss_keypoint (dict): Config for keypoint loss. Default: None.
max_norm (float|None): if not None, the weight of convolution layers
will be clipped to have a maximum norm of max_norm.
is_trajectory (bool): If the model only predicts root joint
position, then this arg should be set to True. In this case,
traj_loss will be calculated. Otherwise, it should be set to
False. Default: False.
"""
def __init__(self,
in_channels,
num_joints,
max_norm=None,
loss_keypoint=None,
is_trajectory=False,
train_cfg=None,
test_cfg=None):
super().__init__()
self.in_channels = in_channels
self.num_joints = num_joints
self.max_norm = max_norm
self.loss = build_loss(loss_keypoint)
self.is_trajectory = is_trajectory
if self.is_trajectory:
assert self.num_joints == 1
self.train_cfg = {} if train_cfg is None else train_cfg
self.test_cfg = {} if test_cfg is None else test_cfg
self.conv = build_conv_layer(
dict(type='Conv1d'), in_channels, num_joints * 3, 1)
if self.max_norm is not None:
# Apply weight norm clip to conv layers
weight_clip = WeightNormClipHook(self.max_norm)
for module in self.modules():
if isinstance(module, nn.modules.conv._ConvNd):
weight_clip.register(module)
@staticmethod
def _transform_inputs(x):
"""Transform inputs for decoder.
Args:
inputs (tuple or list of Tensor | Tensor): multi-level features.
Returns:
Tensor: The transformed inputs
"""
if not isinstance(x, (list, tuple)):
return x
assert len(x) > 0
# return the top-level feature of the 1D feature pyramid
return x[-1]
def forward(self, x):
"""Forward function."""
x = self._transform_inputs(x)
assert x.ndim == 3 and x.shape[2] == 1, f'Invalid shape {x.shape}'
output = self.conv(x)
N = output.shape[0]
return output.reshape(N, self.num_joints, 3)
def get_loss(self, output, target, target_weight):
"""Calculate keypoint loss.
Note:
- batch_size: N
- num_keypoints: K
Args:
output (torch.Tensor[N, K, 3]): Output keypoints.
target (torch.Tensor[N, K, 3]): Target keypoints.
target_weight (torch.Tensor[N, K, 3]):
Weights across different joint types.
If self.is_trajectory is True and target_weight is None,
target_weight will be set inversely proportional to joint
depth.
"""
losses = dict()
assert not isinstance(self.loss, nn.Sequential)
# trajectory model
if self.is_trajectory:
if target.dim() == 2:
target.unsqueeze_(1)
if target_weight is None:
target_weight = (1 / target[:, :, 2:]).expand(target.shape)
assert target.dim() == 3 and target_weight.dim() == 3
losses['traj_loss'] = self.loss(output, target, target_weight)
# pose model
else:
if target_weight is None:
target_weight = target.new_ones(target.shape)
assert target.dim() == 3 and target_weight.dim() == 3
losses['reg_loss'] = self.loss(output, target, target_weight)
return losses
def get_accuracy(self, output, target, target_weight, metas):
"""Calculate accuracy for keypoint loss.
Note:
- batch_size: N
- num_keypoints: K
Args:
output (torch.Tensor[N, K, 3]): Output keypoints.
target (torch.Tensor[N, K, 3]): Target keypoints.
target_weight (torch.Tensor[N, K, 3]):
Weights across different joint types.
metas (list(dict)): Information about data augmentation including:
- target_image_path (str): Optional, path to the image file
- target_mean (float): Optional, normalization parameter of
the target pose.
- target_std (float): Optional, normalization parameter of the
target pose.
- root_position (np.ndarray[3,1]): Optional, global
position of the root joint.
- root_index (torch.ndarray[1,]): Optional, original index of
the root joint before root-centering.
"""
accuracy = dict()
N = output.shape[0]
output_ = output.detach().cpu().numpy()
target_ = target.detach().cpu().numpy()
# Denormalize the predicted pose
if 'target_mean' in metas[0] and 'target_std' in metas[0]:
target_mean = np.stack([m['target_mean'] for m in metas])
target_std = np.stack([m['target_std'] for m in metas])
output_ = self._denormalize_joints(output_, target_mean,
target_std)
target_ = self._denormalize_joints(target_, target_mean,
target_std)
# Restore global position
if self.test_cfg.get('restore_global_position', False):
root_pos = np.stack([m['root_position'] for m in metas])
root_idx = metas[0].get('root_position_index', None)
output_ = self._restore_global_position(output_, root_pos,
root_idx)
target_ = self._restore_global_position(target_, root_pos,
root_idx)
# Get target weight
if target_weight is None:
target_weight_ = np.ones_like(target_)
else:
target_weight_ = target_weight.detach().cpu().numpy()
if self.test_cfg.get('restore_global_position', False):
root_idx = metas[0].get('root_position_index', None)
root_weight = metas[0].get('root_joint_weight', 1.0)
target_weight_ = self._restore_root_target_weight(
target_weight_, root_weight, root_idx)
mpjpe = np.mean(
np.linalg.norm((output_ - target_) * target_weight_, axis=-1))
transformed_output = np.zeros_like(output_)
for i in range(N):
transformed_output[i, :, :] = compute_similarity_transform(
output_[i, :, :], target_[i, :, :])
p_mpjpe = np.mean(
np.linalg.norm(
(transformed_output - target_) * target_weight_, axis=-1))
accuracy['mpjpe'] = output.new_tensor(mpjpe)
accuracy['p_mpjpe'] = output.new_tensor(p_mpjpe)
return accuracy
def inference_model(self, x, flip_pairs=None):
"""Inference function.
Returns:
output_regression (np.ndarray): Output regression.
Args:
x (torch.Tensor[N, K, 2]): Input features.
flip_pairs (None | list[tuple()):
Pairs of keypoints which are mirrored.
"""
output = self.forward(x)
if flip_pairs is not None:
output_regression = fliplr_regression(
output.detach().cpu().numpy(),
flip_pairs,
center_mode='static',
center_x=0)
else:
output_regression = output.detach().cpu().numpy()
return output_regression
def decode(self, metas, output):
"""Decode the keypoints from output regression.
Args:
metas (list(dict)): Information about data augmentation.
By default this includes:
- "target_image_path": path to the image file
output (np.ndarray[N, K, 3]): predicted regression vector.
metas (list(dict)): Information about data augmentation including:
- target_image_path (str): Optional, path to the image file
- target_mean (float): Optional, normalization parameter of
the target pose.
- target_std (float): Optional, normalization parameter of the
target pose.
- root_position (np.ndarray[3,1]): Optional, global
position of the root joint.
- root_index (torch.ndarray[1,]): Optional, original index of
the root joint before root-centering.
"""
# Denormalize the predicted pose
if 'target_mean' in metas[0] and 'target_std' in metas[0]:
target_mean = np.stack([m['target_mean'] for m in metas])
target_std = np.stack([m['target_std'] for m in metas])
output = self._denormalize_joints(output, target_mean, target_std)
# Restore global position
if self.test_cfg.get('restore_global_position', False):
root_pos = np.stack([m['root_position'] for m in metas])
root_idx = metas[0].get('root_position_index', None)
output = self._restore_global_position(output, root_pos, root_idx)
target_image_paths = [m.get('target_image_path', None) for m in metas]
result = {'preds': output, 'target_image_paths': target_image_paths}
return result
@staticmethod
def _denormalize_joints(x, mean, std):
"""Denormalize joint coordinates with given statistics mean and std.
Args:
x (np.ndarray[N, K, 3]): Normalized joint coordinates.
mean (np.ndarray[K, 3]): Mean value.
std (np.ndarray[K, 3]): Std value.
"""
assert x.ndim == 3
assert x.shape == mean.shape == std.shape
return x * std + mean
@staticmethod
def _restore_global_position(x, root_pos, root_idx=None):
"""Restore global position of the root-centered joints.
Args:
x (np.ndarray[N, K, 3]): root-centered joint coordinates
root_pos (np.ndarray[N,1,3]): The global position of the
root joint.
root_idx (int|None): If not none, the root joint will be inserted
back to the pose at the given index.
"""
x = x + root_pos
if root_idx is not None:
x = np.insert(x, root_idx, root_pos.squeeze(1), axis=1)
return x
@staticmethod
def _restore_root_target_weight(target_weight, root_weight, root_idx=None):
"""Restore the target weight of the root joint after the restoration of
the global position.
Args:
target_weight (np.ndarray[N, K, 1]): Target weight of relativized
joints.
root_weight (float): The target weight value of the root joint.
root_idx (int|None): If not none, the root joint weight will be
inserted back to the target weight at the given index.
"""
if root_idx is not None:
root_weight = np.full(
target_weight.shape[0], root_weight, dtype=target_weight.dtype)
target_weight = np.insert(
target_weight, root_idx, root_weight[:, None], axis=1)
return target_weight
def init_weights(self):
"""Initialize the weights."""
for m in self.modules():
if isinstance(m, nn.modules.conv._ConvNd):
kaiming_init(m, mode='fan_in', nonlinearity='relu')
elif isinstance(m, _BatchNorm):
constant_init(m, 1)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/topdown_heatmap_base_head.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
import numpy as np
import torch.nn as nn
# from mmpose.core.evaluation.top_down_eval import keypoints_from_heatmaps
class TopdownHeatmapBaseHead(nn.Module):
"""Base class for top-down heatmap heads.
All top-down heatmap heads should subclass it.
All subclass should overwrite:
Methods:`get_loss`, supporting to calculate loss.
Methods:`get_accuracy`, supporting to calculate accuracy.
Methods:`forward`, supporting to forward model.
Methods:`inference_model`, supporting to inference model.
"""
__metaclass__ = ABCMeta
@abstractmethod
def get_loss(self, **kwargs):
"""Gets the loss."""
@abstractmethod
def get_accuracy(self, **kwargs):
"""Gets the accuracy."""
@abstractmethod
def forward(self, **kwargs):
"""Forward function."""
@abstractmethod
def inference_model(self, **kwargs):
"""Inference function."""
def decode(self, img_metas, output, **kwargs):
"""Decode keypoints from heatmaps.
Args:
img_metas (list(dict)): Information about data augmentation
By default this includes:
- "image_file: path to the image file
- "center": center of the bbox
- "scale": scale of the bbox
- "rotation": rotation of the bbox
- "bbox_score": score of bbox
output (np.ndarray[N, K, H, W]): model predicted heatmaps.
"""
# batch_size = len(img_metas)
# if 'bbox_id' in img_metas[0]:
# bbox_ids = []
# else:
# bbox_ids = None
# c = np.zeros((batch_size, 2), dtype=np.float32)
# s = np.zeros((batch_size, 2), dtype=np.float32)
# image_paths = []
# score = np.ones(batch_size)
# for i in range(batch_size):
# c[i, :] = img_metas[i]['center']
# s[i, :] = img_metas[i]['scale']
# image_paths.append(img_metas[i]['image_file'])
# if 'bbox_score' in img_metas[i]:
# score[i] = np.array(img_metas[i]['bbox_score']).reshape(-1)
# if bbox_ids is not None:
# bbox_ids.append(img_metas[i]['bbox_id'])
# preds, maxvals = keypoints_from_heatmaps(
# output,
# c,
# s,
# unbiased=self.test_cfg.get('unbiased_decoding', False),
# post_process=self.test_cfg.get('post_process', 'default'),
# kernel=self.test_cfg.get('modulate_kernel', 11),
# valid_radius_factor=self.test_cfg.get('valid_radius_factor',
# 0.0546875),
# use_udp=self.test_cfg.get('use_udp', False),
# target_type=self.test_cfg.get('target_type', 'GaussianHeatmap'))
# all_preds = np.zeros((batch_size, preds.shape[1], 3), dtype=np.float32)
# all_boxes = np.zeros((batch_size, 6), dtype=np.float32)
# all_preds[:, :, 0:2] = preds[:, :, 0:2]
# all_preds[:, :, 2:3] = maxvals
# all_boxes[:, 0:2] = c[:, 0:2]
# all_boxes[:, 2:4] = s[:, 0:2]
# all_boxes[:, 4] = np.prod(s * 200.0, axis=1)
# all_boxes[:, 5] = score
# result = {}
# result['preds'] = all_preds
# result['boxes'] = all_boxes
# result['image_paths'] = image_paths
# result['bbox_ids'] = bbox_ids
return None
@staticmethod
def _get_deconv_cfg(deconv_kernel):
"""Get configurations for deconv layers."""
if deconv_kernel == 4:
padding = 1
output_padding = 0
elif deconv_kernel == 3:
padding = 1
output_padding = 1
elif deconv_kernel == 2:
padding = 0
output_padding = 0
else:
raise ValueError(f'Not supported num_kernels ({deconv_kernel}).')
return deconv_kernel, padding, output_padding
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/topdown_heatmap_multi_stage_head.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import copy as cp
import torch.nn as nn
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, Linear,
build_activation_layer, build_conv_layer,
build_norm_layer, build_upsample_layer, constant_init,
kaiming_init, normal_init)
from mmpose.core.evaluation import pose_pck_accuracy
from mmpose.core.post_processing import flip_back
from mmpose.models.builder import build_loss
from ..builder import HEADS
from .topdown_heatmap_base_head import TopdownHeatmapBaseHead
@HEADS.register_module()
class TopdownHeatmapMultiStageHead(TopdownHeatmapBaseHead):
"""Top-down heatmap multi-stage head.
TopdownHeatmapMultiStageHead is consisted of multiple branches,
each of which has num_deconv_layers(>=0) number of deconv layers
and a simple conv2d layer.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
num_stages (int): Number of stages.
num_deconv_layers (int): Number of deconv layers.
num_deconv_layers should >= 0. Note that 0 means
no deconv layers.
num_deconv_filters (list|tuple): Number of filters.
If num_deconv_layers > 0, the length of
num_deconv_kernels (list|tuple): Kernel sizes.
loss_keypoint (dict): Config for keypoint loss. Default: None.
"""
def __init__(self,
in_channels=512,
out_channels=17,
num_stages=1,
num_deconv_layers=3,
num_deconv_filters=(256, 256, 256),
num_deconv_kernels=(4, 4, 4),
extra=None,
loss_keypoint=None,
train_cfg=None,
test_cfg=None):
super().__init__()
self.in_channels = in_channels
self.num_stages = num_stages
self.loss = build_loss(loss_keypoint)
self.train_cfg = {} if train_cfg is None else train_cfg
self.test_cfg = {} if test_cfg is None else test_cfg
self.target_type = self.test_cfg.get('target_type', 'GaussianHeatmap')
if extra is not None and not isinstance(extra, dict):
raise TypeError('extra should be dict or None.')
# build multi-stage deconv layers
self.multi_deconv_layers = nn.ModuleList([])
for _ in range(self.num_stages):
if num_deconv_layers > 0:
deconv_layers = self._make_deconv_layer(
num_deconv_layers,
num_deconv_filters,
num_deconv_kernels,
)
elif num_deconv_layers == 0:
deconv_layers = nn.Identity()
else:
raise ValueError(
f'num_deconv_layers ({num_deconv_layers}) should >= 0.')
self.multi_deconv_layers.append(deconv_layers)
identity_final_layer = False
if extra is not None and 'final_conv_kernel' in extra:
assert extra['final_conv_kernel'] in [0, 1, 3]
if extra['final_conv_kernel'] == 3:
padding = 1
elif extra['final_conv_kernel'] == 1:
padding = 0
else:
# 0 for Identity mapping.
identity_final_layer = True
kernel_size = extra['final_conv_kernel']
else:
kernel_size = 1
padding = 0
# build multi-stage final layers
self.multi_final_layers = nn.ModuleList([])
for i in range(self.num_stages):
if identity_final_layer:
final_layer = nn.Identity()
else:
final_layer = build_conv_layer(
cfg=dict(type='Conv2d'),
in_channels=num_deconv_filters[-1]
if num_deconv_layers > 0 else in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding)
self.multi_final_layers.append(final_layer)
def get_loss(self, output, target, target_weight):
"""Calculate top-down keypoint loss.
Note:
- batch_size: N
- num_keypoints: K
- num_outputs: O
- heatmaps height: H
- heatmaps weight: W
Args:
output (torch.Tensor[N,K,H,W]):
Output heatmaps.
target (torch.Tensor[N,K,H,W]):
Target heatmaps.
target_weight (torch.Tensor[N,K,1]):
Weights across different joint types.
"""
losses = dict()
assert isinstance(output, list)
assert target.dim() == 4 and target_weight.dim() == 3
if isinstance(self.loss, nn.Sequential):
assert len(self.loss) == len(output)
for i in range(len(output)):
target_i = target
target_weight_i = target_weight
if isinstance(self.loss, nn.Sequential):
loss_func = self.loss[i]
else:
loss_func = self.loss
loss_i = loss_func(output[i], target_i, target_weight_i)
if 'heatmap_loss' not in losses:
losses['heatmap_loss'] = loss_i
else:
losses['heatmap_loss'] += loss_i
return losses
def get_accuracy(self, output, target, target_weight):
"""Calculate accuracy for top-down keypoint loss.
Note:
- batch_size: N
- num_keypoints: K
- heatmaps height: H
- heatmaps weight: W
Args:
output (torch.Tensor[N,K,H,W]): Output heatmaps.
target (torch.Tensor[N,K,H,W]): Target heatmaps.
target_weight (torch.Tensor[N,K,1]):
Weights across different joint types.
"""
accuracy = dict()
if self.target_type == 'GaussianHeatmap':
_, avg_acc, _ = pose_pck_accuracy(
output[-1].detach().cpu().numpy(),
target.detach().cpu().numpy(),
target_weight.detach().cpu().numpy().squeeze(-1) > 0)
accuracy['acc_pose'] = float(avg_acc)
return accuracy
def forward(self, x):
"""Forward function.
Returns:
out (list[Tensor]): a list of heatmaps from multiple stages.
"""
out = []
assert isinstance(x, list)
for i in range(self.num_stages):
y = self.multi_deconv_layers[i](x[i])
y = self.multi_final_layers[i](y)
out.append(y)
return out
def inference_model(self, x, flip_pairs=None):
"""Inference function.
Returns:
output_heatmap (np.ndarray): Output heatmaps.
Args:
x (List[torch.Tensor[NxKxHxW]]): Input features.
flip_pairs (None | list[tuple()):
Pairs of keypoints which are mirrored.
"""
output = self.forward(x)
assert isinstance(output, list)
output = output[-1]
if flip_pairs is not None:
# perform flip
output_heatmap = flip_back(
output.detach().cpu().numpy(),
flip_pairs,
target_type=self.target_type)
# feature is not aligned, shift flipped heatmap for higher accuracy
if self.test_cfg.get('shift_heatmap', False):
output_heatmap[:, :, :, 1:] = output_heatmap[:, :, :, :-1]
else:
output_heatmap = output.detach().cpu().numpy()
return output_heatmap
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
"""Make deconv layers."""
if num_layers != len(num_filters):
error_msg = f'num_layers({num_layers}) ' \
f'!= length of num_filters({len(num_filters)})'
raise ValueError(error_msg)
if num_layers != len(num_kernels):
error_msg = f'num_layers({num_layers}) ' \
f'!= length of num_kernels({len(num_kernels)})'
raise ValueError(error_msg)
layers = []
for i in range(num_layers):
kernel, padding, output_padding = \
self._get_deconv_cfg(num_kernels[i])
planes = num_filters[i]
layers.append(
build_upsample_layer(
dict(type='deconv'),
in_channels=self.in_channels,
out_channels=planes,
kernel_size=kernel,
stride=2,
padding=padding,
output_padding=output_padding,
bias=False))
layers.append(nn.BatchNorm2d(planes))
layers.append(nn.ReLU(inplace=True))
self.in_channels = planes
return nn.Sequential(*layers)
def init_weights(self):
"""Initialize model weights."""
for _, m in self.multi_deconv_layers.named_modules():
if isinstance(m, nn.ConvTranspose2d):
normal_init(m, std=0.001)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
for m in self.multi_final_layers.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, std=0.001, bias=0)
class PredictHeatmap(nn.Module):
"""Predict the heat map for an input feature.
Args:
unit_channels (int): Number of input channels.
out_channels (int): Number of output channels.
out_shape (tuple): Shape of the output heatmap.
use_prm (bool): Whether to use pose refine machine. Default: False.
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
"""
def __init__(self,
unit_channels,
out_channels,
out_shape,
use_prm=False,
norm_cfg=dict(type='BN')):
# Protect mutable default arguments
norm_cfg = cp.deepcopy(norm_cfg)
super().__init__()
self.unit_channels = unit_channels
self.out_channels = out_channels
self.out_shape = out_shape
self.use_prm = use_prm
if use_prm:
self.prm = PRM(out_channels, norm_cfg=norm_cfg)
self.conv_layers = nn.Sequential(
ConvModule(
unit_channels,
unit_channels,
kernel_size=1,
stride=1,
padding=0,
norm_cfg=norm_cfg,
inplace=False),
ConvModule(
unit_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=norm_cfg,
act_cfg=None,
inplace=False))
def forward(self, feature):
feature = self.conv_layers(feature)
output = nn.functional.interpolate(
feature, size=self.out_shape, mode='bilinear', align_corners=True)
if self.use_prm:
output = self.prm(output)
return output
class PRM(nn.Module):
"""Pose Refine Machine.
Please refer to "Learning Delicate Local Representations
for Multi-Person Pose Estimation" (ECCV 2020).
Args:
out_channels (int): Channel number of the output. Equals to
the number of key points.
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
"""
def __init__(self, out_channels, norm_cfg=dict(type='BN')):
# Protect mutable default arguments
norm_cfg = cp.deepcopy(norm_cfg)
super().__init__()
self.out_channels = out_channels
self.global_pooling = nn.AdaptiveAvgPool2d((1, 1))
self.middle_path = nn.Sequential(
Linear(self.out_channels, self.out_channels),
build_norm_layer(dict(type='BN1d'), out_channels)[1],
build_activation_layer(dict(type='ReLU')),
Linear(self.out_channels, self.out_channels),
build_norm_layer(dict(type='BN1d'), out_channels)[1],
build_activation_layer(dict(type='ReLU')),
build_activation_layer(dict(type='Sigmoid')))
self.bottom_path = nn.Sequential(
ConvModule(
self.out_channels,
self.out_channels,
kernel_size=1,
stride=1,
padding=0,
norm_cfg=norm_cfg,
inplace=False),
DepthwiseSeparableConvModule(
self.out_channels,
1,
kernel_size=9,
stride=1,
padding=4,
norm_cfg=norm_cfg,
inplace=False), build_activation_layer(dict(type='Sigmoid')))
self.conv_bn_relu_prm_1 = ConvModule(
self.out_channels,
self.out_channels,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=norm_cfg,
inplace=False)
def forward(self, x):
out = self.conv_bn_relu_prm_1(x)
out_1 = out
out_2 = self.global_pooling(out_1)
out_2 = out_2.view(out_2.size(0), -1)
out_2 = self.middle_path(out_2)
out_2 = out_2.unsqueeze(2)
out_2 = out_2.unsqueeze(3)
out_3 = self.bottom_path(out_1)
out = out_1 * (1 + out_2 * out_3)
return out
@HEADS.register_module()
class TopdownHeatmapMSMUHead(TopdownHeatmapBaseHead):
"""Heads for multi-stage multi-unit heads used in Multi-Stage Pose
estimation Network (MSPN), and Residual Steps Networks (RSN).
Args:
unit_channels (int): Number of input channels.
out_channels (int): Number of output channels.
out_shape (tuple): Shape of the output heatmap.
num_stages (int): Number of stages.
num_units (int): Number of units in each stage.
use_prm (bool): Whether to use pose refine machine (PRM).
Default: False.
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
loss_keypoint (dict): Config for keypoint loss. Default: None.
"""
def __init__(self,
out_shape,
unit_channels=256,
out_channels=17,
num_stages=4,
num_units=4,
use_prm=False,
norm_cfg=dict(type='BN'),
loss_keypoint=None,
train_cfg=None,
test_cfg=None):
# Protect mutable default arguments
norm_cfg = cp.deepcopy(norm_cfg)
super().__init__()
self.train_cfg = {} if train_cfg is None else train_cfg
self.test_cfg = {} if test_cfg is None else test_cfg
self.target_type = self.test_cfg.get('target_type', 'GaussianHeatmap')
self.out_shape = out_shape
self.unit_channels = unit_channels
self.out_channels = out_channels
self.num_stages = num_stages
self.num_units = num_units
self.loss = build_loss(loss_keypoint)
self.predict_layers = nn.ModuleList([])
for i in range(self.num_stages):
for j in range(self.num_units):
self.predict_layers.append(
PredictHeatmap(
unit_channels,
out_channels,
out_shape,
use_prm,
norm_cfg=norm_cfg))
def get_loss(self, output, target, target_weight):
"""Calculate top-down keypoint loss.
Note:
- batch_size: N
- num_keypoints: K
- num_outputs: O
- heatmaps height: H
- heatmaps weight: W
Args:
output (torch.Tensor[N,O,K,H,W]): Output heatmaps.
target (torch.Tensor[N,O,K,H,W]): Target heatmaps.
target_weight (torch.Tensor[N,O,K,1]):
Weights across different joint types.
"""
losses = dict()
assert isinstance(output, list)
assert target.dim() == 5 and target_weight.dim() == 4
assert target.size(1) == len(output)
if isinstance(self.loss, nn.Sequential):
assert len(self.loss) == len(output)
for i in range(len(output)):
target_i = target[:, i, :, :, :]
target_weight_i = target_weight[:, i, :, :]
if isinstance(self.loss, nn.Sequential):
loss_func = self.loss[i]
else:
loss_func = self.loss
loss_i = loss_func(output[i], target_i, target_weight_i)
if 'heatmap_loss' not in losses:
losses['heatmap_loss'] = loss_i
else:
losses['heatmap_loss'] += loss_i
return losses
def get_accuracy(self, output, target, target_weight):
"""Calculate accuracy for top-down keypoint loss.
Note:
- batch_size: N
- num_keypoints: K
- heatmaps height: H
- heatmaps weight: W
Args:
output (torch.Tensor[N,K,H,W]): Output heatmaps.
target (torch.Tensor[N,K,H,W]): Target heatmaps.
target_weight (torch.Tensor[N,K,1]):
Weights across different joint types.
"""
accuracy = dict()
if self.target_type == 'GaussianHeatmap':
assert isinstance(output, list)
assert target.dim() == 5 and target_weight.dim() == 4
_, avg_acc, _ = pose_pck_accuracy(
output[-1].detach().cpu().numpy(),
target[:, -1, ...].detach().cpu().numpy(),
target_weight[:, -1,
...].detach().cpu().numpy().squeeze(-1) > 0)
accuracy['acc_pose'] = float(avg_acc)
return accuracy
def forward(self, x):
"""Forward function.
Returns:
out (list[Tensor]): a list of heatmaps from multiple stages
and units.
"""
out = []
assert isinstance(x, list)
assert len(x) == self.num_stages
assert isinstance(x[0], list)
assert len(x[0]) == self.num_units
assert x[0][0].shape[1] == self.unit_channels
for i in range(self.num_stages):
for j in range(self.num_units):
y = self.predict_layers[i * self.num_units + j](x[i][j])
out.append(y)
return out
def inference_model(self, x, flip_pairs=None):
"""Inference function.
Returns:
output_heatmap (np.ndarray): Output heatmaps.
Args:
x (list[torch.Tensor[N,K,H,W]]): Input features.
flip_pairs (None | list[tuple]):
Pairs of keypoints which are mirrored.
"""
output = self.forward(x)
assert isinstance(output, list)
output = output[-1]
if flip_pairs is not None:
output_heatmap = flip_back(
output.detach().cpu().numpy(),
flip_pairs,
target_type=self.target_type)
# feature is not aligned, shift flipped heatmap for higher accuracy
if self.test_cfg.get('shift_heatmap', False):
output_heatmap[:, :, :, 1:] = output_heatmap[:, :, :, :-1]
else:
output_heatmap = output.detach().cpu().numpy()
return output_heatmap
def init_weights(self):
"""Initialize model weights."""
for m in self.predict_layers.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
elif isinstance(m, nn.Linear):
normal_init(m, std=0.01)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/topdown_heatmap_simple_head.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
# from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer,
# constant_init, normal_init)
# from mmpose.core.evaluation import pose_pck_accuracy
# from mmpose.core.post_processing import flip_back
# from mmpose.models.builder import build_loss
# from mmpose.models.utils.ops import resize
# from ..builder import HEADS
import torch.nn.functional as F
from .topdown_heatmap_base_head import TopdownHeatmapBaseHead
def build_conv_layer(cfg, *args, **kwargs) -> nn.Module:
"""LICENSE"""
if cfg is None:
cfg_ = dict(type='Conv2d')
else:
if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict')
if 'type' not in cfg:
raise KeyError('the cfg dict must contain the key "type"')
cfg_ = cfg.copy()
layer_type = cfg_.pop('type')
if layer_type !='Conv2d':
raise KeyError(f'Unrecognized layer type {layer_type}')
else:
conv_layer = nn.Conv2d
layer = conv_layer(*args, **kwargs, **cfg_)
return layer
def build_upsample_layer(cfg, *args, **kwargs) -> nn.Module:
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg:
raise KeyError(
f'the cfg dict must contain the key "type", but got {cfg}')
cfg_ = cfg.copy()
layer_type = cfg_.pop('type')
if layer_type !='deconv':
raise KeyError(f'Unrecognized upsample type {layer_type}')
else:
upsample = nn.ConvTranspose2d
if upsample is nn.Upsample:
cfg_['mode'] = layer_type
layer = upsample(*args, **kwargs, **cfg_)
return layer
# @HEADS.register_module()
class TopdownHeatmapSimpleHead(TopdownHeatmapBaseHead):
"""Top-down heatmap simple head. paper ref: Bin Xiao et al. ``Simple
Baselines for Human Pose Estimation and Tracking``.
TopdownHeatmapSimpleHead is consisted of (>=0) number of deconv layers
and a simple conv2d layer.
Args:
in_channels (int): Number of input channels
out_channels (int): Number of output channels
num_deconv_layers (int): Number of deconv layers.
num_deconv_layers should >= 0. Note that 0 means
no deconv layers.
num_deconv_filters (list|tuple): Number of filters.
If num_deconv_layers > 0, the length of
num_deconv_kernels (list|tuple): Kernel sizes.
in_index (int|Sequence[int]): Input feature index. Default: 0
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
Default: None.
- 'resize_concat': Multiple feature maps will be resized to the
same size as the first one and then concat together.
Usually used in FCN head of HRNet.
- 'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
- None: Only one select feature map is allowed.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
loss_keypoint (dict): Config for keypoint loss. Default: None.
"""
def __init__(self,
in_channels,
out_channels,
num_deconv_layers=3,
num_deconv_filters=(256, 256, 256),
num_deconv_kernels=(4, 4, 4),
extra=None,
in_index=0,
input_transform=None,
align_corners=False,
loss_keypoint=None,
train_cfg=None,
test_cfg=None,
upsample=0,):
super().__init__()
self.in_channels = in_channels
self.loss = None
self.upsample = upsample
self.train_cfg = {} if train_cfg is None else train_cfg
self.test_cfg = {} if test_cfg is None else test_cfg
self.target_type = self.test_cfg.get('target_type', 'GaussianHeatmap')
self._init_inputs(in_channels, in_index, input_transform)
self.in_index = in_index
self.align_corners = align_corners
if extra is not None and not isinstance(extra, dict):
raise TypeError('extra should be dict or None.')
if num_deconv_layers > 0:
self.deconv_layers = self._make_deconv_layer(
num_deconv_layers,
num_deconv_filters,
num_deconv_kernels,
)
elif num_deconv_layers == 0:
self.deconv_layers = nn.Identity()
else:
raise ValueError(
f'num_deconv_layers ({num_deconv_layers}) should >= 0.')
identity_final_layer = False
if extra is not None and 'final_conv_kernel' in extra:
assert extra['final_conv_kernel'] in [0, 1, 3]
if extra['final_conv_kernel'] == 3:
padding = 1
elif extra['final_conv_kernel'] == 1:
padding = 0
else:
# 0 for Identity mapping.
identity_final_layer = True
kernel_size = extra['final_conv_kernel']
else:
kernel_size = 1
padding = 0
if identity_final_layer:
self.final_layer = nn.Identity()
else:
conv_channels = num_deconv_filters[
-1] if num_deconv_layers > 0 else self.in_channels
layers = []
if extra is not None:
num_conv_layers = extra.get('num_conv_layers', 0)
num_conv_kernels = extra.get('num_conv_kernels',
[1] * num_conv_layers)
for i in range(num_conv_layers):
layers.append(
build_conv_layer(
dict(type='Conv2d'),
in_channels=conv_channels,
out_channels=conv_channels,
kernel_size=num_conv_kernels[i],
stride=1,
padding=(num_conv_kernels[i] - 1) // 2))
layers.append(
nn.BatchNorm2d(conv_channels)
)
layers.append(nn.ReLU(inplace=True))
layers.append(
build_conv_layer(
cfg=dict(type='Conv2d'),
in_channels=conv_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding))
if len(layers) > 1:
self.final_layer = nn.Sequential(*layers)
else:
self.final_layer = layers[0]
def get_loss(self, output, target, target_weight):
"""Calculate top-down keypoint loss.
Note:
- batch_size: N
- num_keypoints: K
- heatmaps height: H
- heatmaps weight: W
Args:
output (torch.Tensor[N,K,H,W]): Output heatmaps.
target (torch.Tensor[N,K,H,W]): Target heatmaps.
target_weight (torch.Tensor[N,K,1]):
Weights across different joint types.
"""
losses = dict()
assert not isinstance(self.loss, nn.Sequential)
assert target.dim() == 4 and target_weight.dim() == 3
losses['heatmap_loss'] = self.loss(output, target, target_weight)
return losses
def get_accuracy(self, output, target, target_weight):
"""Calculate accuracy for top-down keypoint loss.
Note:
- batch_size: N
- num_keypoints: K
- heatmaps height: H
- heatmaps weight: W
Args:
output (torch.Tensor[N,K,H,W]): Output heatmaps.
target (torch.Tensor[N,K,H,W]): Target heatmaps.
target_weight (torch.Tensor[N,K,1]):
Weights across different joint types.
"""
accuracy = dict()
if self.target_type == 'GaussianHeatmap':
_, avg_acc, _ = pose_pck_accuracy(
output.detach().cpu().numpy(),
target.detach().cpu().numpy(),
target_weight.detach().cpu().numpy().squeeze(-1) > 0)
accuracy['acc_pose'] = float(avg_acc)
return accuracy
def forward(self, x):
"""Forward function."""
x = self._transform_inputs(x)
x = self.deconv_layers(x)
x = self.final_layer(x)
return x
def inference_model(self, x, flip_pairs=None):
"""Inference function.
Returns:
output_heatmap (np.ndarray): Output heatmaps.
Args:
x (torch.Tensor[N,K,H,W]): Input features.
flip_pairs (None | list[tuple]):
Pairs of keypoints which are mirrored.
"""
output = self.forward(x)
if flip_pairs is not None:
output_heatmap = flip_back(
output.detach().cpu().numpy(),
flip_pairs,
target_type=self.target_type)
# feature is not aligned, shift flipped heatmap for higher accuracy
if self.test_cfg.get('shift_heatmap', False):
output_heatmap[:, :, :, 1:] = output_heatmap[:, :, :, :-1]
else:
output_heatmap = output.detach().cpu().numpy()
return output_heatmap
def _init_inputs(self, in_channels, in_index, input_transform):
"""Check and initialize input transforms.
The in_channels, in_index and input_transform must match.
Specifically, when input_transform is None, only single feature map
will be selected. So in_channels and in_index must be of type int.
When input_transform is not None, in_channels and in_index must be
list or tuple, with the same length.
Args:
in_channels (int|Sequence[int]): Input channels.
in_index (int|Sequence[int]): Input feature index.
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
- 'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
- 'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
- None: Only one select feature map is allowed.
"""
if input_transform is not None:
assert input_transform in ['resize_concat', 'multiple_select']
self.input_transform = input_transform
self.in_index = in_index
if input_transform is not None:
assert isinstance(in_channels, (list, tuple))
assert isinstance(in_index, (list, tuple))
assert len(in_channels) == len(in_index)
if input_transform == 'resize_concat':
self.in_channels = sum(in_channels)
else:
self.in_channels = in_channels
else:
assert isinstance(in_channels, int)
assert isinstance(in_index, int)
self.in_channels = in_channels
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor] | Tensor): multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if not isinstance(inputs, list):
if not isinstance(inputs, list):
if self.upsample > 0:
inputs = resize(
input=F.relu(inputs),
scale_factor=self.upsample,
mode='bilinear',
align_corners=self.align_corners
)
return inputs
if self.input_transform == 'resize_concat':
inputs = [inputs[i] for i in self.in_index]
upsampled_inputs = [
resize(
input=x,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == 'multiple_select':
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
"""Make deconv layers."""
if num_layers != len(num_filters):
error_msg = f'num_layers({num_layers}) ' \
f'!= length of num_filters({len(num_filters)})'
raise ValueError(error_msg)
if num_layers != len(num_kernels):
error_msg = f'num_layers({num_layers}) ' \
f'!= length of num_kernels({len(num_kernels)})'
raise ValueError(error_msg)
layers = []
for i in range(num_layers):
kernel, padding, output_padding = \
self._get_deconv_cfg(num_kernels[i])
planes = num_filters[i]
layers.append(
build_upsample_layer(
dict(type='deconv'),
in_channels=self.in_channels,
out_channels=planes,
kernel_size=kernel,
stride=2,
padding=padding,
output_padding=output_padding,
bias=False))
layers.append(nn.BatchNorm2d(planes))
layers.append(nn.ReLU(inplace=True))
self.in_channels = planes
return nn.Sequential(*layers)
def init_weights(self):
"""Initialize model weights."""
for _, m in self.deconv_layers.named_modules():
if isinstance(m, nn.ConvTranspose2d):
normal_init(m, std=0.001)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
for m in self.final_layer.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, std=0.001, bias=0)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/vipnas_heatmap_simple_head.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
# from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer,
# constant_init, normal_init)
# from mmpose.core.evaluation import pose_pck_accuracy
# from mmpose.core.post_processing import flip_back
# from mmpose.models.builder import build_loss
# from mmpose.models.utils.ops import resize
# from ..builder import HEADS
# from .topdown_heatmap_base_head import TopdownHeatmapBaseHead
# @HEADS.register_module()
class ViPNASHeatmapSimpleHead(TopdownHeatmapBaseHead):
"""ViPNAS heatmap simple head.
ViPNAS: Efficient Video Pose Estimation via Neural Architecture Search.
More details can be found in the `paper
`__ .
TopdownHeatmapSimpleHead is consisted of (>=0) number of deconv layers
and a simple conv2d layer.
Args:
in_channels (int): Number of input channels
out_channels (int): Number of output channels
num_deconv_layers (int): Number of deconv layers.
num_deconv_layers should >= 0. Note that 0 means
no deconv layers.
num_deconv_filters (list|tuple): Number of filters.
If num_deconv_layers > 0, the length of
num_deconv_kernels (list|tuple): Kernel sizes.
num_deconv_groups (list|tuple): Group number.
in_index (int|Sequence[int]): Input feature index. Default: -1
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
Default: None.
- 'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
- 'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
- None: Only one select feature map is allowed.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
loss_keypoint (dict): Config for keypoint loss. Default: None.
"""
def __init__(self,
in_channels,
out_channels,
num_deconv_layers=3,
num_deconv_filters=(144, 144, 144),
num_deconv_kernels=(4, 4, 4),
num_deconv_groups=(16, 16, 16),
extra=None,
in_index=0,
input_transform=None,
align_corners=False,
loss_keypoint=None,
train_cfg=None,
test_cfg=None):
super().__init__()
self.in_channels = in_channels
self.loss = build_loss(loss_keypoint)
self.train_cfg = {} if train_cfg is None else train_cfg
self.test_cfg = {} if test_cfg is None else test_cfg
self.target_type = self.test_cfg.get('target_type', 'GaussianHeatmap')
self._init_inputs(in_channels, in_index, input_transform)
self.in_index = in_index
self.align_corners = align_corners
if extra is not None and not isinstance(extra, dict):
raise TypeError('extra should be dict or None.')
if num_deconv_layers > 0:
self.deconv_layers = self._make_deconv_layer(
num_deconv_layers, num_deconv_filters, num_deconv_kernels,
num_deconv_groups)
elif num_deconv_layers == 0:
self.deconv_layers = nn.Identity()
else:
raise ValueError(
f'num_deconv_layers ({num_deconv_layers}) should >= 0.')
identity_final_layer = False
if extra is not None and 'final_conv_kernel' in extra:
assert extra['final_conv_kernel'] in [0, 1, 3]
if extra['final_conv_kernel'] == 3:
padding = 1
elif extra['final_conv_kernel'] == 1:
padding = 0
else:
# 0 for Identity mapping.
identity_final_layer = True
kernel_size = extra['final_conv_kernel']
else:
kernel_size = 1
padding = 0
if identity_final_layer:
self.final_layer = nn.Identity()
else:
conv_channels = num_deconv_filters[
-1] if num_deconv_layers > 0 else self.in_channels
layers = []
if extra is not None:
num_conv_layers = extra.get('num_conv_layers', 0)
num_conv_kernels = extra.get('num_conv_kernels',
[1] * num_conv_layers)
for i in range(num_conv_layers):
layers.append(
build_conv_layer(
dict(type='Conv2d'),
in_channels=conv_channels,
out_channels=conv_channels,
kernel_size=num_conv_kernels[i],
stride=1,
padding=(num_conv_kernels[i] - 1) // 2))
layers.append(
build_norm_layer(dict(type='BN'), conv_channels)[1])
layers.append(nn.ReLU(inplace=True))
layers.append(
build_conv_layer(
cfg=dict(type='Conv2d'),
in_channels=conv_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding))
if len(layers) > 1:
self.final_layer = nn.Sequential(*layers)
else:
self.final_layer = layers[0]
def get_loss(self, output, target, target_weight):
"""Calculate top-down keypoint loss.
Note:
- batch_size: N
- num_keypoints: K
- heatmaps height: H
- heatmaps weight: W
Args:
output (torch.Tensor[N,K,H,W]): Output heatmaps.
target (torch.Tensor[N,K,H,W]): Target heatmaps.
target_weight (torch.Tensor[N,K,1]):
Weights across different joint types.
"""
losses = dict()
assert not isinstance(self.loss, nn.Sequential)
assert target.dim() == 4 and target_weight.dim() == 3
losses['heatmap_loss'] = self.loss(output, target, target_weight)
return losses
def get_accuracy(self, output, target, target_weight):
"""Calculate accuracy for top-down keypoint loss.
Note:
- batch_size: N
- num_keypoints: K
- heatmaps height: H
- heatmaps weight: W
Args:
output (torch.Tensor[N,K,H,W]): Output heatmaps.
target (torch.Tensor[N,K,H,W]): Target heatmaps.
target_weight (torch.Tensor[N,K,1]):
Weights across different joint types.
"""
accuracy = dict()
if self.target_type.lower() == 'GaussianHeatmap'.lower():
_, avg_acc, _ = pose_pck_accuracy(
output.detach().cpu().numpy(),
target.detach().cpu().numpy(),
target_weight.detach().cpu().numpy().squeeze(-1) > 0)
accuracy['acc_pose'] = float(avg_acc)
return accuracy
def forward(self, x):
"""Forward function."""
x = self._transform_inputs(x)
x = self.deconv_layers(x)
x = self.final_layer(x)
return x
def inference_model(self, x, flip_pairs=None):
"""Inference function.
Returns:
output_heatmap (np.ndarray): Output heatmaps.
Args:
x (torch.Tensor[N,K,H,W]): Input features.
flip_pairs (None | list[tuple]):
Pairs of keypoints which are mirrored.
"""
output = self.forward(x)
if flip_pairs is not None:
output_heatmap = flip_back(
output.detach().cpu().numpy(),
flip_pairs,
target_type=self.target_type)
# feature is not aligned, shift flipped heatmap for higher accuracy
if self.test_cfg.get('shift_heatmap', False):
output_heatmap[:, :, :, 1:] = output_heatmap[:, :, :, :-1]
else:
output_heatmap = output.detach().cpu().numpy()
return output_heatmap
def _init_inputs(self, in_channels, in_index, input_transform):
"""Check and initialize input transforms.
The in_channels, in_index and input_transform must match.
Specifically, when input_transform is None, only single feature map
will be selected. So in_channels and in_index must be of type int.
When input_transform is not None, in_channels and in_index must be
list or tuple, with the same length.
Args:
in_channels (int|Sequence[int]): Input channels.
in_index (int|Sequence[int]): Input feature index.
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
- 'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
- 'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
- None: Only one select feature map is allowed.
"""
if input_transform is not None:
assert input_transform in ['resize_concat', 'multiple_select']
self.input_transform = input_transform
self.in_index = in_index
if input_transform is not None:
assert isinstance(in_channels, (list, tuple))
assert isinstance(in_index, (list, tuple))
assert len(in_channels) == len(in_index)
if input_transform == 'resize_concat':
self.in_channels = sum(in_channels)
else:
self.in_channels = in_channels
else:
assert isinstance(in_channels, int)
assert isinstance(in_index, int)
self.in_channels = in_channels
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor] | Tensor): multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if not isinstance(inputs, list):
return inputs
if self.input_transform == 'resize_concat':
inputs = [inputs[i] for i in self.in_index]
upsampled_inputs = [
resize(
input=x,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == 'multiple_select':
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
def _make_deconv_layer(self, num_layers, num_filters, num_kernels,
num_groups):
"""Make deconv layers."""
if num_layers != len(num_filters):
error_msg = f'num_layers({num_layers}) ' \
f'!= length of num_filters({len(num_filters)})'
raise ValueError(error_msg)
if num_layers != len(num_kernels):
error_msg = f'num_layers({num_layers}) ' \
f'!= length of num_kernels({len(num_kernels)})'
raise ValueError(error_msg)
if num_layers != len(num_groups):
error_msg = f'num_layers({num_layers}) ' \
f'!= length of num_groups({len(num_groups)})'
raise ValueError(error_msg)
layers = []
for i in range(num_layers):
kernel, padding, output_padding = \
self._get_deconv_cfg(num_kernels[i])
planes = num_filters[i]
groups = num_groups[i]
layers.append(
build_upsample_layer(
dict(type='deconv'),
in_channels=self.in_channels,
out_channels=planes,
kernel_size=kernel,
groups=groups,
stride=2,
padding=padding,
output_padding=output_padding,
bias=False))
layers.append(nn.BatchNorm2d(planes))
layers.append(nn.ReLU(inplace=True))
self.in_channels = planes
return nn.Sequential(*layers)
def init_weights(self):
"""Initialize model weights."""
for _, m in self.deconv_layers.named_modules():
if isinstance(m, nn.ConvTranspose2d):
normal_init(m, std=0.001)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
for m in self.final_layer.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, std=0.001, bias=0)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/heads/voxelpose_head.py
================================================
# ------------------------------------------------------------------------------
# Copyright and License Information
# https://github.com/microsoft/voxelpose-pytorch/blob/main/lib/models
# Original Licence: MIT License
# ------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import HEADS
@HEADS.register_module()
class CuboidCenterHead(nn.Module):
"""Get results from the 3D human center heatmap. In this module, human 3D
centers are local maximums obtained from the 3D heatmap via NMS (max-
pooling).
Args:
space_size (list[3]): The size of the 3D space.
cube_size (list[3]): The size of the heatmap volume.
space_center (list[3]): The coordinate of space center.
max_num (int): Maximum of human center detections.
max_pool_kernel (int): Kernel size of the max-pool kernel in nms.
"""
def __init__(self,
space_size,
space_center,
cube_size,
max_num=10,
max_pool_kernel=3):
super(CuboidCenterHead, self).__init__()
# use register_buffer
self.register_buffer('grid_size', torch.tensor(space_size))
self.register_buffer('cube_size', torch.tensor(cube_size))
self.register_buffer('grid_center', torch.tensor(space_center))
self.num_candidates = max_num
self.max_pool_kernel = max_pool_kernel
self.loss = nn.MSELoss()
def _get_real_locations(self, indices):
"""
Args:
indices (torch.Tensor(NXP)): Indices of points in the 3D tensor
Returns:
real_locations (torch.Tensor(NXPx3)): Locations of points
in the world coordinate system
"""
real_locations = indices.float() / (
self.cube_size - 1) * self.grid_size + \
self.grid_center - self.grid_size / 2.0
return real_locations
def _nms_by_max_pool(self, heatmap_volumes):
max_num = self.num_candidates
batch_size = heatmap_volumes.shape[0]
root_cubes_nms = self._max_pool(heatmap_volumes)
root_cubes_nms_reshape = root_cubes_nms.reshape(batch_size, -1)
topk_values, topk_index = root_cubes_nms_reshape.topk(max_num)
topk_unravel_index = self._get_3d_indices(topk_index,
heatmap_volumes[0].shape)
return topk_values, topk_unravel_index
def _max_pool(self, inputs):
kernel = self.max_pool_kernel
padding = (kernel - 1) // 2
max = F.max_pool3d(
inputs, kernel_size=kernel, stride=1, padding=padding)
keep = (inputs == max).float()
return keep * inputs
@staticmethod
def _get_3d_indices(indices, shape):
"""Get indices in the 3-D tensor.
Args:
indices (torch.Tensor(NXp)): Indices of points in the 1D tensor
shape (torch.Size(3)): The shape of the original 3D tensor
Returns:
indices: Indices of points in the original 3D tensor
"""
batch_size = indices.shape[0]
num_people = indices.shape[1]
indices_x = (indices //
(shape[1] * shape[2])).reshape(batch_size, num_people, -1)
indices_y = ((indices % (shape[1] * shape[2])) //
shape[2]).reshape(batch_size, num_people, -1)
indices_z = (indices % shape[2]).reshape(batch_size, num_people, -1)
indices = torch.cat([indices_x, indices_y, indices_z], dim=2)
return indices
def forward(self, heatmap_volumes):
"""
Args:
heatmap_volumes (torch.Tensor(NXLXWXH)):
3D human center heatmaps predicted by the network.
Returns:
human_centers (torch.Tensor(NXPX5)):
Coordinates of human centers.
"""
batch_size = heatmap_volumes.shape[0]
topk_values, topk_unravel_index = self._nms_by_max_pool(
heatmap_volumes.detach())
topk_unravel_index = self._get_real_locations(topk_unravel_index)
human_centers = torch.zeros(
batch_size, self.num_candidates, 5, device=heatmap_volumes.device)
human_centers[:, :, 0:3] = topk_unravel_index
human_centers[:, :, 4] = topk_values
return human_centers
def get_loss(self, pred_cubes, gt):
return dict(loss_center=self.loss(pred_cubes, gt))
@HEADS.register_module()
class CuboidPoseHead(nn.Module):
def __init__(self, beta):
"""Get results from the 3D human pose heatmap. Instead of obtaining
maximums on the heatmap, this module regresses the coordinates of
keypoints via integral pose regression. Refer to `paper.
` for more details.
Args:
beta: Constant to adjust the magnification of soft-maxed heatmap.
"""
super(CuboidPoseHead, self).__init__()
self.beta = beta
self.loss = nn.L1Loss()
def forward(self, heatmap_volumes, grid_coordinates):
"""
Args:
heatmap_volumes (torch.Tensor(NxKxLxWxH)):
3D human pose heatmaps predicted by the network.
grid_coordinates (torch.Tensor(Nx(LxWxH)x3)):
Coordinates of the grids in the heatmap volumes.
Returns:
human_poses (torch.Tensor(NxKx3)): Coordinates of human poses.
"""
batch_size = heatmap_volumes.size(0)
channel = heatmap_volumes.size(1)
x = heatmap_volumes.reshape(batch_size, channel, -1, 1)
x = F.softmax(self.beta * x, dim=2)
grid_coordinates = grid_coordinates.unsqueeze(1)
x = torch.mul(x, grid_coordinates)
human_poses = torch.sum(x, dim=2)
return human_poses
def get_loss(self, preds, targets, weights):
return dict(loss_pose=self.loss(preds * weights, targets * weights))
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/builder/model_builder.py
================================================
import torch
# from configs.coco.ViTPose_base_coco_256x192 import model
from .heads.topdown_heatmap_simple_head import TopdownHeatmapSimpleHead
# import TopdownHeatmapSimpleHead
from .backbones import ViT
# print(model)
import torch
from functools import partial
import torch.nn as nn
import torch.nn.functional as F
from importlib import import_module
def build_model(model_name, checkpoint=None):
try:
path = ".configs.coco." + model_name
mod = import_module(path, package="src.vitpose_infer")
model = getattr(mod, "model")
# from path import model
except:
raise ValueError("not a correct config")
head = TopdownHeatmapSimpleHead(
in_channels=model["keypoint_head"]["in_channels"],
out_channels=model["keypoint_head"]["out_channels"],
num_deconv_filters=model["keypoint_head"]["num_deconv_filters"],
num_deconv_kernels=model["keypoint_head"]["num_deconv_kernels"],
num_deconv_layers=model["keypoint_head"]["num_deconv_layers"],
extra=model["keypoint_head"]["extra"],
)
# print(head)
backbone = ViT(
img_size=model["backbone"]["img_size"],
patch_size=model["backbone"]["patch_size"],
embed_dim=model["backbone"]["embed_dim"],
depth=model["backbone"]["depth"],
num_heads=model["backbone"]["num_heads"],
ratio=model["backbone"]["ratio"],
mlp_ratio=model["backbone"]["mlp_ratio"],
qkv_bias=model["backbone"]["qkv_bias"],
drop_path_rate=model["backbone"]["drop_path_rate"],
)
class VitPoseModel(nn.Module):
def __init__(self, backbone, keypoint_head):
super(VitPoseModel, self).__init__()
self.backbone = backbone
self.keypoint_head = keypoint_head
def forward(self, x):
x = self.backbone(x)
x = self.keypoint_head(x)
return x
pose = VitPoseModel(backbone, head)
if checkpoint is not None:
check = torch.load(checkpoint)
pose.load_state_dict(check["state_dict"])
return pose
# pose = build_model('ViTPose_base_coco_256x192','./models/vitpose-b-multi-coco.pth')
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/model_builder.py
================================================
import torch
# from configs.coco.ViTPose_base_coco_256x192 import model
from .builder.heads.topdown_heatmap_simple_head import TopdownHeatmapSimpleHead
# import TopdownHeatmapSimpleHead
from .builder.backbones import ViT
# print(model)
import torch
from functools import partial
import torch.nn as nn
import torch.nn.functional as F
from importlib import import_module
models = {
"ViTPose_huge_coco_256x192": dict(
type="TopDown",
pretrained=None,
backbone=dict(
type="ViT",
img_size=(256, 192),
patch_size=16,
embed_dim=1280,
depth=32,
num_heads=16,
ratio=1,
use_checkpoint=False,
mlp_ratio=4,
qkv_bias=True,
drop_path_rate=0.55,
),
keypoint_head=dict(
type="TopdownHeatmapSimpleHead",
in_channels=1280,
num_deconv_layers=2,
num_deconv_filters=(256, 256),
num_deconv_kernels=(4, 4),
extra=dict(
final_conv_kernel=1,
),
out_channels=17,
loss_keypoint=dict(type="JointsMSELoss", use_target_weight=True),
),
train_cfg=dict(),
test_cfg=dict(),
),
"ViTPose_base_coco_256x192": dict(
type="TopDown",
pretrained=None,
backbone=dict(
type="ViT",
img_size=(256, 192),
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
ratio=1,
use_checkpoint=False,
mlp_ratio=4,
qkv_bias=True,
drop_path_rate=0.3,
),
keypoint_head=dict(
type="TopdownHeatmapSimpleHead",
in_channels=768,
num_deconv_layers=2,
num_deconv_filters=(256, 256),
num_deconv_kernels=(4, 4),
extra=dict(
final_conv_kernel=1,
),
out_channels=17,
loss_keypoint=dict(type="JointsMSELoss", use_target_weight=True),
),
train_cfg=dict(),
test_cfg=dict(),
),
"ViTPose_base_simple_coco_256x192": dict(
type="TopDown",
pretrained=None,
backbone=dict(
type="ViT",
img_size=(256, 192),
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
ratio=1,
use_checkpoint=False,
mlp_ratio=4,
qkv_bias=True,
drop_path_rate=0.3,
),
keypoint_head=dict(
type="TopdownHeatmapSimpleHead",
in_channels=768,
num_deconv_layers=0,
num_deconv_filters=[],
num_deconv_kernels=[],
upsample=4,
extra=dict(
final_conv_kernel=3,
),
out_channels=17,
loss_keypoint=dict(type="JointsMSELoss", use_target_weight=True),
),
train_cfg=dict(),
test_cfg=dict(
flip_test=True,
post_process="default",
shift_heatmap=False,
target_type="GaussianHeatmap",
modulate_kernel=11,
use_udp=True,
),
),
}
def build_model(model_name, checkpoint=None):
try:
model = models[model_name]
except:
raise ValueError("not a correct config")
head = TopdownHeatmapSimpleHead(
in_channels=model["keypoint_head"]["in_channels"],
out_channels=model["keypoint_head"]["out_channels"],
num_deconv_filters=model["keypoint_head"]["num_deconv_filters"],
num_deconv_kernels=model["keypoint_head"]["num_deconv_kernels"],
num_deconv_layers=model["keypoint_head"]["num_deconv_layers"],
extra=model["keypoint_head"]["extra"],
)
# print(head)
backbone = ViT(
img_size=model["backbone"]["img_size"],
patch_size=model["backbone"]["patch_size"],
embed_dim=model["backbone"]["embed_dim"],
depth=model["backbone"]["depth"],
num_heads=model["backbone"]["num_heads"],
ratio=model["backbone"]["ratio"],
mlp_ratio=model["backbone"]["mlp_ratio"],
qkv_bias=model["backbone"]["qkv_bias"],
drop_path_rate=model["backbone"]["drop_path_rate"],
)
class VitPoseModel(nn.Module):
def __init__(self, backbone, keypoint_head):
super(VitPoseModel, self).__init__()
self.backbone = backbone
self.keypoint_head = keypoint_head
def forward(self, x):
x = self.backbone(x)
x = self.keypoint_head(x)
return x
pose = VitPoseModel(backbone, head)
if checkpoint is not None:
check = torch.load(checkpoint)
pose.load_state_dict(check["state_dict"])
return pose
# pose = build_model('ViTPose_base_coco_256x192','./models/vitpose-b-multi-coco.pth')
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/ViTPose_trt.py
================================================
import tensorrt as trt
import torch.nn
from collections import OrderedDict, namedtuple
import numpy as np
def torch_device_from_trt(device):
if device == trt.TensorLocation.DEVICE:
return torch.device("cuda")
elif device == trt.TensorLocation.HOST:
return torch.device("cpu")
else:
return TypeError("%s is not supported by torch" % device)
def torch_dtype_from_trt(dtype):
if dtype == trt.int8:
return torch.int8
elif trt.__version__ >= '7.0' and dtype == trt.bool:
return torch.bool
elif dtype == trt.int32:
return torch.int32
elif dtype == trt.float16:
return torch.float16
elif dtype == trt.float32:
return torch.float32
else:
raise TypeError("%s is not supported by torch" % dtype)
class TRTModule_ViTPose(torch.nn.Module):
def __init__(self, engine=None, input_names=None, output_names=None, input_flattener=None, output_flattener=None,path=None,device=None):
super(TRTModule_ViTPose, self).__init__()
# self._register_state_dict_hook(TRTModule._on_state_dict)
# self.engine = engine
logger = trt.Logger(trt.Logger.INFO)
with open(path, 'rb') as f, trt.Runtime(logger) as runtime:
self.engine = runtime.deserialize_cuda_engine(f.read())
if self.engine is not None:
self.context = self.engine.create_execution_context()
self.input_names = ['images']
self.output_names = []
self.input_flattener = input_flattener
self.output_flattener = output_flattener
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
# with open(path, 'rb') as f, trt.Runtime(logger) as runtime:
# self.model = runtime.deserialize_cuda_engine(f.read())
# self.context = self.model.create_execution_context()
self.bindings = OrderedDict()
# self.output_names = []
fp16 = False # default updated below
dynamic = False
for i in range(self.engine.num_bindings):
name = self.engine.get_binding_name(i)
dtype = trt.nptype(self.engine.get_binding_dtype(i))
if self.engine.binding_is_input(i):
if -1 in tuple(self.engine.get_binding_shape(i)): # dynamic
dynamic = True
self.context.set_binding_shape(i, tuple(self.engine.get_profile_shape(0, i)[2]))
if dtype == np.float16:
fp16 = True
else: # output
self.output_names.append(name)
shape = tuple(self.context.get_binding_shape(i))
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
self.bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
self.binding_addrs = OrderedDict((n, d.ptr) for n, d in self.bindings.items())
self.batch_size = self.bindings['images'].shape[0]
def forward(self, *inputs):
bindings = [None] * (len(self.input_names) + len(self.output_names))
if self.input_flattener is not None:
inputs = self.input_flattener.flatten(inputs)
for i, input_name in enumerate(self.input_names):
idx = self.engine.get_binding_index(input_name)
shape = tuple(inputs[i].shape)
bindings[idx] = inputs[i].contiguous().data_ptr()
self.context.set_binding_shape(idx, shape)
# create output tensors
outputs = [None] * len(self.output_names)
for i, output_name in enumerate(self.output_names):
idx = self.engine.get_binding_index(output_name)
dtype = torch_dtype_from_trt(self.engine.get_binding_dtype(idx))
shape = tuple(self.context.get_binding_shape(idx))
device = torch_device_from_trt(self.engine.get_location(idx))
output = torch.empty(size=shape, dtype=dtype, device=device)
outputs[i] = output
bindings[idx] = output.data_ptr()
self.context.execute_async_v2(
bindings, torch.cuda.current_stream().cuda_stream
)
if self.output_flattener is not None:
outputs = self.output_flattener.unflatten(outputs)
else:
outputs = tuple(outputs)
if len(outputs) == 1:
outputs = outputs[0]
return outputs
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/__init__.py
================================================
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/convert_to_trt.py
================================================
from torch2trt import TRTModule,torch2trt
from builder import build_model
import torch
pose = build_model('ViTPose_base_coco_256x192','./models/vitpose-b.pth')
pose.cuda().eval()
x = torch.ones(1,3,256,192).cuda()
net_trt = torch2trt(pose, [x],max_batch_size=10, fp16_mode=True)
torch.save(net_trt.state_dict(), 'vitpose_trt.pth')
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/general_utils.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 15 15:49:22 2022
@author: gpastal
"""
import numpy as np
import numpy.ma as ma
# import pika
import json
from collections import OrderedDict
from collections.abc import Iterable
from itertools import chain
import argparse
def make_parser():
parser = argparse.ArgumentParser("ByteTrack Demo!")
# exp file
# tracking args
parser.add_argument("--track_thresh", type=float, default=0.2, help="tracking confidence threshold")
parser.add_argument("--track_buffer", type=int, default=240, help="the frames for keep lost tracks")
parser.add_argument("--match_thresh", type=float, default=0.8, help="matching threshold for tracking")
parser.add_argument(
"--aspect_ratio_thresh", type=float, default=1.6,
help="threshold for filtering out boxes of which aspect ratio are above the given value."
)
parser.add_argument('--min_box_area', type=float, default=10, help='filter out tiny boxes')
parser.add_argument("--mot20", dest="mot20", default=False, action="store_true", help="test mot20.")
return parser
def jitter(tracking,temp,id1):
pass
def jitter2(tracking,temp,id1) :
pass
def create_json_rabbitmq( FRAME_ID,pose):
pass
def producer_rabbitmq():
pass
def fix_head(xyz):
pass
def flatten_lst(x):
if isinstance(x, Iterable):
return [a for i in x for a in flatten_lst(i)]
else:
return [x]
def polys_from_pose(pts):
seg=[]
for ind, i in enumerate(pts):
list_=[]
list_sc=[]
# list1 = [i[0][1],i[0][0]]
# list2 = [i[0][1],i[0][0]]
# print(i)
for j in i:
temp_ = [j[1],j[0]]
if j[2]>0.4:
temp2_ = [1]
else:
temp2_ =[0]
list_.append(temp_)
list_sc.append(temp2_)
# print(list_sc)
# list2 = [i[6][1],i[6][0]]
# list3 = [i[11][1],i[11][0]]
# list4 = [i[12][1],i[12][0]]
# list_ = flatten_lst(list_)
# print(list_)
list_=fix_list_order(list_,list_sc)
# print(list_)
# list_=list(list_)
# list_ = list_.to_list()
# print(list_)
# temp__=list(chain(*list_))
seg.append(list_) # temp_ = list(chain(list1,list2,list3,list4,list1))
return seg
def fix_list_order(list_,list2):
# for index,values in enumerate(list_):
myorder = [0, 2, 4, 6, 8,10,12,14,16,15,13,11,9,7,5,3,1]
cor_list = [list_[i] for i in myorder]
cor_list2 = [list2[i] for i in myorder]
# print(cor_list)
# result = list(set(map(tuple,cor_list)) & set(map(tuple,cor_list2)))
# arr = np.array([x for x in cor_list])
# print(cor_list)
data = np.asarray(cor_list)
# print(data)
mask = np.column_stack((cor_list2, cor_list2))
# masked = ma.masked_array(data, mask=np.column_stack((cor_list2, cor_list2)))#[cor_list2,cor_list2])
# result = list(set(masked[~masked.mask]))
# print(result)
# print(data)
# print(mask)
result2 = []
for inde,i in enumerate(data):
# print(mask[inde])
if mask[inde].all()==1:
result2.append(i[0])
result2.append(i[1])
# result = [int(result[i] for i in result)]
return result2
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/inference_test.py
================================================
from builder import build_model
import torch
from ViTPose_trt import TRTModule_ViTPose
# pose = TRTModule_ViTPose(path='pose_higher_hrnet_w32_512.engine',device='cuda:0')
pose = build_model('ViTPose_base_coco_256x192','./models/vitpose-b.pth')
pose.cuda().eval()
if pose.training:
print('train')
else:
print('eval')
device = torch.device("cuda")
# pose.to(device)
dummy_input = torch.randn(10, 3,256,192, dtype=torch.float).to(device)
repetitions=100
total_time = 0
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
with torch.no_grad():
for rep in range(repetitions):
# starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
starter.record()
# for k in range(10):
_ = pose(dummy_input)
ender.record()
torch.cuda.synchronize()
curr_time = starter.elapsed_time(ender)/1000
total_time += curr_time
Throughput = repetitions*10/total_time
print('Final Throughput:',Throughput)
print('Total time',total_time)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/logger_helper.py
================================================
import logging
class CustomFormatter(logging.Formatter):
grey = "\x1b[38;20m"
yellow = "\x1b[33;20m"
red = "\x1b[31;20m"
bold_red = "\x1b[31;1m"
reset = "\x1b[0m"
format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)"
FORMATS = {
logging.DEBUG: grey + format + reset,
logging.INFO: grey + format + reset,
logging.WARNING: yellow + format + reset,
logging.ERROR: red + format + reset,
logging.CRITICAL: bold_red + format + reset
}
def format(self, record):
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt)
return formatter.format(record)
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/pose_utils.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 15 15:45:33 2022
@author: gpastal
"""
import torch
import torchvision
import torch.nn.functional as F
from torchvision import transforms as TR
import numpy as np
import cv2
import logging
# from simpleHRNet.models_.hrnet import HRNet
# from torch2trt import torch2trt,TRTModule
logger = logging.getLogger("Tracker !")
from .timerr import Timer
from pathlib import Path
# import gdown
timer_det = Timer()
timer_track = Timer()
timer_pose = Timer()
def pose_points_yolo5(detector,image,pose,tracker,tensorrt):
timer_det.tic()
# starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
transform = TR.Compose([
TR.ToPILImage(),
# Padd(),
TR.Resize((256, 192)), # (height, width)
TR.ToTensor(),
TR.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
detections = detector(image)
timer_det.toc()
logger.info('DET FPS -- %s',1./timer_det.average_time)
# print(detections.shape)
dets = detections.xyxy[0]
dets = dets[dets[:,5] == 0.]
# dets = dets[dets[:,4] > 0.3]
# logger.warning(len(dets))
# if len(dets)>0:
# image_gpu = torch.tensor(image).cuda()/255.
# print(image_gpu.size())
timer_track.tic()
online_targets=tracker.update(dets,[image.shape[0],image.shape[1]],image.shape)
online_tlwhs = []
online_ids = []
online_scores = []
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
# vertical = tlwh[2] / tlwh[3] > args.aspect_ratio_threshs
if tlwh[2] * tlwh[3] > 10 :#and not vertical:
online_tlwhs.append(tlwh)
online_ids.append(tid)
online_scores.append(t.score)
# tracker.update()
timer_track.toc()
logger.info('TRACKING FPS --%s',1./timer_track.average_time)
device='cuda'
nof_people = len(online_ids) if online_ids is not None else 0
# nof_people=1
# print(dets)
# print(nof_people)
boxes = torch.empty((nof_people, 4), dtype=torch.int32,device= 'cuda')
# boxes = []
images = torch.empty((nof_people, 3, 256, 192)) # (height, width)
heatmaps = np.zeros((nof_people, 17, 64, 48),
dtype=np.float32)
# starter.record()
# print(online_tlwhs)
if len(online_tlwhs):
for i, (x1, y1, x2, y2) in enumerate(online_tlwhs):
# for i, (x1, y1, x2, y2) in enumerate(np.array([[55,399,424-55,479-399]])):
# if i<1:
x1 = x1.astype(np.int32)
x2 = x1+x2.astype(np.int32)
y1 = y1.astype(np.int32)
y2 = y1+ y2.astype(np.int32)
if x2>image.shape[1]:x2=image.shape[1]-1
if y2>image.shape[0]:y2=image.shape[0]-1
if y1<0: y1=0
if x1<0 : x1=0
# print([x1,x2,y1,y2])
# image = cv2.rectangle(image, (x1,y1), (x2,y2), (0,0,0), 1)
# cv2.imwrite('saved.png',image)
# # Adapt detections to match HRNet input aspect ratio (as suggested by xtyDoge in issue #14)
correction_factor = 256 / 192 * (x2 - x1) / (y2 - y1)
if correction_factor > 1:
# increase y side
center = y1 + (y2 - y1) // 2
length = int(round((y2 - y1) * correction_factor))
y1_new = int( center - length // 2)
y2_new = int( center + length // 2)
image_crop = image[y1:y2, x1:x2, ::-1]
# print(y1,y2,x1,x2)
pad = (int(abs(y1_new-y1))), int(abs(y2_new-y2))
image_crop = np.pad(image_crop,((pad), (0, 0), (0, 0)))
images[i] = transform(image_crop)
boxes[i]= torch.tensor([x1, y1_new, x2, y2_new])
elif correction_factor < 1:
# increase x side
center = x1 + (x2 - x1) // 2
length = int(round((x2 - x1) * 1 / correction_factor))
x1_new = int( center - length // 2)
x2_new = int( center + length // 2)
# images[i] = transform(image[y1:y2, x1:x2, ::-1])
image_crop = image[y1:y2, x1:x2, ::-1]
pad = (abs(x1_new-x1)), int(abs(x2_new-x2))
image_crop = np.pad(image_crop,((0, 0), (pad), (0, 0)))
images[i] = transform(image_crop)
boxes[i]= torch.tensor([x1_new, y1, x2_new, y2])
if images.shape[0] > 0:
images = images.to(device)
if tensorrt:
out = torch.zeros((images.shape[0],17,64,48),device=device)
with torch.no_grad():
timer_pose.tic()
for i in range(images.shape[0]):
# timer_pose.tic()
# print(images[i].size())
out[i] = pose(images[i].unsqueeze(0))
timer_pose.toc()
logger.info('POSE FPS -- %s',1./timer_pose.average_time)
else:
with torch.no_grad():
timer_pose.tic()
out = pose(images)
timer_pose.toc()
logger.info('POSE FPS -- %s',1./timer_pose.average_time)
pts = torch.empty((out.shape[0], out.shape[1], 3), dtype=torch.float32,device=device)
pts2 = np.empty((out.shape[0], out.shape[1], 3), dtype=np.float32)
(b,indices)=torch.max(out,dim=2)
(b,indices)=torch.max(b,dim=2)
(c,indicesc)=torch.max(out,dim=3)
(c,indicesc)=torch.max(c,dim=2)
dim1= torch.tensor(1. / 64,device=device)
dim2= torch.tensor(1. / 48,device=device)
for i in range(0,out.shape[0]):
pts[i, :, 0] = indicesc[i,:] * dim1 * (boxes[i][3] - boxes[i][1]) + boxes[i][1]
pts[i, :, 1] = indices[i,:] *dim2* (boxes[i][2] - boxes[i][0]) + boxes[i][0]
pts[i, :, 2] = c[i,:]
pts=pts.cpu().numpy()
# print(pts)
else:
pts = np.empty((0, 0, 3), dtype=np.float32)
online_tlwhs = []
online_ids = []
online_scores=[]
res = list()
res.append(pts)
if len(res) > 1:
return res,online_tlwhs,online_ids,online_scores#,pts2
else:
return res[0],online_tlwhs,online_ids,online_scores#,pts2
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/pose_viz.py
================================================
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import ffmpeg
def joints_dict():
joints = {
"coco": {
"keypoints": {
0: "nose",
1: "left_eye",
2: "right_eye",
3: "left_ear",
4: "right_ear",
5: "left_shoulder",
6: "right_shoulder",
7: "left_elbow",
8: "right_elbow",
9: "left_wrist",
10: "right_wrist",
11: "left_hip",
12: "right_hip",
13: "left_knee",
14: "right_knee",
15: "left_ankle",
16: "right_ankle"
},
"skeleton": [
# # [16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], [6, 7], [6, 8],
# # [7, 9], [8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]
# [15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7],
# [6, 8], [7, 9], [8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]
[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7],
[6, 8], [7, 9], [8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], # [3, 5], [4, 6]
[0, 5], [0, 6]
]
},
"mpii": {
"keypoints": {
0: "right_ankle",
1: "right_knee",
2: "right_hip",
3: "left_hip",
4: "left_knee",
5: "left_ankle",
6: "pelvis",
7: "thorax",
8: "upper_neck",
9: "head top",
10: "right_wrist",
11: "right_elbow",
12: "right_shoulder",
13: "left_shoulder",
14: "left_elbow",
15: "left_wrist"
},
"skeleton": [
# [5, 4], [4, 3], [0, 1], [1, 2], [3, 2], [13, 3], [12, 2], [13, 12], [13, 14],
# [12, 11], [14, 15], [11, 10], # [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]
[5, 4], [4, 3], [0, 1], [1, 2], [3, 2], [3, 6], [2, 6], [6, 7], [7, 8], [8, 9],
[13, 7], [12, 7], [13, 14], [12, 11], [14, 15], [11, 10],
]
},
}
return joints
def draw_points(image, points, color_palette='tab20', palette_samples=16, confidence_threshold=0.5):
"""
Draws `points` on `image`.
Args:
image: image in opencv format
points: list of points to be drawn.
Shape: (nof_points, 3)
Format: each point should contain (y, x, confidence)
color_palette: name of a matplotlib color palette
Default: 'tab20'
palette_samples: number of different colors sampled from the `color_palette`
Default: 16
confidence_threshold: only points with a confidence higher than this threshold will be drawn. Range: [0, 1]
Default: 0.5
Returns:
A new image with overlaid points
"""
try:
colors = np.round(
np.array(plt.get_cmap(color_palette).colors) * 255
).astype(np.uint8)[:, ::-1].tolist()
except AttributeError: # if palette has not pre-defined colors
colors = np.round(
np.array(plt.get_cmap(color_palette)(np.linspace(0, 1, palette_samples))) * 255
).astype(np.uint8)[:, -2::-1].tolist()
circle_size = max(1, min(image.shape[:2]) // 160) # ToDo Shape it taking into account the size of the detection
# circle_size = max(2, int(np.sqrt(np.max(np.max(points, axis=0) - np.min(points, axis=0)) // 16)))
for i, pt in enumerate(points):
if pt[2] > confidence_threshold:
image = cv2.circle(image, (int(pt[1]), int(pt[0])), circle_size, tuple(colors[i % len(colors)]), -1)
return image
def draw_skeleton(image, points, skeleton, color_palette='Set2', palette_samples=8, person_index=0,
confidence_threshold=0.5):
"""
Draws a `skeleton` on `image`.
Args:
image: image in opencv format
points: list of points to be drawn.
Shape: (nof_points, 3)
Format: each point should contain (y, x, confidence)
skeleton: list of joints to be drawn
Shape: (nof_joints, 2)
Format: each joint should contain (point_a, point_b) where `point_a` and `point_b` are an index in `points`
color_palette: name of a matplotlib color palette
Default: 'Set2'
palette_samples: number of different colors sampled from the `color_palette`
Default: 8
person_index: index of the person in `image`
Default: 0
confidence_threshold: only points with a confidence higher than this threshold will be drawn. Range: [0, 1]
Default: 0.5
Returns:
A new image with overlaid joints
"""
try:
colors = np.round(
np.array(plt.get_cmap(color_palette).colors) * 255
).astype(np.uint8)[:, ::-1].tolist()
except AttributeError: # if palette has not pre-defined colors
colors = np.round(
np.array(plt.get_cmap(color_palette)(np.linspace(0, 1, palette_samples))) * 255
).astype(np.uint8)[:, -2::-1].tolist()
for i, joint in enumerate(skeleton):
pt1, pt2 = points[joint]
if pt1[2] > confidence_threshold and pt2[2] > confidence_threshold:
image = cv2.line(
image, (int(pt1[1]), int(pt1[0])), (int(pt2[1]), int(pt2[0])),
tuple(colors[person_index % len(colors)]), 2
)
return image
def draw_points_and_skeleton(image, points, skeleton, points_color_palette='tab20', points_palette_samples=16,
skeleton_color_palette='Set2', skeleton_palette_samples=8, person_index=0,
confidence_threshold=0.5):
"""
Draws `points` and `skeleton` on `image`.
Args:
image: image in opencv format
points: list of points to be drawn.
Shape: (nof_points, 3)
Format: each point should contain (y, x, confidence)
skeleton: list of joints to be drawn
Shape: (nof_joints, 2)
Format: each joint should contain (point_a, point_b) where `point_a` and `point_b` are an index in `points`
points_color_palette: name of a matplotlib color palette
Default: 'tab20'
points_palette_samples: number of different colors sampled from the `color_palette`
Default: 16
skeleton_color_palette: name of a matplotlib color palette
Default: 'Set2'
skeleton_palette_samples: number of different colors sampled from the `color_palette`
Default: 8
person_index: index of the person in `image`
Default: 0
confidence_threshold: only points with a confidence higher than this threshold will be drawn. Range: [0, 1]
Default: 0.5
Returns:
A new image with overlaid joints
"""
image = draw_skeleton(image, points, skeleton, color_palette=skeleton_color_palette,
palette_samples=skeleton_palette_samples, person_index=person_index,
confidence_threshold=confidence_threshold)
image = draw_points(image, points, color_palette=points_color_palette, palette_samples=points_palette_samples,
confidence_threshold=confidence_threshold)
return image
def save_images(images, target, joint_target, output, joint_output, joint_visibility, summary_writer=None, step=0,
prefix=''):
"""
Creates a grid of images with gt joints and a grid with predicted joints.
This is a basic function for debugging purposes only.
If summary_writer is not None, the grid will be written in that SummaryWriter with name "{prefix}_images" and
"{prefix}_predictions".
Args:
images (torch.Tensor): a tensor of images with shape (batch x channels x height x width).
target (torch.Tensor): a tensor of gt heatmaps with shape (batch x channels x height x width).
joint_target (torch.Tensor): a tensor of gt joints with shape (batch x joints x 2).
output (torch.Tensor): a tensor of predicted heatmaps with shape (batch x channels x height x width).
joint_output (torch.Tensor): a tensor of predicted joints with shape (batch x joints x 2).
joint_visibility (torch.Tensor): a tensor of joint visibility with shape (batch x joints).
summary_writer (tb.SummaryWriter): a SummaryWriter where write the grids.
Default: None
step (int): summary_writer step.
Default: 0
prefix (str): summary_writer name prefix.
Default: ""
Returns:
A pair of images which are built from torchvision.utils.make_grid
"""
# Input images with gt
images_ok = images.detach().clone()
images_ok[:, 0].mul_(0.229).add_(0.485)
images_ok[:, 1].mul_(0.224).add_(0.456)
images_ok[:, 2].mul_(0.225).add_(0.406)
for i in range(images.shape[0]):
joints = joint_target[i] * 4.
joints_vis = joint_visibility[i]
for joint, joint_vis in zip(joints, joints_vis):
if joint_vis[0]:
a = int(joint[1].item())
b = int(joint[0].item())
# images_ok[i][:, a-1:a+1, b-1:b+1] = torch.tensor([1, 0, 0])
images_ok[i][0, a - 1:a + 1, b - 1:b + 1] = 1
images_ok[i][1:, a - 1:a + 1, b - 1:b + 1] = 0
grid_gt = torchvision.utils.make_grid(images_ok, nrow=int(images_ok.shape[0] ** 0.5), padding=2, normalize=False)
if summary_writer is not None:
summary_writer.add_image(prefix + 'images', grid_gt, global_step=step)
# Input images with prediction
images_ok = images.detach().clone()
images_ok[:, 0].mul_(0.229).add_(0.485)
images_ok[:, 1].mul_(0.224).add_(0.456)
images_ok[:, 2].mul_(0.225).add_(0.406)
for i in range(images.shape[0]):
joints = joint_output[i] * 4.
joints_vis = joint_visibility[i]
for joint, joint_vis in zip(joints, joints_vis):
if joint_vis[0]:
a = int(joint[1].item())
b = int(joint[0].item())
# images_ok[i][:, a-1:a+1, b-1:b+1] = torch.tensor([1, 0, 0])
images_ok[i][0, a - 1:a + 1, b - 1:b + 1] = 1
images_ok[i][1:, a - 1:a + 1, b - 1:b + 1] = 0
grid_pred = torchvision.utils.make_grid(images_ok, nrow=int(images_ok.shape[0] ** 0.5), padding=2, normalize=False)
if summary_writer is not None:
summary_writer.add_image(prefix + 'predictions', grid_pred, global_step=step)
# Heatmaps
# ToDo
# for h in range(0,17):
# heatmap = torchvision.utils.make_grid(output[h].detach(), nrow=int(np.sqrt(output.shape[0])),
# padding=2, normalize=True, range=(0, 1))
# summary_writer.add_image('train_heatmap_%d' % h, heatmap, global_step=step + epoch*len_dl_train)
return grid_gt, grid_pred
def check_video_rotation(filename):
# thanks to
# https://stackoverflow.com/questions/53097092/frame-from-video-is-upside-down-after-extracting/55747773#55747773
# this returns meta-data of the video file in form of a dictionary
meta_dict = ffmpeg.probe(filename)
# from the dictionary, meta_dict['streams'][0]['tags']['rotate'] is the key
# we are looking for
rotation_code = None
try:
if int(meta_dict['streams'][0]['tags']['rotate']) == 90:
rotation_code = cv2.ROTATE_90_CLOCKWISE
elif int(meta_dict['streams'][0]['tags']['rotate']) == 180:
rotation_code = cv2.ROTATE_180
elif int(meta_dict['streams'][0]['tags']['rotate']) == 270:
rotation_code = cv2.ROTATE_90_COUNTERCLOCKWISE
else:
raise ValueError
except KeyError:
pass
return rotation_code
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/timerr.py
================================================
import time
class Timer(object):
"""A simple timer."""
def __init__(self):
self.total_time = 0.
self.calls = 0
self.start_time = 0.
self.diff = 0.
self.average_time = 0.
self.duration = 0.
def tic(self):
# using time.time instead of time.clock because time time.clock
# does not normalize for multithreading
self.start_time = time.time()
def toc(self, average=True):
self.diff = time.time() - self.start_time
self.total_time += self.diff
self.calls += 1
self.average_time = self.total_time / self.calls
if average:
self.duration = self.average_time
else:
self.duration = self.diff
return self.duration
def clear(self):
self.total_time = 0.
self.calls = 0
self.start_time = 0.
self.diff = 0.
self.average_time = 0.
self.duration = 0.
================================================
FILE: eval/GVHMR/hmr4d/utils/preproc/vitpose_pytorch/src/vitpose_infer/pose_utils/visualizer.py
================================================
import cv2
import numpy as np
__all__ = ["vis"]
def vis(img, boxes, scores, cls_ids, conf=0.5, class_names=None):
for i in range(len(boxes)):
box = boxes[i]
cls_id = int(cls_ids[i])
score = scores[i]
if score < conf:
continue
x0 = int(box[0])
y0 = int(box[1])
x1 = int(box[2])
y1 = int(box[3])
color = (_COLORS[cls_id] * 255).astype(np.uint8).tolist()
text = '{}:{:.1f}%'.format(class_names[cls_id], score * 100)
txt_color = (0, 0, 0) if np.mean(_COLORS[cls_id]) > 0.5 else (255, 255, 255)
font = cv2.FONT_HERSHEY_SIMPLEX
txt_size = cv2.getTextSize(text, font, 0.4, 1)[0]
cv2.rectangle(img, (x0, y0), (x1, y1), color, 2)
txt_bk_color = (_COLORS[cls_id] * 255 * 0.7).astype(np.uint8).tolist()
cv2.rectangle(
img,
(x0, y0 + 1),
(x0 + txt_size[0] + 1, y0 + int(1.5*txt_size[1])),
txt_bk_color,
-1
)
cv2.putText(img, text, (x0, y0 + txt_size[1]), font, 0.4, txt_color, thickness=1)
return img
def get_color(idx):
idx = idx * 3
color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
return color
def plot_tracking(image, tlwhs, obj_ids, scores=None, frame_id=0, fps=0., ids2=None):
im = np.ascontiguousarray(np.copy(image))
im_h, im_w = im.shape[:2]
top_view = np.zeros([im_w, im_w, 3], dtype=np.uint8) + 255
#text_scale = max(1, image.shape[1] / 1600.)
#text_thickness = 2
#line_thickness = max(1, int(image.shape[1] / 500.))
text_scale = 2
text_thickness = 2
line_thickness = 3
radius = max(5, int(im_w/140.))
cv2.putText(im, 'frame: %d fps: %.2f num: %d' % (frame_id, fps, len(tlwhs)),
(0, int(15 * text_scale)), cv2.FONT_HERSHEY_PLAIN, 2, (0, 0, 255), thickness=2)
for i, tlwh in enumerate(tlwhs):
x1, y1, w, h = tlwh
intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
obj_id = int(obj_ids[i])
id_text = '{}'.format(int(obj_id))
if ids2 is not None:
id_text = id_text + ', {}'.format(int(ids2[i]))
color = get_color(abs(obj_id))
cv2.rectangle(im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness)
cv2.putText(im, id_text, (intbox[0], intbox[1]), cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 0, 255),
thickness=text_thickness)
return im
_COLORS = np.array(
[
0.000, 0.447, 0.741,
0.850, 0.325, 0.098,
0.929, 0.694, 0.125,
0.494, 0.184, 0.556,
0.466, 0.674, 0.188,
0.301, 0.745, 0.933,
0.635, 0.078, 0.184,
0.300, 0.300, 0.300,
0.600, 0.600, 0.600,
1.000, 0.000, 0.000,
1.000, 0.500, 0.000,
0.749, 0.749, 0.000,
0.000, 1.000, 0.000,
0.000, 0.000, 1.000,
0.667, 0.000, 1.000,
0.333, 0.333, 0.000,
0.333, 0.667, 0.000,
0.333, 1.000, 0.000,
0.667, 0.333, 0.000,
0.667, 0.667, 0.000,
0.667, 1.000, 0.000,
1.000, 0.333, 0.000,
1.000, 0.667, 0.000,
1.000, 1.000, 0.000,
0.000, 0.333, 0.500,
0.000, 0.667, 0.500,
0.000, 1.000, 0.500,
0.333, 0.000, 0.500,
0.333, 0.333, 0.500,
0.333, 0.667, 0.500,
0.333, 1.000, 0.500,
0.667, 0.000, 0.500,
0.667, 0.333, 0.500,
0.667, 0.667, 0.500,
0.667, 1.000, 0.500,
1.000, 0.000, 0.500,
1.000, 0.333, 0.500,
1.000, 0.667, 0.500,
1.000, 1.000, 0.500,
0.000, 0.333, 1.000,
0.000, 0.667, 1.000,
0.000, 1.000, 1.000,
0.333, 0.000, 1.000,
0.333, 0.333, 1.000,
0.333, 0.667, 1.000,
0.333, 1.000, 1.000,
0.667, 0.000, 1.000,
0.667, 0.333, 1.000,
0.667, 0.667, 1.000,
0.667, 1.000, 1.000,
1.000, 0.000, 1.000,
1.000, 0.333, 1.000,
1.000, 0.667, 1.000,
0.333, 0.000, 0.000,
0.500, 0.000, 0.000,
0.667, 0.000, 0.000,
0.833, 0.000, 0.000,
1.000, 0.000, 0.000,
0.000, 0.167, 0.000,
0.000, 0.333, 0.000,
0.000, 0.500, 0.000,
0.000, 0.667, 0.000,
0.000, 0.833, 0.000,
0.000, 1.000, 0.000,
0.000, 0.000, 0.167,
0.000, 0.000, 0.333,
0.000, 0.000, 0.500,
0.000, 0.000, 0.667,
0.000, 0.000, 0.833,
0.000, 0.000, 1.000,
0.000, 0.000, 0.000,
0.143, 0.143, 0.143,
0.286, 0.286, 0.286,
0.429, 0.429, 0.429,
0.571, 0.571, 0.571,
0.714, 0.714, 0.714,
0.857, 0.857, 0.857,
0.000, 0.447, 0.741,
0.314, 0.717, 0.741,
0.50, 0.5, 0
]
).astype(np.float32).reshape(-1, 3)
================================================
FILE: eval/GVHMR/hmr4d/utils/pylogger.py
================================================
from time import time
import logging
import torch
from colorlog import ColoredFormatter
def sync_time():
torch.cuda.synchronize()
return time()
Log = logging.getLogger()
Log.time = time
Log.sync_time = sync_time
# Set default
Log.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
# Use colorlog
formatstring = "[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] %(message)s"
datefmt = "%m/%d %H:%M:%S"
ch.setFormatter(ColoredFormatter(formatstring, datefmt=datefmt))
Log.addHandler(ch)
# Log.info("Init-Logger")
def timer(sync_cuda=False, mem=False, loop=1):
"""
Args:
func: function
sync_cuda: bool, whether to synchronize cuda
mem: bool, whether to log memory
"""
def decorator(func):
def wrapper(*args, **kwargs):
if mem:
start_mem = torch.cuda.memory_allocated() / 1024**2
if sync_cuda:
torch.cuda.synchronize()
start = Log.time()
for _ in range(loop):
result = func(*args, **kwargs)
if sync_cuda:
torch.cuda.synchronize()
if loop == 1:
message = f"{func.__name__} took {Log.time() - start:.3f} s."
else:
message = f"{func.__name__} took {((Log.time() - start))/loop:.3f} s. (loop={loop})"
if mem:
end_mem = torch.cuda.memory_allocated() / 1024**2
end_max_mem = torch.cuda.max_memory_allocated() / 1024**2
message += f" Start_Mem {start_mem:.1f} Max {end_max_mem:.1f} MB"
Log.info(message)
return result
return wrapper
return decorator
def timed(fn):
"""example usage: timed(lambda: model(inp))"""
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1000
================================================
FILE: eval/GVHMR/hmr4d/utils/seq_utils.py
================================================
import torch
import numpy as np
# def get_frame_id_list_from_mask(mask):
# """
# Args:
# mask (F,), bool.
# Return:
# frame_id_list: List of frame_ids.
# """
# frame_id_list = []
# i = 0
# while i < len(mask):
# if not mask[i]:
# i += 1
# else:
# j = i
# while j < len(mask) and mask[j]:
# j += 1
# frame_id_list.append(torch.arange(i, j))
# i = j
# return frame_id_list
# From GPT
def get_frame_id_list_from_mask(mask):
# batch=64, 0.13s
"""
Vectorized approach to get frame id list from a boolean mask.
Args:
mask (F,), bool tensor: Mask array where `True` indicates a frame to be processed.
Returns:
frame_id_list: List of torch.Tensors, each tensor containing continuous indices where mask is True.
"""
# Find the indices where the mask changes from False to True and vice versa
padded_mask = torch.cat(
[torch.tensor([False], device=mask.device), mask, torch.tensor([False], device=mask.device)]
)
diffs = torch.diff(padded_mask.int())
starts = (diffs == 1).nonzero(as_tuple=False).squeeze()
ends = (diffs == -1).nonzero(as_tuple=False).squeeze()
if starts.numel() == 0:
return []
if starts.numel() == 1:
starts = starts.reshape(-1)
ends = ends.reshape(-1)
# Create list of ranges
frame_id_list = [torch.arange(start, end) for start, end in zip(starts, ends)]
return frame_id_list
def get_batch_frame_id_lists_from_mask_BLC(masks):
# batch=64, 0.10s
"""
处理三维掩码数组,为每个批次和通道提取连续True区段的索引列表。
参数:
masks (B, L, C), 布尔张量:每个元素代表一个掩码,True表示需要处理的帧。
返回:
batch_frame_id_lists: 对应于每个批次和每个通道的帧id列表的嵌套列表。
"""
B, L, C = masks.size()
# 在序列长度两端添加一个False
padded_masks = torch.cat(
[
torch.zeros((B, 1, C), dtype=torch.bool, device=masks.device),
masks,
torch.zeros((B, 1, C), dtype=torch.bool, device=masks.device),
],
dim=1,
)
# 计算差分来找到True区段的起始和结束点
diffs = torch.diff(padded_masks.int(), dim=1)
starts = (diffs == 1).nonzero(as_tuple=True)
ends = (diffs == -1).nonzero(as_tuple=True)
# 初始化返回列表
batch_frame_id_lists = [[[] for _ in range(C)] for _ in range(B)]
for b in range(B):
for c in range(C):
batch_start = starts[0][(starts[0] == b) & (starts[2] == c)]
batch_end = ends[0][(ends[0] == b) & (ends[2] == c)]
# 确保start和end都是1维张量
batch_frame_id_lists[b][c] = [
torch.arange(start.item(), end.item()) for start, end in zip(batch_start, batch_end)
]
return batch_frame_id_lists
def get_frame_id_list_from_frame_id(frame_id):
mask = torch.zeros(frame_id[-1] + 1, dtype=torch.bool)
mask[frame_id] = True
frame_id_list = get_frame_id_list_from_mask(mask)
return frame_id_list
def rearrange_by_mask(x, mask):
"""
x (L, *)
mask (M,), M >= L
"""
M = mask.size(0)
L = x.size(0)
if M == L:
return x
assert M > L
assert mask.sum() == L
x_rearranged = torch.zeros((M, *x.size()[1:]), dtype=x.dtype, device=x.device)
x_rearranged[mask] = x
return x_rearranged
def frame_id_to_mask(frame_id, max_len):
mask = torch.zeros(max_len, dtype=torch.bool)
mask[frame_id] = True
return mask
def mask_to_frame_id(mask):
frame_id = torch.where(mask)[0]
return frame_id
def linear_interpolate_frame_ids(data, frame_id_list):
data = data.clone()
for i, invalid_frame_ids in enumerate(frame_id_list):
# interplate between prev, next
# if at beginning or end, use the same value
if invalid_frame_ids[0] - 1 < 0 or invalid_frame_ids[-1] + 1 >= len(data):
if invalid_frame_ids[0] - 1 < 0:
data[invalid_frame_ids] = data[invalid_frame_ids[-1] + 1].clone()
else:
data[invalid_frame_ids] = data[invalid_frame_ids[0] - 1].clone()
else:
prev = data[invalid_frame_ids[0] - 1]
next = data[invalid_frame_ids[-1] + 1]
data[invalid_frame_ids] = (
torch.linspace(0, 1, len(invalid_frame_ids) + 2)[1:-1][:, None] * (next - prev)[None] + prev[None]
)
return data
def linear_interpolate(data, N_middle_frames):
"""
Args:
data: (2, C)
Returns:
data_interpolated: (1+N+1, C)
"""
prev = data[0]
next = data[1]
middle = torch.linspace(0, 1, N_middle_frames + 2)[1:-1][:, None] * (next - prev)[None] + prev[None] # (N, C)
data_interpolated = torch.cat([data[0][None], middle, data[1][None]], dim=0) # (1+N+1, C)
return data_interpolated
def find_top_k_span(mask, k=3):
"""
Args:
mask: (L,)
Return:
topk_span: List of tuple, usage: [start, end)
"""
if isinstance(mask, np.ndarray):
mask = torch.from_numpy(mask)
if mask.sum() == 0:
return []
mask = mask.clone().float()
mask = torch.cat([mask.new([0]), mask, mask.new([0])])
diff = mask[1:] - mask[:-1]
start = torch.where(diff == 1)[0]
end = torch.where(diff == -1)[0]
assert len(start) == len(end)
span_lengths = end - start
span_lengths, idx = span_lengths.sort(descending=True)
start = start[idx]
end = end[idx]
return list(zip(start.tolist(), end.tolist()))[:k]
================================================
FILE: eval/GVHMR/hmr4d/utils/smplx_utils.py
================================================
import torch
import torch.nn.functional as F
import numpy as np
import smplx
import pickle
from smplx import SMPL, SMPLX, SMPLXLayer
from hmr4d.utils.body_model import BodyModelSMPLH, BodyModelSMPLX
from hmr4d.utils.body_model.smplx_lite import SmplxLiteCoco17, SmplxLiteV437Coco17, SmplxLiteSmplN24
from hmr4d import PROJ_ROOT
# fmt: off
SMPLH_PARENTS = torch.tensor([-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14,
16, 17, 18, 19, 20, 22, 23, 20, 25, 26, 20, 28, 29, 20, 31, 32, 20, 34,
35, 21, 37, 38, 21, 40, 41, 21, 43, 44, 21, 46, 47, 21, 49, 50])
# fmt: on
def make_smplx(type="neu_fullpose", **kwargs):
if type == "neu_fullpose":
model = smplx.create(
model_path="inputs/models/smplx/SMPLX_NEUTRAL.npz", use_pca=False, flat_hand_mean=True, **kwargs
)
elif type == "supermotion":
# SuperMotion is trained on BEDLAM dataset, the smplx config is the same except only 10 betas are used
bm_kwargs = {
"model_type": "smplx",
"gender": "neutral",
"num_pca_comps": 12,
"flat_hand_mean": False,
}
bm_kwargs.update(kwargs)
model = BodyModelSMPLX(model_path=PROJ_ROOT / "inputs/checkpoints/body_models", **bm_kwargs)
elif type == "supermotion_EVAL3DPW":
# SuperMotion is trained on BEDLAM dataset, the smplx config is the same except only 10 betas are used
bm_kwargs = {
"model_type": "smplx",
"gender": "neutral",
"num_pca_comps": 12,
"flat_hand_mean": True,
}
bm_kwargs.update(kwargs)
model = BodyModelSMPLX(model_path="inputs/checkpoints/body_models", **bm_kwargs)
elif type == "supermotion_coco17":
# Fast but only predicts 17 joints
model = SmplxLiteCoco17()
elif type == "supermotion_v437coco17":
# Predicts 437 verts and 17 joints
model = SmplxLiteV437Coco17()
elif type == "supermotion_smpl24":
model = SmplxLiteSmplN24()
elif type == "rich-smplx":
# https://github.com/paulchhuang/rich_toolkit/blob/main/smplx2images.py
bm_kwargs = {
"model_type": "smplx",
"gender": kwargs.get("gender", "male"),
"num_pca_comps": 12,
"flat_hand_mean": False,
# create_expression=True, create_jaw_pose=Ture
}
# A /smplx folder should exist under the model_path
model = BodyModelSMPLX(model_path="inputs/checkpoints/body_models", **bm_kwargs)
elif type == "rich-smplh":
bm_kwargs = {
"model_type": "smplh",
"gender": kwargs.get("gender", "male"),
"use_pca": False,
"flat_hand_mean": True,
}
model = BodyModelSMPLH(model_path="inputs/checkpoints/body_models", **bm_kwargs)
elif type in ["smplx-circle", "smplx-groundlink"]:
# don't use hand
bm_kwargs = {
"model_path": "inputs/checkpoints/body_models",
"model_type": "smplx",
"gender": kwargs.get("gender"),
"num_betas": 16,
"num_expression": 0,
}
model = BodyModelSMPLX(**bm_kwargs)
elif type == "smplx-motionx":
layer_args = {
"create_global_orient": False,
"create_body_pose": False,
"create_left_hand_pose": False,
"create_right_hand_pose": False,
"create_jaw_pose": False,
"create_leye_pose": False,
"create_reye_pose": False,
"create_betas": False,
"create_expression": False,
"create_transl": False,
}
bm_kwargs = {
"model_type": "smplx",
"model_path": "inputs/checkpoints/body_models",
"gender": "neutral",
"use_pca": False,
"use_face_contour": True,
**layer_args,
}
model = smplx.create(**bm_kwargs)
elif type == "smplx-samp":
# don't use hand
bm_kwargs = {
"model_path": "inputs/checkpoints/body_models",
"model_type": "smplx",
"gender": kwargs.get("gender"),
"num_betas": 10,
"num_expression": 0,
}
model = BodyModelSMPLX(**bm_kwargs)
elif type == "smplx-bedlam":
# don't use hand
bm_kwargs = {
"model_path": "inputs/checkpoints/body_models",
"model_type": "smplx",
"gender": kwargs.get("gender"),
"num_betas": 11,
"num_expression": 0,
}
model = BodyModelSMPLX(**bm_kwargs)
elif type in ["smplx-layer", "smplx-fit3d"]:
# Use layer
if type == "smplx-fit3d":
assert (
kwargs.get("gender") == "neutral"
), "smplx-fit3d use neutral model: https://github.com/sminchisescu-research/imar_vision_datasets_tools/blob/e8c8f83ffac23cc36adf8ec8d0fd1c55679484ef/util/smplx_util.py#L15C34-L15C34"
bm_kwargs = {
"model_path": "inputs/checkpoints/body_models/smplx",
"gender": kwargs.get("gender"),
"num_betas": 10,
"num_expression": 10,
}
model = SMPLXLayer(**bm_kwargs)
elif type == "smpl":
bm_kwargs = {
"model_path": PROJ_ROOT / "inputs/checkpoints/body_models",
"model_type": "smpl",
"gender": "neutral",
"num_betas": 10,
"create_body_pose": False,
"create_betas": False,
"create_global_orient": False,
"create_transl": False,
}
bm_kwargs.update(kwargs)
# model = SMPL(**bm_kwargs)
model = BodyModelSMPLH(**bm_kwargs)
elif type == "smplh":
bm_kwargs = {
"model_type": "smplh",
"gender": kwargs.get("gender", "male"),
"use_pca": False,
"flat_hand_mean": False,
}
model = BodyModelSMPLH(model_path="inputs/checkpoints/body_models", **bm_kwargs)
else:
raise NotImplementedError
return model
def load_parents(npz_path="models/smplx/SMPLX_NEUTRAL.npz"):
smplx_struct = np.load("models/smplx/SMPLX_NEUTRAL.npz", allow_pickle=True)
parents = smplx_struct["kintree_table"][0].astype(np.long)
parents[0] = -1
return parents
def load_smpl_faces(npz_path="models/smplh/SMPLH_FEMALE.pkl"):
with open(npz_path, "rb") as f:
smpl_model = pickle.load(f, encoding="latin1")
faces = np.array(smpl_model["f"].astype(np.int64))
return faces
def decompose_fullpose(fullpose, model_type="smplx"):
assert model_type == "smplx"
fullpose_dict = {
"global_orient": fullpose[..., :3],
"body_pose": fullpose[..., 3:66],
"jaw_pose": fullpose[..., 66:69],
"leye_pose": fullpose[..., 69:72],
"reye_pose": fullpose[..., 72:75],
"left_hand_pose": fullpose[..., 75:120],
"right_hand_pose": fullpose[..., 120:165],
}
return fullpose_dict
def compose_fullpose(fullpose_dict, model_type="smplx"):
assert model_type == "smplx"
fullpose = torch.cat(
[
fullpose_dict[k]
for k in [
"global_orient",
"body_pose",
"jaw_pose",
"leye_pose",
"reye_pose",
"left_hand_pose",
"right_hand_pose",
]
],
dim=-1,
)
return fullpose
def compute_R_from_kinetree(rot_mats, parents):
"""operation of lbs/batch_rigid_transform, focus on 3x3 R only
Parameters
----------
rot_mats: torch.tensor BxNx3x3
Tensor of rotation matrices
parents : torch.tensor BxN
The kinematic tree of each object
Returns
-------
R : torch.tensor BxNx3x3
Tensor of rotation matrices
"""
rot_mat_chain = [rot_mats[:, 0]]
for i in range(1, parents.shape[0]):
curr_res = torch.matmul(rot_mat_chain[parents[i]], rot_mats[:, i])
rot_mat_chain.append(curr_res)
R = torch.stack(rot_mat_chain, dim=1)
return R
def compute_relR_from_kinetree(R, parents):
"""Inverse operation of lbs/batch_rigid_transform, focus on 3x3 R only
Parameters
----------
R : torch.tensor BxNx4x4 or BxNx3x3
Tensor of rotation matrices
parents : torch.tensor BxN
The kinematic tree of each object
Returns
-------
rot_mats: torch.tensor BxNx3x3
Tensor of rotation matrices
"""
R = R[:, :, :3, :3]
Rp = R[:, parents] # Rp[:, 0] is invalid
rot_mats = Rp.transpose(2, 3) @ R
rot_mats[:, 0] = R[:, 0]
return rot_mats
def quat_mul(x, y):
"""
Performs quaternion multiplication on arrays of quaternions
:param x: tensor of quaternions of shape (..., Nb of joints, 4)
:param y: tensor of quaternions of shape (..., Nb of joints, 4)
:return: The resulting quaternions
"""
x0, x1, x2, x3 = x[..., 0:1], x[..., 1:2], x[..., 2:3], x[..., 3:4]
y0, y1, y2, y3 = y[..., 0:1], y[..., 1:2], y[..., 2:3], y[..., 3:4]
# res = np.concatenate(
# [
# y0 * x0 - y1 * x1 - y2 * x2 - y3 * x3,
# y0 * x1 + y1 * x0 - y2 * x3 + y3 * x2,
# y0 * x2 + y1 * x3 + y2 * x0 - y3 * x1,
# y0 * x3 - y1 * x2 + y2 * x1 + y3 * x0,
# ],
# axis=-1,
# )
res = torch.cat(
[
y0 * x0 - y1 * x1 - y2 * x2 - y3 * x3,
y0 * x1 + y1 * x0 - y2 * x3 + y3 * x2,
y0 * x2 + y1 * x3 + y2 * x0 - y3 * x1,
y0 * x3 - y1 * x2 + y2 * x1 + y3 * x0,
],
axis=-1,
)
return res
def quat_inv(q):
"""
Inverts a tensor of quaternions
:param q: quaternion tensor
:return: tensor of inverted quaternions
"""
# res = np.asarray([1, -1, -1, -1], dtype=np.float32) * q
res = torch.tensor([1, -1, -1, -1], device=q.device).float() * q
return res
def quat_mul_vec(q, x):
"""
Performs multiplication of an array of 3D vectors by an array of quaternions (rotation).
:param q: tensor of quaternions of shape (..., Nb of joints, 4)
:param x: tensor of vectors of shape (..., Nb of joints, 3)
:return: the resulting array of rotated vectors
"""
# t = 2.0 * np.cross(q[..., 1:], x)
t = 2.0 * torch.cross(q[..., 1:], x)
# res = x + q[..., 0][..., np.newaxis] * t + np.cross(q[..., 1:], t)
res = x + q[..., 0][..., None] * t + torch.cross(q[..., 1:], t)
return res
def inverse_kinematics_motion(
global_pos,
global_rot,
parents=SMPLH_PARENTS,
):
"""
Args:
global_pos : (B, T, J-1, 3)
global_rot (q) : (B, T, J-1, 4)
parents : SMPLH_PARENTS
Returns:
local_pos : (B, T, J-1, 3)
local_rot (q) : (B, T, J-1, 4)
"""
J = 22
local_pos = quat_mul_vec(
quat_inv(global_rot[..., parents[1:J], :]),
global_pos - global_pos[..., parents[1:J], :],
)
local_rot = (quat_mul(quat_inv(global_rot[..., parents[1:J], :]), global_rot),)
return local_pos, local_rot
def transform_mat(R, t):
"""Creates a batch of transformation matrices
Args:
- R: Bx3x3 array of a batch of rotation matrices
- t: Bx3x1 array of a batch of translation vectors
Returns:
- T: Bx4x4 Transformation matrix
"""
# No padding left or right, only add an extra row
return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
def normalize_joints(joints):
"""
Args:
joints: (B, *, J, 3)
"""
LR_hips_xy = joints[..., 2, [0, 1]] - joints[..., 1, [0, 1]]
LR_shoulders_xy = joints[..., 17, [0, 1]] - joints[..., 16, [0, 1]]
LR_xy = (LR_hips_xy + LR_shoulders_xy) / 2 # (B, *, J, 2)
x_dir = F.pad(F.normalize(LR_xy, 2, -1), (0, 1), "constant", 0) # (B, *, 3)
z_dir = torch.zeros_like(x_dir) # (B, *, 3)
z_dir[..., 2] = 1
y_dir = torch.cross(z_dir, x_dir, dim=-1)
joints_normalized = (joints - joints[..., [0], :]) @ torch.stack([x_dir, y_dir, z_dir], dim=-1)
return joints_normalized
@torch.no_grad()
def compute_Rt_af2az(joints, inverse=False):
"""Assume z coord is upward
Args:
joints: (B, J, 3), in the start-frame
Returns:
R_af2az: (B, 3, 3)
t_af2az: (B, 3)
"""
t_af2az = joints[:, 0, :].detach().clone()
t_af2az[:, 2] = 0 # do not modify z
LR_xy = joints[:, 2, [0, 1]] - joints[:, 1, [0, 1]] # (B, 2)
I_mask = LR_xy.pow(2).sum(-1) < 1e-4 # do not rotate, when can't decided the face direction
x_dir = F.pad(F.normalize(LR_xy, 2, -1), (0, 1), "constant", 0) # (B, 3)
z_dir = torch.zeros_like(x_dir)
z_dir[..., 2] = 1
y_dir = torch.cross(z_dir, x_dir, dim=-1)
R_af2az = torch.stack([x_dir, y_dir, z_dir], dim=-1) # (B, 3, 3)
R_af2az[I_mask] = torch.eye(3).to(R_af2az)
if inverse:
R_az2af = R_af2az.transpose(1, 2)
t_az2af = -(R_az2af @ t_af2az.unsqueeze(2)).squeeze(2)
return R_az2af, t_az2af
else:
return R_af2az, t_af2az
def finite_difference_forward(x, dim_t=1, dup_last=True):
if dim_t == 1:
v = x[:, 1:] - x[:, :-1]
if dup_last:
v = torch.cat([v, v[:, [-1]]], dim=1)
else:
raise NotImplementedError
return v
def compute_joints_zero(betas, gender):
"""
Args:
betas: (16)
gender: 'male' or 'female'
Returns:
joints_zero: (22, 3)
"""
body_model = {
"male": make_smplx(type="humor", gender="male"),
"female": make_smplx(type="humor", gender="female"),
}
smpl_params = {
"root_orient": torch.zeros((1, 3)),
"pose_body": torch.zeros((1, 63)),
"betas": betas[None],
"trans": torch.zeros(1, 3),
}
joints_zero = body_model[gender](**smpl_params).Jtr[0, :22]
return joints_zero
================================================
FILE: eval/GVHMR/hmr4d/utils/video_io_utils.py
================================================
import imageio.v3 as iio
import numpy as np
import torch
from pathlib import Path
import shutil
import ffmpeg
from tqdm import tqdm
import cv2
def get_video_lwh(video_path):
L, H, W, _ = iio.improps(video_path, plugin="pyav").shape
return L, W, H
def read_video_np(video_path, start_frame=0, end_frame=-1, scale=1.0):
"""
Args:
video_path: str
Returns:
frames: np.array, (N, H, W, 3) RGB, uint8
"""
# If video path not exists, an error will be raised by ffmpegs
filter_args = []
should_check_length = False
# 1. Trim
if not (start_frame == 0 and end_frame == -1):
if end_frame == -1:
filter_args.append(("trim", f"start_frame={start_frame}"))
else:
should_check_length = True
filter_args.append(("trim", f"start_frame={start_frame}:end_frame={end_frame}"))
# 2. Scale
if scale != 1.0:
filter_args.append(("scale", f"iw*{scale}:ih*{scale}"))
# Excute then check
frames = iio.imread(video_path, plugin="pyav", filter_sequence=filter_args)
if should_check_length:
assert len(frames) == end_frame - start_frame
return frames
def get_video_reader(video_path):
return iio.imiter(video_path, plugin="pyav")
def read_images_np(image_paths, verbose=False):
"""
Args:
image_paths: list of str
Returns:
images: np.array, (N, H, W, 3) RGB, uint8
"""
if verbose:
images = [cv2.imread(str(img_path))[..., ::-1] for img_path in tqdm(image_paths)]
else:
images = [cv2.imread(str(img_path))[..., ::-1] for img_path in image_paths]
images = np.stack(images, axis=0)
return images
def save_video(images, video_path, fps=30, crf=17):
"""
Args:
images: (N, H, W, 3) RGB, uint8
crf: 17 is visually lossless, 23 is default, +6 results in half the bitrate
0 is lossless, https://trac.ffmpeg.org/wiki/Encode/H.264#crf
"""
if isinstance(images, torch.Tensor):
images = images.cpu().numpy().astype(np.uint8)
elif isinstance(images, list):
images = np.array(images).astype(np.uint8)
with iio.imopen(video_path, "w", plugin="pyav") as writer:
writer.init_video_stream("libx264", fps=fps)
writer._video_stream.options = {"crf": str(crf)}
writer.write(images)
def get_writer(video_path, fps=30, crf=17):
"""remember to .close()"""
writer = iio.imopen(video_path, "w", plugin="pyav")
writer.init_video_stream("libx264", fps=fps)
writer._video_stream.options = {"crf": str(crf)}
return writer
def copy_file(video_path, out_video_path, overwrite=True):
if not overwrite and Path(out_video_path).exists():
return
shutil.copy(video_path, out_video_path)
def merge_videos_horizontal(in_video_paths: list, out_video_path: str):
if len(in_video_paths) < 2:
raise ValueError("At least two video paths are required for merging.")
inputs = [ffmpeg.input(path) for path in in_video_paths]
merged_video = ffmpeg.filter(inputs, "hstack", inputs=len(inputs))
output = ffmpeg.output(merged_video, out_video_path)
ffmpeg.run(output, overwrite_output=True, quiet=True)
def merge_videos_vertical(in_video_paths: list, out_video_path: str):
if len(in_video_paths) < 2:
raise ValueError("At least two video paths are required for merging.")
inputs = [ffmpeg.input(path) for path in in_video_paths]
merged_video = ffmpeg.filter(inputs, "vstack", inputs=len(inputs))
output = ffmpeg.output(merged_video, out_video_path)
ffmpeg.run(output, overwrite_output=True, quiet=True)
================================================
FILE: eval/GVHMR/hmr4d/utils/vis/README.md
================================================
## Pytorch3D Renderer
Example:
```python
from hmr4d.utils.vis.renderer import Renderer
import imageio
fps = 30
focal_length = data["cam_int"][0][0, 0]
width, height = img_hw
faces = smplh[data["gender"]].bm.faces
renderer = Renderer(width, height, focal_length, "cuda", faces)
writer = imageio.get_writer("tmp_debug.mp4", fps=fps, mode="I", format="FFMPEG", macro_block_size=1)
for i in tqdm(range(length)):
img = np.zeros((height, width, 3), dtype=np.uint8)
img = renderer.render_mesh(smplh_out.vertices[i].cuda(), img)
writer.append_data(img)
writer.close()
```
================================================
FILE: eval/GVHMR/hmr4d/utils/vis/cv2_utils.py
================================================
import torch
import cv2
import numpy as np
from hmr4d.utils.wis3d_utils import get_colors_by_conf
def to_numpy(x):
if isinstance(x, np.ndarray):
return x.copy()
elif isinstance(x, list):
return np.array(x)
return x.clone().cpu().numpy()
def draw_bbx_xys_on_image(bbx_xys, image, conf=True):
assert isinstance(bbx_xys, np.ndarray)
assert isinstance(image, np.ndarray)
image = image.copy()
lu_point = (bbx_xys[:2] - bbx_xys[2:] / 2).astype(int)
rd_point = (bbx_xys[:2] + bbx_xys[2:] / 2).astype(int)
color = (255, 178, 102) if conf == True else (128, 128, 128) # orange or gray
image = cv2.rectangle(image, lu_point, rd_point, color, 2)
return image
def draw_bbx_xys_on_image_batch(bbx_xys_batch, image_batch, conf=None):
"""conf: if provided, list of bool"""
use_conf = conf is not None
bbx_xys_batch = to_numpy(bbx_xys_batch)
assert len(bbx_xys_batch) == len(image_batch)
image_batch_out = []
for i in range(len(bbx_xys_batch)):
if use_conf:
image_batch_out.append(draw_bbx_xys_on_image(bbx_xys_batch[i], image_batch[i], conf[i]))
else:
image_batch_out.append(draw_bbx_xys_on_image(bbx_xys_batch[i], image_batch[i]))
return image_batch_out
def draw_bbx_xyxy_on_image(bbx_xys, image, conf=True):
bbx_xys = to_numpy(bbx_xys)
image = to_numpy(image)
color = (255, 178, 102) if conf == True else (128, 128, 128) # orange or gray
image = cv2.rectangle(image, (int(bbx_xys[0]), int(bbx_xys[1])), (int(bbx_xys[2]), int(bbx_xys[3])), color, 2)
return image
def draw_bbx_xyxy_on_image_batch(bbx_xyxy_batch, image_batch, mask=None, conf=None):
"""
Args:
conf: if provided, list of bool, mutually exclusive with mask
mask: whether to draw, historically used
"""
if mask is not None:
assert conf is None
if conf is not None:
assert mask is None
use_conf = conf is not None
bbx_xyxy_batch = to_numpy(bbx_xyxy_batch)
image_batch = to_numpy(image_batch)
assert len(bbx_xyxy_batch) == len(image_batch)
image_batch_out = []
for i in range(len(bbx_xyxy_batch)):
if use_conf:
image_batch_out.append(draw_bbx_xyxy_on_image(bbx_xyxy_batch[i], image_batch[i], conf[i]))
else:
if mask is None or mask[i]:
image_batch_out.append(draw_bbx_xyxy_on_image(bbx_xyxy_batch[i], image_batch[i]))
else:
image_batch_out.append(image_batch[i])
return image_batch_out
def draw_kpts(frame, keypoints, color=(0, 255, 0), thickness=2):
frame_ = frame.copy()
for x, y in keypoints:
cv2.circle(frame_, (int(x), int(y)), thickness, color, -1)
return frame_
def draw_kpts_with_conf(frame, kp2d, conf, thickness=2):
"""
Args:
kp2d: (J, 2),
conf: (J,)
"""
frame_ = frame.copy()
conf = conf.reshape(-1)
colors = get_colors_by_conf(conf) # (J, 3)
colors = colors[:, [2, 1, 0]].int().numpy().tolist()
for j in range(kp2d.shape[0]):
x, y = kp2d[j, :2]
c = colors[j]
cv2.circle(frame_, (int(x), int(y)), thickness, c, -1)
return frame_
def draw_kpts_with_conf_batch(frames, kp2d_batch, conf_batch, thickness=2):
"""
Args:
kp2d_batch: (B, J, 2),
conf_batch: (B, J)
"""
assert len(frames) == len(kp2d_batch)
assert len(frames) == len(conf_batch)
frames_ = []
for i in range(len(frames)):
frames_.append(draw_kpts_with_conf(frames[i], kp2d_batch[i], conf_batch[i], thickness))
return frames_
def draw_coco17_skeleton(img, keypoints, conf_thr=0):
use_conf_thr = True if keypoints.shape[1] == 3 else False
img = img.copy()
# fmt:off
coco_skel = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], [8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]]
# fmt:on
for bone in coco_skel:
if use_conf_thr:
kp1 = keypoints[bone[0]][:2].astype(int)
kp2 = keypoints[bone[1]][:2].astype(int)
kp1_c = keypoints[bone[0]][2]
kp2_c = keypoints[bone[1]][2]
if kp1_c > conf_thr and kp2_c > conf_thr:
img = cv2.line(img, (kp1[0], kp1[1]), (kp2[0], kp2[1]), (0, 255, 0), 4)
if kp1_c > conf_thr:
img = cv2.circle(img, (kp1[0], kp1[1]), 6, (0, 255, 0), -1)
if kp2_c > conf_thr:
img = cv2.circle(img, (kp2[0], kp2[1]), 6, (0, 255, 0), -1)
else:
kp1 = keypoints[bone[0]][:2].astype(int)
kp2 = keypoints[bone[1]][:2].astype(int)
img = cv2.line(img, (kp1[0], kp1[1]), (kp2[0], kp2[1]), (0, 255, 0), 4)
return img
def draw_coco17_skeleton_batch(imgs, keypoints_batch, conf_thr=0):
assert len(imgs) == len(keypoints_batch)
keypoints_batch = to_numpy(keypoints_batch)
imgs_out = []
for i in range(len(imgs)):
imgs_out.append(draw_coco17_skeleton(imgs[i], keypoints_batch[i], conf_thr))
return imgs_out
================================================
FILE: eval/GVHMR/hmr4d/utils/vis/renderer.py
================================================
import cv2
import torch
import numpy as np
from pytorch3d.renderer import (
PerspectiveCameras,
TexturesVertex,
PointLights,
Materials,
RasterizationSettings,
MeshRenderer,
MeshRasterizer,
SoftPhongShader,
)
from pytorch3d.structures import Meshes
from pytorch3d.structures.meshes import join_meshes_as_scene
from pytorch3d.renderer.cameras import look_at_rotation
from pytorch3d.transforms import axis_angle_to_matrix
from .renderer_tools import get_colors, checkerboard_geometry
colors_str_map = {
"gray": [0.8, 0.8, 0.8],
"green": [39, 194, 128],
}
def overlay_image_onto_background(image, mask, bbox, background):
if isinstance(image, torch.Tensor):
image = image.detach().cpu().numpy()
if isinstance(mask, torch.Tensor):
mask = mask.detach().cpu().numpy()
out_image = background.copy()
bbox = bbox[0].int().cpu().numpy().copy()
roi_image = out_image[bbox[1] : bbox[3], bbox[0] : bbox[2]]
roi_image[mask] = image[mask]
out_image[bbox[1] : bbox[3], bbox[0] : bbox[2]] = roi_image
return out_image
def update_intrinsics_from_bbox(K_org, bbox):
device, dtype = K_org.device, K_org.dtype
K = torch.zeros((K_org.shape[0], 4, 4)).to(device=device, dtype=dtype)
K[:, :3, :3] = K_org.clone()
K[:, 2, 2] = 0
K[:, 2, -1] = 1
K[:, -1, 2] = 1
image_sizes = []
for idx, bbox in enumerate(bbox):
left, upper, right, lower = bbox
cx, cy = K[idx, 0, 2], K[idx, 1, 2]
new_cx = cx - left
new_cy = cy - upper
new_height = max(lower - upper, 1)
new_width = max(right - left, 1)
new_cx = new_width - new_cx
new_cy = new_height - new_cy
K[idx, 0, 2] = new_cx
K[idx, 1, 2] = new_cy
image_sizes.append((int(new_height), int(new_width)))
return K, image_sizes
def perspective_projection(x3d, K, R=None, T=None):
if R != None:
x3d = torch.matmul(R, x3d.transpose(1, 2)).transpose(1, 2)
if T != None:
x3d = x3d + T.transpose(1, 2)
x2d = torch.div(x3d, x3d[..., 2:])
x2d = torch.matmul(K, x2d.transpose(-1, -2)).transpose(-1, -2)[..., :2]
return x2d
def compute_bbox_from_points(X, img_w, img_h, scaleFactor=1.2):
left = torch.clamp(X.min(1)[0][:, 0], min=0, max=img_w)
right = torch.clamp(X.max(1)[0][:, 0], min=0, max=img_w)
top = torch.clamp(X.min(1)[0][:, 1], min=0, max=img_h)
bottom = torch.clamp(X.max(1)[0][:, 1], min=0, max=img_h)
cx = (left + right) / 2
cy = (top + bottom) / 2
width = right - left
height = bottom - top
new_left = torch.clamp(cx - width / 2 * scaleFactor, min=0, max=img_w - 1)
new_right = torch.clamp(cx + width / 2 * scaleFactor, min=1, max=img_w)
new_top = torch.clamp(cy - height / 2 * scaleFactor, min=0, max=img_h - 1)
new_bottom = torch.clamp(cy + height / 2 * scaleFactor, min=1, max=img_h)
bbox = torch.stack((new_left.detach(), new_top.detach(), new_right.detach(), new_bottom.detach())).int().float().T
return bbox
class Renderer:
def __init__(self, width, height, focal_length=None, device="cuda", faces=None, K=None, bin_size=None):
"""set bin_size to 0 for no binning"""
self.width = width
self.height = height
self.bin_size = bin_size
assert (focal_length is not None) ^ (K is not None), "focal_length and K are mutually exclusive"
self.device = device
if faces is not None:
if isinstance(faces, np.ndarray):
faces = torch.from_numpy((faces).astype("int"))
self.faces = faces.unsqueeze(0).to(self.device)
self.initialize_camera_params(focal_length, K)
self.lights = PointLights(device=device, location=[[0.0, 0.0, -10.0]])
self.create_renderer()
def create_renderer(self):
self.renderer = MeshRenderer(
rasterizer=MeshRasterizer(
raster_settings=RasterizationSettings(
image_size=self.image_sizes[0], blur_radius=1e-5, bin_size=self.bin_size
),
),
shader=SoftPhongShader(
device=self.device,
lights=self.lights,
),
)
def create_camera(self, R=None, T=None):
if R is not None:
self.R = R.clone().view(1, 3, 3).to(self.device)
if T is not None:
self.T = T.clone().view(1, 3).to(self.device)
return PerspectiveCameras(
device=self.device, R=self.R.mT, T=self.T, K=self.K_full, image_size=self.image_sizes, in_ndc=False
)
def initialize_camera_params(self, focal_length, K):
# Extrinsics
self.R = torch.diag(torch.tensor([1, 1, 1])).float().to(self.device).unsqueeze(0)
self.T = torch.tensor([0, 0, 0]).unsqueeze(0).float().to(self.device)
# Intrinsics
if K is not None:
self.K = K.float().reshape(1, 3, 3).to(self.device)
else:
assert focal_length is not None, "focal_length or K should be provided"
self.K = (
torch.tensor([[focal_length, 0, self.width / 2], [0, focal_length, self.height / 2], [0, 0, 1]])
.float()
.reshape(1, 3, 3)
.to(self.device)
)
self.bboxes = torch.tensor([[0, 0, self.width, self.height]]).float()
self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, self.bboxes)
self.cameras = self.create_camera()
def set_intrinsic(self, K):
self.K = K.reshape(1, 3, 3)
def set_ground(self, length, center_x, center_z):
device = self.device
length, center_x, center_z = map(float, (length, center_x, center_z))
v, f, vc, fc = map(torch.from_numpy, checkerboard_geometry(length=length, c1=center_x, c2=center_z, up="y"))
v, f, vc = v.to(device), f.to(device), vc.to(device)
self.ground_geometry = [v, f, vc]
def update_bbox(self, x3d, scale=2.0, mask=None):
"""Update bbox of cameras from the given 3d points
x3d: input 3D keypoints (or vertices), (num_frames, num_points, 3)
"""
if x3d.size(-1) != 3:
x2d = x3d.unsqueeze(0)
else:
x2d = perspective_projection(x3d.unsqueeze(0), self.K, self.R, self.T.reshape(1, 3, 1))
if mask is not None:
x2d = x2d[:, ~mask]
bbox = compute_bbox_from_points(x2d, self.width, self.height, scale)
self.bboxes = bbox
self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox)
self.cameras = self.create_camera()
self.create_renderer()
def reset_bbox(
self,
):
bbox = torch.zeros((1, 4)).float().to(self.device)
bbox[0, 2] = self.width
bbox[0, 3] = self.height
self.bboxes = bbox
self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox)
self.cameras = self.create_camera()
self.create_renderer()
def render_mesh(self, vertices, background=None, colors=[0.8, 0.8, 0.8], VI=50):
self.update_bbox(vertices[::VI], scale=1.2)
vertices = vertices.unsqueeze(0)
if isinstance(colors, torch.Tensor):
# per-vertex color
verts_features = colors.to(device=vertices.device, dtype=vertices.dtype)
colors = [0.8, 0.8, 0.8]
else:
if colors[0] > 1:
colors = [c / 255.0 for c in colors]
verts_features = torch.tensor(colors).reshape(1, 1, 3).to(device=vertices.device, dtype=vertices.dtype)
verts_features = verts_features.repeat(1, vertices.shape[1], 1)
textures = TexturesVertex(verts_features=verts_features)
mesh = Meshes(
verts=vertices,
faces=self.faces,
textures=textures,
)
materials = Materials(device=self.device, specular_color=(colors,), shininess=0)
results = torch.flip(self.renderer(mesh, materials=materials, cameras=self.cameras, lights=self.lights), [1, 2])
image = results[0, ..., :3] * 255
mask = results[0, ..., -1] > 1e-3
if background is None:
background = np.ones((self.height, self.width, 3)).astype(np.uint8) * 255
image = overlay_image_onto_background(image, mask, self.bboxes, background.copy())
self.reset_bbox()
return image
def render_with_ground(self, verts, colors, cameras, lights, faces=None):
"""
:param verts (N, V, 3), potential multiple people
:param colors (N, 3) or (N, V, 3)
:param faces (N, F, 3), optional, otherwise self.faces is used will be used
"""
# Sanity check of input verts, colors and faces: (B, V, 3), (B, F, 3), (B, V, 3)
N, V, _ = verts.shape
if faces is None:
faces = self.faces.clone().expand(N, -1, -1)
else:
assert len(faces.shape) == 3, "faces should have shape of (N, F, 3)"
assert len(colors.shape) in [2, 3]
if len(colors.shape) == 2:
assert len(colors) == N, "colors of shape 2 should be (N, 3)"
colors = colors[:, None]
colors = colors.expand(N, V, -1)[..., :3]
# (V, 3), (F, 3), (V, 3)
gv, gf, gc = self.ground_geometry
verts = list(torch.unbind(verts, dim=0)) + [gv]
faces = list(torch.unbind(faces, dim=0)) + [gf]
colors = list(torch.unbind(colors, dim=0)) + [gc[..., :3]]
mesh = create_meshes(verts, faces, colors)
materials = Materials(device=self.device, shininess=0)
results = self.renderer(mesh, cameras=cameras, lights=lights, materials=materials)
image = (results[0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
return image
def create_meshes(verts, faces, colors):
"""
:param verts (B, V, 3)
:param faces (B, F, 3)
:param colors (B, V, 3)
"""
textures = TexturesVertex(verts_features=colors)
meshes = Meshes(verts=verts, faces=faces, textures=textures)
return join_meshes_as_scene(meshes)
def get_global_cameras(verts, device="cuda", distance=5, position=(-5.0, 5.0, 0.0)):
"""This always put object at the center of view"""
positions = torch.tensor([position]).repeat(len(verts), 1)
targets = verts.mean(1)
directions = targets - positions
directions = directions / torch.norm(directions, dim=-1).unsqueeze(-1) * distance
positions = targets - directions
rotation = look_at_rotation(positions, targets).mT
translation = -(rotation @ positions.unsqueeze(-1)).squeeze(-1)
lights = PointLights(device=device, location=[position])
return rotation, translation, lights
def get_global_cameras_static(
verts, beta=4.0, cam_height_degree=30, target_center_height=1.0, use_long_axis=False, vec_rot=45, device="cuda"
):
L, V, _ = verts.shape
# Compute target trajectory, denote as center + scale
targets = verts.mean(1) # (L, 3)
targets[:, 1] = 0 # project to xz-plane
target_center = targets.mean(0) # (3,)
target_scale, target_idx = torch.norm(targets - target_center, dim=-1).max(0)
# a 45 degree vec from longest axis
if use_long_axis:
long_vec = targets[target_idx] - target_center # (x, 0, z)
long_vec = long_vec / torch.norm(long_vec)
R = axis_angle_to_matrix(torch.tensor([0, np.pi / 4, 0])).to(long_vec)
vec = R @ long_vec
else:
vec_rad = vec_rot / 180 * np.pi
vec = torch.tensor([np.sin(vec_rad), 0, np.cos(vec_rad)]).float()
vec = vec / torch.norm(vec)
# Compute camera position (center + scale * vec * beta) + y=4
target_scale = max(target_scale, 1.0) * beta
position = target_center + vec * target_scale
position[1] = target_scale * np.tan(np.pi * cam_height_degree / 180) + target_center_height
# Compute camera rotation and translation
positions = position.unsqueeze(0).repeat(L, 1)
target_centers = target_center.unsqueeze(0).repeat(L, 1)
target_centers[:, 1] = target_center_height
rotation = look_at_rotation(positions, target_centers).mT
translation = -(rotation @ positions.unsqueeze(-1)).squeeze(-1)
lights = PointLights(device=device, location=[position.tolist()])
return rotation, translation, lights
def get_ground_params_from_points(root_points, vert_points):
"""xz-plane is the ground plane
Args:
root_points: (L, 3), to decide center
vert_points: (L, V, 3), to decide scale
"""
root_max = root_points.max(0)[0] # (3,)
root_min = root_points.min(0)[0] # (3,)
cx, _, cz = (root_max + root_min) / 2.0
vert_max = vert_points.reshape(-1, 3).max(0)[0] # (L, 3)
vert_min = vert_points.reshape(-1, 3).min(0)[0] # (L, 3)
scale = (vert_max - vert_min)[[0, 2]].max()
return float(scale), float(cx), float(cz)
================================================
FILE: eval/GVHMR/hmr4d/utils/vis/renderer_tools.py
================================================
import os
import cv2
import numpy as np
import torch
from PIL import Image
def read_image(path, scale=1):
im = Image.open(path)
if scale == 1:
return np.array(im)
W, H = im.size
w, h = int(scale * W), int(scale * H)
return np.array(im.resize((w, h), Image.ANTIALIAS))
def transform_torch3d(T_c2w):
"""
:param T_c2w (*, 4, 4)
returns (*, 3, 3), (*, 3)
"""
R1 = torch.tensor(
[
[-1.0, 0.0, 0.0],
[0.0, -1.0, 0.0],
[0.0, 0.0, 1.0],
],
device=T_c2w.device,
)
R2 = torch.tensor(
[
[1.0, 0.0, 0.0],
[0.0, -1.0, 0.0],
[0.0, 0.0, -1.0],
],
device=T_c2w.device,
)
cam_R, cam_t = T_c2w[..., :3, :3], T_c2w[..., :3, 3]
cam_R = torch.einsum("...ij,jk->...ik", cam_R, R1)
cam_t = torch.einsum("ij,...j->...i", R2, cam_t)
return cam_R, cam_t
def transform_pyrender(T_c2w):
"""
:param T_c2w (*, 4, 4)
"""
T_vis = torch.tensor(
[
[1.0, 0.0, 0.0, 0.0],
[0.0, -1.0, 0.0, 0.0],
[0.0, 0.0, -1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
],
device=T_c2w.device,
)
return torch.einsum("...ij,jk->...ik", torch.einsum("ij,...jk->...ik", T_vis, T_c2w), T_vis)
def smpl_to_geometry(verts, faces, vis_mask=None, track_ids=None):
"""
:param verts (B, T, V, 3)
:param faces (F, 3)
:param vis_mask (optional) (B, T) visibility of each person
:param track_ids (optional) (B,)
returns list of T verts (B, V, 3), faces (F, 3), colors (B, 3)
where B is different depending on the visibility of the people
"""
B, T = verts.shape[:2]
device = verts.device
# (B, 3)
colors = track_to_colors(track_ids) if track_ids is not None else torch.ones(B, 3, device) * 0.5
# list T (B, V, 3), T (B, 3), T (F, 3)
return filter_visible_meshes(verts, colors, faces, vis_mask)
def filter_visible_meshes(verts, colors, faces, vis_mask=None, vis_opacity=False):
"""
:param verts (B, T, V, 3)
:param colors (B, 3)
:param faces (F, 3)
:param vis_mask (optional tensor, default None) (B, T) ternary mask
-1 if not in frame
0 if temporarily occluded
1 if visible
:param vis_opacity (optional bool, default False)
if True, make occluded people alpha=0.5, otherwise alpha=1
returns a list of T lists verts (Bi, V, 3), colors (Bi, 4), faces (F, 3)
"""
# import ipdb; ipdb.set_trace()
B, T = verts.shape[:2]
faces = [faces for t in range(T)]
if vis_mask is None:
verts = [verts[:, t] for t in range(T)]
colors = [colors for t in range(T)]
return verts, colors, faces
# render occluded and visible, but not removed
vis_mask = vis_mask >= 0
if vis_opacity:
alpha = 0.5 * (vis_mask[..., None] + 1)
else:
alpha = (vis_mask[..., None] >= 0).float()
vert_list = [verts[vis_mask[:, t], t] for t in range(T)]
colors = [torch.cat([colors[vis_mask[:, t]], alpha[vis_mask[:, t], t]], dim=-1) for t in range(T)]
bounds = get_bboxes(verts, vis_mask)
return vert_list, colors, faces, bounds
def get_bboxes(verts, vis_mask):
"""
return bb_min, bb_max, and mean for each track (B, 3) over entire trajectory
:param verts (B, T, V, 3)
:param vis_mask (B, T)
"""
B, T, *_ = verts.shape
bb_min, bb_max, mean = [], [], []
for b in range(B):
v = verts[b, vis_mask[b, :T]] # (Tb, V, 3)
bb_min.append(v.amin(dim=(0, 1)))
bb_max.append(v.amax(dim=(0, 1)))
mean.append(v.mean(dim=(0, 1)))
bb_min = torch.stack(bb_min, dim=0)
bb_max = torch.stack(bb_max, dim=0)
mean = torch.stack(mean, dim=0)
# point to a track that's long and close to the camera
zs = mean[:, 2]
counts = vis_mask[:, :T].sum(dim=-1) # (B,)
mask = counts < 0.8 * T
zs[mask] = torch.inf
sel = torch.argmin(zs)
return bb_min.amin(dim=0), bb_max.amax(dim=0), mean[sel]
def track_to_colors(track_ids):
"""
:param track_ids (B)
"""
color_map = torch.from_numpy(get_colors()).to(track_ids)
return color_map[track_ids] / 255 # (B, 3)
def get_colors():
# color_file = os.path.abspath(os.path.join(__file__, "../colors_phalp.txt"))
color_file = os.path.abspath(os.path.join(__file__, "../colors.txt"))
RGB_tuples = np.vstack(
[
np.loadtxt(color_file, skiprows=0),
# np.loadtxt(color_file, skiprows=1),
np.random.uniform(0, 255, size=(10000, 3)),
[[0, 0, 0]],
]
)
b = np.where(RGB_tuples == 0)
RGB_tuples[b] = 1
return RGB_tuples.astype(np.float32)
def checkerboard_geometry(
length=12.0,
color0=[0.8, 0.9, 0.9],
color1=[0.6, 0.7, 0.7],
tile_width=0.5,
alpha=1.0,
up="y",
c1=0.0,
c2=0.0,
):
assert up == "y" or up == "z"
color0 = np.array(color0 + [alpha])
color1 = np.array(color1 + [alpha])
num_rows = num_cols = max(2, int(length / tile_width))
radius = float(num_rows * tile_width) / 2.0
vertices = []
vert_colors = []
faces = []
face_colors = []
for i in range(num_rows):
for j in range(num_cols):
u0, v0 = j * tile_width - radius, i * tile_width - radius
us = np.array([u0, u0, u0 + tile_width, u0 + tile_width])
vs = np.array([v0, v0 + tile_width, v0 + tile_width, v0])
zs = np.zeros(4)
if up == "y":
cur_verts = np.stack([us, zs, vs], axis=-1) # (4, 3)
cur_verts[:, 0] += c1
cur_verts[:, 2] += c2
else:
cur_verts = np.stack([us, vs, zs], axis=-1) # (4, 3)
cur_verts[:, 0] += c1
cur_verts[:, 1] += c2
cur_faces = np.array([[0, 1, 3], [1, 2, 3], [0, 3, 1], [1, 3, 2]], dtype=np.int64)
cur_faces += 4 * (i * num_cols + j) # the number of previously added verts
use_color0 = (i % 2 == 0 and j % 2 == 0) or (i % 2 == 1 and j % 2 == 1)
cur_color = color0 if use_color0 else color1
cur_colors = np.array([cur_color, cur_color, cur_color, cur_color])
vertices.append(cur_verts)
faces.append(cur_faces)
vert_colors.append(cur_colors)
face_colors.append(cur_colors)
vertices = np.concatenate(vertices, axis=0).astype(np.float32)
vert_colors = np.concatenate(vert_colors, axis=0).astype(np.float32)
faces = np.concatenate(faces, axis=0).astype(np.float32)
face_colors = np.concatenate(face_colors, axis=0).astype(np.float32)
return vertices, faces, vert_colors, face_colors
def camera_marker_geometry(radius, height, up):
assert up == "y" or up == "z"
if up == "y":
vertices = np.array(
[
[-radius, -radius, 0],
[radius, -radius, 0],
[radius, radius, 0],
[-radius, radius, 0],
[0, 0, height],
]
)
else:
vertices = np.array(
[
[-radius, 0, -radius],
[radius, 0, -radius],
[radius, 0, radius],
[-radius, 0, radius],
[0, -height, 0],
]
)
faces = np.array(
[
[0, 3, 1],
[1, 3, 2],
[0, 1, 4],
[1, 2, 4],
[2, 3, 4],
[3, 0, 4],
]
)
face_colors = np.array(
[
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
[0.0, 1.0, 0.0, 1.0],
[1.0, 0.0, 0.0, 1.0],
[0.0, 1.0, 0.0, 1.0],
[1.0, 0.0, 0.0, 1.0],
]
)
return vertices, faces, face_colors
def vis_keypoints(
keypts_list,
img_size,
radius=6,
thickness=3,
kpt_score_thr=0.3,
dataset="TopDownCocoDataset",
):
"""
Visualize keypoints
From ViTPose/mmpose/apis/inference.py
"""
palette = np.array(
[
[255, 128, 0],
[255, 153, 51],
[255, 178, 102],
[230, 230, 0],
[255, 153, 255],
[153, 204, 255],
[255, 102, 255],
[255, 51, 255],
[102, 178, 255],
[51, 153, 255],
[255, 153, 153],
[255, 102, 102],
[255, 51, 51],
[153, 255, 153],
[102, 255, 102],
[51, 255, 51],
[0, 255, 0],
[0, 0, 255],
[255, 0, 0],
[255, 255, 255],
]
)
if dataset in (
"TopDownCocoDataset",
"BottomUpCocoDataset",
"TopDownOCHumanDataset",
"AnimalMacaqueDataset",
):
# show the results
skeleton = [
[15, 13],
[13, 11],
[16, 14],
[14, 12],
[11, 12],
[5, 11],
[6, 12],
[5, 6],
[5, 7],
[6, 8],
[7, 9],
[8, 10],
[1, 2],
[0, 1],
[0, 2],
[1, 3],
[2, 4],
[3, 5],
[4, 6],
]
pose_link_color = palette[[0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]]
pose_kpt_color = palette[[16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]]
elif dataset == "TopDownCocoWholeBodyDataset":
# show the results
skeleton = [
[15, 13],
[13, 11],
[16, 14],
[14, 12],
[11, 12],
[5, 11],
[6, 12],
[5, 6],
[5, 7],
[6, 8],
[7, 9],
[8, 10],
[1, 2],
[0, 1],
[0, 2],
[1, 3],
[2, 4],
[3, 5],
[4, 6],
[15, 17],
[15, 18],
[15, 19],
[16, 20],
[16, 21],
[16, 22],
[91, 92],
[92, 93],
[93, 94],
[94, 95],
[91, 96],
[96, 97],
[97, 98],
[98, 99],
[91, 100],
[100, 101],
[101, 102],
[102, 103],
[91, 104],
[104, 105],
[105, 106],
[106, 107],
[91, 108],
[108, 109],
[109, 110],
[110, 111],
[112, 113],
[113, 114],
[114, 115],
[115, 116],
[112, 117],
[117, 118],
[118, 119],
[119, 120],
[112, 121],
[121, 122],
[122, 123],
[123, 124],
[112, 125],
[125, 126],
[126, 127],
[127, 128],
[112, 129],
[129, 130],
[130, 131],
[131, 132],
]
pose_link_color = palette[
[0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]
+ [16, 16, 16, 16, 16, 16]
+ [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]
+ [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]
]
pose_kpt_color = palette[
[16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 0, 0] + [19] * (68 + 42)
]
elif dataset == "TopDownAicDataset":
skeleton = [
[2, 1],
[1, 0],
[0, 13],
[13, 3],
[3, 4],
[4, 5],
[8, 7],
[7, 6],
[6, 9],
[9, 10],
[10, 11],
[12, 13],
[0, 6],
[3, 9],
]
pose_link_color = palette[[9, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 0, 7, 7]]
pose_kpt_color = palette[[9, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 0, 0]]
elif dataset == "TopDownMpiiDataset":
skeleton = [
[0, 1],
[1, 2],
[2, 6],
[6, 3],
[3, 4],
[4, 5],
[6, 7],
[7, 8],
[8, 9],
[8, 12],
[12, 11],
[11, 10],
[8, 13],
[13, 14],
[14, 15],
]
pose_link_color = palette[[16, 16, 16, 16, 16, 16, 7, 7, 0, 9, 9, 9, 9, 9, 9]]
pose_kpt_color = palette[[16, 16, 16, 16, 16, 16, 7, 7, 0, 0, 9, 9, 9, 9, 9, 9]]
elif dataset == "TopDownMpiiTrbDataset":
skeleton = [
[12, 13],
[13, 0],
[13, 1],
[0, 2],
[1, 3],
[2, 4],
[3, 5],
[0, 6],
[1, 7],
[6, 7],
[6, 8],
[7, 9],
[8, 10],
[9, 11],
[14, 15],
[16, 17],
[18, 19],
[20, 21],
[22, 23],
[24, 25],
[26, 27],
[28, 29],
[30, 31],
[32, 33],
[34, 35],
[36, 37],
[38, 39],
]
pose_link_color = palette[[16] * 14 + [19] * 13]
pose_kpt_color = palette[[16] * 14 + [0] * 26]
elif dataset in ("OneHand10KDataset", "FreiHandDataset", "PanopticDataset"):
skeleton = [
[0, 1],
[1, 2],
[2, 3],
[3, 4],
[0, 5],
[5, 6],
[6, 7],
[7, 8],
[0, 9],
[9, 10],
[10, 11],
[11, 12],
[0, 13],
[13, 14],
[14, 15],
[15, 16],
[0, 17],
[17, 18],
[18, 19],
[19, 20],
]
pose_link_color = palette[[0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]]
pose_kpt_color = palette[[0, 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]]
elif dataset == "InterHand2DDataset":
skeleton = [
[0, 1],
[1, 2],
[2, 3],
[4, 5],
[5, 6],
[6, 7],
[8, 9],
[9, 10],
[10, 11],
[12, 13],
[13, 14],
[14, 15],
[16, 17],
[17, 18],
[18, 19],
[3, 20],
[7, 20],
[11, 20],
[15, 20],
[19, 20],
]
pose_link_color = palette[[0, 0, 0, 4, 4, 4, 8, 8, 8, 12, 12, 12, 16, 16, 16, 0, 4, 8, 12, 16]]
pose_kpt_color = palette[[0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16, 0]]
elif dataset == "Face300WDataset":
# show the results
skeleton = []
pose_link_color = palette[[]]
pose_kpt_color = palette[[19] * 68]
kpt_score_thr = 0
elif dataset == "FaceAFLWDataset":
# show the results
skeleton = []
pose_link_color = palette[[]]
pose_kpt_color = palette[[19] * 19]
kpt_score_thr = 0
elif dataset == "FaceCOFWDataset":
# show the results
skeleton = []
pose_link_color = palette[[]]
pose_kpt_color = palette[[19] * 29]
kpt_score_thr = 0
elif dataset == "FaceWFLWDataset":
# show the results
skeleton = []
pose_link_color = palette[[]]
pose_kpt_color = palette[[19] * 98]
kpt_score_thr = 0
elif dataset == "AnimalHorse10Dataset":
skeleton = [
[0, 1],
[1, 12],
[12, 16],
[16, 21],
[21, 17],
[17, 11],
[11, 10],
[10, 8],
[8, 9],
[9, 12],
[2, 3],
[3, 4],
[5, 6],
[6, 7],
[13, 14],
[14, 15],
[18, 19],
[19, 20],
]
pose_link_color = palette[[4] * 10 + [6] * 2 + [6] * 2 + [7] * 2 + [7] * 2]
pose_kpt_color = palette[[4, 4, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, 4, 7, 7, 7, 4, 4, 7, 7, 7, 4]]
elif dataset == "AnimalFlyDataset":
skeleton = [
[1, 0],
[2, 0],
[3, 0],
[4, 3],
[5, 4],
[7, 6],
[8, 7],
[9, 8],
[11, 10],
[12, 11],
[13, 12],
[15, 14],
[16, 15],
[17, 16],
[19, 18],
[20, 19],
[21, 20],
[23, 22],
[24, 23],
[25, 24],
[27, 26],
[28, 27],
[29, 28],
[30, 3],
[31, 3],
]
pose_link_color = palette[[0] * 25]
pose_kpt_color = palette[[0] * 32]
elif dataset == "AnimalLocustDataset":
skeleton = [
[1, 0],
[2, 1],
[3, 2],
[4, 3],
[6, 5],
[7, 6],
[9, 8],
[10, 9],
[11, 10],
[13, 12],
[14, 13],
[15, 14],
[17, 16],
[18, 17],
[19, 18],
[21, 20],
[22, 21],
[24, 23],
[25, 24],
[26, 25],
[28, 27],
[29, 28],
[30, 29],
[32, 31],
[33, 32],
[34, 33],
]
pose_link_color = palette[[0] * 26]
pose_kpt_color = palette[[0] * 35]
elif dataset == "AnimalZebraDataset":
skeleton = [[1, 0], [2, 1], [3, 2], [4, 2], [5, 7], [6, 7], [7, 2], [8, 7]]
pose_link_color = palette[[0] * 8]
pose_kpt_color = palette[[0] * 9]
elif dataset in "AnimalPoseDataset":
skeleton = [
[0, 1],
[0, 2],
[1, 3],
[0, 4],
[1, 4],
[4, 5],
[5, 7],
[6, 7],
[5, 8],
[8, 12],
[12, 16],
[5, 9],
[9, 13],
[13, 17],
[6, 10],
[10, 14],
[14, 18],
[6, 11],
[11, 15],
[15, 19],
]
pose_link_color = palette[[0] * 20]
pose_kpt_color = palette[[0] * 20]
else:
NotImplementedError()
img_w, img_h = img_size
img = 255 * np.ones((img_h, img_w, 3), dtype=np.uint8)
img = imshow_keypoints(
img,
keypts_list,
skeleton,
kpt_score_thr,
pose_kpt_color,
pose_link_color,
radius,
thickness,
)
alpha = 255 * (img != 255).any(axis=-1, keepdims=True).astype(np.uint8)
return np.concatenate([img, alpha], axis=-1)
def imshow_keypoints(
img,
pose_result,
skeleton=None,
kpt_score_thr=0.3,
pose_kpt_color=None,
pose_link_color=None,
radius=4,
thickness=1,
show_keypoint_weight=False,
):
"""Draw keypoints and links on an image.
From ViTPose/mmpose/core/visualization/image.py
Args:
img (H, W, 3) array
pose_result (list[kpts]): The poses to draw. Each element kpts is
a set of K keypoints as an Kx3 numpy.ndarray, where each
keypoint is represented as x, y, score.
kpt_score_thr (float, optional): Minimum score of keypoints
to be shown. Default: 0.3.
pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
the keypoint will not be drawn.
pose_link_color (np.array[Mx3]): Color of M links. If None, the
links will not be drawn.
thickness (int): Thickness of lines.
show_keypoint_weight (bool): If True, opacity indicates keypoint score
"""
img_h, img_w, _ = img.shape
idcs = [0, 16, 15, 18, 17, 5, 2, 6, 3, 7, 4, 12, 9, 13, 10, 14, 11]
for kpts in pose_result:
kpts = np.array(kpts, copy=False)[idcs]
# draw each point on image
if pose_kpt_color is not None:
assert len(pose_kpt_color) == len(kpts)
for kid, kpt in enumerate(kpts):
x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
if kpt_score > kpt_score_thr:
color = tuple(int(c) for c in pose_kpt_color[kid])
if show_keypoint_weight:
img_copy = img.copy()
cv2.circle(img_copy, (int(x_coord), int(y_coord)), radius, color, -1)
transparency = max(0, min(1, kpt_score))
cv2.addWeighted(img_copy, transparency, img, 1 - transparency, 0, dst=img)
else:
cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1)
# draw links
if skeleton is not None and pose_link_color is not None:
assert len(pose_link_color) == len(skeleton)
for sk_id, sk in enumerate(skeleton):
pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
if (
pos1[0] > 0
and pos1[0] < img_w
and pos1[1] > 0
and pos1[1] < img_h
and pos2[0] > 0
and pos2[0] < img_w
and pos2[1] > 0
and pos2[1] < img_h
and kpts[sk[0], 2] > kpt_score_thr
and kpts[sk[1], 2] > kpt_score_thr
):
color = tuple(int(c) for c in pose_link_color[sk_id])
if show_keypoint_weight:
img_copy = img.copy()
X = (pos1[0], pos2[0])
Y = (pos1[1], pos2[1])
mX = np.mean(X)
mY = np.mean(Y)
length = ((Y[0] - Y[1]) ** 2 + (X[0] - X[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1]))
stickwidth = 2
polygon = cv2.ellipse2Poly(
(int(mX), int(mY)),
(int(length / 2), int(stickwidth)),
int(angle),
0,
360,
1,
)
cv2.fillConvexPoly(img_copy, polygon, color)
transparency = max(0, min(1, 0.5 * (kpts[sk[0], 2] + kpts[sk[1], 2])))
cv2.addWeighted(img_copy, transparency, img, 1 - transparency, 0, dst=img)
else:
cv2.line(img, pos1, pos2, color, thickness=thickness)
return img
================================================
FILE: eval/GVHMR/hmr4d/utils/vis/renderer_utils.py
================================================
from hmr4d.utils.vis.renderer import Renderer
from tqdm import tqdm
import numpy as np
def simple_render_mesh(render_dict):
"""Render an camera-space mesh, blank background"""
width, height, focal_length = render_dict["whf"]
faces = render_dict["faces"]
verts = render_dict["verts"]
renderer = Renderer(width, height, focal_length, device="cuda", faces=faces)
outputs = []
for i in tqdm(range(len(verts)), desc=f"Rendering"):
img = renderer.render_mesh(verts[i].cuda(), colors=[0.8, 0.8, 0.8])
outputs.append(img)
outputs = np.stack(outputs, axis=0)
return outputs
def simple_render_mesh_background(render_dict, VI=50, colors=[0.8, 0.8, 0.8]):
"""Render an camera-space mesh, blank background"""
K = render_dict["K"]
faces = render_dict["faces"]
verts = render_dict["verts"]
background = render_dict["background"]
N_frames = len(verts)
if len(background.shape) == 3:
background = [background] * N_frames
height, width = background[0].shape[:2]
renderer = Renderer(width, height, device="cuda", faces=faces, K=K)
outputs = []
for i in tqdm(range(len(verts)), desc=f"Rendering"):
img = renderer.render_mesh(verts[i].cuda(), colors=colors, background=background[i], VI=VI)
outputs.append(img)
outputs = np.stack(outputs, axis=0)
return outputs
================================================
FILE: eval/GVHMR/hmr4d/utils/vis/rich_logger.py
================================================
from pytorch_lightning.utilities import rank_zero_only
from omegaconf import DictConfig, OmegaConf
import rich
import rich.tree
import rich.syntax
from hmr4d.utils.pylogger import Log
@rank_zero_only
def print_cfg(cfg: DictConfig, use_rich: bool = False):
if use_rich:
print_order = ("data", "model", "callbacks", "logger", "pl_trainer")
style = "dim"
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
# add fields from `print_order` to queue
# add all the other fields to queue (not specified in `print_order`)
queue = []
for field in print_order:
queue.append(field) if field in cfg else Log.warn(f"Field '{field}' not found in config. Skipping.")
for field in cfg:
if field not in queue:
queue.append(field)
# generate config tree from queue
for field in queue:
branch = tree.add(field, style=style, guide_style=style)
config_group = cfg[field]
if isinstance(config_group, DictConfig):
branch_content = OmegaConf.to_yaml(config_group, resolve=False)
else:
branch_content = str(config_group)
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
rich.print(tree)
else:
Log.info(OmegaConf.to_yaml(cfg, resolve=False))
================================================
FILE: eval/GVHMR/hmr4d/utils/wis3d_utils.py
================================================
from wis3d import Wis3D
from pathlib import Path
from datetime import datetime
import torch
import numpy as np
from einops import einsum
from pytorch3d.transforms import axis_angle_to_matrix
def make_wis3d(output_dir="outputs/wis3d", name="debug", time_postfix=False):
"""
Make a Wis3D instance. e.g.:
from hmr4d.utils.wis3d_utils import make_wis3d
wis3d = make_wis3d(time_postfix=True)
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
if time_postfix:
time_str = datetime.now().strftime("%m%d-%H%M-%S")
name = f"{name}_{time_str}"
print(f"Creating Wis3D {name}")
wis3d = Wis3D(output_dir.absolute(), name)
return wis3d
color_schemes = {
"red": ([255, 168, 154], [153, 17, 1]),
"green": ([183, 255, 191], [0, 171, 8]),
"blue": ([183, 255, 255], [0, 0, 255]),
"cyan": ([183, 255, 255], [0, 255, 255]),
"magenta": ([255, 183, 255], [255, 0, 255]),
"black": ([0, 0, 0], [0, 0, 0]),
"orange": ([255, 183, 0], [255, 128, 0]),
"grey": ([203, 203, 203], [203, 203, 203]),
}
def get_gradient_colors(scheme="red", num_points=120, alpha=1.0):
"""
Return a list of colors that are gradient from start to end.
"""
start_rgba = torch.tensor(color_schemes[scheme][0] + [255 * alpha]) / 255
end_rgba = torch.tensor(color_schemes[scheme][1] + [255 * alpha]) / 255
colors = torch.stack([torch.linspace(s, e, steps=num_points) for s, e in zip(start_rgba, end_rgba)], dim=-1)
return colors
def get_const_colors(name="red", partial_shape=(120, 5), alpha=1.0):
"""
Return colors (partial_shape, 4)
"""
rgba = torch.tensor(color_schemes[name][1] + [255 * alpha]) / 255
partial_shape = tuple(partial_shape)
colors = rgba[None].repeat(*partial_shape, 1)
return colors
def get_colors_by_conf(conf, low="red", high="green"):
colors = torch.stack([conf] * 3, dim=-1)
colors = colors * torch.tensor(color_schemes[high][1]) + (1 - colors) * torch.tensor(color_schemes[low][1])
return colors
# ================== Colored Motion Sequence ================== #
KINEMATIC_CHAINS = {
"smpl22": [
[0, 2, 5, 8, 11], # right-leg
[0, 1, 4, 7, 10], # left-leg
[0, 3, 6, 9, 12, 15], # spine
[9, 14, 17, 19, 21], # right-arm
[9, 13, 16, 18, 20], # left-arm
],
"h36m17": [
[0, 1, 2, 3], # right-leg
[0, 4, 5, 6], # left-leg
[0, 7, 8, 9, 10], # spine
[8, 14, 15, 16], # right-arm
[8, 11, 12, 13], # left-arm
],
"coco17": [
[12, 14, 16], # right-leg
[11, 13, 15], # left-leg
[4, 2, 0, 1, 3], # replace spine with head
[6, 8, 10], # right-arm
[5, 7, 9], # left-arm
],
}
def convert_motion_as_line_mesh(motion, skeleton_type="smpl22", const_color=None):
if isinstance(motion, np.ndarray):
motion = torch.from_numpy(motion)
motion = motion.detach().cpu()
kinematic_chain = KINEMATIC_CHAINS[skeleton_type]
color_names = ["red", "green", "blue", "cyan", "magenta"]
s_points = []
e_points = []
m_colors = []
length = motion.shape[0]
device = motion.device
for chain, color_name in zip(kinematic_chain, color_names):
num_line = len(chain) - 1
s_points.append(motion[:, chain[:-1]])
e_points.append(motion[:, chain[1:]])
if const_color is not None:
color_name = const_color
color_ = get_const_colors(color_name, partial_shape=(length, num_line), alpha=1.0).to(device) # (L, 4, 4)
m_colors.append(color_[..., :3] * 255) # (L, 4, 3)
s_points = torch.cat(s_points, dim=1) # (L, ?, 3)
e_points = torch.cat(e_points, dim=1)
m_colors = torch.cat(m_colors, dim=1)
vertices = []
for f in range(length):
vertices_, faces, vertex_colors = create_skeleton_mesh(s_points[f], e_points[f], radius=0.02, color=m_colors[f])
vertices.append(vertices_)
vertices = torch.stack(vertices, dim=0)
return vertices, faces, vertex_colors
def add_motion_as_lines(motion, wis3d, name="joints22", skeleton_type="smpl22", const_color=None, offset=0):
"""
Args:
motion (tensor): (L, J, 3)
"""
vertices, faces, vertex_colors = convert_motion_as_line_mesh(
motion, skeleton_type=skeleton_type, const_color=const_color
)
for f in range(len(vertices)):
wis3d.set_scene_id(f + offset)
wis3d.add_mesh(vertices[f], faces, vertex_colors, name=name) # Add skeleton as cylinders
# Old way to add lines, this may cause problems when the number of lines is large
# wis3d.add_lines(s_points[f], e_points[f], m_colors[f], name=name)
def add_prog_motion_as_lines(motion, wis3d, name="joints22", skeleton_type="smpl22"):
"""
Args:
motion (tensor): (P, L, J, 3)
"""
if isinstance(motion, np.ndarray):
motion = torch.from_numpy(motion)
P, L, J, _ = motion.shape
device = motion.device
kinematic_chain = KINEMATIC_CHAINS[skeleton_type]
color_names = ["red", "green", "blue", "cyan", "magenta"]
s_points = []
e_points = []
m_colors = []
for chain, color_name in zip(kinematic_chain, color_names):
num_line = len(chain) - 1
s_points.append(motion[:, :, chain[:-1]])
e_points.append(motion[:, :, chain[1:]])
color_ = get_gradient_colors(color_name, L, alpha=1.0).to(device) # (L, 4)
color_ = color_[None, :, None, :].repeat(P, 1, num_line, 1) # (P, L, num_line, 4)
m_colors.append(color_[..., :3] * 255) # (P, L, num_line, 3)
s_points = torch.cat(s_points, dim=-2) # (L, ?, 3)
e_points = torch.cat(e_points, dim=-2)
m_colors = torch.cat(m_colors, dim=-2)
s_points = s_points.reshape(P, -1, 3)
e_points = e_points.reshape(P, -1, 3)
m_colors = m_colors.reshape(P, -1, 3)
for p in range(P):
wis3d.set_scene_id(p)
wis3d.add_lines(s_points[p], e_points[p], m_colors[p], name=name)
def add_joints_motion_as_spheres(joints, wis3d, radius=0.05, name="joints", label_each_joint=False):
"""Visualize skeleton as spheres to explore the skeleton.
Args:
joints: (NF, NJ, 3)
wis3d
radius: radius of the spheres
name
label_each_joint: if True, each joints will have a label in wis3d (then you can interact with it, but it's slower)
"""
colors = torch.zeros_like(joints).float()
n_frames = joints.shape[0]
n_joints = joints.shape[1]
for i in range(n_joints):
colors[:, i, 1] = 255 / n_joints * i
colors[:, i, 2] = 255 / n_joints * (n_joints - i)
for f in range(n_frames):
wis3d.set_scene_id(f)
if label_each_joint:
for i in range(n_joints):
wis3d.add_spheres(
joints[f, i].float(),
radius=radius,
colors=colors[f, i],
name=f"{name}-j{i}",
)
else:
wis3d.add_spheres(
joints[f].float(),
radius=radius,
colors=colors[f],
name=f"{name}",
)
def create_skeleton_mesh(p1, p2, radius, color, resolution=4, return_merged=True):
"""
Create mesh between p1 and p2.
Args:
p1 (torch.Tensor): (N, 3),
p2 (torch.Tensor): (N, 3),
radius (float): radius,
color (torch.Tensor): (N, 3)
resolution (int): number of vertices in one circle, denoted as Q
Returns:
vertices (torch.Tensor): (N * 2Q, 3), if return_merged is False (N, 2Q, 3)
faces (torch.Tensor): (M', 3), if return_merged is False (N, M, 3)
vertex_colors (torch.Tensor): (N * 2Q, 3), if return_merged is False (N, 2Q, 3)
"""
N = p1.shape[0]
# Calculate segment direction
seg_dir = p2 - p1 # (N, 3)
unit_seg_dir = seg_dir / seg_dir.norm(dim=-1, keepdim=True) # (N, 3)
# Compute an orthogonal vector
x_vec = torch.tensor([1, 0, 0], device=p1.device).float().unsqueeze(0).repeat(N, 1) # (N, 3)
y_vec = torch.tensor([0, 1, 0], device=p1.device).float().unsqueeze(0).repeat(N, 1)
ortho_vec = torch.cross(unit_seg_dir, x_vec, dim=-1) # (N, 3)
ortho_vec_ = torch.cross(unit_seg_dir, y_vec, dim=-1) # (N, 3) backup
ortho_vec = torch.where(ortho_vec.norm(dim=-1, keepdim=True) > 1e-3, ortho_vec, ortho_vec_)
# Get circle points on two ends
unit_ortho_vec = ortho_vec / ortho_vec.norm(dim=-1, keepdim=True) # (N, 3)
theta = torch.linspace(0, 2 * np.pi, resolution, device=p1.device)
rotation_matrix = axis_angle_to_matrix(unit_seg_dir[:, None] * theta[None, :, None]) # (N, Q, 3, 3)
rotated_points = einsum(rotation_matrix, unit_ortho_vec, "n q i j, n i -> n q j") * radius # (N, Q, 3)
bottom_points = rotated_points + p1.unsqueeze(1) # (N, Q, 3)
top_points = rotated_points + p2.unsqueeze(1) # (N, Q, 3)
# Combine bottom and top points
vertices = torch.cat([bottom_points, top_points], dim=1) # (N, 2Q, 3)
# Generate face
indices = torch.arange(0, resolution, device=p1.device)
bottom_indices = indices
top_indices = indices + resolution
# outside face
face_bottom = torch.stack([bottom_indices[:-2], bottom_indices[1:-1], bottom_indices[-1].repeat(resolution - 2)], 1)
face_top = torch.stack([top_indices[1:-1], top_indices[:-2], top_indices[-1].repeat(resolution - 2)], 1)
faces = torch.cat(
[
torch.stack([bottom_indices[1:], bottom_indices[:-1], top_indices[:-1]], 1), # out face
torch.stack([bottom_indices[1:], top_indices[:-1], top_indices[1:]], 1), # out face
face_bottom,
face_top,
]
)
faces = faces.unsqueeze(0).repeat(p1.shape[0], 1, 1) # (N, M, 3)
# Assign colors
vertex_colors = color.unsqueeze(1).repeat(1, resolution * 2, 1)
if return_merged:
# manully adjust face ids
N, V = vertices.shape[:2]
faces = faces + torch.arange(0, N, device=p1.device).unsqueeze(1).unsqueeze(1) * V
faces = faces.reshape(-1, 3)
vertices = vertices.reshape(-1, 3)
vertex_colors = vertex_colors.reshape(-1, 3)
return vertices, faces, vertex_colors
def get_lines_of_my_frustum(frustum_points):
"""
frustum_points: (B, 8, 3), in (near {lu ru rd ld}, far {lu ru rd ld})
"""
start_points = frustum_points[:, [0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7]].cpu().numpy()
end_points = frustum_points[:, [4, 5, 6, 7, 1, 2, 3, 0, 5, 6, 7, 4]].cpu().numpy()
return start_points, end_points
def draw_colored_vec(wis3d, vec, name, radius=0.02, colors="r", starts=None, l=1.0):
"""
Args:
vec: (3) or (L, 3), should be the same length as colors, like 'rgb'
"""
if len(vec.shape) == 1:
vec = vec[None]
else:
assert len(vec.shape) == 2
assert len(vec) == len(colors)
# split colors, 'rgb' to 'r', 'g', 'b'
color_tensor = torch.zeros((len(colors), 3))
c2rgb = {
"r": torch.tensor([1, 0, 0]).float(),
"g": torch.tensor([0, 1, 0]).float(),
"b": torch.tensor([0, 0, 1]).float(),
}
for i, c in enumerate(colors):
color_tensor[i] = c2rgb[c]
if starts is None:
starts = torch.zeros_like(vec)
ends = starts + vec * l
vertices, faces, vertex_colors = create_skeleton_mesh(starts, ends, radius, color_tensor, resolution=10)
wis3d.add_mesh(vertices, faces, vertex_colors, name=name)
def draw_T_w2c(wis3d, T_w2c, name, radius=0.01, all_in_one=True, l=0.1):
"""
Draw a camera trajectory in world coordinate.
Args:
T_w2c: (L, 4, 4)
"""
color_tensor = torch.eye(3)
if all_in_one:
starts = -T_w2c[:, :3, :3].mT @ T_w2c[:, :3, [3]] # (L, 3, 1)
starts = starts[:, None, :, 0].expand(-1, 3, -1).reshape(-1, 3) # (L*3, 3)
vec = T_w2c[:, :3, :3].reshape(-1, 3) # (L * 3, 3)
ends = starts + vec * l
color_tensor = color_tensor[None].expand(T_w2c.size(0), -1, -1).reshape(-1, 3)
vertices, faces, vertex_colors = create_skeleton_mesh(starts, ends, radius, color_tensor, resolution=10)
else:
raise NotImplementedError
wis3d.add_mesh(vertices, faces, vertex_colors, name=name)
def create_checkerboard_mesh(y=0.0, grid_size=1.0, bounds=((-3, -3), (3, 3))):
"""
example usage:
vertices, faces, vertex_colors = create_checkerboard_mesh()
wis3d.add_mesh(vertices=vertices, faces=faces, vertex_colors=vertex_colors, name="one")
"""
color1 = np.array([236, 240, 241], np.uint8) # light
color2 = np.array([120, 120, 120], np.uint8) # dark
# 扩大范围
min_x, min_z = bounds[0]
max_x, max_z = bounds[1]
min_x = grid_size * np.floor(min_x / grid_size)
min_z = grid_size * np.floor(min_z / grid_size)
max_x = grid_size * np.ceil(max_x / grid_size)
max_z = grid_size * np.ceil(max_z / grid_size)
vertices = []
faces = []
vertex_colors = []
eps = 1e-4 # HACK: disable smooth color & double-side color artifacts of wis3d
for i, x in enumerate(np.arange(min_x, max_x, grid_size)):
for j, z in enumerate(np.arange(min_z, max_z, grid_size)):
# Right-hand rule for normal direction
x += ((i % 2 * 2) - 1) * eps
z += ((j % 2 * 2) - 1) * eps
v1 = np.array([x, y, z])
v2 = np.array([x, y, z + grid_size])
v3 = np.array([x + grid_size, y, z + grid_size])
v4 = np.array([x + grid_size, y, z])
offset = np.array([0, -eps, 0]) # For visualizing the down-side of the mesh
vertices.extend([v1, v2, v3, v4, v1 + offset, v2 + offset, v3 + offset, v4 + offset])
idx = len(vertices) - 8
faces.extend(
[
[idx, idx + 1, idx + 2],
[idx + 2, idx + 3, idx],
[idx + 4, idx + 7, idx + 6], # double-sided
[idx + 6, idx + 5, idx + 4], # double-sided
]
)
vertex_color = color1 if (i + j) % 2 == 0 else color2
vertex_colors.extend([vertex_color] * 8)
# To numpy.array and the shape should be (n, 3)
vertices = np.array(vertices)
faces = np.array(faces)
vertex_colors = np.array(vertex_colors)
assert len(vertices.shape) == 2 and vertices.shape[1] == 3
assert len(faces.shape) == 2 and faces.shape[1] == 3
assert len(vertex_colors.shape) == 2 and vertex_colors.shape[1] == 3 and vertex_colors.dtype == np.uint8
return vertices, faces, vertex_colors
def add_a_trimesh(mesh, wis3d, name):
mesh.apply_transform(wis3d.three_to_world)
# filename = wis3d.__get_export_file_name("mesh", name)
export_dir = Path(wis3d.out_folder) / wis3d.sequence_name / f"{wis3d.scene_id:05d}" / "meshes"
export_dir.mkdir(parents=True, exist_ok=True)
assert name is not None
filename = export_dir / f"{name}.ply"
wis3d.counters["mesh"] += 1
mesh.export(filename)
================================================
FILE: eval/GVHMR/pyproject.toml
================================================
[tool.black]
line-length = 120
include = '\.pyi?$'
exclude = '''
/(
\.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
'''
================================================
FILE: eval/GVHMR/pyrightconfig.json
================================================
{
"exclude": [
"./inputs",
"./outputs"
],
"typeCheckingMode": "off",
}
================================================
FILE: eval/GVHMR/requirements.txt
================================================
# PyTorch
--extra-index-url https://download.pytorch.org/whl/cu121
torch==2.3.0+cu121
torchvision==0.18.0+cu121
timm==0.9.12 # For HMR2.0a feature extraction
# Lightning + Hydra
lightning==2.3.0
hydra-core==1.3
hydra-zen
hydra_colorlog
rich
# Common utilities
numpy==1.23.5
jupyter
matplotlib
ipdb
setuptools>=68.0
black
tensorboardX
opencv-python
ffmpeg-python
scikit-image
termcolor
einops
imageio==2.34.1
av # imageio[pyav], improved performance over imageio[ffmpeg]
joblib
# Diffusion
# diffusers[torch]==0.19.3
# transformers==4.31.0
# 3D-Vision
pytorch3d @ https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt230/pytorch3d-0.7.6-cp310-cp310-linux_x86_64.whl
trimesh
chumpy
smplx
# open3d==0.17.0
wis3d
# 2D-Pose
ultralytics==8.2.42 # YOLO
cython_bbox
lapx
================================================
FILE: eval/GVHMR/setup.py
================================================
from setuptools import setup, find_packages
setup(
name="gvhmr",
version="1.0.0",
packages=find_packages(),
author="Zehong Shen",
description=["GVHMR training and inference"],
url="https://github.com/zju3dv/GVHMR",
)
================================================
FILE: eval/GVHMR/tools/demo/colab_demo.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "Dv4XCJqiKtun"
},
"source": [
"
GVHMR
\n",
"\n",
"
World-Grounded Human Motion Recovery via Gravity-View Coordinates
\n",
"\n",
"> World-Grounded Human Motion Recovery via Gravity-View Coordinates \n",
"> [Zehong Shen](https://zehongs.github.io/)\\*,\n",
"[Huaijin Pi](https://phj128.github.io/)\\*,\n",
"[Yan Xia](https://isshikihugh.github.io/scholar),\n",
"[Zhi Cen](https://scholar.google.com/citations?user=Xyy-uFMAAAAJ),\n",
"[Sida Peng](https://pengsida.net/)†,\n",
"[Zechen Hu](https://zju3dv.github.io/gvhmr),\n",
"[Hujun Bao](http://www.cad.zju.edu.cn/home/bao/),\n",
"[Ruizhen Hu](https://csse.szu.edu.cn/staff/ruizhenhu/),\n",
"[Xiaowei Zhou](https://xzhou.me/) \n",
"> SIGGRAPH Asia 2024\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zK2VFE-CEk-l"
},
"source": [
"## Installation\n",
"\n",
"Check [INSTALL.md](https://github.com/IsshikiHugh/GVHMR/blob/main/docs/INSTALL.md) if you want to install GVHMR in your own machine.\n",
"\n",
"> Tips: you can fold the section and run the whole installation section at once."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PIPJRQZFGplJ",
"outputId": "44c58662-6b8e-4bda-fdb3-d796b6744bd6"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sat Sep 14 17:16:45 2024 \n",
"+---------------------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2 |\n",
"|-----------------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|=========================================+======================+======================|\n",
"| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 52C P8 9W / 70W | 0MiB / 15360MiB | 0% Default |\n",
"| | | N/A |\n",
"+-----------------------------------------+----------------------+----------------------+\n",
" \n",
"+---------------------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=======================================================================================|\n",
"| No running processes found |\n",
"+---------------------------------------------------------------------------------------+\n"
]
}
],
"source": [
"# Controlling notebook is connected to NVIDIA drivers with CUDA. If this doesn't load check that GPU is selected as hardware accelerator under Edit -> Notebook settings.\n",
"# Google may shut down the GPU if usage has surpassed the allocation\n",
"\n",
"!nvidia-smi"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nUM1o8lsGU0n"
},
"source": [
"### Environment Prepration ~ 15 min\n",
"\n",
"Check [INSTALL.md#Environment](https://github.com/IsshikiHugh/GVHMR/blob/main/docs/INSTALL.md#environment) for further information."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cKvUncvK-943"
},
"outputs": [],
"source": [
"import os\n",
"from pathlib import Path\n",
"\n",
"# Clone the repo.\n",
"!git clone https://github.com/zju3dv/GVHMR --recursive\n",
"proj_root = str(Path('GVHMR').absolute())\n",
"\n",
"# Install GVHMR. (If Colab asks you to restart, just click cancel and rerun the block.)\n",
"%cd {proj_root}\n",
"%pip install -r requirements.txt\n",
"%pip install -e ."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vvyrlqQOIz2g"
},
"outputs": [],
"source": [
"# Install DPVO.\n",
"%cd third-party/DPVO\n",
"\n",
"!wget https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.zip\n",
"!unzip -o eigen-3.4.0.zip -d thirdparty && rm -rf eigen-3.4.0.zip\n",
"\n",
"%pip install torch-scatter -f \"https://data.pyg.org/whl/torch-2.3.0+cu121.html\"\n",
"%pip install numba pypose\n",
"\n",
"if 'cuda_home' not in locals():\n",
" cuda_home = '/usr/local/cuda-12'\n",
" if not Path(cuda_home).exists():\n",
" raise FileNotFoundError('CUDA_HOME for cuda 12.x not found!')\n",
"\n",
" os.environ['CUDA_HOME'] = cuda_home\n",
" os.environ['PATH'] = os.environ['PATH'] + f':{cuda_home}/bin'\n",
"\n",
"%pip install -e .\n",
"%cd {proj_root}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vziM8bI6E5ur"
},
"source": [
"### Data Prepration ~ 1 min\n",
"\n",
"Check [INSTALL.md#Inputs&Outputs](https://github.com/IsshikiHugh/GVHMR/blob/main/docs/INSTALL.md#inputs--outputs) for further information."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3SYDPesR_1m9",
"outputId": "8658d240-9471-4779-c67e-61f9d92142dd"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/content/GVHMR\n",
"mkdir: cannot create directory ‘inputs’: File exists\n",
"mkdir: cannot create directory ‘outputs’: File exists\n",
"aria2 is already the newest version (1.36.0-1).\n",
"0 upgraded, 0 newly installed, 0 to remove and 49 not upgraded.\n",
"\n",
"Download Results:\n",
"gid |stat|avg speed |path/URI\n",
"======+====+===========+=======================================================\n",
"858436|\u001b[1;32mOK\u001b[0m | 0B/s|inputs/checkpoints/body_models/smpl/SMPL_NEUTRAL.pkl\n",
"\n",
"Status Legend:\n",
"(OK):download completed.\n",
"\n",
"Download Results:\n",
"gid |stat|avg speed |path/URI\n",
"======+====+===========+=======================================================\n",
"da93ed|\u001b[1;32mOK\u001b[0m | 0B/s|inputs/checkpoints/body_models/smplx/SMPLX_NEUTRAL.npz\n",
"\n",
"Status Legend:\n",
"(OK):download completed.\n",
"\n",
"Download Results:\n",
"gid |stat|avg speed |path/URI\n",
"======+====+===========+=======================================================\n",
"b2dfd7|\u001b[1;32mOK\u001b[0m | 0B/s|inputs/checkpoints/dpvo/dpvo.pth\n",
"\n",
"Status Legend:\n",
"(OK):download completed.\n",
"\n",
"Download Results:\n",
"gid |stat|avg speed |path/URI\n",
"======+====+===========+=======================================================\n",
"5ce060|\u001b[1;32mOK\u001b[0m | 0B/s|inputs/checkpoints/gvhmr/gvhmr_siga24_release.ckpt\n",
"\n",
"Status Legend:\n",
"(OK):download completed.\n",
"\n",
"Download Results:\n",
"gid |stat|avg speed |path/URI\n",
"======+====+===========+=======================================================\n",
"d41b3e|\u001b[1;32mOK\u001b[0m | 0B/s|inputs/checkpoints/hmr2/epoch=10-step=25000.ckpt\n",
"\n",
"Status Legend:\n",
"(OK):download completed.\n",
"\n",
"Download Results:\n",
"gid |stat|avg speed |path/URI\n",
"======+====+===========+=======================================================\n",
"3fe18e|\u001b[1;32mOK\u001b[0m | 0B/s|inputs/checkpoints/vitpose/vitpose-h-multi-coco.pth\n",
"\n",
"Status Legend:\n",
"(OK):download completed.\n",
"\n",
"Download Results:\n",
"gid |stat|avg speed |path/URI\n",
"======+====+===========+=======================================================\n",
"4fe1eb|\u001b[1;32mOK\u001b[0m | 0B/s|inputs/checkpoints/yolo/yolov8x.pt\n",
"\n",
"Status Legend:\n",
"(OK):download completed.\n"
]
}
],
"source": [
"%cd {proj_root}\n",
"!mkdir inputs\n",
"!mkdir outputs\n",
"\n",
"!apt install -y -qq aria2\n",
"!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/SMPLer-X/resolve/main/SMPL_NEUTRAL.pkl -d inputs/checkpoints/body_models/smpl -o SMPL_NEUTRAL.pkl\n",
"!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/SMPLer-X/resolve/main/SMPLX_NEUTRAL.npz -d inputs/checkpoints/body_models/smplx -o SMPLX_NEUTRAL.npz\n",
"!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/GVHMR/resolve/main/dpvo/dpvo.pth -d inputs/checkpoints/dpvo -o dpvo.pth\n",
"!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/GVHMR/resolve/main/gvhmr/gvhmr_siga24_release.ckpt -d inputs/checkpoints/gvhmr -o gvhmr_siga24_release.ckpt\n",
"!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/GVHMR/resolve/main/hmr2/epoch%3D10-step%3D25000.ckpt -d inputs/checkpoints/hmr2 -o epoch=10-step=25000.ckpt\n",
"!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/GVHMR/resolve/main/vitpose/vitpose-h-multi-coco.pth -d inputs/checkpoints/vitpose -o vitpose-h-multi-coco.pth\n",
"!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/GVHMR/resolve/main/yolo/yolov8x.pt -d inputs/checkpoints/yolo -o yolov8x.pt"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IxHWylQ-KJdG"
},
"source": [
"## Run Demo\n",
"\n",
"Use `-s` to skip visual odometry if you know the camera is static, otherwise the camera will be estimated by DPVO.\n",
"\n",
"We also provide a script demo_folder.py to inference a entire folder.\n",
"\n",
"\n",
"```shell\n",
"python tools/demo/demo.py --video=docs/example_video/tennis.mp4 -s\n",
"python tools/demo/demo_folder.py -f inputs/demo/folder_in -d outputs/demo/folder_out -s\n",
"\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"id": "YHaW4uybOQgG"
},
"outputs": [],
"source": [
"import io\n",
"import base64\n",
"from IPython.display import HTML\n",
"from hmr4d.utils.video_io_utils import get_video_lwh\n",
"\n",
"def display_video(fn):\n",
" L, W, H = get_video_lwh(fn)\n",
" scale = min(W, 1080) / W\n",
" W, H = int(W * scale), int(H * scale)\n",
" video_encoded = base64.b64encode(io.open(fn, 'rb').read())\n",
" return HTML(data=''''''.format(W, H, video_encoded.decode('ascii')))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3EOGJBa_OCWu"
},
"source": [
"### Demo Tennis"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "NQIVoNKMKlhb",
"outputId": "e125532f-f095-46b2-dd70-11f227b9307d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[\u001b[36m09/14 17:37:05\u001b[0m][\u001b[32mINFO\u001b[0m] [Input]: /content/GVHMR/docs/example_video/tennis.mp4\u001b[0m\n",
"[\u001b[36m09/14 17:37:05\u001b[0m][\u001b[32mINFO\u001b[0m] (L, W, H) = (312, 812, 720)\u001b[0m\n",
"[\u001b[36m09/14 17:37:06\u001b[0m][\u001b[32mINFO\u001b[0m] [Output Dir]: outputs/demo/tennis\u001b[0m\n",
"[\u001b[36m09/14 17:37:06\u001b[0m][\u001b[32mINFO\u001b[0m] [Copy Video] /content/GVHMR/docs/example_video/tennis.mp4 -> outputs/demo/tennis/0_input_video.mp4\u001b[0m\n",
"Copy: 100% 312/312 [00:09<00:00, 33.08it/s]\n",
"[\u001b[36m09/14 17:37:17\u001b[0m][\u001b[32mINFO\u001b[0m] [GPU]: Tesla T4\u001b[0m\n",
"[\u001b[36m09/14 17:37:17\u001b[0m][\u001b[32mINFO\u001b[0m] [GPU]: _CudaDeviceProperties(name='Tesla T4', major=7, minor=5, total_memory=15102MB, multi_processor_count=40)\u001b[0m\n",
"[\u001b[36m09/14 17:37:17\u001b[0m][\u001b[32mINFO\u001b[0m] [Preprocess] Start!\u001b[0m\n",
"YoloV8 Tracking: 100% 312/312 [00:27<00:00, 11.14it/s]\n",
"ViTPose: 100% 20/20 [00:42<00:00, 2.13s/it]\n",
"HMR2 Feature: 100% 20/20 [00:22<00:00, 1.11s/it]\n",
"[\u001b[36m09/14 17:39:36\u001b[0m][\u001b[32mINFO\u001b[0m] [Preprocess] End. Time elapsed: 139.28s\u001b[0m\n",
"[\u001b[36m09/14 17:39:36\u001b[0m][\u001b[32mINFO\u001b[0m] [HMR4D] Predicting\u001b[0m\n",
"[\u001b[36m09/14 17:39:36\u001b[0m][\u001b[32mINFO\u001b[0m] [EnDecoder] Use MM_V1_AMASS_LOCAL_BEDLAM_CAM for statistics!\u001b[0m\n",
"[\u001b[36m09/14 17:39:38\u001b[0m][\u001b[32mINFO\u001b[0m] [PL-Trainer] Loading ckpt type: inputs/checkpoints/gvhmr/gvhmr_siga24_release.ckpt\u001b[0m\n",
"[\u001b[36m09/14 17:39:40\u001b[0m][\u001b[32mINFO\u001b[0m] [HMR4D] Elapsed: 1.32s for data-length=10.4s\u001b[0m\n",
"Rendering Incam: 100% 312/312 [00:24<00:00, 12.71it/s]\n",
"Rendering Global: 100% 312/312 [00:16<00:00, 18.68it/s]\n",
"[\u001b[36m09/14 17:40:29\u001b[0m][\u001b[32mINFO\u001b[0m] [Merge Videos]\u001b[0m\n"
]
}
],
"source": [
"# Run demo.\n",
"video_fn = f'{proj_root}/docs/example_video/tennis.mp4'\n",
"!python {proj_root}/tools/demo/demo.py --video={video_fn} -s"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 499
},
"id": "wPam0XZfQLL5",
"outputId": "02b63472-4539-42a1-b8fd-c9208f3ded72"
},
"outputs": [
{
"data": {
"text/html": [
""
],
"text/plain": [
""
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Visualize the result.\n",
"display_video('outputs/demo/tennis/tennis_3_incam_global_horiz.mp4')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CnOwTAfyRo57"
},
"source": [
"### Custom Demo\n",
"\n",
"In order to run your own video, you need to follow instructions below:\n",
"\n",
"0. Run the code block below to initialize `demo_google_drive_video()`.\n",
"1. Upload your video on Google drive and set accessbility as \"Anyone with link (Viewer)\".\n",
"2. Copy the link and extract the ID. The link should follow this pattern: `https://drive.google.com/file/d//view?usp=sharing`.\n",
"3. Call `demo_google_drive_video()` with the URL, and pass `True` if the camera is static.\n"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "jqOda6ONRm_C",
"outputId": "93672ab6-a490-407c-d6d9-a8174be9bb5a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/content/GVHMR\n"
]
}
],
"source": [
"%cd {proj_root}\n",
"!mkdir -p inputs/demo\n",
"\n",
"def demo_google_drive_video(url:str, static_camera:bool):\n",
" ''' URL should be like https://drive.google.com/file/d/xxxxxxxx/view?usp=drive_link '''\n",
"\n",
" print(f'[1/3] 📥 Downloading video...')\n",
" gdid = url.split('/')[5]\n",
" video_name = f'custom_{gdid}'\n",
" download_url = f'\\'https://drive.google.com/uc?id={gdid}&export=download&confirm=t\\''\n",
" !gdown {download_url} -O inputs/demo/{video_name}.mp4\n",
"\n",
" print(f'[2/3] 💃 Start running GVHMR...')\n",
" flag = '-s' if static_camera else ''\n",
" !python {proj_root}/tools/demo/demo.py --video=inputs/demo/{video_name}.mp4 {flag}\n",
"\n",
" print(f'[3/3] 📺 Displaying result...')\n",
" return display_video(f'outputs/demo/{video_name}/{video_name}_3_incam_global_horiz.mp4')"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "ElCtnGFQVHU1",
"outputId": "d957a268-8648-4f2e-c360-c82ee56a7b25"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1/3] 📥 Downloading video...\n",
"Downloading...\n",
"From: https://drive.google.com/uc?id=1KkTsoAuj9yq5JCJ-qg4_MnNBn56wR47F&export=download&confirm=t\n",
"To: /content/GVHMR/inputs/demo/custom_1KkTsoAuj9yq5JCJ-qg4_MnNBn56wR47F.mp4\n",
"100% 936k/936k [00:00<00:00, 9.14MB/s]\n",
"[2/3] 💃 Start running GVHMR...\n",
"[\u001b[36m09/14 18:09:12\u001b[0m][\u001b[32mINFO\u001b[0m] [Input]: inputs/demo/custom_1KkTsoAuj9yq5JCJ-qg4_MnNBn56wR47F.mp4\u001b[0m\n",
"[\u001b[36m09/14 18:09:12\u001b[0m][\u001b[32mINFO\u001b[0m] (L, W, H) = (367, 682, 666)\u001b[0m\n",
"[\u001b[36m09/14 18:09:13\u001b[0m][\u001b[32mINFO\u001b[0m] [Output Dir]: outputs/demo/custom_1KkTsoAuj9yq5JCJ-qg4_MnNBn56wR47F\u001b[0m\n",
"[\u001b[36m09/14 18:09:13\u001b[0m][\u001b[32mINFO\u001b[0m] [Copy Video] inputs/demo/custom_1KkTsoAuj9yq5JCJ-qg4_MnNBn56wR47F.mp4 -> outputs/demo/custom_1KkTsoAuj9yq5JCJ-qg4_MnNBn56wR47F/0_input_video.mp4\u001b[0m\n",
"Copy: 100% 367/367 [00:09<00:00, 39.06it/s]\n",
"[\u001b[36m09/14 18:09:23\u001b[0m][\u001b[32mINFO\u001b[0m] [GPU]: Tesla T4\u001b[0m\n",
"[\u001b[36m09/14 18:09:23\u001b[0m][\u001b[32mINFO\u001b[0m] [GPU]: _CudaDeviceProperties(name='Tesla T4', major=7, minor=5, total_memory=15102MB, multi_processor_count=40)\u001b[0m\n",
"[\u001b[36m09/14 18:09:23\u001b[0m][\u001b[32mINFO\u001b[0m] [Preprocess] Start!\u001b[0m\n",
"YoloV8 Tracking: 100% 367/367 [00:29<00:00, 12.65it/s]\n",
"ViTPose: 100% 23/23 [00:52<00:00, 2.27s/it]\n",
"HMR2 Feature: 100% 23/23 [00:25<00:00, 1.12s/it]\n",
"[\u001b[36m09/14 18:11:57\u001b[0m][\u001b[32mINFO\u001b[0m] [Preprocess] End. Time elapsed: 154.35s\u001b[0m\n",
"[\u001b[36m09/14 18:11:57\u001b[0m][\u001b[32mINFO\u001b[0m] [HMR4D] Predicting\u001b[0m\n",
"[\u001b[36m09/14 18:11:57\u001b[0m][\u001b[32mINFO\u001b[0m] [EnDecoder] Use MM_V1_AMASS_LOCAL_BEDLAM_CAM for statistics!\u001b[0m\n",
"[\u001b[36m09/14 18:11:59\u001b[0m][\u001b[32mINFO\u001b[0m] [PL-Trainer] Loading ckpt type: inputs/checkpoints/gvhmr/gvhmr_siga24_release.ckpt\u001b[0m\n",
"[\u001b[36m09/14 18:12:02\u001b[0m][\u001b[32mINFO\u001b[0m] [HMR4D] Elapsed: 1.88s for data-length=12.2s\u001b[0m\n",
"Rendering Incam: 100% 367/367 [00:21<00:00, 16.96it/s]\n",
"Rendering Global: 100% 367/367 [00:17<00:00, 21.44it/s]\n",
"[\u001b[36m09/14 18:12:47\u001b[0m][\u001b[32mINFO\u001b[0m] [Merge Videos]\u001b[0m\n",
"[3/3] 📺 Displaying result...\n"
]
},
{
"data": {
"text/html": [
""
],
"text/plain": [
""
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"demo_google_drive_video(url='https://drive.google.com/file/d/1KkTsoAuj9yq5JCJ-qg4_MnNBn56wR47F/view?usp=drive_link', static_camera=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "irhhFd76pH5c"
},
"outputs": [],
"source": [
"# Try it yourself!\n",
"demo_google_drive_video(...)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: eval/GVHMR/tools/demo/demo.py
================================================
import cv2
import torch
import pytorch_lightning as pl
import numpy as np
import argparse
from hmr4d.utils.pylogger import Log
import hydra
from hydra import initialize_config_module, compose
from pathlib import Path
from pytorch3d.transforms import quaternion_to_matrix
import os
from hmr4d.configs import register_store_gvhmr
from hmr4d.utils.video_io_utils import (
get_video_lwh,
read_video_np,
save_video,
merge_videos_horizontal,
get_writer,
get_video_reader,
)
from hmr4d.utils.vis.cv2_utils import draw_bbx_xyxy_on_image_batch, draw_coco17_skeleton_batch
from hmr4d.utils.preproc import Tracker, Extractor, VitPoseExtractor, SLAMModel
from hmr4d.utils.geo.hmr_cam import get_bbx_xys_from_xyxy, estimate_K, convert_K_to_K4, create_camera_sensor
from hmr4d.utils.geo_transform import compute_cam_angvel
from hmr4d.model.gvhmr.gvhmr_pl_demo import DemoPL
from hmr4d.utils.net_utils import detach_to_cpu, to_cuda
from hmr4d.utils.smplx_utils import make_smplx
from hmr4d.utils.vis.renderer import Renderer, get_global_cameras_static, get_ground_params_from_points
from tqdm import tqdm
from hmr4d.utils.geo_transform import apply_T_on_points, compute_T_ayfz2ay
from einops import einsum, rearrange
CRF = 23 # 17 is lossless, every +6 halves the mp4 size
def parse_args_to_cfg():
# Put all args to cfg
parser = argparse.ArgumentParser()
parser.add_argument("--video", type=str, default="inputs/demo/dance_3.mp4")
parser.add_argument("--output_root", type=str, default=None, help="by default to outputs/demo")
parser.add_argument("-s", "--static_cam", action="store_true", help="If true, skip DPVO")
parser.add_argument("--verbose", action="store_true", help="If true, draw intermediate results")
args = parser.parse_args()
# Input
video_path = Path(args.video)
assert video_path.exists(), f"Video not found at {video_path}"
length, width, height = get_video_lwh(video_path)
Log.info(f"[Input]: {video_path}")
Log.info(f"(L, W, H) = ({length}, {width}, {height})")
# Cfg
with initialize_config_module(version_base="1.3", config_module=f"hmr4d.configs"):
overrides = [
f"video_name={video_path.stem}",
f"static_cam={args.static_cam}",
f"verbose={args.verbose}",
]
# Allow to change output root
if args.output_root is not None:
overrides.append(f"output_root={args.output_root}")
register_store_gvhmr()
cfg = compose(config_name="demo", overrides=overrides)
# Output
Log.info(f"[Output Dir]: {cfg.output_dir}")
Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
Path(cfg.preprocess_dir).mkdir(parents=True, exist_ok=True)
# Copy raw-input-video to video_path
Log.info(f"[Copy Video] {video_path} -> {cfg.video_path}")
if not Path(cfg.video_path).exists() or get_video_lwh(video_path)[0] != get_video_lwh(cfg.video_path)[0]:
reader = get_video_reader(video_path)
writer = get_writer(cfg.video_path, fps=30, crf=CRF)
for img in tqdm(reader, total=get_video_lwh(video_path)[0], desc=f"Copy"):
writer.write_frame(img)
writer.close()
reader.close()
return cfg
@torch.no_grad()
def run_preprocess(cfg):
Log.info(f"[Preprocess] Start!")
tic = Log.time()
video_path = cfg.video_path
paths = cfg.paths
static_cam = cfg.static_cam
verbose = cfg.verbose
# Get bbx tracking result
if not Path(paths.bbx).exists():
tracker = Tracker()
bbx_xyxy = tracker.get_one_track(video_path).float() # (L, 4)
bbx_xys = get_bbx_xys_from_xyxy(bbx_xyxy, base_enlarge=1.2).float() # (L, 3) apply aspect ratio and enlarge
torch.save({"bbx_xyxy": bbx_xyxy, "bbx_xys": bbx_xys}, paths.bbx)
del tracker
else:
bbx_xys = torch.load(paths.bbx)["bbx_xys"]
Log.info(f"[Preprocess] bbx (xyxy, xys) from {paths.bbx}")
if verbose:
video = read_video_np(video_path)
bbx_xyxy = torch.load(paths.bbx)["bbx_xyxy"]
video_overlay = draw_bbx_xyxy_on_image_batch(bbx_xyxy, video)
save_video(video_overlay, cfg.paths.bbx_xyxy_video_overlay)
# Get VitPose
if not Path(paths.vitpose).exists():
vitpose_extractor = VitPoseExtractor()
vitpose = vitpose_extractor.extract(video_path, bbx_xys)
torch.save(vitpose, paths.vitpose)
del vitpose_extractor
else:
vitpose = torch.load(paths.vitpose)
Log.info(f"[Preprocess] vitpose from {paths.vitpose}")
if verbose:
video = read_video_np(video_path)
video_overlay = draw_coco17_skeleton_batch(video, vitpose, 0.5)
save_video(video_overlay, paths.vitpose_video_overlay)
# Get vit features
if not Path(paths.vit_features).exists():
extractor = Extractor()
vit_features = extractor.extract_video_features(video_path, bbx_xys)
torch.save(vit_features, paths.vit_features)
del extractor
else:
Log.info(f"[Preprocess] vit_features from {paths.vit_features}")
# Get DPVO results
if not static_cam: # use slam to get cam rotation
if not Path(paths.slam).exists():
length, width, height = get_video_lwh(cfg.video_path)
K_fullimg = estimate_K(width, height)
intrinsics = convert_K_to_K4(K_fullimg)
slam = SLAMModel(video_path, width, height, intrinsics, buffer=4000, resize=0.5)
bar = tqdm(total=length, desc="DPVO")
while True:
ret = slam.track()
if ret:
bar.update()
else:
break
slam_results = slam.process() # (L, 7), numpy
torch.save(slam_results, paths.slam)
else:
Log.info(f"[Preprocess] slam results from {paths.slam}")
Log.info(f"[Preprocess] End. Time elapsed: {Log.time()-tic:.2f}s")
def load_data_dict(cfg):
paths = cfg.paths
length, width, height = get_video_lwh(cfg.video_path)
if cfg.static_cam:
R_w2c = torch.eye(3).repeat(length, 1, 1)
else:
traj = torch.load(cfg.paths.slam)
traj_quat = torch.from_numpy(traj[:, [6, 3, 4, 5]])
R_w2c = quaternion_to_matrix(traj_quat).mT
K_fullimg = estimate_K(width, height).repeat(length, 1, 1)
# K_fullimg = create_camera_sensor(width, height, 26)[2].repeat(length, 1, 1)
data = {
"length": torch.tensor(length),
"bbx_xys": torch.load(paths.bbx)["bbx_xys"],
"kp2d": torch.load(paths.vitpose),
"K_fullimg": K_fullimg,
"cam_angvel": compute_cam_angvel(R_w2c),
"f_imgseq": torch.load(paths.vit_features),
}
return data
def render_incam(cfg):
incam_video_path = Path(cfg.paths.incam_video)
if incam_video_path.exists():
Log.info(f"[Render Incam] Video already exists at {incam_video_path}")
return
pred = torch.load(cfg.paths.hmr4d_results)
smplx = make_smplx("supermotion").cuda()
smplx2smpl = torch.load("hmr4d/utils/body_model/smplx2smpl_sparse.pt").cuda()
faces_smpl = make_smplx("smpl").faces
# smpl
smplx_out = smplx(**to_cuda(pred["smpl_params_incam"]))
pred_c_verts = torch.stack([torch.matmul(smplx2smpl, v_) for v_ in smplx_out.vertices])
# -- rendering code -- #
video_path = cfg.video_path
length, width, height = get_video_lwh(video_path)
K = pred["K_fullimg"][0]
# renderer
renderer = Renderer(width, height, device="cuda", faces=faces_smpl, K=K)
reader = get_video_reader(video_path) # (F, H, W, 3), uint8, numpy
bbx_xys_render = torch.load(cfg.paths.bbx)["bbx_xys"]
# -- render mesh -- #
verts_incam = pred_c_verts
writer = get_writer(incam_video_path, fps=30, crf=CRF)
for i, img_raw in tqdm(enumerate(reader), total=get_video_lwh(video_path)[0], desc=f"Rendering Incam"):
img = renderer.render_mesh(verts_incam[i].cuda(), img_raw, [0.8, 0.8, 0.8])
# # bbx
# bbx_xys_ = bbx_xys_render[i].cpu().numpy()
# lu_point = (bbx_xys_[:2] - bbx_xys_[2:] / 2).astype(int)
# rd_point = (bbx_xys_[:2] + bbx_xys_[2:] / 2).astype(int)
# img = cv2.rectangle(img, lu_point, rd_point, (255, 178, 102), 2)
writer.write_frame(img)
writer.close()
reader.close()
def render_global(cfg):
global_video_path = Path(cfg.paths.global_video)
if global_video_path.exists():
Log.info(f"[Render Global] Video already exists at {global_video_path}")
return
debug_cam = False
pred = torch.load(cfg.paths.hmr4d_results)
smplx = make_smplx("supermotion").cuda()
smplx2smpl = torch.load("hmr4d/utils/body_model/smplx2smpl_sparse.pt").cuda()
faces_smpl = make_smplx("smpl").faces
J_regressor = torch.load("hmr4d/utils/body_model/smpl_neutral_J_regressor.pt").cuda()
# smpl
smplx_out = smplx(**to_cuda(pred["smpl_params_global"]))
pred_ay_verts = torch.stack([torch.matmul(smplx2smpl, v_) for v_ in smplx_out.vertices])
def move_to_start_point_face_z(verts):
"XZ to origin, Start from the ground, Face-Z"
# position
verts = verts.clone() # (L, V, 3)
offset = einsum(J_regressor, verts[0], "j v, v i -> j i")[0] # (3)
offset[1] = verts[:, :, [1]].min()
verts = verts - offset
# face direction
T_ay2ayfz = compute_T_ayfz2ay(einsum(J_regressor, verts[[0]], "j v, l v i -> l j i"), inverse=True)
verts = apply_T_on_points(verts, T_ay2ayfz)
return verts
verts_glob = move_to_start_point_face_z(pred_ay_verts)
joints_glob = einsum(J_regressor, verts_glob, "j v, l v i -> l j i") # (L, J, 3)
global_R, global_T, global_lights = get_global_cameras_static(
verts_glob.cpu(),
beta=2.0,
cam_height_degree=20,
target_center_height=1.0,
)
# -- rendering code -- #
video_path = cfg.video_path
length, width, height = get_video_lwh(video_path)
_, _, K = create_camera_sensor(width, height, 24) # render as 24mm lens
# renderer
renderer = Renderer(width, height, device="cuda", faces=faces_smpl, K=K)
# renderer = Renderer(width, height, device="cuda", faces=faces_smpl, K=K, bin_size=0)
# -- render mesh -- #
scale, cx, cz = get_ground_params_from_points(joints_glob[:, 0], verts_glob)
renderer.set_ground(scale * 1.5, cx, cz)
color = torch.ones(3).float().cuda() * 0.8
render_length = length if not debug_cam else 8
writer = get_writer(global_video_path, fps=30, crf=CRF)
for i in tqdm(range(render_length), desc=f"Rendering Global"):
cameras = renderer.create_camera(global_R[i], global_T[i])
img = renderer.render_with_ground(verts_glob[[i]], color[None], cameras, global_lights)
writer.write_frame(img)
writer.close()
if __name__ == "__main__":
cfg = parse_args_to_cfg()
paths = cfg.paths
Log.info(f"[GPU]: {torch.cuda.get_device_name()}")
Log.info(f'[GPU]: {torch.cuda.get_device_properties("cuda")}')
# ===== Preprocess and save to disk ===== #
run_preprocess(cfg)
data = load_data_dict(cfg)
# ===== HMR4D ===== #
if not Path(paths.hmr4d_results).exists():
Log.info("[HMR4D] Predicting")
model: DemoPL = hydra.utils.instantiate(cfg.model, _recursive_=False)
model.load_pretrained_model(cfg.ckpt_path)
model = model.eval().cuda()
tic = Log.sync_time()
pred = model.predict(data, static_cam=cfg.static_cam)
pred = detach_to_cpu(pred)
np.save(os.path.join(cfg.output_dir, 'smpl_orient.npy'), pred['smpl_params_incam']['global_orient'].numpy())
np.save(os.path.join(cfg.output_dir, 'smpl_transl.npy'), pred['smpl_params_incam']['transl'].numpy())
data_time = data["length"] / 30
Log.info(f"[HMR4D] Elapsed: {Log.sync_time() - tic:.2f}s for data-length={data_time:.1f}s")
torch.save(pred, paths.hmr4d_results)
# ===== Render ===== #
render_incam(cfg)
render_global(cfg)
if not Path(paths.incam_global_horiz_video).exists():
Log.info("[Merge Videos]")
merge_videos_horizontal([paths.incam_video, paths.global_video], paths.incam_global_horiz_video)
================================================
FILE: eval/GVHMR/tools/demo/demo_folder.py
================================================
import argparse
from pathlib import Path
from tqdm import tqdm
from hmr4d.utils.pylogger import Log
import subprocess
import os
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-f", "--folder", type=str)
parser.add_argument("-d", "--output_root", type=str, default=None)
parser.add_argument("-s", "--static_cam", action="store_true", help="If true, skip DPVO")
args = parser.parse_args()
output_root = args.output_root
sub_folders = os.listdir(args.folder)
mp4_paths = []
for sub_folder in sub_folders:
files = os.listdir(os.path.join(args.folder, sub_folder))
for file in files:
if file.endswith('.mp4'):
mp4_path = os.path.join(args.folder, sub_folder, file)
mp4_paths.append(mp4_path)
# Run demo.py for each .mp4 file
Log.info(f"Found {len(mp4_paths)} .mp4 files in {args.folder}")
for mp4_path in tqdm(mp4_paths):
try:
command = ["python", "tools/demo/demo.py", "--video", str(mp4_path)]
if output_root is not None:
command += ["--output_root", output_root]
if args.static_cam:
command += ["-s"]
Log.info(f"Running: {' '.join(command)}")
subprocess.run(command, env=dict(os.environ), check=True)
except:
continue
================================================
FILE: eval/GVHMR/tools/eval_pose.py
================================================
# ParticleSfM
# Copyright (C) 2022 ByteDance Inc.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
import os
from os.path import dirname
import numpy as np
import argparse
import tqdm
import numpy as np
import imageio
import os
from glob import glob
from multiprocessing import Pool
import json
import math
from scipy.spatial.transform import Rotation as R
def batch_rotation_matrix_angle_error(R1_batch, R2_batch):
assert R1_batch.shape == R2_batch.shape
assert R1_batch.shape[1:] == (3, 3)
B = R1_batch.shape[0]
angle_errors = np.zeros(B)
for i in range(B):
R_relative = np.dot(R1_batch[i].T, R2_batch[i])
r = R.from_matrix(R_relative)
angle_error = r.magnitude()
# angle_errors[i] = np.degrees(angle_error)
angle_errors[i] = angle_error
return angle_errors
def normalize(x):
return x / np.linalg.norm(x)
def viewmatrix(z, up, pos):
vec2 = normalize(z)
vec1_avg = up
vec0 = normalize(np.cross(vec1_avg, vec2))
vec1 = normalize(np.cross(vec2, vec0))
m = np.stack([vec0, vec1, vec2, pos], 1)
return m
def matrix_to_euler_angles(matrix):
sy = math.sqrt(matrix[0][0] * matrix[0][0] + matrix[1][0] * matrix[1][0])
singular = sy < 1e-6
if not singular:
x = math.atan2(matrix[2][1], matrix[2][2])
y = math.atan2(-matrix[2][0], sy)
z = math.atan2(matrix[1][0], matrix[0][0])
else:
x = math.atan2(-matrix[1][2], matrix[1][1])
y = math.atan2(-matrix[2][0], sy)
z = 0
return math.degrees(x), math.degrees(y), math.degrees(z)
def eul2rot(theta) :
R = np.array([[np.cos(theta[1])*np.cos(theta[2]), np.sin(theta[0])*np.sin(theta[1])*np.cos(theta[2]) - np.sin(theta[2])*np.cos(theta[0]), np.sin(theta[1])*np.cos(theta[0])*np.cos(theta[2]) + np.sin(theta[0])*np.sin(theta[2])],
[np.sin(theta[2])*np.cos(theta[1]), np.sin(theta[0])*np.sin(theta[1])*np.sin(theta[2]) + np.cos(theta[0])*np.cos(theta[2]), np.sin(theta[1])*np.sin(theta[2])*np.cos(theta[0]) - np.sin(theta[0])*np.cos(theta[2])],
[-np.sin(theta[1]), np.sin(theta[0])*np.cos(theta[1]), np.cos(theta[0])*np.cos(theta[1])]])
return R.T
def extract_location_rotation(data):
results = {}
for key, value in data.items():
matrix = parse_matrix(value)
location = np.array([matrix[3][0], matrix[3][1], matrix[3][2]])
rotation = eul2rot(matrix_to_euler_angles(matrix))
transofmed_matrix = np.identity(4)
transofmed_matrix[:3,3] = location
transofmed_matrix[:3,:3] = rotation
results[key] = transofmed_matrix
return results
def parse_matrix(matrix_str):
rows = matrix_str.strip().split('] [')
matrix = []
for row in rows:
row = row.replace('[', '').replace(']', '')
matrix.append(list(map(float, row.split())))
return np.array(matrix)
def batch_axis_angle_to_rotation_matrix(r_batch):
batch_size = r_batch.shape[0]
rotation_matrices = []
for i in range(batch_size):
r = r_batch[i]
theta = np.linalg.norm(r)
if theta == 0:
rotation_matrices.append(np.eye(3))
else:
k = r / theta
kx, ky, kz = k
K = np.array([
[0, -kz, ky],
[kz, 0, -kx],
[-ky, kx, 0]
])
# Rodrigues formulation
R = np.eye(3) + np.sin(theta) * K + (1 - np.cos(theta)) * np.dot(K, K)
rotation_matrices.append(R)
return np.array(rotation_matrices)
if __name__ == '__main__' :
parser = argparse.ArgumentParser()
parser.add_argument("-f", "--folder", type=str)
args = parser.parse_args()
folder_path = args.folder
video_files = os.listdir(folder_path)
for video_file in video_files:
if video_file.endswith('txt'):
video_files.remove(video_file)
num_video_files = len(video_files)
with open(('eval_gt_poses.json'), 'r') as f: eval_gt_poses = json.load(f)
transl_err_all, rotat_err_all = 0, 0
for video_file in tqdm.tqdm(sorted(video_files)):
obj_poses = np.array(eval_gt_poses[video_file])
start_frame_ind = 10
sample_n_frames = 77
frame_indices = np.linspace(start_frame_ind, start_frame_ind + sample_n_frames - 1, sample_n_frames, dtype=int)
obj_poses = obj_poses[frame_indices]
# load smpl pose
video_path = os.path.join(folder_path, video_file)
smpl_poses = np.zeros_like(obj_poses)
smpl_poses[:,3,3] = 1.
obj_rotats = np.load(os.path.join(video_path, 'smpl_orient.npy'))
smpl_poses[:,:3,:3] = batch_axis_angle_to_rotation_matrix(obj_rotats)
smpl_poses[:,:3,3] = np.load(os.path.join(video_path, 'smpl_transl.npy'))
# align y-axis orientation
smpl_poses[:,:3,3][:,1] *= -1.
smpl_poses[:,:,:2] *= -1.
# align pose translation
translation_bias = smpl_poses[0,:3,3] - obj_poses[0,:3,3]
smpl_poses[:,:3,3] -= translation_bias
# evaluation
transl_err = np.linalg.norm(smpl_poses[:,:3,3] - obj_poses[:,:3,3],ord=2,axis=1).mean()
rotat_err = batch_rotation_matrix_angle_error(obj_poses[:,:3,:3],smpl_poses[:,:3,:3]).mean()
transl_err_all += transl_err
rotat_err_all += rotat_err
print(video_path)
print('transl_err:{:.3f}'.format(transl_err))
print('rotat_err:{:.3f}'.format(rotat_err))
print('transl_err_all:{:.3f}'.format(transl_err_all/num_video_files))
print('rotat_err_all:{:.3f}'.format(rotat_err_all/num_video_files))
================================================
FILE: eval/GVHMR/tools/train.py
================================================
import hydra
import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks.checkpoint import Checkpoint
from hmr4d.utils.pylogger import Log
from hmr4d.configs import register_store_gvhmr
from hmr4d.utils.vis.rich_logger import print_cfg
from hmr4d.utils.net_utils import load_pretrained_model, get_resume_ckpt_path
def get_callbacks(cfg: DictConfig) -> list:
"""Parse and instantiate all the callbacks in the config."""
if not hasattr(cfg, "callbacks") or cfg.callbacks is None:
return None
# Handle special callbacks
enable_checkpointing = cfg.pl_trainer.get("enable_checkpointing", True)
# Instantiate all the callbacks
callbacks = []
for callback in cfg.callbacks.values():
if callback is not None:
cb = hydra.utils.instantiate(callback, _recursive_=False)
# skip when disable checkpointing and the callback is Checkpoint
if not enable_checkpointing and isinstance(cb, Checkpoint):
continue
else:
callbacks.append(cb)
return callbacks
def train(cfg: DictConfig) -> None:
"""Train/Test"""
Log.info(f"[Exp Name]: {cfg.exp_name}")
if cfg.task == "fit":
Log.info(f"[GPU x Batch] = {cfg.pl_trainer.devices} x {cfg.data.loader_opts.train.batch_size}")
pl.seed_everything(cfg.seed)
# preparation
datamodule: pl.LightningDataModule = hydra.utils.instantiate(cfg.data, _recursive_=False)
model: pl.LightningModule = hydra.utils.instantiate(cfg.model, _recursive_=False)
if cfg.ckpt_path is not None:
load_pretrained_model(model, cfg.ckpt_path)
# PL callbacks and logger
callbacks = get_callbacks(cfg)
has_ckpt_cb = any([isinstance(cb, Checkpoint) for cb in callbacks])
if not has_ckpt_cb and cfg.pl_trainer.get("enable_checkpointing", True):
Log.warning("No checkpoint-callback found. Disabling PL auto checkpointing.")
cfg.pl_trainer = {**cfg.pl_trainer, "enable_checkpointing": False}
logger = hydra.utils.instantiate(cfg.logger, _recursive_=False)
# PL-Trainer
if cfg.task == "test":
Log.info("Test mode forces full-precision.")
cfg.pl_trainer = {**cfg.pl_trainer, "precision": 32}
trainer = pl.Trainer(
accelerator="gpu",
logger=logger if logger is not None else False,
callbacks=callbacks,
**cfg.pl_trainer,
)
if cfg.task == "fit":
resume_path = None
if cfg.resume_mode is not None:
resume_path = get_resume_ckpt_path(cfg.resume_mode, ckpt_dir=cfg.callbacks.model_checkpoint.dirpath)
Log.info(f"Resume training from {resume_path}")
Log.info("Start Fitiing...")
trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader(), ckpt_path=resume_path)
elif cfg.task == "test":
Log.info("Start Testing...")
trainer.test(model, datamodule.test_dataloader())
else:
raise ValueError(f"Unknown task: {cfg.task}")
Log.info("End of script.")
@hydra.main(version_base="1.3", config_path="../hmr4d/configs", config_name="train")
def main(cfg) -> None:
print_cfg(cfg, use_rich=True)
train(cfg)
if __name__ == "__main__":
register_store_gvhmr()
main()
================================================
FILE: eval/GVHMR/tools/unitest/make_hydra_cfg.py
================================================
from hmr4d.configs import parse_args_to_cfg, register_store_gvhmr
from hmr4d.utils.vis.rich_logger import print_cfg
if __name__ == "__main__":
register_store_gvhmr()
cfg = parse_args_to_cfg()
print_cfg(cfg, use_rich=True)
================================================
FILE: eval/GVHMR/tools/unitest/run_dataset.py
================================================
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
def get_dataset(DATA_TYPE):
if DATA_TYPE == "BEDLAM_V2":
from hmr4d.dataset.bedlam.bedlam import BedlamDatasetV2
return BedlamDatasetV2()
if DATA_TYPE == "3DPW_TRAIN":
from hmr4d.dataset.threedpw.threedpw_motion_train import ThreedpwSmplDataset
return ThreedpwSmplDataset()
if __name__ == "__main__":
DATA_TYPE = "3DPW_TRAIN"
dataset = get_dataset(DATA_TYPE)
print(len(dataset))
data = dataset[0]
from hmr4d.datamodule.mocap_trainX_testY import collate_fn
loader = DataLoader(
dataset,
shuffle=False,
num_workers=0,
persistent_workers=False,
pin_memory=False,
batch_size=1,
collate_fn=collate_fn,
)
i = 0
for batch in tqdm(loader):
i += 1
# if i == 20:
# raise AssertionError
# time.sleep(0.2)
pass
================================================
FILE: eval/GVHMR/tools/video/merge_folder.py
================================================
"""This script will glob two folder, check the mp4 files are one-to-one match precisely, then call merge_horizontal.py to merge them one by one"""
import os
import argparse
from pathlib import Path
def main():
parser = argparse.ArgumentParser()
parser.add_argument("input_dir1", type=str)
parser.add_argument("input_dir2", type=str)
parser.add_argument("output_dir", type=str)
parser.add_argument("--vertical", action="store_true") # By default use horizontal
args = parser.parse_args()
# Check input
input_dir1 = Path(args.input_dir1)
input_dir2 = Path(args.input_dir2)
assert input_dir1.exists()
assert input_dir2.exists()
video_paths1 = sorted(input_dir1.glob("*.mp4"))
video_paths2 = sorted(input_dir2.glob("*.mp4"))
assert len(video_paths1) == len(video_paths2)
for path1, path2 in zip(video_paths1, video_paths2):
assert path1.stem == path2.stem
# Merge to output
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for path1, path2 in zip(video_paths1, video_paths2):
out_path = output_dir / f"{path1.stem}.mp4"
in_paths = [str(path1), str(path2)]
print(f"Merging {in_paths} to {out_path}")
if args.vertical:
os.system(f"python tools/video/merge_vertical.py {' '.join(in_paths)} -o {out_path}")
else:
os.system(f"python tools/video/merge_horizontal.py {' '.join(in_paths)} -o {out_path}")
if __name__ == "__main__":
main()
================================================
FILE: eval/GVHMR/tools/video/merge_horizontal.py
================================================
import argparse
from hmr4d.utils.video_io_utils import merge_videos_horizontal
def parse_args():
"""python tools/video/merge_horizontal.py a.mp4 b.mp4 c.mp4 -o out.mp4"""
parser = argparse.ArgumentParser()
parser.add_argument("input_videos", nargs="+", help="Input video paths")
parser.add_argument("-o", "--output", type=str, required=True, help="Output video path")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
merge_videos_horizontal(args.input_videos, args.output)
================================================
FILE: eval/GVHMR/tools/video/merge_vertical.py
================================================
import argparse
from hmr4d.utils.video_io_utils import merge_videos_vertical
def parse_args():
"""python tools/video/merge_vertical.py a.mp4 b.mp4 c.mp4 -o out.mp4"""
parser = argparse.ArgumentParser()
parser.add_argument("input_videos", nargs="+", help="Input video paths")
parser.add_argument("-o", "--output", type=str, required=True, help="Output video path")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
merge_videos_vertical(args.input_videos, args.output)
================================================
FILE: eval/common_metrics_on_video_quality/.gitignore
================================================
__pycache__
================================================
FILE: eval/common_metrics_on_video_quality/README.md
================================================
# common_metrics_on_video_quality
You can easily calculate the following video quality metrics:
- **FVD**: Frechét Video Distance
- **SSIM**: structural similarity index measure
- **LPIPS**: learned perceptual image patch similarity
- **PSNR**: peak-signal-to-noise ratio
As for FVD
1. The codebase refers to [MVCD](https://github.com/voletiv/mcvd-pytorch) and other websites and projects, I've just extracted the part of it that's relevant to the calculation. This code can be used to evaluate FVD scores for generative or predictive models.
2. Now **we have supported 2 pytorch-based FVD implementations** ([videogpt](https://github.com/wilson1yan/VideoGPT) and [styleganv](https://github.com/universome/stylegan-v), see issue [#4](https://github.com/JunyaoHu/common_metrics_on_video_quality/issues/4)). Their calculations are almost identical, and the difference is negligible.
3. FVD calculates the feature distance between two sets of videos. (the I3D features of each video are do not go through the softmax() function, and the size of the last dimension is 400, not 1024)
And...
- This project supports grayscale and RGB videos.
- This project supports Ubuntu, but maybe something is wrong with Windows. If you can solve it, welcome any PR.
- **If the project cannot run correctly, please give me an issue or PR~**
- For more details see below Notice.
# Example
8 videos of a batch, 10 frames, 3 channels, 64x64 size.
```
import torch
from calculate_fvd import calculate_fvd
from calculate_psnr import calculate_psnr
from calculate_ssim import calculate_ssim
from calculate_lpips import calculate_lpips
NUMBER_OF_VIDEOS = 8
VIDEO_LENGTH = 30
CHANNEL = 3
SIZE = 64
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
device = torch.device("cuda")
device = torch.device("cpu")
import json
result = {}
result['fvd'] = calculate_fvd(videos1, videos2, device, method='styleganv')
# result['fvd'] = calculate_fvd(videos1, videos2, device, method='videogpt')
result['ssim'] = calculate_ssim(videos1, videos2)
result['psnr'] = calculate_psnr(videos1, videos2)
result['lpips'] = calculate_lpips(videos1, videos2, device)
print(json.dumps(result, indent=4))
```
It means we calculate:
- `FVD-frames[:10]`, `FVD-frames[:11]`, ..., `FVD-frames[:30]`
- `avg-PSNR/SSIM/LPIPS-frame[0]`, `avg-PSNR/SSIM/LPIPS-frame[1]`, ..., `avg-PSNR/SSIM/LPIPS-frame[:30]`, and their std.
We cannot calculate `FVD-frames[:8]`, and it will pass when calculating, see ps.6.
The result shows: a all-zero matrix and a all-one matrix, their FVD-30 (FVD[:30]) is 151.17 (styleganv method). We also calculate their standard deviation. Other metrics are the same. And we use the calculation method of styleganv.
```
{
"fvd": {
"value": {
"10": 570.07320378183,
"11": 486.1906542471159,
"12": 552.3373915075898,
"13": 146.6242330185728,
"14": 172.57268402948895,
"15": 133.88932632116126,
"16": 153.11023578170108,
"17": 357.56400892781204,
"18": 382.1335612721498,
"19": 306.7100176942531,
"20": 338.18221898178774,
"21": 77.95587603163293,
"22": 82.49997632357349,
"23": 64.41624523513073,
"24": 66.08097153313875,
"25": 314.4341061962642,
"26": 316.8616746151064,
"27": 288.884418528541,
"28": 287.8192683223724,
"29": 152.15076552354864,
"30": 151.16806952692093
},
"video_setting": [
8,
3,
30,
64,
64
],
"video_setting_name": "batch_size, channel, time, heigth, width"
},
"video_setting": [
8,
3,
30,
64,
64
],
"video_setting_name": "batch_size, channel, time, heigth, width"
},
"ssim": {
"value": {
"0": 9.999000099990664e-05,
...,
"29": 9.999000099990664e-05
},
"value_std": {
"0": 0.0,
...,
"29": 0.0
},
"video_setting": [
30,
3,
64,
64
],
"video_setting_name": "time, channel, heigth, width"
},
"psnr": {
"value": {
"0": 0.0,
...,
"29": 0.0
},
"value_std": {
"0": 0.0,
...,
"29": 0.0
},
"video_setting": [
30,
3,
64,
64
],
"video_setting_name": "time, channel, heigth, width"
},
"lpips": {
"value": {
"0": 0.8140146732330322,
...,
"29": 0.8140146732330322
},
"value_std": {
"0": 0.0,
...,
"29": 0.0
},
"video_setting": [
30,
3,
64,
64
],
"video_setting_name": "time, channel, heigth, width"
}
}
```
# Notice
1. You should `pip install lpips` first.
3. Make sure the pixel value of videos should be in [0, 1].
2. If you have something wrong with downloading FVD pre-trained model, you should manually download any of the following and put it into FVD folder.
- `i3d_torchscript.pt` from [here](https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt)
- `i3d_pretrained_400.pt` from [here](https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI)
4. For grayscale videos, we multiply to 3 channels [as it says](https://github.com/richzhang/PerceptualSimilarity/issues/23#issuecomment-492368812).
5. We average SSIM when images have 3 channels, ssim is the only metric extremely sensitive to gray being compared to b/w.
6. Because the i3d model downsamples in the time dimension, `frames_num` should > 10 when calculating FVD, so FVD calculation begins from 10-th frame, like upper example.
7. You had better use `scipy==1.7.3/1.9.3`, if you use 1.11.3, **you will calculate a WRONG FVD VALUE!!!**
8. If you are running demo.py on a multi-GPU machine, remember to export CUDA_VISIBLE_DEVICES=0, see [here](https://github.com/JunyaoHu/common_metrics_on_video_quality/issues/13).
# Star Trend
## Star History
[](https://star-history.com/#JunyaoHu/common_metrics_on_video_quality&Date)
================================================
FILE: eval/common_metrics_on_video_quality/calculate_clip.py
================================================
import cv2
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
import json
import os
from tqdm import tqdm
import torch
import clip
from PIL import Image
import cv2
import numpy as np
import os
import argparse
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
def get_video_scores(video_path, prompt):
video = cv2.VideoCapture(video_path)
texts = [prompt]
clip_score_list = []
while True:
ret, frame = video.read()
if ret:
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
inputs = processor(text=texts, images=[image], return_tensors="pt", padding=True, truncation=True).to(device)
logits_per_image = model(**inputs).logits_per_image
clip_score = logits_per_image.item()
clip_score_list.append(clip_score)
else:
break
video.release()
return sum(clip_score_list) / len(clip_score_list)
parser = argparse.ArgumentParser()
parser.add_argument("-v_f", "--videos_folder", type=str)
args = parser.parse_args()
videos_folder_path = args.videos_folder
prompts_path = '/ytech_m2v2_hdd/fuxiao/scenectrl/common_metrics_on_video_quality/eval_prompts.json'
with open(prompts_path, "r", encoding="utf-8") as f: prompts_dict = json.load(f)
sub_folders = os.listdir(videos_folder_path)
videos_name = []
for sub_folder in sub_folders:
files = os.listdir(os.path.join(videos_folder_path, sub_folder))
for file in files:
if file.endswith('.mp4'):
video_name = os.path.join(sub_folder, file)
videos_name.append(video_name)
num_videos = len(videos_name)
prompts = []
video_paths = []
for video_name in videos_name:
prompt = prompts_dict[video_name.split('/')[0]]
video_path = os.path.join(videos_folder_path, video_name)
prompts.append(prompt)
video_paths.append(video_path)
import csv
CLIP_T = True
if CLIP_T:
scores = []
for i in tqdm(range(num_videos)):
# 加载图片
video_path = video_paths[i]
# 准备文本
texts = prompts[i]
score = get_video_scores(video_path, texts)
scores.append(score)
print(f"CLIP-SIM: {sum(scores)/len(scores)/100.}")
#### CLIP-T ####
# basemodel: 33.44
================================================
FILE: eval/common_metrics_on_video_quality/calculate_fvd.py
================================================
import numpy as np
import torch
from tqdm import tqdm
def trans(x):
# if greyscale images add channel
if x.shape[-3] == 1:
x = x.repeat(1, 1, 3, 1, 1)
# permute BTCHW -> BCTHW
x = x.permute(0, 2, 1, 3, 4)
return x
def calculate_fvd(videos1, videos2, device, method='styleganv'):
if method == 'styleganv':
from fvd.styleganv.fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained
elif method == 'videogpt':
from fvd.videogpt.fvd import load_i3d_pretrained
from fvd.videogpt.fvd import get_fvd_logits as get_fvd_feats
from fvd.videogpt.fvd import frechet_distance
print("calculate_fvd...")
# videos [batch_size, timestamps, channel, h, w]
assert videos1.shape == videos2.shape
i3d = load_i3d_pretrained(device=device)
fvd_results = []
# support grayscale input, if grayscale -> channel*3
# BTCHW -> BCTHW
# videos -> [batch_size, channel, timestamps, h, w]
videos1 = trans(videos1)
videos2 = trans(videos2)
fvd_results = {}
# for calculate FVD, each clip_timestamp must >= 10
# get a video clip
# videos_clip [batch_size, channel, timestamps[:clip], h, w]
videos_clip1 = videos1[:, :, :]
videos_clip2 = videos2[:, :, :]
# get FVD features
feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device)
feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device)
# calculate FVD when timestamps[:clip]
fvd_results = frechet_distance(feats1, feats2)
result = {
"value": fvd_results,
"video_setting": videos1.shape,
"video_setting_name": "batch_size, channel, time, heigth, width",
}
return result
# test code / using example
def main():
NUMBER_OF_VIDEOS = 8
VIDEO_LENGTH = 50
CHANNEL = 3
SIZE = 64
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
device = torch.device("cuda")
# device = torch.device("cpu")
import json
result = calculate_fvd(videos1, videos2, device, method='videogpt')
print(json.dumps(result, indent=4))
result = calculate_fvd(videos1, videos2, device, method='styleganv')
print(json.dumps(result, indent=4))
if __name__ == "__main__":
main()
================================================
FILE: eval/common_metrics_on_video_quality/calculate_fvd_styleganv.py
================================================
import torch
from calculate_fvd import calculate_fvd
from calculate_psnr import calculate_psnr
from calculate_ssim import calculate_ssim
from calculate_lpips import calculate_lpips
import argparse
import os
import cv2
import decord
import numpy as np
import tqdm
import glob
import copy
os.environ["CUDA_VISIBLE_DEVICES"]="0"
parser = argparse.ArgumentParser()
parser.add_argument("-v1_f", "--videos1_folder", type=str)
parser.add_argument("-v2_f", "--videos2_folder", type=str)
args = parser.parse_args()
videos1_folder_path = args.videos1_folder
videos2_folder_path = args.videos2_folder
sub_folders = os.listdir(videos1_folder_path)
videos_name = []
for sub_folder in sub_folders:
files = os.listdir(os.path.join(videos1_folder_path, sub_folder))
for file in files:
if file.endswith('.mp4'):
video_name = os.path.join(sub_folder, file)
videos_name.append(video_name)
base_dir = os.path.dirname(videos2_folder_path)
base_name = os.path.basename(videos2_folder_path)
os.makedirs(f'{base_dir}/eval_1', exist_ok=True)
os.makedirs(f'{base_dir}/eval_2', exist_ok=True)
# ps: pixel value should be in [0, 1]!
NUMBER_OF_VIDEOS = len(videos_name)
VIDEO_LENGTH = 77
CHANNEL = 3
H_SIZE = 384
W_SIZE = 672
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, H_SIZE, W_SIZE, requires_grad=False)
videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, H_SIZE, W_SIZE, requires_grad=False)
for video_idx, video_name in tqdm.tqdm(enumerate(videos_name)):
print(video_name)
video_frames_path = os.path.join(videos1_folder_path, video_name)
cap = cv2.VideoCapture(video_frames_path)
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
ctx = decord.cpu(0)
reader = decord.VideoReader(video_frames_path, ctx=ctx, height=height, width=width)
frame_indexes = [frame_idx for frame_idx in range(VIDEO_LENGTH)]
try:
video_chunk = reader.get_batch(frame_indexes).asnumpy()
except:
video_chunk = reader.get_batch(frame_indexes).numpy()
for frame_idx in range(VIDEO_LENGTH):
cv2.imwrite(f'{base_dir}/eval_1/{video_idx:03d}_{frame_idx:02d}.png', video_chunk[frame_idx][:,:,::-1])
video_chunk = video_chunk.transpose(0,3,1,2)/255.
videos1[video_idx] = torch.from_numpy(video_chunk)
video_frames_path = os.path.join(videos2_folder_path, video_name)
cap = cv2.VideoCapture(video_frames_path)
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
ctx = decord.cpu(0)
reader = decord.VideoReader(video_frames_path, ctx=ctx, height=height, width=width)
frame_indexes = [frame_idx for frame_idx in range(VIDEO_LENGTH)]
try:
video_chunk = reader.get_batch(frame_indexes).asnumpy()
except:
video_chunk = reader.get_batch(frame_indexes).numpy()
for frame_idx in range(VIDEO_LENGTH):
cv2.imwrite(f'{base_dir}/eval_2/{video_idx:03d}_{frame_idx:02d}.png', video_chunk[frame_idx][:,:,::-1])
video_chunk = video_chunk.transpose(0,3,1,2)/255.
videos2[video_idx] = torch.from_numpy(video_chunk)
if NUMBER_OF_VIDEOS == 1:
videos1 = videos1.repeat(2,1,1,1,1)
videos2 = videos2.repeat(2,1,1,1,1)
print('load videos done')
device = torch.device("cuda")
import json
result = {}
result['fvd_styleganv'] = calculate_fvd(videos1, videos2, device, method='styleganv')
# result['fvd_videogpt'] = calculate_fvd(videos1, videos2, device, method='videogpt')
fvd_value = result['fvd_styleganv']['value']
print(f'FVD: {fvd_value}')
================================================
FILE: eval/common_metrics_on_video_quality/calculate_lpips.py
================================================
import numpy as np
import torch
from tqdm import tqdm
import math
import torch
import lpips
spatial = True # Return a spatial map of perceptual distance.
# Linearly calibrated models (LPIPS)
loss_fn = lpips.LPIPS(net='alex', spatial=spatial) # Can also set net = 'squeeze' or 'vgg'
# loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg'
def trans(x):
# if greyscale images add channel
if x.shape[-3] == 1:
x = x.repeat(1, 1, 3, 1, 1)
# value range [0, 1] -> [-1, 1]
x = x * 2 - 1
return x
def calculate_lpips(videos1, videos2, device):
# image should be RGB, IMPORTANT: normalized to [-1,1]
print("calculate_lpips...")
assert videos1.shape == videos2.shape
# videos [batch_size, timestamps, channel, h, w]
# support grayscale input, if grayscale -> channel*3
# value range [0, 1] -> [-1, 1]
videos1 = trans(videos1)
videos2 = trans(videos2)
lpips_results = []
for video_num in tqdm(range(videos1.shape[0])):
# get a video
# video [timestamps, channel, h, w]
video1 = videos1[video_num]
video2 = videos2[video_num]
lpips_results_of_a_video = []
for clip_timestamp in range(len(video1)):
# get a img
# img [timestamps[x], channel, h, w]
# img [channel, h, w] tensor
img1 = video1[clip_timestamp].unsqueeze(0).to(device)
img2 = video2[clip_timestamp].unsqueeze(0).to(device)
loss_fn.to(device)
# calculate lpips of a video
lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist())
lpips_results.append(lpips_results_of_a_video)
lpips_results = np.array(lpips_results)
lpips = {}
lpips_std = {}
for clip_timestamp in range(len(video1)):
lpips[clip_timestamp] = np.mean(lpips_results[:,clip_timestamp])
lpips_std[clip_timestamp] = np.std(lpips_results[:,clip_timestamp])
result = {
"value": lpips,
"value_std": lpips_std,
"video_setting": video1.shape,
"video_setting_name": "time, channel, heigth, width",
}
return result
# test code / using example
def main():
NUMBER_OF_VIDEOS = 8
VIDEO_LENGTH = 50
CHANNEL = 3
SIZE = 64
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
device = torch.device("cuda")
# device = torch.device("cpu")
import json
result = calculate_lpips(videos1, videos2, device)
print(json.dumps(result, indent=4))
if __name__ == "__main__":
main()
================================================
FILE: eval/common_metrics_on_video_quality/calculate_psnr.py
================================================
import numpy as np
import torch
from tqdm import tqdm
import math
def img_psnr(img1, img2):
# [0,1]
# compute mse
# mse = np.mean((img1-img2)**2)
mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2)
# compute psnr
if mse < 1e-10:
return 100
psnr = 20 * math.log10(1 / math.sqrt(mse))
return psnr
def trans(x):
return x
def calculate_psnr(videos1, videos2):
print("calculate_psnr...")
# videos [batch_size, timestamps, channel, h, w]
assert videos1.shape == videos2.shape
videos1 = trans(videos1)
videos2 = trans(videos2)
psnr_results = []
for video_num in tqdm(range(videos1.shape[0])):
# get a video
# video [timestamps, channel, h, w]
video1 = videos1[video_num]
video2 = videos2[video_num]
psnr_results_of_a_video = []
for clip_timestamp in range(len(video1)):
# get a img
# img [timestamps[x], channel, h, w]
# img [channel, h, w] numpy
img1 = video1[clip_timestamp].numpy()
img2 = video2[clip_timestamp].numpy()
# calculate psnr of a video
psnr_results_of_a_video.append(img_psnr(img1, img2))
psnr_results.append(psnr_results_of_a_video)
psnr_results = np.array(psnr_results)
psnr = {}
psnr_std = {}
for clip_timestamp in range(len(video1)):
psnr[clip_timestamp] = np.mean(psnr_results[:,clip_timestamp])
psnr_std[clip_timestamp] = np.std(psnr_results[:,clip_timestamp])
result = {
"value": psnr,
"value_std": psnr_std,
"video_setting": video1.shape,
"video_setting_name": "time, channel, heigth, width",
}
return result
# test code / using example
def main():
NUMBER_OF_VIDEOS = 8
VIDEO_LENGTH = 50
CHANNEL = 3
SIZE = 64
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
import json
result = calculate_psnr(videos1, videos2)
print(json.dumps(result, indent=4))
if __name__ == "__main__":
main()
================================================
FILE: eval/common_metrics_on_video_quality/calculate_ssim.py
================================================
import numpy as np
import torch
from tqdm import tqdm
import cv2
def ssim(img1, img2):
C1 = 0.01 ** 2
C2 = 0.03 ** 2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1 ** 2
mu2_sq = mu2 ** 2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
def calculate_ssim_function(img1, img2):
# [0,1]
# ssim is the only metric extremely sensitive to gray being compared to b/w
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
if img1.ndim == 2:
return ssim(img1, img2)
elif img1.ndim == 3:
if img1.shape[0] == 3:
ssims = []
for i in range(3):
ssims.append(ssim(img1[i], img2[i]))
return np.array(ssims).mean()
elif img1.shape[0] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
else:
raise ValueError('Wrong input image dimensions.')
def trans(x):
return x
def calculate_ssim(videos1, videos2):
print("calculate_ssim...")
# videos [batch_size, timestamps, channel, h, w]
assert videos1.shape == videos2.shape
videos1 = trans(videos1)
videos2 = trans(videos2)
ssim_results = []
for video_num in tqdm(range(videos1.shape[0])):
# get a video
# video [timestamps, channel, h, w]
video1 = videos1[video_num]
video2 = videos2[video_num]
ssim_results_of_a_video = []
for clip_timestamp in range(len(video1)):
# get a img
# img [timestamps[x], channel, h, w]
# img [channel, h, w] numpy
img1 = video1[clip_timestamp].numpy()
img2 = video2[clip_timestamp].numpy()
# calculate ssim of a video
ssim_results_of_a_video.append(calculate_ssim_function(img1, img2))
ssim_results.append(ssim_results_of_a_video)
ssim_results = np.array(ssim_results)
ssim = {}
ssim_std = {}
for clip_timestamp in range(len(video1)):
ssim[clip_timestamp] = np.mean(ssim_results[:,clip_timestamp])
ssim_std[clip_timestamp] = np.std(ssim_results[:,clip_timestamp])
result = {
"value": ssim,
"value_std": ssim_std,
"video_setting": video1.shape,
"video_setting_name": "time, channel, heigth, width",
}
return result
# test code / using example
def main():
NUMBER_OF_VIDEOS = 8
VIDEO_LENGTH = 50
CHANNEL = 3
SIZE = 64
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
device = torch.device("cuda")
import json
result = calculate_ssim(videos1, videos2)
print(json.dumps(result, indent=4))
if __name__ == "__main__":
main()
================================================
FILE: eval/common_metrics_on_video_quality/download_eval_visual.sh
================================================
gdown https://drive.google.com/uc\?id\=1U2hd6qvwKLfp7c8yGgcTqdqrP_lKJElB
gdown https://drive.google.com/uc\?id\=1jMH2-ZC0ZBgtqej5Sp-E5ebBIX7mk3Xz
gdown https://drive.google.com/uc\?id\=1kfdCDA5koYh9g3IkCCHb4XPch2CJAwek
unzip fvd.zip
unzip eval_sets.zip
unzip base_t2v_eval_sets.zip
mv eval_sets eval_folder/
mv base_t2v_eval_sets eval_folder/
rm -rf *.zip
================================================
FILE: eval/common_metrics_on_video_quality/eval_prompts.json
================================================
{
"0_D_loc5_530_t1n2_0212_Hemi12_1": "a man with short spiky brown hair, athletic build, a navy blue jacket, beige cargo pants, and black sneakers is moving in the fjord",
"1_D_loc1_31_t1n6_001f_Hemi12_1": "a woman with long wavy blonde hair, petite figure, a red floral dress, white sandals, and a yellow shoulder bag is moving in the sunset beach",
"2_D_loc1_296_t1n8_0128_Hemi12_1": "a man with a shaved head, broad shoulders, a gray graphic t-shirt, dark jeans, and brown leather boots is moving in the cave",
"3_D_loc5_1100_t1n9_044c_Hemi12_1": "a woman with shoulder-length straight auburn hair, a slender figure, a green button-up blouse, black leggings, and white sneakers is moving in the snowy tundra",
"4_D_loc5_590_t1n11_024e_Hemi12_1": "a man with messy black hair, tall frame, a plaid red and black shirt, faded blue jeans, and tan hiking boots is moving in the prairie",
"5_D_loc5_1125_t1n13_0465_Hemi12_1": "a man with medium-length straight brown hair, tall and slender, a gray crew-neck t-shirt, beige trousers, and dark green sneakers is moving in the asian town",
"6_D_loc3_1158_t1n15_0486_Hemi12_1": "a woman with short curly black hair, slender build, a pink hoodie, light gray joggers, and blue sneakers is moving in the rainforest",
"7_D_loc4_1469_t1n17_05bd_Hemi12_1": "a man with short black wavy hair, lean figure, a green and yellow plaid shirt, dark brown pants, and black suede shoes is moving in the canyon",
"8_D_loc1_351_t1n18_015f_Hemi12_1": "a man with curly black hair, muscular build, a dark green hoodie, gray joggers, and white running shoes is moving in the savanna",
"9_D_loc5_1475_t1n31_05c3_Hemi12_1": "a woman with short blonde hair, slim athletic build, a red leather jacket, dark blue jeans, and white sneakers is moving in the urban rooftop garden",
"10_D_loc5_730_t1n34_02da_Hemi12_1": "a man with medium-length wavy brown hair, lean build, a black bomber jacket, olive green cargo pants, and brown hiking boots is moving in the swamp",
"11_D_loc3_1028_t1n38_0404_Hemi12_1": "a man with buzz-cut blonde hair, stocky build, a gray zip-up sweater, black shorts, and red basketball shoes is moving in the riverbank",
"12_D_loc1_01_t2n1_0001_Hemi12_1": "a man with short spiky brown hair, athletic build, a navy blue jacket, beige cargo pants, and black sneakers and a dog with a fluffy coat, wagging tail, and warm golden-brown fur, exuding a gentle and friendly charm are moving in the fjord",
"13_D_loc1_136_t2n2_0088_Hemi12_1": "a woman with long wavy blonde hair, petite figure, a red floral dress, white sandals, and a yellow shoulder bag and a tiger with vibrant orange and black stripes, piercing yellow eyes, and a powerful stance, exuding strength and grace are moving in the sunset beach",
"14_D_loc2_137_t2n3_0089_Hemi12_1": "a man with a shaved head, broad shoulders, a gray graphic t-shirt, dark jeans, and brown leather boots and a giraffe with golden-yellow fur, long legs, a tall slender neck, and patches of brown spots, exuding elegance and calm are moving in the cave",
"15_D_loc1_141_t2n6_008d_Hemi12_1": "a woman with shoulder-length straight auburn hair, a slender figure, a green button-up blouse, black leggings, and white sneakers and an alpaca with soft white wool, short legs, a thick neck, and a fluffy head of fur, radiating gentle charm are moving in the snowy tundra",
"16_D_loc2_142_t2n7_008e_Hemi12_1": "a man with messy black hair, tall frame, a plaid red and black shirt, faded blue jeans, and tan hiking boots and a zebra with black and white stripes, sturdy legs, a short neck, and a sleek mane running down its back are moving in the prairie",
"17_D_loc1_506_t2n9_01fa_Hemi12_1": "a man with medium-length straight brown hair, tall and slender, a gray crew-neck t-shirt, beige trousers, and dark green sneakers and an alpaca with soft white wool, short legs, a thick neck, and a fluffy head of fur, radiating gentle charm are moving in the asian town",
"18_D_loc1_11_t2n10_000b_Hemi12_1": "a woman with short curly black hair, slender build, a pink hoodie, light gray joggers, and blue sneakers and a gazelle with light golden fur, long slender legs, a thin neck, and short, sharp horns, embodying elegance and agility are moving in the rainforest",
"19_D_loc1_81_t2n11_0051_Hemi12_1": "a man with short black wavy hair, lean figure, a green and yellow plaid shirt, dark brown pants, and black suede shoes and a horse with chestnut brown fur, muscular legs, a slim neck, and a flowing mane, exuding strength and grace are moving in the canyon",
"20_D_loc1_16_t2n13_0010_Hemi12_1": "a man with curly black hair, muscular build, a dark green hoodie, gray joggers, and white running shoes and a sleek black panther with a smooth, glossy coat, emerald green eyes, and a powerful stance are moving in the savanna",
"21_D_loc1_86_t2n14_0056_Hemi12_1": "a woman with short blonde hair, slim athletic build, a red leather jacket, dark blue jeans, and white sneakers and a cheetah with golden fur covered in black spots, intense amber eyes, and a slender, agile body are moving in the urban rooftop garden",
"22_D_loc1_1446_t2n15_05a6_Hemi12_1": "a man with medium-length wavy brown hair, lean build, a black bomber jacket, olive green cargo pants, and brown hiking boots and a regal lion with a thick, flowing golden mane, sharp brown eyes, and a powerful muscular frame are moving in the swamp",
"23_D_loc2_92_t2n17_005c_Hemi12_1": "a man with buzz-cut blonde hair, stocky build, a gray zip-up sweater, black shorts, and red basketball shoes and a snow leopard with pale gray fur adorned with dark rosettes, icy blue eyes, and a stealthy, poised posture are moving in the riverbank",
"24_D_loc1_156_t2n19_009c_Hemi12_1": "a woman with long straight black hair, toned build, a blue denim jacket, light gray leggings, and black slip-on shoes and a jaguar with a golden-yellow coat dotted with intricate black rosettes, deep green eyes, and a muscular build are moving in the coral reef",
"25_D_loc1_101_t2n21_0065_Hemi12_1": "a man with short curly red hair, average build, a black leather jacket, dark blue cargo pants, and white sneakers and a wolf with thick silver-gray fur, alert golden eyes, and a lean yet strong body, exuding confidence and boldness are moving in the volcanic landscape",
"26_D_loc1_161_t2n22_00a1_Hemi12_1": "a woman with shoulder-length wavy brown hair, slim build, a green parka, black leggings, and gray hiking boots and a tiger with a pristine white coat marked by bold black stripes, bright blue eyes, and a graceful, poised form are moving in the wind farm",
"27_D_loc1_36_t2n25_0024_Hemi12_1": "a man with short straight black hair, tall and lean build, a navy blue sweater, khaki shorts, and brown sandals and a lynx with tufted ears, soft reddish-brown fur with faint spots, and intense yellow-green eyes are moving in the town street",
"28_D_loc1_166_t2n26_00a6_Hemi12_1": "a woman with pixie-cut blonde hair, athletic build, a red windbreaker, blue ripped jeans, and black combat boots and a bear with dark brown fur, small but fierce black eyes, and a broad and muscular build, radiating power are moving in the night city square",
"29_D_loc1_221_t2n29_00dd_Hemi12_1": "a man with medium-length wavy gray hair, muscular build, a maroon t-shirt, beige chinos, and brown loafers and a swift fox with reddish-orange fur, a bushy tail tipped with white, and sharp, intelligent amber eyes are moving in the mall lobby",
"30_D_loc1_171_t2n30_00ab_Hemi12_1": "a woman with long curly black hair, average build, a purple hoodie, black athletic shorts, and white running shoes and a bear with dark brown fur, small but fierce black eyes, and a broad and muscular build, radiating power are moving in the glacier",
"31_D_loc1_121_t2n33_0079_Hemi12_1": "a man with short spiky blonde hair, slim build, a black trench coat, blue jeans, and brown hiking shoes and a fox with sleek russet fur, a bushy tail tipped with black, and bright green and cunning eyes are moving in the seaside street",
"32_D_loc1_176_t2n34_00b0_Hemi12_1": "a man with short spiky brown hair, athletic build, a navy blue jacket, beige cargo pants, and black sneakers and a kangaroo with brown fur, powerful hind legs, and a muscular tail, showcasing its strength and agility are moving in the gymnastics room",
"33_D_loc1_126_t2n36_007e_Hemi12_1": "a woman with long wavy blonde hair, petite figure, a red floral dress, white sandals, and a yellow shoulder bag and a polar bear with thick white fur, strong paws, and a black nose, embodying the essence of the Arctic are moving in the abandoned factory",
"34_D_loc1_56_t2n38_0038_Hemi12_1": "a man with a shaved head, broad shoulders, a gray graphic t-shirt, dark jeans, and brown leather boots and a cheetah with a slender build, spotted golden fur, and sharp eyes, epitomizing speed and agility are moving in the autumn forest",
"35_D_loc1_131_t2n41_0083_Hemi12_1": "a woman with shoulder-length straight auburn hair, a slender figure, a green button-up blouse, black leggings, and white sneakers and a dolphin with sleek grey skin, a curved dorsal fin, and intelligent, playful eyes, reflecting its nature are moving in the mountain village",
"36_D_loc1_01_t2n1_0001_Hemi12_1": "a man with messy black hair, tall frame, a plaid red and black shirt, faded blue jeans, and tan hiking boots and a wolf with a body covered in thick silver fur, sharp ears, and piercing yellow eyes, showcasing its alertness are moving in the coastal harbor",
"37_D_loc1_136_t2n2_0088_Hemi12_1": "a man with medium-length straight brown hair, tall and slender, a gray crew-neck t-shirt, beige trousers, and dark green sneakers and a leopard with a body covered in golden fur, dark rosettes, and a long muscular tail, emphasizing its strength are moving in the ancient ruins",
"38_D_loc2_137_t2n3_0089_Hemi12_1": "a woman with short curly black hair, slender build, a pink hoodie, light gray joggers, and blue sneakers and a penguin with a body covered in smooth black-and-white feathers, short wings, and webbed feet are moving in the modern metropolis",
"39_D_loc1_141_t2n6_008d_Hemi12_1": "a man with short black wavy hair, lean figure, a green and yellow plaid shirt, dark brown pants, and black suede shoes and a gazelle with a body covered in sleek tan fur, long legs, and elegant curved horns, showcasing its grace are moving in the desert",
"40_D_loc2_142_t2n7_008e_Hemi12_1": "a man with curly black hair, muscular build, a dark green hoodie, gray joggers, and white running shoes and a rabbit with a body covered in soft fur, quick hops, and a playful demeanor, showcasing its energy are moving in the forest",
"41_D_loc1_506_t2n9_01fa_Hemi12_1": "a woman with short curly black hair, slender build, a pink hoodie, light gray joggers, and blue sneakers and a koala with a body covered in soft grey fur, large round ears, and a black nose, radiating cuteness are moving in the city",
"42_D_loc1_11_t2n10_000b_Hemi12_1": "a man with medium-length wavy brown hair, lean build, a black bomber jacket, olive green cargo pants, and brown hiking boots and a rhinoceros with a body covered in thick grey skin, a massive horn on its snout, and sturdy legs are moving in the snowy street",
"43_D_loc1_81_t2n11_0051_Hemi12_1": "a man with buzz-cut blonde hair, stocky build, a gray zip-up sweater, black shorts, and red basketball shoes and a flamingo with a body covered in pink feathers, long slender legs, and a gracefully curved neck are moving in the park",
"44_D_loc1_16_t2n13_0010_Hemi12_1": "a woman with long straight black hair, toned build, a blue denim jacket, light gray leggings, and black slip-on shoes and a parrot with bright red, blue, and yellow feathers, a curved beak, and sharp intelligent eyes are moving in the fjord",
"45_D_loc1_86_t2n14_0056_Hemi12_1": "a man with short curly red hair, average build, a black leather jacket, dark blue cargo pants, and white sneakers and a hippopotamus with a body covered in thick grey-brown skin, massive jaws, and a large body are moving in the sunset beach",
"46_D_loc1_1446_t2n15_05a6_Hemi12_1": "a woman with shoulder-length wavy brown hair, slim build, a green parka, black leggings, and gray hiking boots and a crocodile with a body covered in scaly green skin, a powerful tail, and sharp teeth are moving in the cave",
"47_D_loc2_92_t2n17_005c_Hemi12_1": "a man with short straight black hair, tall and lean build, a navy blue sweater, khaki shorts, and brown sandals and a moose with a body covered in thick brown fur, massive antlers, and a bulky frame are moving in the snowy tundra",
"48_D_loc1_156_t2n19_009c_Hemi12_1": "a woman with pixie-cut blonde hair, athletic build, a red windbreaker, blue ripped jeans, and black combat boots and a fluttering butterfly with intricate wing patterns, vivid colors, and graceful flight are moving in the prairie",
"49_D_loc1_101_t2n21_0065_Hemi12_1": "a man with medium-length wavy gray hair, muscular build, a maroon t-shirt, beige chinos, and brown loafers and a chameleon with a body covered in vibrant green scales, bulging eyes, and a curled tail, showcasing its unique charm are moving in the asian town",
"50_D_loc1_161_t2n22_00a1_Hemi12_1": "a woman with long curly black hair, average build, a purple hoodie, black athletic shorts, and white running shoes and a lemur with a body covered in soft grey fur, a ringed tail, and wide yellow eyes, and curious expression are moving in the rainforest",
"51_D_loc1_36_t2n25_0024_Hemi12_1": "a man with short spiky blonde hair, slim build, a black trench coat, blue jeans, and brown hiking shoes and a squirrel with a body covered in bushy red fur, large eyes, and a fluffy tail are moving in the canyon",
"52_D_loc1_166_t2n26_00a6_Hemi12_1": "a man with short spiky brown hair, athletic build, a navy blue jacket, beige cargo pants, and black sneakers and a panda with a body covered in fluffy black-and-white fur, a round face, and gentle eyes, radiating warmth are moving in the savanna",
"53_D_loc1_221_t2n29_00dd_Hemi12_1": "a woman with long wavy blonde hair, petite figure, a red floral dress, white sandals, and a yellow shoulder bag and a porcupine with a body covered in spiky brown quills, a small nose, and curious eyes are moving in the urban rooftop garden",
"54_D_loc1_171_t2n30_00ab_Hemi12_1": "a man with a shaved head, broad shoulders, a gray graphic t-shirt, dark jeans, and brown leather boots and a sedan with a sleek metallic silver body, long wheelbase, a low-profile hood, and a small rear spoiler are moving in the swamp",
"55_D_loc1_121_t2n33_0079_Hemi12_1": "a woman with shoulder-length straight auburn hair, a slender figure, a green button-up blouse, black leggings, and white sneakers and an SUV with a matte black exterior, elevated suspension, a tall roofline, and a compact rear roof rack are moving in the riverbank",
"56_D_loc1_176_t2n34_00b0_Hemi12_1": "a man with messy black hair, tall frame, a plaid red and black shirt, faded blue jeans, and tan hiking boots and a pickup truck with rugged dark green paint, extended cab, raised suspension, and a modest cargo bed cover are moving in the coral reef",
"57_D_loc1_126_t2n36_007e_Hemi12_1": "a man with medium-length straight brown hair, tall and slender, a gray crew-neck t-shirt, beige trousers, and dark green sneakers and a vintage convertible with a body covered in shiny red paint, chrome bumpers, and a stylish design are moving in the volcanic landscape",
"58_D_loc1_56_t2n38_0038_Hemi12_1": "a woman with short curly black hair, slender build, a pink hoodie, light gray joggers, and blue sneakers and a futuristic electric car with a minimalist silver design, slim LED lights, and smooth curves are moving in the wind farm",
"59_D_loc1_131_t2n41_0083_Hemi12_1": "a man with short black wavy hair, lean figure, a green and yellow plaid shirt, dark brown pants, and black suede shoes and a compact electric vehicle with a silver finish, aerodynamic profile, and efficient battery are moving in the town street",
"60_D_loc1_01_t2n1_0001_Hemi12_1": "a man with curly black hair, muscular build, a dark green hoodie, gray joggers, and white running shoes and a firefighting robot with a water cannon arm, heat sensors, and durable red-and-silver exterior are moving in the night city square",
"61_D_loc1_136_t2n2_0088_Hemi12_1": "a woman with short blonde hair, slim athletic build, a red leather jacket, dark blue jeans, and white sneakers and an industrial welding robot with articulated arms, a laser precision welder, and heat-resistant shields are moving in the mall lobby",
"62_D_loc2_137_t2n3_0089_Hemi12_1": "a woman with short blonde hair, slim athletic build, a red leather jacket, dark blue jeans, and white sneakers and a dog with a fluffy coat, wagging tail, and warm golden-brown fur, exuding a gentle and friendly charm are moving in the glacier",
"63_D_loc1_141_t2n6_008d_Hemi12_1": "a man with buzz-cut blonde hair, stocky build, a gray zip-up sweater, black shorts, and red basketball shoes and a disaster rescue robot with reinforced limbs, advanced AI, and a rugged body designed to navigate are moving in the seaside street",
"64_D_loc2_142_t2n7_008e_Hemi12_1": "a woman with long straight black hair, toned build, a blue denim jacket, light gray leggings, and black slip-on shoes and an exploration rover robot with solar panels, durable wheels, and advanced sensors for planetary exploration are moving in the gymnastics room",
"65_D_loc1_506_t2n9_01fa_Hemi12_1": "a man with short curly red hair, average build, a black leather jacket, dark blue cargo pants, and white sneakers and a dog with a fluffy coat, wagging tail, and warm golden-brown fur, exuding a gentle and friendly charm are moving in the abandoned factory",
"66_D_loc1_11_t2n10_000b_Hemi12_1": "a woman with shoulder-length wavy brown hair, slim build, a green parka, black leggings, and gray hiking boots and a tiger with vibrant orange and black stripes, piercing yellow eyes, and a powerful stance, exuding strength and grace are moving in the autumn forest",
"67_D_loc1_81_t2n11_0051_Hemi12_1": "a man with short straight black hair, tall and lean build, a navy blue sweater, khaki shorts, and brown sandals and a giraffe with golden-yellow fur, long legs, a tall slender neck, and patches of brown spots, exuding elegance and calm are moving in the mountain village",
"68_D_loc1_16_t2n13_0010_Hemi12_1": "a woman with pixie-cut blonde hair, athletic build, a red windbreaker, blue ripped jeans, and black combat boots and an alpaca with soft white wool, short legs, a thick neck, and a fluffy head of fur, radiating gentle charm are moving in the coastal harbor",
"69_D_loc1_86_t2n14_0056_Hemi12_1": "a man with medium-length wavy gray hair, muscular build, a maroon t-shirt, beige chinos, and brown loafers and a zebra with black and white stripes, sturdy legs, a short neck, and a sleek mane running down its back are moving in the ancient ruins",
"70_D_loc1_1446_t2n15_05a6_Hemi12_1": "a woman with long curly black hair, average build, a purple hoodie, black athletic shorts, and white running shoes and a deer with sleek tan fur, long slender legs, a graceful neck, and tiny antlers atop its head are moving in the modern metropolis",
"71_D_loc2_92_t2n17_005c_Hemi12_1": "a man with short spiky blonde hair, slim build, a black trench coat, blue jeans, and brown hiking shoes and a gazelle with light golden fur, long slender legs, a thin neck, and short, sharp horns, embodying elegance and agility are moving in the desert",
"72_D_loc1_156_t2n19_009c_Hemi12_1": "a man with short spiky brown hair, athletic build, a navy blue jacket, beige cargo pants, and black sneakers and a horse with chestnut brown fur, muscular legs, a slim neck, and a flowing mane, exuding strength and grace are moving in the forest",
"73_D_loc1_101_t2n21_0065_Hemi12_1": "a man with short spiky blonde hair, slim build, a black trench coat, blue jeans, and brown hiking shoes and a sleek black panther with a smooth, glossy coat, emerald green eyes, and a powerful stance are moving in the city",
"74_D_loc1_161_t2n22_00a1_Hemi12_1": "a man with a shaved head, broad shoulders, a gray graphic t-shirt, dark jeans, and brown leather boots and a cheetah with golden fur covered in black spots, intense amber eyes, and a slender, agile body are moving in the snowy street",
"75_D_loc1_36_t2n25_0024_Hemi12_1": "a woman with shoulder-length straight auburn hair, a slender figure, a green button-up blouse, black leggings, and white sneakers and a regal lion with a thick, flowing golden mane, sharp brown eyes, and a powerful muscular frame are moving in the park",
"76_D_loc1_166_t2n26_00a6_Hemi12_1": "a man with messy black hair, tall frame, a plaid red and black shirt, faded blue jeans, and tan hiking boots and a snow leopard with pale gray fur adorned with dark rosettes, icy blue eyes, and a stealthy, poised posture are moving in the fjord",
"77_D_loc1_221_t2n29_00dd_Hemi12_1": "a man with medium-length straight brown hair, tall and slender, a gray crew-neck t-shirt, beige trousers, and dark green sneakers and a jaguar with a golden-yellow coat dotted with intricate black rosettes, deep green eyes, and a muscular build are moving in the sunset beach",
"78_D_loc1_171_t2n30_00ab_Hemi12_1": "a woman with short curly black hair, slender build, a pink hoodie, light gray joggers, and blue sneakers and a wolf with thick silver-gray fur, alert golden eyes, and a lean yet strong body, exuding confidence and boldness are moving in the cave",
"79_D_loc1_121_t2n33_0079_Hemi12_1": "a man with short black wavy hair, lean figure, a green and yellow plaid shirt, dark brown pants, and black suede shoes and a tiger with a pristine white coat marked by bold black stripes, bright blue eyes, and a graceful, poised form are moving in the snowy tundra",
"80_D_loc1_176_t2n34_00b0_Hemi12_1": "a man with curly black hair, muscular build, a dark green hoodie, gray joggers, and white running shoes and a lynx with tufted ears, soft reddish-brown fur with faint spots, and intense yellow-green eyes are moving in the prairie",
"81_D_loc1_126_t2n36_007e_Hemi12_1": "a woman with short blonde hair, slim athletic build, a red leather jacket, dark blue jeans, and white sneakers and a bear with dark brown fur, small but fierce black eyes, and a broad and muscular build, radiating power are moving in the asian town",
"82_D_loc1_56_t2n38_0038_Hemi12_1": "a man with medium-length wavy brown hair, lean build, a black bomber jacket, olive green cargo pants, and brown hiking boots and a swift fox with reddish-orange fur, a bushy tail tipped with white, and sharp, intelligent amber eyes are moving in the rainforest",
"83_D_loc1_131_t2n41_0083_Hemi12_1": "a man with buzz-cut blonde hair, stocky build, a gray zip-up sweater, black shorts, and red basketball shoes and a falcon with blue-gray feathers, sharp talons, and keen yellow eyes fixed on its prey below are moving in the canyon",
"84_D_loc1_301_t3n3_012d_Hemi12_1": "a woman with long wavy blonde hair, petite figure, a red floral dress, white sandals, and a yellow shoulder bag and a squirrel with a body covered in bushy red fur, large eyes, and a fluffy tail and a penguin with a body covered in smooth black-and-white feathers, short wings, and webbed feet are moving in the sunset beach",
"85_D_loc3_133_t3n1_0085_Hemi12_1": "a woman with short curly black hair, slender build, a pink hoodie, light gray joggers, and blue sneakers and a leopard with a body covered in golden fur, dark rosettes, and a long muscular tail, emphasizing its strength and a rhinoceros with a body covered in thick grey skin, a massive horn on its snout, and sturdy legs are moving in the rainforest",
"86_D_loc1_46_t3n8_002e_Hemi12_1": "a woman with short blonde hair, slim athletic build, a red leather jacket, dark blue jeans, and white sneakers and a parrot with bright red, blue, and yellow feathers, a curved beak, and sharp intelligent eyes and a gazelle with light golden fur, long slender legs, a thin neck, and short, sharp horns, embodying elegance and agility are moving in the urban rooftop garden",
"87_D_loc2_192_t3n9_00c0_Hemi12_1": "a woman with pixie-cut blonde hair, athletic build, a red windbreaker, blue ripped jeans, and black combat boots and a fluttering butterfly with intricate wing patterns, vivid colors, and graceful flight and a sleek black panther with a smooth, glossy coat, emerald green eyes, and a powerful stance are moving in the night city square",
"88_D_loc3_93_t3n21_005d_Hemi12_1": "a man with medium-length wavy gray hair, muscular build, a maroon t-shirt, beige chinos, and brown loafers and a cheetah with golden fur covered in black spots, intense amber eyes, and a slender, agile body and a rhinoceros with a body covered in thick grey skin, a massive horn on its snout, and sturdy legs are moving in the mall lobby",
"89_D_loc1_301_t3n3_012d_Hemi12_1": "a man with short spiky blonde hair, slim build, a black trench coat, blue jeans, and brown hiking shoes and a fluttering butterfly with intricate wing patterns, vivid colors, and graceful flight and a dolphin with sleek grey skin, a curved dorsal fin, and intelligent, playful eyes, reflecting its nature are moving in the seaside street",
"90_D_loc1_46_t3n8_002e_Hemi12_1": "a woman with long wavy blonde hair, petite figure, a red floral dress, white sandals, and a yellow shoulder bag and a crocodile with a body covered in scaly green skin, a powerful tail, and sharp teeth and a dolphin with sleek grey skin, a curved dorsal fin, and intelligent, playful eyes, reflecting its nature are moving in the abandoned factory",
"91_D_loc3_93_t3n21_005d_Hemi12_1": "a man with short straight black hair, tall and lean build, a navy blue sweater, khaki shorts, and brown sandals and a koala with a body covered in soft grey fur, large round ears, and a black nose, radiating cuteness and a parrot with bright red, blue, and yellow feathers, a curved beak, and sharp intelligent eyes are moving in the snowy tundra",
"92_D_loc2_192_t3n9_00c0_Hemi12_1": "a woman with short curly black hair, slender build, a pink hoodie, light gray joggers, and blue sneakers and a koala with a body covered in soft grey fur, large round ears, and a black nose, radiating cuteness and a cheetah with golden fur covered in black spots, intense amber eyes, and a slender, agile body are moving in the wind farm",
"93_D_loc3_133_t3n1_0085_Hemi12_1": "a man with curly black hair, muscular build, a dark green hoodie, gray joggers, and white running shoes and a sedan with a sleek metallic silver body, long wheelbase, a low-profile hood, and a small rear spoiler and a koala with a body covered in soft grey fur, large round ears, and a black nose, radiating cuteness are moving in the night city square",
"94_D_loc1_301_t3n3_012d_Hemi12_1": "a woman with short blonde hair, slim athletic build, a red leather jacket, dark blue jeans, and white sneakers and a sleek black panther with a smooth, glossy coat, emerald green eyes, and a powerful stance and a parrot with bright red, blue, and yellow feathers, a curved beak, and sharp intelligent eyes are moving in the mall lobby",
"95_D_loc3_93_t3n21_005d_Hemi12_1": "a man with short curly red hair, average build, a black leather jacket, dark blue cargo pants, and white sneakers and a koala with a body covered in soft grey fur, large round ears, and a black nose, radiating cuteness and a fox with sleek russet fur, a bushy tail tipped with black, and bright green and cunning eyes are moving in the abandoned factory",
"96_D_loc1_46_t3n8_002e_Hemi12_1": "a woman with shoulder-length straight auburn hair, a slender figure, a green button-up blouse, black leggings, and white sneakers and a rhinoceros with a body covered in thick grey skin, a massive horn on its snout, and sturdy legs and a tiger with vibrant orange and black stripes, piercing yellow eyes, and a powerful stance, exuding strength and grace are moving in the park",
"97_D_loc1_871_t3n7_0367_Hemi12_1": "a woman with shoulder-length wavy brown hair, slim build, a green parka, black leggings, and gray hiking boots and a polar bear with thick white fur, strong paws, and a black nose, embodying the essence of the Arctic and a vintage convertible with a body covered in shiny red paint, chrome bumpers, and a stylish design are moving in the swamp",
"98_D_loc3_133_t3n1_0085_Hemi12_1": "a woman with long curly black hair, average build, a purple hoodie, black athletic shorts, and white running shoes and a fluttering butterfly with intricate wing patterns, vivid colors, and graceful flight and a sleek black panther with a smooth, glossy coat, emerald green eyes, and a powerful stance are moving in the wind farm",
"99_D_loc3_133_t3n1_0085_Hemi12_1": "a woman with pixie-cut blonde hair, athletic build, a red windbreaker, blue ripped jeans, and black combat boots and a vintage convertible with a body covered in shiny red paint, chrome bumpers, and a stylish design and a falcon with blue-gray feathers, sharp talons, and keen yellow eyes fixed on its prey below are moving in the fjord"
}
================================================
FILE: eval/common_metrics_on_video_quality/eval_visual.sh
================================================
basedir=eval_folder
folder1_path=${basedir}/base_t2v_eval_sets
folder2_path=${basedir}/eval_sets
# calculate FVD
python calculate_fvd_styleganv.py -v1_f ${folder1_path} -v2_f ${folder2_path}
# calculate FID
python -m pytorch_fid ${basedir}/eval_1 ${basedir}/eval_2
# calculate CLIP-SIM
python calculate_clip.py -v_f ${folder2_path}
rm -rf ${basedir}/eval_1
rm -rf ${basedir}/eval_2