Repository: NJU-3DV/SpatialVID Branch: main Commit: 23840d4ec122 Files: 538 Total size: 3.1 MB Directory structure: gitextract__0evbu59/ ├── .gitignore ├── .gitmodules ├── Dockerfile.cuda ├── LICENSE ├── README.md ├── camera_pose_annotation/ │ ├── .gitignore │ ├── README.md │ ├── __init__.py │ ├── camera_tracking/ │ │ ├── __init__.py │ │ ├── camera_tracking.py │ │ └── inference_batch.py │ ├── cvd_opt/ │ │ ├── __init__.py │ │ ├── cvd_opt.py │ │ ├── geometry_utils.py │ │ ├── inference_batch.py │ │ └── preprocess/ │ │ ├── __init__.py │ │ ├── core/ │ │ │ ├── __init__.py │ │ │ ├── corr.py │ │ │ ├── datasets.py │ │ │ ├── extractor.py │ │ │ ├── raft.py │ │ │ ├── update.py │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ ├── augmentor.py │ │ │ ├── flow_viz.py │ │ │ ├── frame_utils.py │ │ │ └── utils.py │ │ ├── inference_batch.py │ │ └── preprocess_flow.py │ ├── depth_estimation/ │ │ ├── Depth-Anything/ │ │ │ ├── __init__.py │ │ │ ├── depth_anything_v2/ │ │ │ │ ├── dinov2.py │ │ │ │ ├── dinov2_layers/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── attention.py │ │ │ │ │ ├── block.py │ │ │ │ │ ├── drop_path.py │ │ │ │ │ ├── layer_scale.py │ │ │ │ │ ├── mlp.py │ │ │ │ │ ├── patch_embed.py │ │ │ │ │ └── swiglu_ffn.py │ │ │ │ ├── dpt.py │ │ │ │ └── util/ │ │ │ │ ├── blocks.py │ │ │ │ └── transform.py │ │ │ ├── inference.py │ │ │ └── inference_batch.py │ │ ├── UniDepth/ │ │ │ ├── __init__.py │ │ │ ├── inference.py │ │ │ ├── inference_batch.py │ │ │ └── unidepth/ │ │ │ ├── datasets/ │ │ │ │ ├── _2d3ds.py │ │ │ │ ├── _4dor.py │ │ │ │ ├── __init__.py │ │ │ │ ├── a2d2.py │ │ │ │ ├── adt.py │ │ │ │ ├── aimotive.py │ │ │ │ ├── argoverse.py │ │ │ │ ├── argoverse2.py │ │ │ │ ├── arkit.py │ │ │ │ ├── ase.py │ │ │ │ ├── base_dataset.py │ │ │ │ ├── bdd.py │ │ │ │ ├── bedlam.py │ │ │ │ ├── behave.py │ │ │ │ ├── blendedmvg.py │ │ │ │ ├── cityscape.py │ │ │ │ ├── ddad.py │ │ │ │ ├── deep360.py │ │ │ │ ├── dense.py │ │ │ │ ├── diml.py │ │ │ │ ├── diode.py │ │ │ │ ├── dl3dv.py │ │ │ │ ├── driving_stereo.py │ │ │ │ ├── dtu_rmvd.py │ │ │ │ ├── dummy.py │ │ │ │ ├── dynamic_replica.py │ │ │ │ ├── eden.py │ │ │ │ ├── eth3d.py │ │ │ │ ├── eth3d_rmvd.py │ │ │ │ ├── facedepth.py │ │ │ │ ├── flsea.py │ │ │ │ ├── futurehouse.py │ │ │ │ ├── gibson.py │ │ │ │ ├── hammer.py │ │ │ │ ├── hm3d.py │ │ │ │ ├── hoi4d.py │ │ │ │ ├── hrwsi.py │ │ │ │ ├── hypersim.py │ │ │ │ ├── ibims.py │ │ │ │ ├── image_dataset.py │ │ │ │ ├── ken_burns.py │ │ │ │ ├── kitti.py │ │ │ │ ├── kitti360.py │ │ │ │ ├── kitti_multi.py │ │ │ │ ├── kitti_rmvd.py │ │ │ │ ├── lyft.py │ │ │ │ ├── mapillary.py │ │ │ │ ├── matrix_city.py │ │ │ │ ├── matterport3d.py │ │ │ │ ├── megadepth.py │ │ │ │ ├── megadepth_s.py │ │ │ │ ├── midair.py │ │ │ │ ├── mip.py │ │ │ │ ├── ms2.py │ │ │ │ ├── mvimgnet.py │ │ │ │ ├── mvsynth.py │ │ │ │ ├── nerds360.py │ │ │ │ ├── niantic_mapfree.py │ │ │ │ ├── nuscenes.py │ │ │ │ ├── nyuv2.py │ │ │ │ ├── oasis.py │ │ │ │ ├── pipelines/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── formating.py │ │ │ │ │ └── transforms.py │ │ │ │ ├── point_odyssey.py │ │ │ │ ├── proteus.py │ │ │ │ ├── samplers copy.py │ │ │ │ ├── samplers.py │ │ │ │ ├── scannet.py │ │ │ │ ├── scannetpp.py │ │ │ │ ├── sequence_dataset.py │ │ │ │ ├── sintel copy.py │ │ │ │ ├── sintel.py │ │ │ │ ├── sunrgbd.py │ │ │ │ ├── synscapes.py │ │ │ │ ├── tartanair.py │ │ │ │ ├── taskonomy.py │ │ │ │ ├── tat_rmvd.py │ │ │ │ ├── theo.py │ │ │ │ ├── unrealstereo4k.py │ │ │ │ ├── urbansyn.py │ │ │ │ ├── utils.py │ │ │ │ ├── utils_decode.py │ │ │ │ ├── vkitti.py │ │ │ │ ├── void.py │ │ │ │ ├── waymo.py │ │ │ │ └── wildrgbd.py │ │ │ ├── layers/ │ │ │ │ ├── __init__.py │ │ │ │ ├── activation.py │ │ │ │ ├── attention.py │ │ │ │ ├── convnext.py │ │ │ │ ├── drop_path.py │ │ │ │ ├── layer_scale.py │ │ │ │ ├── mlp.py │ │ │ │ ├── nystrom.py │ │ │ │ ├── nystrom_attention.py │ │ │ │ ├── positional_encoding.py │ │ │ │ └── upsample.py │ │ │ ├── models/ │ │ │ │ ├── __init__.py │ │ │ │ ├── backbones/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── convnext.py │ │ │ │ │ ├── convnext2.py │ │ │ │ │ ├── dinov2.py │ │ │ │ │ └── metadinov2/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── attention.py │ │ │ │ │ ├── block.py │ │ │ │ │ ├── dino_head.py │ │ │ │ │ ├── drop_path.py │ │ │ │ │ ├── layer_scale.py │ │ │ │ │ ├── mlp.py │ │ │ │ │ ├── patch_embed.py │ │ │ │ │ └── swiglu_ffn.py │ │ │ │ ├── encoder.py │ │ │ │ ├── unidepthv1/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── decoder.py │ │ │ │ │ └── unidepthv1.py │ │ │ │ └── unidepthv2/ │ │ │ │ ├── __init__.py │ │ │ │ ├── decoder.py │ │ │ │ ├── decoder_old.py │ │ │ │ ├── export.py │ │ │ │ ├── unidepthv2.py │ │ │ │ └── unidepthv2_old.py │ │ │ ├── ops/ │ │ │ │ ├── __init__.py │ │ │ │ ├── extract_patches/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── compile.sh │ │ │ │ │ ├── functions/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── extract_patches.py │ │ │ │ │ ├── modules/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── patch_extractor.py │ │ │ │ │ ├── setup.py │ │ │ │ │ ├── src/ │ │ │ │ │ │ ├── cpu/ │ │ │ │ │ │ │ ├── extract_patches_cpu.cpp │ │ │ │ │ │ │ └── extract_patches_cpu.h │ │ │ │ │ │ ├── cuda/ │ │ │ │ │ │ │ ├── extract_patches_cuda.h │ │ │ │ │ │ │ ├── extract_patches_kernel.cu │ │ │ │ │ │ │ └── extract_patches_kernel.cuh │ │ │ │ │ │ ├── extract_patches.cpp │ │ │ │ │ │ └── extract_patches.h │ │ │ │ │ └── test.py │ │ │ │ ├── knn/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── compile.sh │ │ │ │ │ ├── functions/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── knn.py │ │ │ │ │ ├── setup.py │ │ │ │ │ └── src/ │ │ │ │ │ ├── knn.cu │ │ │ │ │ ├── knn.h │ │ │ │ │ ├── knn_cpu.cpp │ │ │ │ │ ├── knn_ext.cpp │ │ │ │ │ └── utils/ │ │ │ │ │ ├── dispatch.cuh │ │ │ │ │ ├── index_utils.cuh │ │ │ │ │ ├── mink.cuh │ │ │ │ │ └── pytorch3d_cutils.h │ │ │ │ ├── losses/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── arel.py │ │ │ │ │ ├── confidence.py │ │ │ │ │ ├── distill.py │ │ │ │ │ ├── dummy.py │ │ │ │ │ ├── local_ssi.py │ │ │ │ │ ├── regression.py │ │ │ │ │ ├── silog.py │ │ │ │ │ └── utils.py │ │ │ │ └── scheduler.py │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ ├── camera.py │ │ │ ├── chamfer_distance.py │ │ │ ├── constants.py │ │ │ ├── coordinate.py │ │ │ ├── distributed.py │ │ │ ├── ema_torch.py │ │ │ ├── evaluation_depth.py │ │ │ ├── geometric.py │ │ │ ├── misc.py │ │ │ ├── positional_embedding.py │ │ │ ├── sht.py │ │ │ ├── validation.py │ │ │ └── visualization.py │ │ └── __init__.py │ └── dynamic_mask/ │ ├── __init__.py │ ├── inference_batch.py │ └── sam2/ │ ├── __init__.py │ ├── automatic_mask_generator.py │ ├── benchmark.py │ ├── build_sam.py │ ├── configs/ │ │ ├── sam2/ │ │ │ ├── sam2_hiera_b+.yaml │ │ │ ├── sam2_hiera_l.yaml │ │ │ ├── sam2_hiera_s.yaml │ │ │ └── sam2_hiera_t.yaml │ │ ├── sam2.1/ │ │ │ ├── sam2.1_hiera_b+.yaml │ │ │ ├── sam2.1_hiera_l.yaml │ │ │ ├── sam2.1_hiera_s.yaml │ │ │ └── sam2.1_hiera_t.yaml │ │ └── sam2.1_training/ │ │ └── sam2.1_hiera_b+_MOSE_finetune.yaml │ ├── csrc/ │ │ └── connected_components.cu │ ├── modeling/ │ │ ├── __init__.py │ │ ├── backbones/ │ │ │ ├── __init__.py │ │ │ ├── hieradet.py │ │ │ ├── image_encoder.py │ │ │ └── utils.py │ │ ├── memory_attention.py │ │ ├── memory_encoder.py │ │ ├── position_encoding.py │ │ ├── sam/ │ │ │ ├── __init__.py │ │ │ ├── mask_decoder.py │ │ │ ├── prompt_encoder.py │ │ │ └── transformer.py │ │ ├── sam2_base.py │ │ └── sam2_utils.py │ ├── sam2_hiera_b+.yaml │ ├── sam2_hiera_l.yaml │ ├── sam2_hiera_s.yaml │ ├── sam2_hiera_t.yaml │ ├── sam2_image_predictor.py │ ├── sam2_video_predictor.py │ ├── sam2_video_predictor_legacy.py │ └── utils/ │ ├── __init__.py │ ├── amg.py │ ├── misc.py │ └── transforms.py ├── caption/ │ ├── LLM/ │ │ ├── __init__.py │ │ ├── inference.py │ │ ├── prompt1.txt │ │ └── prompt2.txt │ ├── README.md │ ├── VQA/ │ │ ├── __init__.py │ │ ├── inference.py │ │ └── prompt.txt │ ├── __init__.py │ ├── tagging/ │ │ ├── __init__.py │ │ ├── inference.py │ │ └── prompt.txt │ └── utils/ │ ├── __init__.py │ ├── api_call.py │ └── combine.py ├── docker-entrypoint.sh ├── requirements/ │ ├── requirements.txt │ ├── requirements_annotation.txt │ └── requirements_scoring.txt ├── scoring/ │ ├── README.md │ ├── __init__.py │ ├── aesthetic/ │ │ ├── __init__.py │ │ └── inference.py │ ├── luminance/ │ │ ├── __init__.py │ │ └── inference.py │ ├── motion/ │ │ ├── INSTALL.md │ │ ├── __init__.py │ │ └── inference.py │ └── ocr/ │ ├── __init__.py │ └── inference.py ├── scripts/ │ ├── annotation.sh │ ├── caption.sh │ ├── docker_prepulls.sh │ ├── download_checkpoints.sh │ └── scoring.sh ├── utils/ │ ├── README.md │ ├── __init__.py │ ├── convert.py │ ├── cut.py │ ├── cut_fast.py │ ├── download_SpatialVID.py │ ├── download_YouTube.py │ ├── evaluation.py │ ├── expand_npz.py │ ├── extract_frames.py │ ├── filter.py │ ├── get_clip.py │ ├── get_info.py │ ├── get_instructions.py │ ├── get_instructions_enhanced.py │ ├── merge_tables.py │ ├── normalize_intrinsics.py │ ├── pack_clip_assets.py │ ├── quat_to_mat.py │ ├── read_depth.py │ ├── read_video.py │ └── scene_detect.py └── viser/ ├── .clang-format ├── .gitignore ├── .pre-commit-config.yaml ├── .prettierignore ├── LICENSE ├── README.md ├── docs/ │ ├── .gitignore │ ├── Makefile │ ├── source/ │ │ ├── _static/ │ │ │ └── css/ │ │ │ └── custom.css │ │ ├── _templates/ │ │ │ └── sidebar/ │ │ │ └── brand.html │ │ ├── camera_handles.md │ │ ├── client_handles.md │ │ ├── conf.py │ │ ├── conventions.md │ │ ├── development.md │ │ ├── events.md │ │ ├── examples/ │ │ │ ├── 00_coordinate_frames.rst │ │ │ ├── 01_image.rst │ │ │ ├── 02_gui.rst │ │ │ ├── 03_gui_callbacks.rst │ │ │ ├── 04_camera_poses.rst │ │ │ ├── 05_camera_commands.rst │ │ │ ├── 06_mesh.rst │ │ │ ├── 07_record3d_visualizer.rst │ │ │ ├── 08_smpl_visualizer.rst │ │ │ ├── 09_urdf_visualizer.rst │ │ │ ├── 10_realsense.rst │ │ │ ├── 11_colmap_visualizer.rst │ │ │ ├── 12_click_meshes.rst │ │ │ ├── 13_theming.rst │ │ │ ├── 14_markdown.rst │ │ │ ├── 15_gui_in_scene.rst │ │ │ ├── 16_modal.rst │ │ │ ├── 17_background_composite.rst │ │ │ ├── 18_splines.rst │ │ │ ├── 19_get_renders.rst │ │ │ ├── 20_scene_pointer.rst │ │ │ ├── 21_set_up_direction.rst │ │ │ ├── 22_games.rst │ │ │ ├── 23_plotly.rst │ │ │ ├── 24_notification.rst │ │ │ └── 25_smpl_visualizer_skinned.rst │ │ ├── extras.md │ │ ├── gui_api.md │ │ ├── gui_handles.md │ │ ├── icons.md │ │ ├── index.md │ │ ├── infrastructure.md │ │ ├── scene_api.md │ │ ├── scene_handles.md │ │ ├── server.md │ │ └── transforms.md │ └── update_example_docs.py ├── examples/ │ ├── 00_coordinate_frames.py │ ├── 01_image.py │ ├── 02_gui.py │ ├── 03_gui_callbacks.py │ ├── 04_camera_poses.py │ ├── 05_camera_commands.py │ ├── 06_mesh.py │ ├── 07_record3d_visualizer.py │ ├── 08_smpl_visualizer.py │ ├── 09_urdf_visualizer.py │ ├── 10_realsense.py │ ├── 11_colmap_visualizer.py │ ├── 12_click_meshes.py │ ├── 13_theming.py │ ├── 14_markdown.py │ ├── 15_gui_in_scene.py │ ├── 16_modal.py │ ├── 17_background_composite.py │ ├── 18_splines.py │ ├── 19_get_renders.py │ ├── 20_scene_pointer.py │ ├── 21_set_up_direction.py │ ├── 22_games.py │ ├── 23_plotly.py │ ├── 24_notification.py │ ├── 25_smpl_visualizer_skinned.py │ ├── assets/ │ │ ├── .gitignore │ │ ├── download_colmap_garden.sh │ │ ├── download_dragon_mesh.sh │ │ ├── download_record3d_dance.sh │ │ └── mdx_example.mdx │ ├── experimental/ │ │ └── gaussian_splats.py │ └── quick_save.py ├── pyproject.toml ├── src/ │ └── viser/ │ ├── __init__.py │ ├── _client_autobuild.py │ ├── _gui_api.py │ ├── _gui_handles.py │ ├── _icons.py │ ├── _icons_enum.py │ ├── _icons_enum.pyi │ ├── _icons_generate_enum.py │ ├── _messages.py │ ├── _notification_handle.py │ ├── _scene_api.py │ ├── _scene_handles.py │ ├── _tunnel.py │ ├── _viser.py │ ├── client/ │ │ ├── .eslintrc.js │ │ ├── .gitignore │ │ ├── index.html │ │ ├── package.json │ │ ├── postcss.config.cjs │ │ ├── public/ │ │ │ ├── hdri/ │ │ │ │ └── potsdamer_platz_1k.hdr │ │ │ └── manifest.json │ │ ├── src/ │ │ │ ├── App.css.ts │ │ │ ├── App.tsx │ │ │ ├── AppTheme.ts │ │ │ ├── BrowserWarning.tsx │ │ │ ├── CameraControls.tsx │ │ │ ├── ClickUtils.tsx │ │ │ ├── ControlPanel/ │ │ │ │ ├── BottomPanel.tsx │ │ │ │ ├── ControlPanel.tsx │ │ │ │ ├── FloatingPanel.tsx │ │ │ │ ├── Generated.tsx │ │ │ │ ├── GuiComponentContext.tsx │ │ │ │ ├── GuiState.tsx │ │ │ │ ├── SceneTreeTable.css.ts │ │ │ │ ├── SceneTreeTable.tsx │ │ │ │ ├── ServerControls.tsx │ │ │ │ └── SidebarPanel.tsx │ │ │ ├── FilePlayback.tsx │ │ │ ├── Markdown.tsx │ │ │ ├── MessageHandler.tsx │ │ │ ├── Modal.tsx │ │ │ ├── Outlines.tsx │ │ │ ├── SceneTree.tsx │ │ │ ├── SceneTreeState.tsx │ │ │ ├── SearchParamsUtils.tsx │ │ │ ├── Splatting/ │ │ │ │ ├── GaussianSplats.tsx │ │ │ │ ├── SplatSortWorker.ts │ │ │ │ └── WasmSorter/ │ │ │ │ ├── Sorter.mjs │ │ │ │ ├── Sorter.wasm │ │ │ │ ├── build.sh │ │ │ │ └── sorter.cpp │ │ │ ├── ThreeAssets.tsx │ │ │ ├── Titlebar.tsx │ │ │ ├── Utils.ts │ │ │ ├── WebsocketFunctions.tsx │ │ │ ├── WebsocketInterface.tsx │ │ │ ├── WebsocketMessages.tsx │ │ │ ├── WebsocketServerWorker.ts │ │ │ ├── WorldTransformUtils.ts │ │ │ ├── components/ │ │ │ │ ├── Button.tsx │ │ │ │ ├── ButtonGroup.tsx │ │ │ │ ├── Checkbox.tsx │ │ │ │ ├── ComponentStyles.css.ts │ │ │ │ ├── Dropdown.tsx │ │ │ │ ├── Folder.css.ts │ │ │ │ ├── Folder.tsx │ │ │ │ ├── Markdown.tsx │ │ │ │ ├── MultiSlider.tsx │ │ │ │ ├── MultiSliderPrimitive/ │ │ │ │ │ ├── LICENSE │ │ │ │ │ ├── Marks/ │ │ │ │ │ │ └── Marks.tsx │ │ │ │ │ ├── MultiSlider/ │ │ │ │ │ │ └── MultiSlider.tsx │ │ │ │ │ ├── Slider.context.ts │ │ │ │ │ ├── Slider.module.css │ │ │ │ │ ├── SliderRoot/ │ │ │ │ │ │ └── SliderRoot.tsx │ │ │ │ │ ├── Thumb/ │ │ │ │ │ │ └── Thumb.tsx │ │ │ │ │ ├── Track/ │ │ │ │ │ │ └── Track.tsx │ │ │ │ │ ├── index.ts │ │ │ │ │ └── utils/ │ │ │ │ │ ├── get-change-value/ │ │ │ │ │ │ └── get-change-value.ts │ │ │ │ │ ├── get-client-position/ │ │ │ │ │ │ └── get-client-position.ts │ │ │ │ │ ├── get-floating-value/ │ │ │ │ │ │ └── get-gloating-value.ts │ │ │ │ │ ├── get-position/ │ │ │ │ │ │ └── get-position.ts │ │ │ │ │ └── get-precision/ │ │ │ │ │ └── get-precision.ts │ │ │ │ ├── NumberInput.tsx │ │ │ │ ├── PlotlyComponent.tsx │ │ │ │ ├── ProgressBar.tsx │ │ │ │ ├── Rgb.tsx │ │ │ │ ├── Rgba.tsx │ │ │ │ ├── Slider.tsx │ │ │ │ ├── TabGroup.tsx │ │ │ │ ├── TextInput.tsx │ │ │ │ ├── UploadButton.tsx │ │ │ │ ├── Vector2.tsx │ │ │ │ ├── Vector3.tsx │ │ │ │ ├── common.tsx │ │ │ │ └── utils.tsx │ │ │ ├── index.css │ │ │ ├── index.tsx │ │ │ └── react-app-env.d.ts │ │ ├── tsconfig.json │ │ ├── vite-env.d.ts │ │ └── vite.config.mts │ ├── extras/ │ │ ├── __init__.py │ │ ├── _record3d.py │ │ ├── _record3d_customized.py │ │ ├── _record3d_customized_megasam.py │ │ ├── _urdf.py │ │ └── colmap/ │ │ ├── __init__.py │ │ └── _colmap_utils.py │ ├── infra/ │ │ ├── __init__.py │ │ ├── _async_message_buffer.py │ │ ├── _infra.py │ │ ├── _messages.py │ │ └── _typescript_interface_gen.py │ ├── py.typed │ ├── scripts/ │ │ ├── __init__.py │ │ └── dev_checks.py │ ├── theme/ │ │ ├── __init__.py │ │ └── _titlebar.py │ └── transforms/ │ ├── __init__.py │ ├── _base.py │ ├── _se2.py │ ├── _se3.py │ ├── _so2.py │ ├── _so3.py │ ├── hints/ │ │ └── __init__.py │ └── utils/ │ ├── __init__.py │ └── _utils.py ├── sync_message_defs.py ├── visualize_megasam.py └── visualize_pose.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[codz] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .buildx-cache/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py.cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # UV # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. #uv.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock #poetry.toml # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. # https://pdm-project.org/en/latest/usage/project/#working-with-version-control #pdm.lock #pdm.toml .pdm-python .pdm-build/ # pixi # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. #pixi.lock # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one # in the .venv directory. It is recommended not to include this directory in version control. .pixi # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .envrc .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ # Abstra # Abstra is an AI-powered process automation framework. # Ignore directories containing user credentials, local state, and settings. # Learn more at https://abstra.io/docs .abstra/ # Visual Studio Code # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore # and can be added to the global gitignore or merged into this file. However, if you prefer, # you could uncomment the following to ignore the entire vscode folder .vscode/ # Ruff stuff: .ruff_cache/ # PyPI configuration file .pypirc # Cursor # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data # refer to https://docs.cursor.com/context/ignore-files test/ checkpoints/ .cursorignore .cursorindexingignore # Marimo marimo/_static/ marimo/_lsp/ __marimo__/ .DS_Store ================================================ FILE: .gitmodules ================================================ [submodule "camera_pose_annotation/base"] path = camera_pose_annotation/base url = https://github.com/SpatialVID/base.git ================================================ FILE: Dockerfile.cuda ================================================ # This Dockerfile builds FFmpeg with NVIDIA GPU support and libvmaf from source # It uses a two-stage build to create a smaller runtime image # This file is adapted from https://github.com/Netflix/vmaf/blob/master/Dockerfile.cuda ARG CUDA_BASE_IMAGE=docker.io/nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04 ARG RUN_TIME_IMG=docker.io/nvidia/cuda:12.6.3-runtime-ubuntu22.04 # ARG CUDA_BASE_IMAGE=swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04 # ARG RUN_TIME_IMG=swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/nvidia/cuda:12.6.3-runtime-ubuntu22.04 FROM $CUDA_BASE_IMAGE as builder ARG VMAF_TAG=master ARG FFMPEG_TAG=master RUN DEBIAN_FRONTEND=noninteractive apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y libopenjp2-7-dev \ ninja-build cmake git python3 python3-pip nasm xxd pkg-config curl unzip nvidia-cuda-toolkit RUN git clone https://github.com/Netflix/vmaf.git && cd vmaf && git checkout $VMAF_TAG RUN git clone https://github.com/FFmpeg/FFmpeg.git && cd FFmpeg && git checkout $FFMPEG_TAG RUN git clone https://github.com/FFmpeg/nv-codec-headers.git && cd nv-codec-headers && make && make install # install vmaf RUN python3 -m pip install meson RUN cd vmaf && meson libvmaf/build libvmaf -Denable_cuda=true -Denable_avx512=true --buildtype release && \ ninja -vC libvmaf/build && \ ninja -vC libvmaf/build install ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/x86_64-linux-gnu/ RUN ldconfig # install ffmpeg RUN cd FFmpeg && ./configure \ --enable-libnpp \ --enable-nonfree \ --enable-nvdec \ --enable-nvenc \ --enable-cuvid \ --enable-cuda \ --enable-cuda-nvcc \ --enable-libvmaf \ --enable-ffnvcodec \ --disable-stripping \ --extra-cflags="-I/usr/local/cuda/include" \ --extra-ldflags="-L/usr/local/cuda/lib64 -L/usr/local/cuda/lib64/stubs/" RUN cd FFmpeg && make -j && make install RUN mkdir /data # Create a smaller runtime image FROM ${RUN_TIME_IMG} as runtime ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update && apt-get install -y --no-install-recommends \ ca-certificates python3 python3-pip python3-venv libnuma-dev libsm6 libxext6 libxrender1 libgl1 git vim && rm -rf /var/lib/apt/lists/* WORKDIR /workspace # Copy FFmpeg and libvmaf from builder (installed under /usr/local) COPY --from=builder /usr/local /usr/local # copy libraries installed by the builder stage if present COPY --from=builder /usr/lib/ /usr/lib/ # Link python RUN ln -sf /usr/bin/python3 /usr/bin/python RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel # Copy repository COPY . /workspace RUN apt-get update # Install Python requirements (may still fail for some packages requiring system libs) RUN python3 -m pip --no-cache-dir install -r requirements/requirements.txt RUN python3 -m pip --no-cache-dir install -r requirements/requirements_scoring.txt || true RUN python3 -m pip --no-cache-dir install -r requirements/requirements_annotation.txt || true # Entrypoint COPY docker-entrypoint.sh /usr/local/bin/docker-entrypoint.sh RUN chmod +x /usr/local/bin/docker-entrypoint.sh ENV FFMPEG_PATH=/usr/local/bin/ffmpeg ENTRYPOINT ["/usr/local/bin/docker-entrypoint.sh"] CMD ["bash", "ldconfig"] ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================

SpatialVID: A Large-Scale Video Dataset with Spatial Annotations

Jiahao Wang1*Yufeng Yuan1*Rujie Zheng1*Youtian Lin1Jian Gao1Lin-Zhuo Chen1
Yajie Bao1Yi Zhang1Chang Zeng1Yanxi Zhou1Xiaoxiao Long1Hao Zhu1
Zhaoxiang Zhang2Xun Cao1Yao Yao1†
1Nanjing University  2Institute of Automation, Chinese Academy of Science 
*Equal Contribution  †Corresponding Author
CVPR 2026

## 🎉NEWS + [2026.02.21] 🎉 SpatialVID is accepted by CVPR 2026! + [2025.10.11] 🐳 Docker support is now available, featuring a pre-configured environment with NVIDIA GPU-accelerated FFmpeg. + [2025.09.29] 🚀 Depth data for the SpatialVID-HQ dataset is now officially available. + [2025.09.24] 🤗 Raw metadata access is now available via a [gated HuggingFace dataset](https://huggingface.co/datasets/SpatialVID/SpatialVID-RAW) to better support community research!! + [2025.09.24] 🔭 Enhanced instructions for better camera control are updated. + [2025.09.18] 🎆 SpatialVID dataset is now available on both HuggingFace and ModelScope. + [2025.09.14] 📢 We have also uploaded the SpatialVID-HQ dataset to ModelScope offering more diverse download options. + [2025.09.11] 🔥 Our paper, code and SpatialVID-HQ dataset are released! **[✍️ Note]** Each video clip is paired with a dedicated annotation folder (named after the video’s id). The folder contains 5 key files, and details regarding these files can be found in [Detailed Explanation of Annotation Files](https://huggingface.co/datasets/SpatialVID/SpatialVID#3-detailed-explanation-of-annotation-files). ## Abstract Significant progress has been made in spatial intelligence, spanning both spatial reconstruction and world exploration. However, the scalability and real-world fidelity of current models remain severely constrained by the scarcity of large-scale, high-quality training data. While several datasets provide camera pose information, they are typically limited in scale, diversity, and annotation richness, particularly for real-world dynamic scenes with ground-truth camera motion. To this end, we collect **SpatialVID**, a dataset consisting of a large corpus of in-the-wild videos with diverse scenes, camera movements and dense 3D annotations such as per-frame camera poses, depth, and motion instructions. Specifically, we collect more than **21,000 hours** of raw videos, and process them into **2.7 million clips** through a hierarchical filtering pipeline, totaling **7,089 hours** of dynamic content. A subsequent annotation pipeline enriches these clips with detailed spatial and semantic information, including camera poses, depth maps, dynamic masks, structured captions, and serialized motion instructions. Analysis of SpatialVID's data statistics reveals a richness and diversity that directly foster improved model generalization and performance, establishing it as a key asset for the video and 3D vision research community. ## Preparation This section describes how to set up the environment manually. For a simpler, containerized setup, please refer to the **[Docker Setup and Usage](#docker-setup-and-usage)** section. ### Environment 1. Necessary packages ```bash git clone --recursive https://github.com/NJU-3DV/SpatialVID.git cd SpatialVid conda create -n SpatialVID python=3.10.13 conda activate SpatialVID pip install -r requirements/requirements.txt ``` 2. Package needed for scoring ```bash pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ pip install -r requirements/requirements_scoring.txt ``` Ignore the warning about `nvidia-nccl-cu12` and `numpy` version, it is not a problem. About FFMPEG, please refer to the [`INSTALL.md`](scoring/motion/INSTALL.md) for detailed instructions on how to install ffmpeg. After installation, replace the `FFMPEG_PATH` variable in the [`scoring/motion/inference.py`](scoring/motion/inference.py) and [`utils/cut.py`](utils/cut.py) with the actual path to your ffmpeg executable, default is `/usr/local/bin/ffmpeg`. ⚠️ If your videos are in av1 codec instead of h264, you need to install ffmpeg (already in our requirement script), then run the following to make conda support av1 codec: ```bash pip uninstall opencv-python conda install -c conda-forge opencv==4.11.0 ``` If unfortunately your conda environment still cannot support av1 codec, you can use the `--backend av` option in the scoring scripts to use PyAV as the video reading backend. But note that using PyAV for frame extraction may lead to slight inaccuracies in frame positioning. 3. Package needed for annotation ```bash pip install -r requirements/requirements_annotation.txt ``` Compile the extensions for the camera tracking module: ```bash cd camera_pose_annotation/base python setup.py install ``` 4. [Optional] Package needed for visualization ```bash pip install plotly pip install -e viser ``` ### Model Weight Download the model weights used in our experiments: ```bash bash scripts/download_checkpoints.sh ``` Or you can manually download the model weights from the following links and place them in the appropriate directories. | Model | File Name | URL | | ------------------- | ----------------------- | --------------------------------------------------------------------------------------------------------------- | | Aesthetic Predictor | aesthetic | [🔗](https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/sac+logos+ava1-l14-linearMSE.pth) | | MegaSAM | megasam_final | [🔗](https://github.com/mega-sam/mega-sam/blob/main/checkpoints/megasam_final.pth) | | RAFT | raft-things | [🔗](https://drive.google.com/uc?id=1MqDajR89k-xLV0HIrmJ0k-n8ZpG6_suM) | | Depth Anything | Depth-Anything-V2-Large | [🔗](https://huggingface.co/depth-anything/Depth-Anything-V2-Large) | | UniDepth | unidepth-v2-vitl14 | [🔗](https://huggingface.com/lpiccinelli/unidepth-v2-vitl14) | | SAM | sam2.1-hiera-large | [🔗](https://huggingface.co/facebook/sam2.1-hiera-large) | ## Quick Start The whole pipeline is illustrated in the figure below:

1. Scoring ```bash bash scripts/scoring.sh ``` Inside the [`scoring.sh`](scripts/scoring.sh) script, you need to set the following variables: - `ROOT_VIDEO` is the directory containing the input video files. - `OUTPUT_DIR` is the directory where the output files will be saved. 2. Annotation ```bash bash scripts/annotation.sh ``` Inside the [`annotation.sh`](scripts/annotation.sh) script, you need to set the following variables: - `CSV` is the CSV file generated by the scoring script, default is `$OUTPUT_DIR/results.csv`. - `OUTPUT_DIR` is the directory where the output files will be saved. 3. Caption ```bash bash scripts/caption.sh ``` Inside the [`caption.sh`](scripts/caption.sh) script, you need to set the following variables: - `CSV` is the CSV file generated by the annotation script, default is `$OUTPUT_DIR/results.csv`. - `SRC_DIR` is the annotation output directory, default is the same as the `OUTPUT_DIR` in the annotation step. - `OUTPUT_DIR` is the directory where the output files will be saved. - The API keys for the LLM models used in the captioning step. You can replace them with your own API keys. 4. Visualization - You can visualize the `poses.npy` in the `reconstruction` folder of each annotated clip using the [`visualize_pose.py`](viser/visualize_pose.py) script. - You can visualize the final annotation result(`sgd_cvd_hr.npz`) using the [`visualize_megasam.py`](viser/visualize_megasam.py) script. Note that if you want to visualize any clip in our dataset, you need to use the script [`pack_clip_assets.py`](utils/pack_clip_assets.py) to unify the depth, RGB frames, intrinsics, extrinsics, etc. of that clip into a single npz file first. And then you can use the visualization script to visualize it. ## Docker Setup and Usage We provide a Dockerfile to create a fully configured environment that includes all dependencies, including a custom-built FFmpeg with NVIDIA acceleration. This is the recommended way to ensure reproducibility and avoid environment-related issues. Before you begin, ensure your system environment is similar to the configuration below. Version matching is crucial for a successful compilation. The GPU needs to support HEVC; refer to the [NVIDIA NVDEC Support Matrix](https://en.wikipedia.org/wiki/NVIDIA_Video_Coding_Engine#NVDEC). ### Prerequisites: Setting up the Host Environment Before building and running the Docker container, your host machine must be configured to support GPU access for Docker. 1. **NVIDIA Drivers**: Ensure you have the latest NVIDIA drivers installed. You can verify this by running `nvidia-smi`. 2. **Docker Engine**: Install Docker on your system. Follow the official instructions at [docs.docker.com/engine/install/](https://docs.docker.com/engine/install/). 3. **NVIDIA Container Toolkit**: This toolkit allows Docker containers to access the host's NVIDIA GPU. Install it using the following commands (for Debian/Ubuntu): To run docker containers with GPU support you have to install the [nvidia container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). ```bash # Add the GPG key curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg # Add the repository curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \ sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \ sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list # Update package lists and install the toolkit sudo apt-get install -y \ nvidia-container-toolkit=${NVIDIA_CONTAINER_TOOLKIT_VERSION} \ nvidia-container-toolkit-base=${NVIDIA_CONTAINER_TOOLKIT_VERSION} \ libnvidia-container-tools=${NVIDIA_CONTAINER_TOOLKIT_VERSION} \ libnvidia-container1=${NVIDIA_CONTAINER_TOOLKIT_VERSION} # Configure Docker to use the NVIDIA runtime sudo nvidia-ctk runtime configure --runtime=containerd # Restart the Docker daemon to apply the changes sudo systemctl restart containerd ``` For other operating systems, please refer to the [official NVIDIA documentation](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). 4. **Docker Image Pre-pulls [optional]**: To accelerate the build process, we provide a script to pre-pull necessary Docker images from a mirror registry. ```bash bash scripts/build_gpu_docker.sh ``` ### Build and Run the Container You can also build and run the image using standard Docker commands from the root of the repository. 1. **Build the GPU image**: ```bash docker build -f Dockerfile.cuda \ --build-arg NUM_JOBS=8 \ -t spatialvid-gpu . ``` 2. **Run the container**: ```bash docker run --gpus all --rm -it \ -v $(pwd):/workspace \ -w /workspace \ -e NVIDIA_DRIVER_CAPABILITIES=compute,video,utility \ spatialvid-gpu bash ``` 3. **Verify the environment (inside the container)**: Once inside the container, you can verify that FFmpeg and PyTorch are correctly installed and can access the GPU. ```bash # Check the custom FFmpeg build /usr/local/bin/ffmpeg -version # Check PyTorch and CUDA availability python3 -c "import torch; print(f'PyTorch: {torch.__version__}, CUDA: {torch.version.cuda}, GPU Available: {torch.cuda.is_available()}')" ``` ## Dataset Download Our dataset is available on [HuggingFace](https://huggingface.co/SpatialVID) and [ModelScope](https://www.modelscope.cn/organization/SpatialVID). Apart from downloading the dataset using terminal commands, we provide scripts to download the SpatialVID/SpatialVID-HQ dataset from HuggingFace. Please refer to the [`download_SpatialVID.py`](utils/download_SpatialVID.py) script for more details. We also provide our script to download the raw videos from YouTube. You can refer to the [`download_YouTube.py`](utils/download_YouTube.py) script for more details. ## License Please refer to the [LICENSE](LICENSE) file for more details about the license of our code. ⚠️ SpatialVID dataset is released under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) (CC-BY-NC-SA-4.0). Users must attribute the original source, use the resource only for non-commercial purposes, and release any modified/derived works under the same license. If you are the copyright owner of any video in our dataset and you need it to be removed, please contact us, and we will remove the video samples from our dataset / Github / project webpage / technical presentation as soon as possible. ## References Thanks to the developers and contributors of the following open-source repositories, whose invaluable work has greatly inspire our project: - [Open-Sora](https://github.com/hpcaitech/Open-Sora): An initiative dedicated to efficiently producing high-quality video. - [MegaSaM](https://github.com/mega-sam/mega-sam): An accurate, fast and robust casual structure and motion from casual dynamic videos. - [Depth Anything V2](https://github.com/DepthAnything/Depth-Anything-V2): A model for monocular depth estimation. - [UniDepthV2](https://github.com/lpiccinelli-eth/UniDepth): A model for universal monocular metric depth estimation. - [SAM2](https://github.com/facebookresearch/sam2): A model towards solving promptable visual segmentation in images and videos. - [Viser](https://viser.studio/latest/): A library for interactive 3D visualization in Python. Our repository is licensed under the Apache 2.0 License. However, if you use MegaSaM or other components in your work, please follow their license. ## Citation ```bibtex @article{wang2025spatialvid, title={Spatialvid: A large-scale video dataset with spatial annotations}, author={Wang, Jiahao and Yuan, Yufeng and Zheng, Rujie and Lin, Youtian and Gao, Jian and Chen, Lin-Zhuo and Bao, Yajie and Zhang, Yi and Zeng, Chang and Zhou, Yanxi and others}, journal={arXiv preprint arXiv:2509.09676}, year={2025} } ``` ================================================ FILE: camera_pose_annotation/.gitignore ================================================ # files data/* *.log *.txt *.bz2 *.zip *.ipynb data_videos !requirements.txt !requirements_megasam.txt #python *.pyc __pycache__/ # dir outputs/ outputs_303/ data_videos/ checkpoints/* !checkpoints/megasam_final.pth DROID-SLAM/ .vscode/ ================================================ FILE: camera_pose_annotation/README.md ================================================ # Camera Pose Annotation ## Depth Estimation Use both [Depth-Anything V2](depth_estimation/Depth-Anything) and [UniDepth V2](depth_estimation/UniDepth) to estimate depth maps from images. Download the pre-trained models from the respective repositories. Skip this step if you already follow the installation instructions in [README](../README.md). - [Depth-Anything V2](https://huggingface.co/depth-anything/Depth-Anything-V2-Large) - [UniDepth V2](https://huggingface.co/lpiccinelli/unidepth-v2-vitl14) To inference depth using Depth-Anything V2, run the following command: ```bash torchrun --standalone --nproc_per_node ${GPU_NUM} camera_pose_annotation/depth_estimation/Depth-Anything/inference_batch.py \ ${CSV} \ --encoder vitl \ --checkpoints_path checkpoints \ --OUTPUT_DIR ${OUTPUT_DIR} \ --bs 16 \ --num_workers ${GPU_NUM} ``` To inference depth using UniDepth V2, run the following command: ```bash torchrun --standalone --nproc_per_node ${GPU_NUM} camera_pose_annotation/depth_estimation/UniDepth/inference_batch.py \ ${CSV} \ --OUTPUT_DIR ${OUTPUT_DIR} \ --checkpoints_path checkpoints \ --bs 32 \ --num_workers ${GPU_NUM} ``` ## Camera Tracking Using a DROID-SLAM based method to track camera poses from videos. To inference a single video, run the following command: ```bash python camera_pose_annotation/camera_tracking/camera_tracking.py \ --dir_path ${DIR_PATH} \ --weights checkpoints/megasam_final.pth \ --disable_vis ``` To inference videos in batch, run the following command: ```bash python camera_pose_annotation/camera_tracking/inference_batch.py ${CSV} \ --OUTPUT_DIR ${OUTPUT_DIR} \ --checkpoints_path checkpoints --gpu_id ${CUDA_VISIBLE_DEVICES} \ --num_workers $((GPU_NUM * 2)) ``` ## CVD (Camera View Depth) Optimization ### Optical Flow Infer optical flow using RAFT model. Download the [`raft_things.pth`](https://drive.google.com/uc?id=1MqDajR89k-xLV0HIrmJ0k-n8ZpG6_suM). To inference a single video, run the following command: ```bash python camera_pose_annotation/cvd_opt/preprocess/preprocess_flow.py \ --dir_path ${DIR_PATH} \ --model checkpoints/raft-things.pth \ --mixed_precision ``` To inference videos in batch, run the following command: ```bash python camera_pose_annotation/cvd_opt/preprocess/inference_batch.py ${CSV} \ --OUTPUT_DIR ${OUTPUT_DIR} \ --checkpoints_path checkpoints --gpu_id ${CUDA_VISIBLE_DEVICES} \ --num_workers $((GPU_NUM * 2)) ``` ### Optimization Using the optical flow to optimize the estimated depth maps. To inference a single video, run the following command: ```bash python camera_pose_annotation/cvd_opt/cvd_opt.py \ --dir_path ${DIR_PATH} \ --w_grad 2.0 --w_normal 5.0 ``` To inference videos in batch, run the following command: ```bash python camera_pose_annotation/cvd_opt/inference_batch.py ${CSV} \ --OUTPUT_DIR ${OUTPUT_DIR} \ --gpu_id ${CUDA_VISIBLE_DEVICES} \ --num_workers $((GPU_NUM * 2)) ``` ## Dynamic Mask Given the limitations of MegaSaM in predicting motion probabilities, we opt to enhance its performance using SAM2. Specifically, an adaptive thresholding mechanism, calibrated to the system’s motion probability distribution, is first employed to generate initial masks. Subsequently, contour detection is performed to mitigate redundant segmentation of overlapping regions; for each identified contour, four evenly spaced anchor points are sampled along its perimeter to serve as dedicated prompts for the SAM2 model. Download the pre-trained [SAM2 model](https://huggingface.co/facebook/sam2.1-hiera-large). Run the following command: ```bash python camera_pose_annotation/dynamic_mask/inference_batch.py ${CSV} \ --OUTPUT_DIR ${OUTPUT_DIR} \ --checkpoints_path checkpoints --gpu_num ${GPU_NUM} \ --num_workers $((GPU_NUM * 2)) ``` ================================================ FILE: camera_pose_annotation/__init__.py ================================================ ================================================ FILE: camera_pose_annotation/camera_tracking/__init__.py ================================================ ================================================ FILE: camera_pose_annotation/camera_tracking/camera_tracking.py ================================================ # Copyright 2025 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Test camera tracking on a single scene.""" # pylint: disable=invalid-name # pylint: disable=g-importing-member # pylint: disable=g-bad-import-order # pylint: disable=g-import-not-at-top # pylint: disable=redefined-outer-name # pylint: disable=undefined-variable # pylint: disable=undefined-loop-variable import sys sys.path.append("camera_pose_annotation/base/droid_slam") from droid import Droid from lietorch import SE3 import argparse import glob import os import cv2 import torch import numpy as np from tqdm import tqdm import torch.nn.functional as F def image_stream( image_list, mono_disp_list, scene_name, use_depth=False, aligns=None, K=None, stride=1, ): """image generator.""" del scene_name, stride fx, fy, cx, cy = ( K[0, 0], K[1, 1], K[0, 2], K[1, 2], ) # np.loadtxt(os.path.join(dir_path, 'calibration.txt')).tolist() for t, (image_file) in enumerate(image_list): image = cv2.imread(image_file) # depth = cv2.imread(depth_file, cv2.IMREAD_ANYDEPTH) / 5000. # depth = np.float32(np.load(depth_file)) / 300.0 # depth = 1. / pt_data["depth"] mono_disp = mono_disp_list[t] # mono_disp = np.float32(np.load(disp_file)) #/ 300.0 depth = np.clip( 1.0 / ((1.0 / aligns[2]) * (aligns[0] * mono_disp + aligns[1])), 1e-4, 1e4, ) depth[depth < 1e-2] = 0.0 # breakpoint() h0, w0, _ = image.shape h1 = int(h0 * np.sqrt((384 * 512) / (h0 * w0))) w1 = int(w0 * np.sqrt((384 * 512) / (h0 * w0))) image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_AREA) image = image[: h1 - h1 % 8, : w1 - w1 % 8] image = torch.as_tensor(image).permute(2, 0, 1) depth = torch.as_tensor(depth) depth = F.interpolate( depth[None, None], (h1, w1), mode="nearest-exact" ).squeeze() depth = depth[: h1 - h1 % 8, : w1 - w1 % 8] mask = torch.ones_like(depth) intrinsics = torch.as_tensor([fx, fy, cx, cy]) intrinsics[0::2] *= w1 / w0 intrinsics[1::2] *= h1 / h0 if use_depth: yield t, image[None], depth, intrinsics, mask else: yield t, image[None], intrinsics, mask def save_full_reconstruction( droid, full_traj, rgb_list, senor_depth_list, motion_prob, scene_name, save_path ): """Save full reconstruction.""" from pathlib import Path t = full_traj.shape[0] images = np.array(rgb_list[:t]) # droid.video.images[:t].cpu().numpy() disps = 1.0 / (np.array(senor_depth_list[:t]) + 1e-6) poses = full_traj # .cpu().numpy() intrinsics = droid.video.intrinsics[:t].cpu().numpy() Path(f"{save_path}").mkdir(parents=True, exist_ok=True) np.save(f"{save_path}/images.npy", images) np.save(f"{save_path}/disps.npy", disps) np.save(f"{save_path}/poses.npy", poses) np.save(f"{save_path}/intrinsics.npy", intrinsics * 8.0) np.save(f"{save_path}/motion_prob.npy", motion_prob) intrinsics = intrinsics[0] * 8.0 poses_th = torch.as_tensor(poses, device="cpu") cam_c2w = SE3(poses_th).inv().matrix().numpy() K = np.eye(3) K[0, 0] = intrinsics[0] K[1, 1] = intrinsics[1] K[0, 2] = intrinsics[2] K[1, 2] = intrinsics[3] max_frames = min(1000, images.shape[0]) if not os.path.exists(save_path): os.makedirs(save_path) np.savez( os.path.join(save_path, f"{scene_name}_droid.npz"), images=np.uint8(images[:max_frames, ::-1, ...].transpose(0, 2, 3, 1)), depths=np.float32(1.0 / disps[:max_frames, ...]), intrinsic=K, cam_c2w=cam_c2w[:max_frames], ) def parse_args(): """Parse command line arguments.""" parser = argparse.ArgumentParser() parser.add_argument("--dir_path", help="path to the dataset") parser.add_argument("--weights", default="droid.pth") parser.add_argument("--buffer", type=int, default=1024) parser.add_argument("--image_size", default=[240, 320]) parser.add_argument("--disable_vis", action="store_true") parser.add_argument("--beta", type=float, default=0.3) parser.add_argument( "--filter_thresh", type=float, default=2.0 ) # motion threhold for keyframe parser.add_argument("--warmup", type=int, default=8) parser.add_argument("--keyframe_thresh", type=float, default=2.0) parser.add_argument("--frontend_thresh", type=float, default=12.0) parser.add_argument("--frontend_window", type=int, default=25) parser.add_argument("--frontend_radius", type=int, default=2) parser.add_argument("--frontend_nms", type=int, default=1) parser.add_argument("--stereo", action="store_true") parser.add_argument("--depth", action="store_true") parser.add_argument("--upsample", action="store_true") parser.add_argument("--scene_name", help="scene_name") parser.add_argument("--backend_thresh", type=float, default=16.0) parser.add_argument("--backend_radius", type=int, default=2) parser.add_argument("--backend_nms", type=int, default=3) return parser.parse_args() def main(): args = parse_args() scene_name = os.path.basename(args.dir_path) rgb_list = [] senor_depth_list = [] img_path = os.path.join(args.dir_path, "img") img_list = sorted(glob.glob(os.path.join(img_path, "*.jpg"))) img_list += sorted(glob.glob(os.path.join(img_path, "*.png"))) # NOTE Mono is inverse depth, but metric-depth is depth! mono_disp_paths = sorted( glob.glob(os.path.join(args.dir_path, "depth-anything", "*.npy")) ) metric_depth_paths = sorted( glob.glob(os.path.join(args.dir_path, "unidepth", "*.npz")) ) img_0 = cv2.imread(img_list[0]) scales = [] shifts = [] mono_disp_list = [] fovs = [] for t, (mono_disp_file, metric_depth_file) in enumerate( zip(mono_disp_paths, metric_depth_paths) ): da_disp = np.float32(np.load(mono_disp_file)) # / 300.0 uni_data = np.load(metric_depth_file) metric_depth = uni_data["depth"] fovs.append(uni_data["fov"]) da_disp = cv2.resize( da_disp, (metric_depth.shape[1], metric_depth.shape[0]), interpolation=cv2.INTER_NEAREST_EXACT, ) mono_disp_list.append(da_disp) gt_disp = 1.0 / (metric_depth + 1e-8) # avoid some bug from UniDepth valid_mask = (metric_depth < 2.0) & (da_disp < 0.02) gt_disp[valid_mask] = 1e-2 # avoid cases sky dominate entire video sky_ratio = np.sum(da_disp < 0.01) / (da_disp.shape[0] * da_disp.shape[1]) if sky_ratio > 0.5: non_sky_mask = da_disp > 0.01 gt_disp_ms = gt_disp[non_sky_mask] - np.median(gt_disp[non_sky_mask]) + 1e-8 da_disp_ms = da_disp[non_sky_mask] - np.median(da_disp[non_sky_mask]) + 1e-8 scale = np.median(gt_disp_ms / da_disp_ms) shift = np.median(gt_disp[non_sky_mask] - scale * da_disp[non_sky_mask]) else: gt_disp_ms = gt_disp - np.median(gt_disp) + 1e-8 da_disp_ms = da_disp - np.median(da_disp) + 1e-8 scale = np.median(gt_disp_ms / da_disp_ms) shift = np.median(gt_disp - scale * da_disp) gt_disp_ms = gt_disp - np.median(gt_disp) + 1e-8 da_disp_ms = da_disp - np.median(da_disp) + 1e-8 scale = np.median(gt_disp_ms / da_disp_ms) shift = np.median(gt_disp - scale * da_disp) scales.append(scale) shifts.append(shift) print("************** UNIDEPTH FOV ", np.median(fovs)) ff = img_0.shape[1] / (2 * np.tan(np.radians(np.median(fovs) / 2.0))) K = np.eye(3) K[0, 0] = ff * 1.0 # pp_intrinsic[0] * (img_0.shape[1] / (pp_intrinsic[1] * 2)) K[1, 1] = ff * 1.0 # pp_intrinsic[0] * (img_0.shape[0] / (pp_intrinsic[2] * 2)) K[0, 2] = ( img_0.shape[1] / 2.0 ) # pp_intrinsic[1]) * (img_0.shape[1] / (pp_intrinsic[1] * 2)) K[1, 2] = ( img_0.shape[0] / 2.0 ) # (pp_intrinsic[2]) * (img_0.shape[0] / (pp_intrinsic[2] * 2)) ss_product = np.array(scales) * np.array(shifts) med_idx = np.argmin(np.abs(ss_product - np.median(ss_product))) align_scale = scales[med_idx] # np.median(np.array(scales)) align_shift = shifts[med_idx] # np.median(np.array(shifts)) normalize_scale = ( np.percentile((align_scale * np.array(mono_disp_list) + align_shift), 98) / 2.0 ) aligns = (align_scale, align_shift, normalize_scale) for t, image, depth, intrinsics, mask in tqdm( image_stream( img_list, mono_disp_list, scene_name, use_depth=True, aligns=aligns, K=K, ) ): rgb_list.append(image[0]) senor_depth_list.append(depth) # breakpoint() if t == 0: args.image_size = [image.shape[2], image.shape[3]] droid = Droid(args, device=0) droid.track(t, image, depth, intrinsics=intrinsics, mask=mask) # last frame droid.track_final(t, image, depth, intrinsics=intrinsics, mask=mask) traj_est, depth_est, motion_prob = droid.terminate( image_stream( img_list, mono_disp_list, scene_name, use_depth=True, aligns=aligns, K=K, ), _opt_intr=True, # default is opt_focal full_ba=True, scene_name=scene_name, ) save_full_reconstruction( droid, traj_est, rgb_list, senor_depth_list, motion_prob, args.scene_name, os.path.join(args.dir_path, "reconstructions"), ) if __name__ == "__main__": main() ================================================ FILE: camera_pose_annotation/camera_tracking/inference_batch.py ================================================ """ Batch inference for camera tracking using multiple GPUs. This module provides functionality for: - Parallel camera tracking processing across multiple videos - Multi-GPU support with automatic device assignment - Subprocess management for camera tracking pipeline - Progress tracking and error handling """ import pandas as pd import os import argparse import concurrent.futures from multiprocessing import Manager import subprocess import queue from tqdm import tqdm def process_single_row(row, index, args, worker_id=0): """ Process a single video for camera tracking. """ dir_path = os.path.join(args.dir_path, row["id"]) device_id = worker_id % args.gpu_num cmd = ( f"CUDA_VISIBLE_DEVICES={args.gpu_id[device_id]} python camera_pose_annotation/camera_tracking/camera_tracking.py " f"--dir_path {dir_path} " f"--weights {args.checkpoints_path}/megasam_final.pth " f"--disable_vis" ) process = subprocess.Popen( cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) stdout, stderr = process.communicate() if process.returncode != 0: print(f"Error tracking camera for {row['id']}: {stderr.decode()}") def worker(task_queue, args, worker_id, pbar): """ Worker function for parallel camera tracking processing. """ while True: try: index, row = task_queue.get(timeout=1) except queue.Empty: break process_single_row(row, index, args, worker_id) task_queue.task_done() pbar.update(1) def parse_args(): """Parse command line arguments for camera tracking batch inference.""" parser = argparse.ArgumentParser() parser.add_argument("--csv_path", type=str, help="Path to the csv file") parser.add_argument("--dir_path", type=str, default="./outputs") parser.add_argument("--checkpoints_path", type=str, default="./checkpoints") parser.add_argument( "--gpu_id", type=str, default="0", help="Comma-separated list of GPU IDs to use" ) parser.add_argument( "--num_workers", type=int, default=4, help="Number of workers for parallel processing", ) parser.add_argument( "--disable_parallel", action="store_true", help="Disable parallel processing" ) return parser.parse_args() def main(): args = parse_args() # Parse GPU configuration args.gpu_num = len(args.gpu_id.split(",")) args.gpu_id = [int(gpu) for gpu in args.gpu_id.split(",")] df = pd.read_csv(args.csv_path) if args.disable_parallel: # Sequential processing for index, row in tqdm(df.iterrows(), total=len(df), desc="Processing rows"): process_single_row(row, index, args) else: # Parallel processing with multiple workers manager = Manager() task_queue = manager.Queue() # Add all tasks to queue for index, row in df.iterrows(): task_queue.put((index, row)) with tqdm(total=len(df), desc="Processing rows") as pbar: with concurrent.futures.ThreadPoolExecutor( max_workers=args.num_workers ) as executor: futures = [] for id in range(args.num_workers): futures.append(executor.submit(worker, task_queue, args, id, pbar)) for future in concurrent.futures.as_completed(futures): future.result() if __name__ == "__main__": main() ================================================ FILE: camera_pose_annotation/cvd_opt/__init__.py ================================================ ================================================ FILE: camera_pose_annotation/cvd_opt/cvd_opt.py ================================================ # Copyright 2025 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Consistent video depth optimization.""" # pylint: disable=invalid-name # pylint: disable=g-importing-member # pylint: disable=redefined-outer-name import argparse import os from pathlib import Path import pandas as pd from geometry_utils import NormalGenerator import kornia from lietorch import SE3 import numpy as np import torch import zipfile import tempfile import OpenEXR import Imath def save_depth(path, depths): with zipfile.ZipFile(path, "w", zipfile.ZIP_DEFLATED) as z: for index, depth in enumerate(depths): height, width = depth.shape header = OpenEXR.Header(width, height) header["channels"] = {"Z": Imath.Channel(Imath.PixelType(Imath.PixelType.HALF))} with tempfile.NamedTemporaryFile(suffix=".exr") as f: exr = OpenEXR.OutputFile(f.name, header) exr.writePixels({"Z": depth.astype(np.float16).tobytes()}) exr.close() z.write(f.name, f"{index:05d}.exr") def gradient_loss(gt, pred, u): """Gradient loss.""" del u diff = pred - gt v_gradient = torch.abs(diff[..., 0:-2, 1:-1] - diff[..., 2:, 1:-1]) # * mask_v h_gradient = torch.abs(diff[..., 1:-1, 0:-2] - diff[..., 1:-1, 2:]) # * mask_h pred_grad = torch.abs(pred[..., 0:-2, 1:-1] - (pred[..., 2:, 1:-1])) + torch.abs( pred[..., 1:-1, 0:-2] - pred[..., 1:-1, 2:] ) gt_grad = torch.abs(gt[..., 0:-2, 1:-1] - (gt[..., 2:, 1:-1])) + torch.abs( gt[..., 1:-1, 0:-2] - gt[..., 1:-1, 2:] ) grad_diff = torch.abs(pred_grad - gt_grad) nearby_mask = (torch.exp(gt[..., 1:-1, 1:-1]) > 1.0).float().detach() # weight = (1. - torch.exp(-(grad_diff * 5.)).detach()) weight = 1.0 - torch.exp(-(grad_diff * 5.0)).detach() weight *= nearby_mask g_loss = torch.mean(h_gradient * weight) + torch.mean(v_gradient * weight) return g_loss def si_loss(gt, pred): log_gt = torch.log(torch.clamp(gt, 1e-3, 1e3)).view(gt.shape[0], -1) log_pred = torch.log(torch.clamp(pred, 1e-3, 1e3)).view(pred.shape[0], -1) log_diff = log_gt - log_pred num_pixels = gt.shape[-2] * gt.shape[-1] data_loss = torch.sum(log_diff**2, dim=-1) / num_pixels - torch.sum( log_diff, dim=-1 ) ** 2 / (num_pixels**2) return torch.mean(data_loss) def sobel_fg_alpha(disp, mode="sobel", beta=10.0): sobel_grad = kornia.filters.spatial_gradient(disp, mode=mode, normalized=False) sobel_mag = torch.sqrt( sobel_grad[:, :, 0, Ellipsis] ** 2 + sobel_grad[:, :, 1, Ellipsis] ** 2 ) alpha = torch.exp(-1.0 * beta * sobel_mag).detach() return alpha ALPHA_MOTION = 0.25 RESIZE_FACTOR = 0.5 def consistency_loss( cam_c2w, K, K_inv, disp_data, init_disp, uncertainty, flows, flow_masks, ii, jj, compute_normals, fg_alpha, w_ratio=1.0, w_flow=0.2, w_si=1.0, w_grad=2.0, w_normal=4.0, ): """Consistency loss.""" _, H, W = disp_data.shape # mesh grid xx = torch.arange(0, W).view(1, -1).repeat(H, 1) yy = torch.arange(0, H).view(-1, 1).repeat(1, W) xx = xx.view(1, 1, H, W) # .repeat(B ,1 ,1 ,1) yy = yy.view(1, 1, H, W) # .repeat(B ,1 ,1 ,1) grid = torch.cat((xx, yy), 1).float().cuda().permute(0, 2, 3, 1) # [None, ...] loss_flow = 0.0 # flow reprojection loss loss_d_ratio = 0.0 # depth consistency loss flows_step = flows.permute(0, 2, 3, 1) flow_masks_step = flow_masks.permute(0, 2, 3, 1).squeeze(-1) cam_1to2 = torch.bmm( torch.linalg.inv(torch.index_select(cam_c2w, dim=0, index=jj)), torch.index_select(cam_c2w, dim=0, index=ii), ) # warp disp from target time pixel_locations = grid + flows_step resize_factor = torch.tensor([W - 1.0, H - 1.0]).cuda()[None, None, None, ...] normalized_pixel_locations = 2 * (pixel_locations / resize_factor) - 1.0 disp_sampled = torch.nn.functional.grid_sample( torch.index_select(disp_data, dim=0, index=jj)[:, None, ...], normalized_pixel_locations, align_corners=True, ) uu = torch.index_select(uncertainty, dim=0, index=ii).squeeze(1) grid_h = torch.cat([grid, torch.ones_like(grid[..., 0:1])], dim=-1).unsqueeze(-1) # depth of reference view ref_depth = 1.0 / torch.clamp( torch.index_select(disp_data, dim=0, index=ii), 1e-3, 1e3 ) pts_3d_ref = ref_depth[..., None, None] * (K_inv[None, None, None] @ grid_h) rot = cam_1to2[:, None, None, :3, :3] trans = cam_1to2[:, None, None, :3, 3:4] pts_3d_tgt = (rot @ pts_3d_ref) + trans # [:, None, None, :, None] depth_tgt = pts_3d_tgt[:, :, :, 2:3, 0] disp_tgt = 1.0 / torch.clamp(depth_tgt, 0.1, 1e3) # flow consistency loss pts_2D_tgt = K[None, None, None] @ pts_3d_tgt flow_masks_step_ = flow_masks_step * (pts_2D_tgt[:, :, :, 2, 0] > 0.1) pts_2D_tgt = pts_2D_tgt[:, :, :, :2, 0] / torch.clamp( pts_2D_tgt[:, :, :, 2:, 0], 1e-3, 1e3 ) disp_sampled = torch.clamp(disp_sampled, 1e-3, 1e2) disp_tgt = torch.clamp(disp_tgt, 1e-3, 1e2) ratio = torch.maximum( disp_sampled.squeeze() / disp_tgt.squeeze(), disp_tgt.squeeze() / disp_sampled.squeeze(), ) ratio_error = torch.abs(ratio - 1.0) # loss_d_ratio += torch.sum( (ratio_error * uu + ALPHA_MOTION * torch.log(1.0 / uu)) * flow_masks_step_ ) / (torch.sum(flow_masks_step_) + 1e-8) flow_error = torch.abs(pts_2D_tgt - pixel_locations) loss_flow += torch.sum( (flow_error * uu[..., None] + ALPHA_MOTION * torch.log(1.0 / uu[..., None])) * flow_masks_step_[..., None] ) / (torch.sum(flow_masks_step_) * 2.0 + 1e-8) # prior mono-depth reg loss loss_prior = si_loss(init_disp, disp_data) KK = torch.inverse(K_inv) # multi gradient consistency disp_data_ds = disp_data[:, None, ...] init_disp_ds = init_disp[:, None, ...] K_rescale = KK.clone() K_inv_rescale = torch.inverse(K_rescale) pred_normal = compute_normals[0]( 1.0 / torch.clamp(disp_data_ds, 1e-3, 1e3), K_inv_rescale[None] ) init_normal = compute_normals[0]( 1.0 / torch.clamp(init_disp_ds, 1e-3, 1e3), K_inv_rescale[None] ) loss_normal = torch.mean( fg_alpha * (1.0 - torch.sum(pred_normal * init_normal, dim=1)) ) # / (1e-8 + torch.sum(fg_alpha)) loss_grad = 0.0 for scale in range(4): interval = 2**scale disp_data_ds = torch.nn.functional.interpolate( disp_data[:, None, ...], scale_factor=(1.0 / interval, 1.0 / interval), mode="nearest-exact", ) init_disp_ds = torch.nn.functional.interpolate( init_disp[:, None, ...], scale_factor=(1.0 / interval, 1.0 / interval), mode="nearest-exact", ) uncertainty_rs = torch.nn.functional.interpolate( uncertainty, scale_factor=(1.0 / interval, 1.0 / interval), mode="nearest-exact", ) loss_grad += gradient_loss( torch.log(disp_data_ds), torch.log(init_disp_ds), uncertainty_rs ) return ( w_ratio * loss_d_ratio + w_si * loss_prior + w_flow * loss_flow + w_normal * loss_normal + loss_grad * w_grad ) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--w_grad", type=float, default=2.0, help="w_grad") parser.add_argument("--w_normal", type=float, default=6.0, help="w_normal") parser.add_argument("--dir_path", type=str, default=".", help="directory path") parser.add_argument("--only_depth", action="store_true", help="only save optimize depth") return parser.parse_args() if __name__ == "__main__": args = parse_args() scene_name = os.path.basename(args.dir_path) cache_dir = os.path.join(args.dir_path, "cache-flow") rootdir = os.path.join(args.dir_path, "reconstructions") print("***************************** ", scene_name) img_data = np.load(os.path.join(rootdir, "images.npy"))[:, ::-1, ...] disp_data = np.load(os.path.join(rootdir, "disps.npy")) + 1e-6 intrinsics = np.load(os.path.join(rootdir, "intrinsics.npy")) poses = np.load(os.path.join(rootdir, "poses.npy")) mot_prob = np.load(os.path.join(rootdir, "motion_prob.npy")) flows = np.load(os.path.join(cache_dir, "flows.npy"), allow_pickle=True) flow_masks = np.load(os.path.join(cache_dir, "flows_masks.npy"), allow_pickle=True) flow_masks = np.float32(flow_masks) iijj = np.load(os.path.join(cache_dir, "ii-jj.npy"), allow_pickle=True) intrinsics = intrinsics[0] poses_th = torch.as_tensor(poses, device="cpu").float().cuda() K = np.eye(3) K[0, 0] = intrinsics[0] K[1, 1] = intrinsics[1] K[0, 2] = intrinsics[2] K[1, 2] = intrinsics[3] img_data_pt = ( torch.from_numpy(np.ascontiguousarray(img_data)).float().cuda() / 255.0 ) flows = torch.from_numpy(np.ascontiguousarray(flows)).float().cuda() flow_masks = ( torch.from_numpy(np.ascontiguousarray(flow_masks)).float().cuda() ) # .unsqueeze(1) iijj = torch.from_numpy(np.ascontiguousarray(iijj)).float().cuda() ii = iijj[0, ...].long() jj = iijj[1, ...].long() K = torch.from_numpy(K).float().cuda() init_disp = torch.from_numpy(disp_data).float().cuda() disp_data = torch.from_numpy(disp_data).float().cuda() assert init_disp.shape == disp_data.shape init_disp = torch.nn.functional.interpolate( init_disp.unsqueeze(1), scale_factor=(RESIZE_FACTOR, RESIZE_FACTOR), mode="bilinear", ).squeeze(1) disp_data = torch.nn.functional.interpolate( disp_data.unsqueeze(1), scale_factor=(RESIZE_FACTOR, RESIZE_FACTOR), mode="bilinear", ).squeeze(1) fg_alpha = sobel_fg_alpha(init_disp[:, None, ...]) > 0.2 fg_alpha = fg_alpha.squeeze(1).float() + 0.2 cvd_prob = torch.nn.functional.interpolate( torch.from_numpy(mot_prob).unsqueeze(1).cuda(), scale_factor=(4, 4), mode="bilinear", ) cvd_prob[cvd_prob > 0.5] = 0.5 cvd_prob = torch.clamp(cvd_prob, 1e-3, 1.0) # rescale intrinsic matrix to small resolution K_o = K.clone() K[0:2, ...] *= RESIZE_FACTOR K_inv = torch.linalg.inv(K) disp_data.requires_grad = False poses_th.requires_grad = False uncertainty = cvd_prob # First optimize scale and shift to align them log_scale_ = torch.log(torch.ones(init_disp.shape[0]).to(disp_data.device)) shift_ = torch.zeros(init_disp.shape[0]).to(disp_data.device) log_scale_.requires_grad = True shift_.requires_grad = True uncertainty.requires_grad = True optim = torch.optim.Adam( [ {"params": log_scale_, "lr": 1e-2}, {"params": shift_, "lr": 1e-2}, {"params": uncertainty, "lr": 1e-2}, ] ) compute_normals = [] compute_normals.append(NormalGenerator(disp_data.shape[-2], disp_data.shape[-1])) init_disp = torch.clamp(init_disp, 1e-3, 1e3) for i in range(100): optim.zero_grad() cam_c2w = SE3(poses_th).inv().matrix() scale_ = torch.exp(log_scale_) loss = consistency_loss( cam_c2w, K, K_inv, torch.clamp( disp_data * scale_[..., None, None] + shift_[..., None, None], 1e-3, 1e3, ), init_disp, torch.clamp(uncertainty, 1e-4, 1e3), flows, flow_masks, ii, jj, compute_normals, fg_alpha, ) loss.backward() uncertainty.grad = torch.nan_to_num(uncertainty.grad, nan=0.0) log_scale_.grad = torch.nan_to_num(log_scale_.grad, nan=0.0) shift_.grad = torch.nan_to_num(shift_.grad, nan=0.0) optim.step() print("step ", i, loss.item()) # Then optimize depth and uncertainty disp_data = ( disp_data * torch.exp(log_scale_)[..., None, None].detach() + shift_[..., None, None].detach() ) init_disp = ( init_disp * torch.exp(log_scale_)[..., None, None].detach() + shift_[..., None, None].detach() ) init_disp = torch.clamp(init_disp, 1e-3, 1e3) disp_data.requires_grad = True uncertainty.requires_grad = True poses_th.requires_grad = False # True optim = torch.optim.Adam( [ {"params": disp_data, "lr": 5e-3}, {"params": uncertainty, "lr": 5e-3}, ] ) losses = [] for i in range(400): optim.zero_grad() cam_c2w = SE3(poses_th).inv().matrix() loss = consistency_loss( cam_c2w, K, K_inv, torch.clamp(disp_data, 1e-3, 1e3), init_disp, torch.clamp(uncertainty, 1e-4, 1e3), flows, flow_masks, ii, jj, compute_normals, fg_alpha, w_ratio=1.0, w_flow=0.2, w_si=1, w_grad=args.w_grad, w_normal=args.w_normal, ) loss.backward() disp_data.grad = torch.nan_to_num(disp_data.grad, nan=0.0) uncertainty.grad = torch.nan_to_num(uncertainty.grad, nan=0.0) optim.step() print("step ", i, loss.item()) losses.append(loss) disp_data_opt = ( torch.nn.functional.interpolate( disp_data.unsqueeze(1), scale_factor=(2, 2), mode="bilinear" ) .squeeze(1) .detach() .cpu() .numpy() ) if args.only_depth: save_depth( os.path.join(args.dir_path, "depth_opt.zip"), disp_data_opt ) else: np.savez( os.path.join(args.dir_path, "sgd_cvd_hr.npz"), images=np.uint8(img_data_pt.cpu().numpy().transpose(0, 2, 3, 1) * 255.0), depths=np.clip(np.float16(1.0 / disp_data_opt), 1e-3, 1e2), intrinsic=K_o.detach().cpu().numpy(), cam_c2w=cam_c2w.detach().cpu().numpy(), ) ================================================ FILE: camera_pose_annotation/cvd_opt/geometry_utils.py ================================================ # Copyright 2025 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Geometry utils for MegaSaM.""" # pylint: disable=invalid-name import kornia import numpy as np import torch from torch import jit from torch import nn from torch import Tensor # pylint: disable=g-importing-member import torch.nn.functional as F @torch.jit.script def to_homogeneous(input_tensor: Tensor, dim: int = 0) -> Tensor: """Converts tensor to homogeneous coordinates by adding ones to the specified dimension.""" ones = torch.ones_like(input_tensor.select(dim, 0).unsqueeze(dim)) output_bkn = torch.cat([input_tensor, ones], dim=dim) return output_bkn class BackprojectDepth(nn.Module): """Layer that projects points from 2D camera to 3D space. The 3D points are represented in homogeneous coordinates. """ def __init__(self, height: int, width: int): super().__init__() self.height = height self.width = width xx, yy = torch.meshgrid( torch.arange(self.width), torch.arange(self.height), indexing="xy", ) pix_coords_2hw = torch.stack((xx, yy), axis=0) + 0.5 pix_coords_13N = ( to_homogeneous( pix_coords_2hw, dim=0, ) .flatten(1) .unsqueeze(0) ) # make these tensors into buffers so they are put on the correct GPU # automatically self.register_buffer("pix_coords_13N", pix_coords_13N) # @jit.script_method def forward(self, depth_b1hw: Tensor, invK_b44: Tensor) -> Tensor: """Backprojects spatial points in 2D image space to world space using invK_b44 at the depths defined in depth_b1hw.""" cam_points_b3N = torch.matmul( invK_b44[:, :3, :3], self.pix_coords_13N.float().cuda() ) cam_points_b3N = depth_b1hw.flatten(start_dim=2) * cam_points_b3N cam_points_b4N = to_homogeneous(cam_points_b3N, dim=1) return cam_points_b4N class Project3D(jit.ScriptModule): """Layer that projects 3D points into the 2D camera.""" def __init__(self, eps: float = 1e-8): super().__init__() self.register_buffer("eps", torch.tensor(eps).view(1, 1, 1)) @jit.script_method def forward( self, points_b4N: Tensor, K_b44: Tensor, cam_T_world_b44: Tensor ) -> Tensor: """Projects spatial points in 3D world space to camera image space using the extrinsics matrix cam_T_world_b44 and intrinsics K_b44.""" P_b44 = K_b44 @ cam_T_world_b44 cam_points_b3N = P_b44[:, :3] @ points_b4N # from Kornia and OpenCV: # https://kornia.readthedocs.io/en/latest/_modules/kornia/geometry/conversions.html#convert_points_from_homogeneous mask = torch.abs(cam_points_b3N[:, 2:]) > self.eps depth_b1N = cam_points_b3N[:, 2:] + self.eps scale = torch.where( mask, 1.0 / depth_b1N, torch.tensor(1.0, device=depth_b1N.device) ) pix_coords_b2N = cam_points_b3N[:, :2] * scale return torch.cat([pix_coords_b2N, depth_b1N], dim=1) class NormalGenerator(nn.Module): """Estimates normals from depth maps.""" def __init__( self, height: int, width: int, smoothing_kernel_size: int = 5, smoothing_kernel_std: float = 2.0, ): """Estimates normals from depth maps.""" super().__init__() self.height = height self.width = width self.backproject = BackprojectDepth(self.height, self.width) self.kernel_size = smoothing_kernel_size self.std = smoothing_kernel_std # @jit.script_method def forward(self, depth_b1hw: Tensor, invK_b44: Tensor) -> Tensor: """Estimates a normal at each location in the depth map.""" # First smoothes incoming depth maps with a gaussian blur, backprojects # those depth points into world space (see BackprojectDepth), estimates # the spatial gradient at those points, and finally uses normalized cross # correlation to estimate a normal vector at each location. depth_smooth_b1hw = kornia.filters.gaussian_blur2d( depth_b1hw, (self.kernel_size, self.kernel_size), (self.std, self.std), ) cam_points_b4N = self.backproject(depth_smooth_b1hw, invK_b44) cam_points_b3hw = cam_points_b4N[:, :3].view(-1, 3, self.height, self.width) gradients_b32hw = kornia.filters.spatial_gradient(cam_points_b3hw) return F.normalize( torch.cross( gradients_b32hw[:, :, 0], gradients_b32hw[:, :, 1], dim=1, ), dim=1, ) def get_camera_rays( world_T_cam_b44, world_points_b3N, in_camera_frame, cam_T_world_b44=None, eps=1e-4, ): """Computes camera rays for given camera data and points, optionally shifts rays to camera frame.""" del eps if in_camera_frame: batch_size = world_points_b3N.shape[0] num_points = world_points_b3N.shape[2] world_points_b4N = torch.cat( [ world_points_b3N, torch.ones(batch_size, 1, num_points).to(world_points_b3N.device), ], 1, ) camera_points_b3N = torch.matmul( cam_T_world_b44[:, :3, :4], world_points_b4N ) rays_b3N = camera_points_b3N else: rays_b3N = world_points_b3N - world_T_cam_b44[:, 0:3, 3][:, :, None].expand( world_points_b3N.shape ) rays_b3N = torch.nn.functional.normalize(rays_b3N, dim=1) return rays_b3N def pose_distance(pose_b44): """DVMVS frame pose distance.""" R = pose_b44[:, :3, :3] t = pose_b44[:, :3, 3] R_trace = R.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1) R_measure = torch.sqrt( 2 * (1 - torch.minimum(torch.ones_like(R_trace) * 3.0, R_trace) / 3) ) t_measure = torch.norm(t, dim=1) combined_measure = torch.sqrt(t_measure**2 + R_measure**2) return combined_measure, R_measure, t_measure def qvec2rotmat(qvec): """Quaternion to 3x3 rotation matrix.""" return np.array([ [ 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], ], [ 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], ], [ 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, ], ]) def rotx(t): """3D Rotation about the x-axis.""" c = np.cos(t) s = np.sin(t) return np.array([[1, 0, 0], [0, c, -s], [0, s, c]]) def roty(t): """3D Rotation about the y-axis.""" c = np.cos(t) s = np.sin(t) return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]]) def rotz(t): """3D Rotation about the z-axis.""" c = np.cos(t) s = np.sin(t) return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]]) ================================================ FILE: camera_pose_annotation/cvd_opt/inference_batch.py ================================================ """ Batch inference script for CVD (Camera View Depth) optimization. Processes multiple video clips in parallel using multi-GPU setup. """ import pandas as pd import os import argparse import concurrent.futures from multiprocessing import Manager import subprocess import queue from tqdm import tqdm def process_single_row(row, index, args, worker_id=0): """Process a single video clip for CVD optimization.""" dir_path = os.path.join(args.dir_path, row["id"]) device_id = worker_id % args.gpu_num # Build command for CVD optimization with specific GPU cmd = ( f"CUDA_VISIBLE_DEVICES={args.gpu_id[device_id]} python camera_pose_annotation/cvd_opt/cvd_opt.py " f"--dir_path {dir_path} " f"--w_grad 2.0 --w_normal 5.0 " ) if args.only_depth: cmd += "--only_depth " process = subprocess.Popen( cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) stdout, stderr = process.communicate() if process.returncode != 0: print(f"Error optimizing CVD for {row['id']}: {stderr.decode()}") def worker(task_queue, args, worker_id, pbar): """Worker function for parallel CVD optimization processing.""" while True: try: index, row = task_queue.get(timeout=1) except queue.Empty: break process_single_row(row, index, args, worker_id) task_queue.task_done() pbar.update(1) def parse_args(): """Parse command line arguments for CVD batch processing.""" parser = argparse.ArgumentParser() parser.add_argument("--csv_path", type=str, help="Path to the csv file") parser.add_argument("--dir_path", type=str, default="./outputs") parser.add_argument("--only_depth", action="store_true", help="Only save optimized depth") parser.add_argument( "--gpu_id", type=str, default="0", help="Comma-separated list of GPU IDs to use" ) parser.add_argument( "--num_workers", type=int, default=4, help="Number of workers for parallel processing", ) parser.add_argument( "--disable_parallel", action="store_true", help="Disable parallel processing" ) return parser.parse_args() def main(): args = parse_args() # Parse GPU configuration args.gpu_num = len(args.gpu_id.split(",")) args.gpu_id = [int(gpu) for gpu in args.gpu_id.split(",")] df = pd.read_csv(args.csv_path) if args.disable_parallel: # Sequential processing for index, row in tqdm(df.iterrows(), total=len(df), desc="Processing rows"): process_single_row(row, index, args) else: # Parallel processing with multiple workers manager = Manager() task_queue = manager.Queue() for index, row in df.iterrows(): task_queue.put((index, row)) with tqdm(total=len(df), desc="Processing rows") as pbar: with concurrent.futures.ThreadPoolExecutor( max_workers=args.num_workers ) as executor: futures = [] for id in range(args.num_workers): futures.append(executor.submit(worker, task_queue, args, id, pbar)) for future in concurrent.futures.as_completed(futures): future.result() if __name__ == "__main__": main() ================================================ FILE: camera_pose_annotation/cvd_opt/preprocess/__init__.py ================================================ ================================================ FILE: camera_pose_annotation/cvd_opt/preprocess/core/__init__.py ================================================ ================================================ FILE: camera_pose_annotation/cvd_opt/preprocess/core/corr.py ================================================ # Copyright 2025 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Correlation block for MegaSaM.""" import torch import torch.nn.functional as F from .utils.utils import bilinear_sampler # pylint: disable=g-import-not-at-top try: import alt_cuda_corr except: # pylint: disable=bare-except # alt_cuda_corr is not compiled pass class CorrBlock: """Correlation block for MegaSaM.""" def __init__(self, fmap1, fmap2, num_levels=4, radius=4): self.num_levels = num_levels self.radius = radius self.corr_pyramid = [] # all pairs correlation corr = CorrBlock.corr(fmap1, fmap2) batch, h1, w1, dim, h2, w2 = corr.shape corr = corr.reshape(batch * h1 * w1, dim, h2, w2) self.corr_pyramid.append(corr) for _ in range(self.num_levels - 1): corr = F.avg_pool2d(corr, 2, stride=2) self.corr_pyramid.append(corr) def __call__(self, coords): r = self.radius coords = coords.permute(0, 2, 3, 1) batch, h1, w1, _ = coords.shape out_pyramid = [] for i in range(self.num_levels): corr = self.corr_pyramid[i] dx = torch.linspace(-r, r, 2 * r + 1) dy = torch.linspace(-r, r, 2 * r + 1) delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) coords_lvl = centroid_lvl + delta_lvl corr = bilinear_sampler(corr, coords_lvl) corr = corr.view(batch, h1, w1, -1) out_pyramid.append(corr) out = torch.cat(out_pyramid, dim=-1) return out.permute(0, 3, 1, 2).contiguous().float() @classmethod def corr(cls, fmap1, fmap2): del cls batch, dim, ht, wd = fmap1.shape fmap1 = fmap1.view(batch, dim, ht * wd) fmap2 = fmap2.view(batch, dim, ht * wd) corr = torch.matmul(fmap1.transpose(1, 2), fmap2) corr = corr.view(batch, ht, wd, 1, ht, wd) return corr / torch.sqrt(torch.tensor(dim).float()) class AlternateCorrBlock: """Correlation block for MegaSaM.""" def __init__(self, fmap1, fmap2, num_levels=4, radius=4): self.num_levels = num_levels self.radius = radius self.pyramid = [(fmap1, fmap2)] for _ in range(self.num_levels): fmap1 = F.avg_pool2d(fmap1, 2, stride=2) fmap2 = F.avg_pool2d(fmap2, 2, stride=2) self.pyramid.append((fmap1, fmap2)) def __call__(self, coords): coords = coords.permute(0, 2, 3, 1) # pylint: disable=invalid-name B, H, W, _ = coords.shape dim = self.pyramid[0][0].shape[1] corr_list = [] for i in range(self.num_levels): r = self.radius fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() (corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) corr_list.append(corr.squeeze(1)) corr = torch.stack(corr_list, dim=1) corr = corr.reshape(B, -1, H, W) return corr / torch.sqrt(torch.tensor(dim).float()) ================================================ FILE: camera_pose_annotation/cvd_opt/preprocess/core/datasets.py ================================================ # Copyright 2025 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Dataset classes for MegaSaM.""" import glob import os import os.path as osp import random import numpy as np import torch from torch.utils import data from utils import frame_utils from utils.augmentor import FlowAugmentor from utils.augmentor import SparseFlowAugmentor class FlowDataset(data.Dataset): """Base class for flow datasets.""" def __init__(self, aug_params=None, sparse=False): self.augmentor = None self.sparse = sparse if aug_params is not None: if sparse: self.augmentor = SparseFlowAugmentor(**aug_params) else: self.augmentor = FlowAugmentor(**aug_params) self.is_test = False self.init_seed = False self.flow_list = [] self.image_list = [] self.extra_info = [] def __getitem__(self, index): if self.is_test: img1 = frame_utils.read_gen(self.image_list[index][0]) img2 = frame_utils.read_gen(self.image_list[index][1]) img1 = np.array(img1).astype(np.uint8)[..., :3] img2 = np.array(img2).astype(np.uint8)[..., :3] img1 = torch.from_numpy(img1).permute(2, 0, 1).float() img2 = torch.from_numpy(img2).permute(2, 0, 1).float() return img1, img2, self.extra_info[index] if not self.init_seed: worker_info = torch.utils.data.get_worker_info() if worker_info is not None: torch.manual_seed(worker_info.id) np.random.seed(worker_info.id) random.seed(worker_info.id) self.init_seed = True index = index % len(self.image_list) valid = None if self.sparse: flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) else: flow = frame_utils.read_gen(self.flow_list[index]) img1 = frame_utils.read_gen(self.image_list[index][0]) img2 = frame_utils.read_gen(self.image_list[index][1]) flow = np.array(flow).astype(np.float32) img1 = np.array(img1).astype(np.uint8) img2 = np.array(img2).astype(np.uint8) # grayscale images if len(img1.shape) == 2: img1 = np.tile(img1[..., None], (1, 1, 3)) img2 = np.tile(img2[..., None], (1, 1, 3)) else: img1 = img1[..., :3] img2 = img2[..., :3] if self.augmentor is not None: if self.sparse: img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) else: img1, img2, flow = self.augmentor(img1, img2, flow) img1 = torch.from_numpy(img1).permute(2, 0, 1).float() img2 = torch.from_numpy(img2).permute(2, 0, 1).float() flow = torch.from_numpy(flow).permute(2, 0, 1).float() if valid is not None: valid = torch.from_numpy(valid) else: valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) return img1, img2, flow, valid.float() def __rmul__(self, v): self.flow_list = v * self.flow_list self.image_list = v * self.image_list return self def __len__(self): return len(self.image_list) class MpiSintel(FlowDataset): """MpiSintel dataset.""" def __init__( self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean', ): super(MpiSintel, self).__init__(aug_params) flow_root = osp.join(root, split, 'flow') image_root = osp.join(root, split, dstype) if split == 'test': self.is_test = True for scene in os.listdir(image_root): image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) for i in range(len(image_list) - 1): self.image_list += [[image_list[i], image_list[i + 1]]] self.extra_info += [(scene, i)] # scene and frame_id if split != 'test': self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) class FlyingChairs(FlowDataset): """FlyingChairs dataset.""" def __init__( self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data', ): super(FlyingChairs, self).__init__(aug_params) images = sorted(glob(osp.join(root, '*.ppm'))) flows = sorted(glob(osp.join(root, '*.flo'))) assert len(images) // 2 == len(flows) split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) for i in range(len(flows)): exid = split_list[i] if (split == 'training' and exid == 1) or ( split == 'validation' and exid == 2 ): self.flow_list += [flows[i]] self.image_list += [[images[2 * i], images[2 * i + 1]]] class FlyingThings3D(FlowDataset): """FlyingThings3D dataset.""" def __init__( self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass', ): super(FlyingThings3D, self).__init__(aug_params) for cam in ['left']: for direction in ['into_future', 'into_past']: image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) for idir, fdir in zip(image_dirs, flow_dirs): images = sorted(glob(osp.join(idir, '*.png'))) flows = sorted(glob(osp.join(fdir, '*.pfm'))) for i in range(len(flows) - 1): if direction == 'into_future': self.image_list += [[images[i], images[i + 1]]] self.flow_list += [flows[i]] elif direction == 'into_past': self.image_list += [[images[i + 1], images[i]]] self.flow_list += [flows[i + 1]] class KITTI(FlowDataset): """KITTI dataset.""" def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): super(KITTI, self).__init__(aug_params, sparse=True) if split == 'testing': self.is_test = True root = osp.join(root, split) images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) for img1, img2 in zip(images1, images2): frame_id = img1.split('/')[-1] self.extra_info += [[frame_id]] self.image_list += [[img1, img2]] if split == 'training': self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) class HD1K(FlowDataset): """HD1K dataset.""" def __init__(self, aug_params=None, root='datasets/HD1k'): super(HD1K, self).__init__(aug_params, sparse=True) seq_ix = 0 while 1: flows = sorted( glob( os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix) ) ) images = sorted( glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)) ) if not flows: break for i in range(len(flows) - 1): self.flow_list += [flows[i]] self.image_list += [[images[i], images[i + 1]]] seq_ix += 1 # pylint: disable=invalid-name def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): """Create the data loader for the corresponding training set.""" if args.stage == 'chairs': aug_params = { 'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True, } train_dataset = FlyingChairs(aug_params, split='training') elif args.stage == 'things': aug_params = { 'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True, } clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') train_dataset = clean_dataset + final_dataset elif args.stage == 'sintel': aug_params = { 'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True, } things = FlyingThings3D(aug_params, dstype='frames_cleanpass') sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') sintel_final = MpiSintel(aug_params, split='training', dstype='final') if TRAIN_DS == 'C+T+K+S+H': kitti = KITTI({ 'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True, }) hd1k = HD1K({ 'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True, }) train_dataset = ( 100 * sintel_clean + 100 * sintel_final + 200 * kitti + 5 * hd1k + things ) elif TRAIN_DS == 'C+T+K/S': train_dataset = 100 * sintel_clean + 100 * sintel_final + things else: raise ValueError('Unknown split: %s' % TRAIN_DS) elif args.stage == 'kitti': aug_params = { 'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False, } train_dataset = KITTI(aug_params, split='training') else: raise ValueError('Unknown training set: %s' % args.stage) train_loader = data.DataLoader( train_dataset, batch_size=args.batch_size, pin_memory=False, shuffle=True, num_workers=4, drop_last=True, ) print('Training with %d image pairs' % len(train_dataset)) return train_loader ================================================ FILE: camera_pose_annotation/cvd_opt/preprocess/core/extractor.py ================================================ # Copyright 2025 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Network layer classes for MegaSaM.""" import torch from torch import nn class ResidualBlock(nn.Module): """Residual block for MegaSaM.""" def __init__(self, in_planes, planes, norm_fn='group', stride=1): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d( in_planes, planes, kernel_size=3, padding=1, stride=stride ) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) self.relu = nn.ReLU(inplace=True) num_groups = planes // 8 if norm_fn == 'group': self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) if stride != 1: self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) elif norm_fn == 'batch': self.norm1 = nn.BatchNorm2d(planes) self.norm2 = nn.BatchNorm2d(planes) if stride != 1: self.norm3 = nn.BatchNorm2d(planes) elif norm_fn == 'instance': self.norm1 = nn.InstanceNorm2d(planes) self.norm2 = nn.InstanceNorm2d(planes) if stride != 1: self.norm3 = nn.InstanceNorm2d(planes) elif norm_fn == 'none': self.norm1 = nn.Sequential() self.norm2 = nn.Sequential() if stride != 1: self.norm3 = nn.Sequential() if stride == 1: self.downsample = None else: self.downsample = nn.Sequential( nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 ) def forward(self, x): y = x y = self.relu(self.norm1(self.conv1(y))) y = self.relu(self.norm2(self.conv2(y))) if self.downsample is not None: x = self.downsample(x) return self.relu(x + y) class BottleneckBlock(nn.Module): """Bottleneck block for MegaSaM.""" def __init__(self, in_planes, planes, norm_fn='group', stride=1): super(BottleneckBlock, self).__init__() self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0) self.conv2 = nn.Conv2d( planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride ) self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) self.relu = nn.ReLU(inplace=True) num_groups = planes // 8 if norm_fn == 'group': self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) if stride != 1: self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) elif norm_fn == 'batch': self.norm1 = nn.BatchNorm2d(planes // 4) self.norm2 = nn.BatchNorm2d(planes // 4) self.norm3 = nn.BatchNorm2d(planes) if stride != 1: self.norm4 = nn.BatchNorm2d(planes) elif norm_fn == 'instance': self.norm1 = nn.InstanceNorm2d(planes // 4) self.norm2 = nn.InstanceNorm2d(planes // 4) self.norm3 = nn.InstanceNorm2d(planes) if stride != 1: self.norm4 = nn.InstanceNorm2d(planes) elif norm_fn == 'none': self.norm1 = nn.Sequential() self.norm2 = nn.Sequential() self.norm3 = nn.Sequential() if stride != 1: self.norm4 = nn.Sequential() if stride == 1: self.downsample = None else: self.downsample = nn.Sequential( nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4 ) def forward(self, x): y = x y = self.relu(self.norm1(self.conv1(y))) y = self.relu(self.norm2(self.conv2(y))) y = self.relu(self.norm3(self.conv3(y))) if self.downsample is not None: x = self.downsample(x) return self.relu(x + y) class BasicEncoder(nn.Module): """Basic encoder for MegaSaM.""" def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): super(BasicEncoder, self).__init__() self.norm_fn = norm_fn if self.norm_fn == 'group': self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) elif self.norm_fn == 'batch': self.norm1 = nn.BatchNorm2d(64) elif self.norm_fn == 'instance': self.norm1 = nn.InstanceNorm2d(64) elif self.norm_fn == 'none': self.norm1 = nn.Sequential() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) self.relu1 = nn.ReLU(inplace=True) self.in_planes = 64 self.layer1 = self._make_layer(64, stride=1) self.layer2 = self._make_layer(96, stride=2) self.layer3 = self._make_layer(128, stride=2) # output convolution self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) self.dropout = None if dropout > 0: self.dropout = nn.Dropout2d(p=dropout) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0) def _make_layer(self, dim, stride=1): layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) layers = (layer1, layer2) self.in_planes = dim return nn.Sequential(*layers) def forward(self, x): # if input is list, combine batch dimension is_list = isinstance(x, tuple) or isinstance(x, list) if is_list: batch_dim = x[0].shape[0] x = torch.cat(x, dim=0) x = self.conv1(x) x = self.norm1(x) x = self.relu1(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.conv2(x) if self.training and self.dropout is not None: x = self.dropout(x) if is_list: x = torch.split(x, [batch_dim, batch_dim], dim=0) # pylint: disable=undefined-variable return x class SmallEncoder(nn.Module): """Small encoder for MegaSaM.""" def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): super(SmallEncoder, self).__init__() self.norm_fn = norm_fn if self.norm_fn == 'group': self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) elif self.norm_fn == 'batch': self.norm1 = nn.BatchNorm2d(32) elif self.norm_fn == 'instance': self.norm1 = nn.InstanceNorm2d(32) elif self.norm_fn == 'none': self.norm1 = nn.Sequential() self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) self.relu1 = nn.ReLU(inplace=True) self.in_planes = 32 self.layer1 = self._make_layer(32, stride=1) self.layer2 = self._make_layer(64, stride=2) self.layer3 = self._make_layer(96, stride=2) self.dropout = None if dropout > 0: self.dropout = nn.Dropout2d(p=dropout) self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0) def _make_layer(self, dim, stride=1): layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) layers = (layer1, layer2) self.in_planes = dim return nn.Sequential(*layers) def forward(self, x): # if input is list, combine batch dimension is_list = isinstance(x, tuple) or isinstance(x, list) if is_list: batch_dim = x[0].shape[0] x = torch.cat(x, dim=0) x = self.conv1(x) x = self.norm1(x) x = self.relu1(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.conv2(x) if self.training and self.dropout is not None: x = self.dropout(x) if is_list: x = torch.split(x, [batch_dim, batch_dim], dim=0) # pylint: disable=undefined-variable return x ================================================ FILE: camera_pose_annotation/cvd_opt/preprocess/core/raft.py ================================================ # Copyright 2025 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """RAFT network for MegaSaM.""" from .corr import AlternateCorrBlock from .corr import CorrBlock from .extractor import BasicEncoder from .extractor import SmallEncoder import torch from torch import nn import torch.nn.functional as F from .update import BasicUpdateBlock from .update import SmallUpdateBlock from .utils.utils import coords_grid from .utils.utils import upflow8 try: autocast = torch.cuda.amp.autocast except: # pylint: disable=bare-except # dummy autocast for PyTorch < 1.6 class autocast: # pylint: disable=invalid-name def __init__(self, enabled): pass def __enter__(self): pass def __exit__(self, *args): pass class RAFT(nn.Module): """RAFT network for MegaSaM.""" def __init__(self, args): super(RAFT, self).__init__() self.args = args self.mixed_precision = True if args.small: self.hidden_dim = hdim = 96 self.context_dim = cdim = 64 args.corr_levels = 4 args.corr_radius = 3 else: self.hidden_dim = hdim = 128 self.context_dim = cdim = 128 args.corr_levels = 4 args.corr_radius = 4 if 'dropout' not in self.args: self.args.dropout = 0 if 'alternate_corr' not in self.args: self.args.alternate_corr = False # feature network, context network, and update block if args.small: self.fnet = SmallEncoder( output_dim=128, norm_fn='instance', dropout=args.dropout ) self.cnet = SmallEncoder( output_dim=hdim + cdim, norm_fn='none', dropout=args.dropout ) self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) else: self.fnet = BasicEncoder( output_dim=256, norm_fn='instance', dropout=args.dropout ) self.cnet = BasicEncoder( output_dim=hdim + cdim, norm_fn='batch', dropout=args.dropout ) self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) def freeze_bn(self): for m in self.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() def initialize_flow(self, img): """Flow is represented as difference between two coordinate grids flow = coords1 - coords0.""" # pylint: disable=invalid-name N, _, H, W = img.shape coords0 = coords_grid(N, H // 8, W // 8).to(img.device) coords1 = coords_grid(N, H // 8, W // 8).to(img.device) # optical flow computed as difference: flow = coords1 - coords0 return coords0, coords1 def upsample_flow(self, flow, mask): """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination.""" # pylint: disable=invalid-name N, _, H, W = flow.shape mask = mask.view(N, 1, 9, 8, 8, H, W) mask = torch.softmax(mask, dim=2) up_flow = F.unfold(8 * flow, [3, 3], padding=1) up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) up_flow = torch.sum(mask * up_flow, dim=2) up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) return up_flow.reshape(N, 2, 8 * H, 8 * W) def forward( self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False, ): """Estimate optical flow between pair of frames.""" image1 = 2 * (image1 / 255.0) - 1.0 image2 = 2 * (image2 / 255.0) - 1.0 image1 = image1.contiguous() image2 = image2.contiguous() hdim = self.hidden_dim cdim = self.context_dim # run the feature network with autocast(enabled=self.mixed_precision): fmap1, fmap2 = self.fnet([image1, image2]) fmap1 = fmap1.float() fmap2 = fmap2.float() if self.args.alternate_corr: corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) else: corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) # run the context network with autocast(enabled=self.mixed_precision): cnet = self.cnet(image1) net, inp = torch.split(cnet, [hdim, cdim], dim=1) net = torch.tanh(net) inp = torch.relu(inp) coords0, coords1 = self.initialize_flow(image1) if flow_init is not None: coords1 = coords1 + flow_init flow_predictions = [] flow_up = None for _ in range(iters): coords1 = coords1.detach() corr = corr_fn(coords1) # index correlation volume flow = coords1 - coords0 with autocast(enabled=self.mixed_precision): net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) # F(t+1) = F(t) + \Delta(t) coords1 = coords1 + delta_flow # upsample predictions if up_mask is None: flow_up = upflow8(coords1 - coords0) else: flow_up = self.upsample_flow(coords1 - coords0, up_mask) flow_predictions.append(flow_up) if test_mode: if flow_up is None: raise ValueError('flow_up is None') return coords1 - coords0, flow_up, net return flow_predictions ================================================ FILE: camera_pose_annotation/cvd_opt/preprocess/core/update.py ================================================ # Copyright 2025 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Update block for consistent video depth optimization.""" import torch from torch import nn import torch.nn.functional as F class FlowHead(nn.Module): def __init__(self, input_dim=128, hidden_dim=256): super(FlowHead, self).__init__() self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) self.relu = nn.ReLU(inplace=True) def forward(self, x): return self.conv2(self.relu(self.conv1(x))) class ConvGRU(nn.Module): """GRU with convolution.""" def __init__(self, hidden_dim=128, input_dim=192 + 128): super(ConvGRU, self).__init__() self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) def forward(self, h, x): hx = torch.cat([h, x], dim=1) z = torch.sigmoid(self.convz(hx)) r = torch.sigmoid(self.convr(hx)) q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) h = (1 - z) * h + z * q return h class SepConvGRU(nn.Module): """GRU with separate convolution for horizontal and vertical directions.""" def __init__(self, hidden_dim=128, input_dim=192 + 128): super(SepConvGRU, self).__init__() self.convz1 = nn.Conv2d( hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) ) self.convr1 = nn.Conv2d( hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) ) self.convq1 = nn.Conv2d( hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) ) self.convz2 = nn.Conv2d( hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) ) self.convr2 = nn.Conv2d( hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) ) self.convq2 = nn.Conv2d( hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) ) def forward(self, h, x): # horizontal hx = torch.cat([h, x], dim=1) z = torch.sigmoid(self.convz1(hx)) r = torch.sigmoid(self.convr1(hx)) q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) h = (1 - z) * h + z * q # vertical hx = torch.cat([h, x], dim=1) z = torch.sigmoid(self.convz2(hx)) r = torch.sigmoid(self.convr2(hx)) q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) h = (1 - z) * h + z * q return h class SmallMotionEncoder(nn.Module): """Small motion encoder for MegaSaM.""" def __init__(self, args): super(SmallMotionEncoder, self).__init__() cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) self.convf1 = nn.Conv2d(2, 64, 7, padding=3) self.convf2 = nn.Conv2d(64, 32, 3, padding=1) self.conv = nn.Conv2d(128, 80, 3, padding=1) def forward(self, flow, corr): cor = F.relu(self.convc1(corr)) flo = F.relu(self.convf1(flow)) flo = F.relu(self.convf2(flo)) cor_flo = torch.cat([cor, flo], dim=1) out = F.relu(self.conv(cor_flo)) return torch.cat([out, flow], dim=1) class BasicMotionEncoder(nn.Module): """Basic motion encoder for MegaSaM.""" def __init__(self, args): super(BasicMotionEncoder, self).__init__() cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) self.convc2 = nn.Conv2d(256, 192, 3, padding=1) self.convf1 = nn.Conv2d(2, 128, 7, padding=3) self.convf2 = nn.Conv2d(128, 64, 3, padding=1) self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) def forward(self, flow, corr): cor = F.relu(self.convc1(corr)) cor = F.relu(self.convc2(cor)) flo = F.relu(self.convf1(flow)) flo = F.relu(self.convf2(flo)) cor_flo = torch.cat([cor, flo], dim=1) out = F.relu(self.conv(cor_flo)) return torch.cat([out, flow], dim=1) class SmallUpdateBlock(nn.Module): """Small update block for MegaSaM.""" def __init__(self, args, hidden_dim=96): super(SmallUpdateBlock, self).__init__() self.encoder = SmallMotionEncoder(args) self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64) self.flow_head = FlowHead(hidden_dim, hidden_dim=128) def forward(self, net, inp, corr, flow): motion_features = self.encoder(flow, corr) inp = torch.cat([inp, motion_features], dim=1) net = self.gru(net, inp) delta_flow = self.flow_head(net) return net, None, delta_flow class BasicUpdateBlock(nn.Module): """Basic update block for MegaSaM.""" def __init__(self, args, hidden_dim=128, input_dim=128): super(BasicUpdateBlock, self).__init__() self.args = args self.encoder = BasicMotionEncoder(args) self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim) self.flow_head = FlowHead(hidden_dim, hidden_dim=256) self.mask = nn.Sequential( nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 64 * 9, 1, padding=0), ) def forward(self, net, inp, corr, flow, upsample=True): motion_features = self.encoder(flow, corr) inp = torch.cat([inp, motion_features], dim=1) net = self.gru(net, inp) delta_flow = self.flow_head(net) # scale mask to balence gradients mask = 0.25 * self.mask(net) return net, mask, delta_flow ================================================ FILE: camera_pose_annotation/cvd_opt/preprocess/core/utils/__init__.py ================================================ ================================================ FILE: camera_pose_annotation/cvd_opt/preprocess/core/utils/augmentor.py ================================================ # Copyright 2025 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Augmentation utils for MegaSaM.""" # pylint: disable=g-import-not-at-top # pylint: disable=g-importing-member import cv2 import numpy as np from PIL import Image cv2.setNumThreads(0) cv2.ocl.setUseOpenCL(False) from torchvision.transforms import ColorJitter class FlowAugmentor: """Augmentation for flow for MegaSaM.""" def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): # spatial augmentation params self.crop_size = crop_size self.min_scale = min_scale self.max_scale = max_scale self.spatial_aug_prob = 0.8 self.stretch_prob = 0.8 self.max_stretch = 0.2 # flip augmentation params self.do_flip = do_flip self.h_flip_prob = 0.5 self.v_flip_prob = 0.1 # photometric augmentation params self.photo_aug = ColorJitter( brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5 / 3.14 ) self.asymmetric_color_aug_prob = 0.2 self.eraser_aug_prob = 0.5 def color_transform(self, img1, img2): """Photometric augmentation.""" # asymmetric if np.random.rand() < self.asymmetric_color_aug_prob: img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) # symmetric else: image_stack = np.concatenate([img1, img2], axis=0) image_stack = np.array( self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8 ) img1, img2 = np.split(image_stack, 2, axis=0) return img1, img2 def eraser_transform(self, img1, img2, bounds=[50, 100]): # pylint: disable=dangerous-default-value """Occlusion augmentation.""" ht, wd = img1.shape[:2] if np.random.rand() < self.eraser_aug_prob: mean_color = np.mean(img2.reshape(-1, 3), axis=0) for _ in range(np.random.randint(1, 3)): x0 = np.random.randint(0, wd) y0 = np.random.randint(0, ht) dx = np.random.randint(bounds[0], bounds[1]) dy = np.random.randint(bounds[0], bounds[1]) img2[y0 : y0 + dy, x0 : x0 + dx, :] = mean_color return img1, img2 def spatial_transform(self, img1, img2, flow): """Spatial augmentation.""" # randomly sample scale ht, wd = img1.shape[:2] min_scale = np.maximum( (self.crop_size[0] + 8) / float(ht), (self.crop_size[1] + 8) / float(wd) ) scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) scale_x = scale scale_y = scale if np.random.rand() < self.stretch_prob: scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) scale_x = np.clip(scale_x, min_scale, None) scale_y = np.clip(scale_y, min_scale, None) if np.random.rand() < self.spatial_aug_prob: # rescale the images img1 = cv2.resize( img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR ) img2 = cv2.resize( img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR ) flow = cv2.resize( flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR ) flow = flow * [scale_x, scale_y] if self.do_flip: if np.random.rand() < self.h_flip_prob: # h-flip img1 = img1[:, ::-1] img2 = img2[:, ::-1] flow = flow[:, ::-1] * [-1.0, 1.0] if np.random.rand() < self.v_flip_prob: # v-flip img1 = img1[::-1, :] img2 = img2[::-1, :] flow = flow[::-1, :] * [1.0, -1.0] y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] return img1, img2, flow def __call__(self, img1, img2, flow): img1, img2 = self.color_transform(img1, img2) img1, img2 = self.eraser_transform(img1, img2) img1, img2, flow = self.spatial_transform(img1, img2, flow) img1 = np.ascontiguousarray(img1) img2 = np.ascontiguousarray(img2) flow = np.ascontiguousarray(flow) return img1, img2, flow class SparseFlowAugmentor: """Augmentation for sparse flow for MegaSaM.""" def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): # spatial augmentation params self.crop_size = crop_size self.min_scale = min_scale self.max_scale = max_scale self.spatial_aug_prob = 0.8 self.stretch_prob = 0.8 self.max_stretch = 0.2 # flip augmentation params self.do_flip = do_flip self.h_flip_prob = 0.5 self.v_flip_prob = 0.1 # photometric augmentation params self.photo_aug = ColorJitter( brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3 / 3.14 ) self.asymmetric_color_aug_prob = 0.2 self.eraser_aug_prob = 0.5 def color_transform(self, img1, img2): image_stack = np.concatenate([img1, img2], axis=0) image_stack = np.array( self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8 ) img1, img2 = np.split(image_stack, 2, axis=0) return img1, img2 def eraser_transform(self, img1, img2): ht, wd = img1.shape[:2] if np.random.rand() < self.eraser_aug_prob: mean_color = np.mean(img2.reshape(-1, 3), axis=0) for _ in range(np.random.randint(1, 3)): x0 = np.random.randint(0, wd) y0 = np.random.randint(0, ht) dx = np.random.randint(50, 100) dy = np.random.randint(50, 100) img2[y0 : y0 + dy, x0 : x0 + dx, :] = mean_color return img1, img2 def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): """Resize sparse flow map.""" ht, wd = flow.shape[:2] coords = np.meshgrid(np.arange(wd), np.arange(ht)) coords = np.stack(coords, axis=-1) coords = coords.reshape(-1, 2).astype(np.float32) flow = flow.reshape(-1, 2).astype(np.float32) valid = valid.reshape(-1).astype(np.float32) coords0 = coords[valid >= 1] flow0 = flow[valid >= 1] ht1 = int(round(ht * fy)) wd1 = int(round(wd * fx)) coords1 = coords0 * [fx, fy] flow1 = flow0 * [fx, fy] xx = np.round(coords1[:, 0]).astype(np.int32) yy = np.round(coords1[:, 1]).astype(np.int32) v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) xx = xx[v] yy = yy[v] flow1 = flow1[v] flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) valid_img = np.zeros([ht1, wd1], dtype=np.int32) flow_img[yy, xx] = flow1 valid_img[yy, xx] = 1 return flow_img, valid_img def spatial_transform(self, img1, img2, flow, valid): """Randomly sample scale and apply it to images and flow map.""" ht, wd = img1.shape[:2] min_scale = np.maximum( (self.crop_size[0] + 1) / float(ht), (self.crop_size[1] + 1) / float(wd) ) scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) scale_x = np.clip(scale, min_scale, None) scale_y = np.clip(scale, min_scale, None) if np.random.rand() < self.spatial_aug_prob: # rescale the images img1 = cv2.resize( img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR ) img2 = cv2.resize( img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR ) flow, valid = self.resize_sparse_flow_map( flow, valid, fx=scale_x, fy=scale_y ) if self.do_flip: if np.random.rand() < 0.5: # h-flip img1 = img1[:, ::-1] img2 = img2[:, ::-1] flow = flow[:, ::-1] * [-1.0, 1.0] valid = valid[:, ::-1] margin_y = 20 margin_x = 50 y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) x0 = np.random.randint( -margin_x, img1.shape[1] - self.crop_size[1] + margin_x ) y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] valid = valid[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] return img1, img2, flow, valid def __call__(self, img1, img2, flow, valid): img1, img2 = self.color_transform(img1, img2) img1, img2 = self.eraser_transform(img1, img2) img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) img1 = np.ascontiguousarray(img1) img2 = np.ascontiguousarray(img2) flow = np.ascontiguousarray(flow) valid = np.ascontiguousarray(valid) return img1, img2, flow, valid ================================================ FILE: camera_pose_annotation/cvd_opt/preprocess/core/utils/flow_viz.py ================================================ # Copyright 2025 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Flow visualization code. Based on https://github.com/tomrunia/OpticalFlow_Visualization """ import numpy as np def make_colorwheel(): """Generates a color wheel for optical flow visualization. Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf Code follows the original C++ source code of Daniel Scharstein. Code follows the the Matlab source code of Deqing Sun. Returns: np.ndarray: Color wheel """ # pylint: disable=invalid-name RY = 15 YG = 6 GC = 4 CB = 11 BM = 13 MR = 6 ncols = RY + YG + GC + CB + BM + MR colorwheel = np.zeros((ncols, 3)) col = 0 # RY colorwheel[0:RY, 0] = 255 colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) col = col + RY # YG colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) colorwheel[col : col + YG, 1] = 255 col = col + YG # GC colorwheel[col : col + GC, 1] = 255 colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) col = col + GC # CB colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) colorwheel[col : col + CB, 2] = 255 col = col + CB # BM colorwheel[col : col + BM, 2] = 255 colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) col = col + BM # MR colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) colorwheel[col : col + MR, 0] = 255 return colorwheel def flow_uv_to_colors(u, v, convert_to_bgr=False): """Applies the flow color wheel to (possibly clipped) flow components u and v. According to the C++ source code of Daniel Scharstein According to the Matlab source code of Deqing Sun Args: u (np.ndarray): Input horizontal flow of shape [H,W] v (np.ndarray): Input vertical flow of shape [H,W] convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. Returns: np.ndarray: Flow visualization image of shape [H,W,3] """ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) colorwheel = make_colorwheel() # shape [55x3] ncols = colorwheel.shape[0] rad = np.sqrt(np.square(u) + np.square(v)) a = np.arctan2(-v, -u) / np.pi fk = (a + 1) / 2 * (ncols - 1) k0 = np.floor(fk).astype(np.int32) k1 = k0 + 1 k1[k1 == ncols] = 0 f = fk - k0 for i in range(colorwheel.shape[1]): tmp = colorwheel[:, i] col0 = tmp[k0] / 255.0 col1 = tmp[k1] / 255.0 col = (1 - f) * col0 + f * col1 idx = rad <= 1 col[idx] = 1 - rad[idx] * (1 - col[idx]) col[~idx] = col[~idx] * 0.75 # out of range # Note the 2-i => BGR instead of RGB ch_idx = 2 - i if convert_to_bgr else i flow_image[:, :, ch_idx] = np.floor(255 * col) return flow_image def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): """Expects a two dimensional flow image of shape. Args: flow_uv (np.ndarray): Flow UV image of shape [H,W,2] clip_flow (float, optional): Clip maximum of flow values. Defaults to None. convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. Returns: np.ndarray: Flow visualization image of shape [H,W,3] """ assert flow_uv.ndim == 3, 'input flow must have three dimensions' assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' if clip_flow is not None: flow_uv = np.clip(flow_uv, 0, clip_flow) u = flow_uv[:, :, 0] v = flow_uv[:, :, 1] rad = np.sqrt(np.square(u) + np.square(v)) rad_max = np.max(rad) epsilon = 1e-5 u = u / (rad_max + epsilon) v = v / (rad_max + epsilon) return flow_uv_to_colors(u, v, convert_to_bgr) ================================================ FILE: camera_pose_annotation/cvd_opt/preprocess/core/utils/frame_utils.py ================================================ # Copyright 2025 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Frame utils for MegaSaM.""" # pylint: disable=invalid-name # pylint: disable=g-doc-args # pylint: disable=broad-exception-raised import os import re import cv2 import numpy as np from PIL import Image cv2.setNumThreads(0) cv2.ocl.setUseOpenCL(False) TAG_CHAR = np.array([202021.25], np.float32) def readFlow(fn): """Read .flo file in Middlebury format.""" # Code adapted from: # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy # WARNING: this will work on little-endian architectures (eg Intel x86) only! # print 'fn = %s'%(fn) with open(fn, 'rb') as f: magic = np.fromfile(f, np.float32, count=1) if 202021.25 != magic: print('Magic number incorrect. Invalid .flo file') return None else: w = np.fromfile(f, np.int32, count=1) h = np.fromfile(f, np.int32, count=1) # print 'Reading %d x %d flo file\n' % (w, h) data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) # Reshape data into 3D array (columns, rows, bands) # The reshape here is for visualization, the original code is (w,h,2) return np.resize(data, (int(h), int(w), 2)) def readPFM(file): """Read PFM file.""" file = open(file, 'rb') header = file.readline().rstrip() if header == b'PF': color = True elif header == b'Pf': color = False else: raise Exception('Not a PFM file.') dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) if dim_match: width, height = map(int, dim_match.groups()) else: raise Exception('Malformed PFM header.') scale = float(file.readline().rstrip()) if scale < 0: # little-endian endian = '<' else: endian = '>' # big-endian data = np.fromfile(file, endian + 'f') shape = (height, width, 3) if color else (height, width) data = np.reshape(data, shape) data = np.flipud(data) return data def writeFlow(filename, uv, v=None): """Write optical flow to file. If v is None, uv is assumed to contain both u and v channels, stacked in depth. Original code by Deqing Sun, adapted from Daniel Scharstein. """ nBands = 2 if v is None: assert uv.ndim == 3 assert uv.shape[2] == 2 u = uv[:, :, 0] v = uv[:, :, 1] else: u = uv assert u.shape == v.shape height, width = u.shape f = open(filename, 'wb') # write the header f.write(TAG_CHAR) np.array(width).astype(np.int32).tofile(f) np.array(height).astype(np.int32).tofile(f) # arrange into matrix form tmp = np.zeros((height, width * nBands)) tmp[:, np.arange(width) * 2] = u tmp[:, np.arange(width) * 2 + 1] = v tmp.astype(np.float32).tofile(f) f.close() def readFlowKITTI(filename): flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) flow = flow[:, :, ::-1].astype(np.float32) flow, valid = flow[:, :, :2], flow[:, :, 2] flow = (flow - 2**15) / 64.0 return flow, valid def readDispKITTI(filename): disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 valid = disp > 0.0 flow = np.stack([-disp, np.zeros_like(disp)], -1) return flow, valid def writeFlowKITTI(filename, uv): uv = 64.0 * uv + 2**15 valid = np.ones([uv.shape[0], uv.shape[1], 1]) uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) cv2.imwrite(filename, uv[..., ::-1]) def read_gen(file_name, pil=False): """Read image or flow file.""" del pil ext = os.path.splitext(file_name)[-1] if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': return Image.open(file_name) elif ext == '.bin' or ext == '.raw': return np.load(file_name) elif ext == '.flo': return readFlow(file_name).astype(np.float32) # pylint: disable=attribute-error elif ext == '.pfm': flow = readPFM(file_name).astype(np.float32) if len(flow.shape) == 2: return flow else: return flow[:, :, :-1] return [] ================================================ FILE: camera_pose_annotation/cvd_opt/preprocess/core/utils/utils.py ================================================ # Copyright 2025 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utility functions for MegaSaM.""" # pylint: disable=invalid-name import numpy as np from scipy import interpolate import torch import torch.nn.functional as F class InputPadder: """Pads images such that dimensions are divisible by 8.""" def __init__(self, dims, mode='sintel'): self.ht, self.wd = dims[-2:] pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 if mode == 'sintel': self._pad = [ pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2, ] else: self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] def pad(self, *inputs): return [F.pad(x, self._pad, mode='replicate') for x in inputs] def unpad(self, x): ht, wd = x.shape[-2:] c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] return x[..., c[0] : c[1], c[2] : c[3]] def forward_interpolate(flow): """Interpolate flow map to match the original image size.""" flow = flow.detach().cpu().numpy() dx, dy = flow[0], flow[1] ht, wd = dx.shape x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) x1 = x0 + dx y1 = y0 + dy x1 = x1.reshape(-1) y1 = y1.reshape(-1) dx = dx.reshape(-1) dy = dy.reshape(-1) valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) x1 = x1[valid] y1 = y1[valid] dx = dx[valid] dy = dy[valid] flow_x = interpolate.griddata( (x1, y1), dx, (x0, y0), method='nearest', fill_value=0 ) flow_y = interpolate.griddata( (x1, y1), dy, (x0, y0), method='nearest', fill_value=0 ) flow = np.stack([flow_x, flow_y], axis=0) return torch.from_numpy(flow).float() def bilinear_sampler(img, coords, mode='bilinear', mask=False): """Wrapper for grid_sample, uses pixel coordinates.""" del mode H, W = img.shape[-2:] xgrid, ygrid = coords.split([1, 1], dim=-1) xgrid = 2 * xgrid / (W - 1) - 1 ygrid = 2 * ygrid / (H - 1) - 1 grid = torch.cat([xgrid, ygrid], dim=-1) img = F.grid_sample(img, grid, align_corners=True) if mask: mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) return img, mask.float() return img def coords_grid(batch, ht, wd): coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) coords = torch.stack(coords[::-1], dim=0).float() return coords[None].repeat(batch, 1, 1, 1) def upflow8(flow, mode='bilinear'): new_size = (8 * flow.shape[2], 8 * flow.shape[3]) return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) ================================================ FILE: camera_pose_annotation/cvd_opt/preprocess/inference_batch.py ================================================ """ Batch inference script for optical flow preprocessing using RAFT model. Processes multiple video clips in parallel to generate optical flow data for CVD optimization. """ import pandas as pd import os import argparse import concurrent.futures from multiprocessing import Manager import subprocess import queue from tqdm import tqdm def process_single_row(row, index, args, worker_id=0): """Process a single video clip for optical flow generation.""" dir_path = os.path.join(args.dir_path, row["id"]) device_id = worker_id % args.gpu_num # Build command for optical flow preprocessing with RAFT model cmd = ( f"CUDA_VISIBLE_DEVICES={args.gpu_id[device_id]} python camera_pose_annotation/cvd_opt/preprocess/preprocess_flow.py " f"--dir_path {dir_path} " f"--model {args.checkpoints_path}/raft-things.pth " f"--mixed_precision" ) process = subprocess.Popen( cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) stdout, stderr = process.communicate() if process.returncode != 0: print(f"Error generating optical flow for {row['id']}: {stderr.decode()}") def worker(task_queue, args, worker_id, pbar): """Worker function for parallel optical flow preprocessing.""" while True: try: index, row = task_queue.get(timeout=1) except queue.Empty: break process_single_row(row, index, args, worker_id) task_queue.task_done() pbar.update(1) def parse_args(): """Parse command line arguments for optical flow preprocessing.""" parser = argparse.ArgumentParser() parser.add_argument("--csv_path", type=str, help="Path to the csv file") parser.add_argument("--dir_path", type=str, default="./outputs") parser.add_argument("--checkpoints_path", type=str, default="./checkpoints") parser.add_argument( "--gpu_id", type=str, default="0", help="Comma-separated list of GPU IDs to use" ) parser.add_argument( "--num_workers", type=int, default=4, help="Number of workers for parallel processing", ) parser.add_argument( "--disable_parallel", action="store_true", help="Disable parallel processing" ) return parser.parse_args() def main(): args = parse_args() # Parse GPU configuration args.gpu_num = len(args.gpu_id.split(",")) args.gpu_id = [int(gpu) for gpu in args.gpu_id.split(",")] df = pd.read_csv(args.csv_path) if args.disable_parallel: # Sequential processing for index, row in tqdm(df.iterrows(), total=len(df), desc="Processing rows"): process_single_row(row, index, args) else: # Parallel processing with multiple workers manager = Manager() task_queue = manager.Queue() for index, row in df.iterrows(): task_queue.put((index, row)) with tqdm(total=len(df), desc="Processing rows") as pbar: with concurrent.futures.ThreadPoolExecutor( max_workers=args.num_workers ) as executor: futures = [] for id in range(args.num_workers): futures.append(executor.submit(worker, task_queue, args, id, pbar)) for future in concurrent.futures.as_completed(futures): future.result() if __name__ == "__main__": main() ================================================ FILE: camera_pose_annotation/cvd_opt/preprocess/preprocess_flow.py ================================================ # Copyright 2025 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Preprocess flow for MegaSaM.""" import cv2 import tqdm import argparse from pathlib import Path # pylint: disable=g-importing-member from core.utils.utils import InputPadder from core.raft import RAFT import glob import os import sys import numpy as np import torch def warp_flow(img, flow): h, w = flow.shape[:2] flow_new = flow.copy() flow_new[:, :, 0] += np.arange(w) flow_new[:, :, 1] += np.arange(h)[:, np.newaxis] res = cv2.remap( img, flow_new, None, cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT ) return res def resize_flow(flow, img_h, img_w): # flow = np.load(flow_path) flow_h, flow_w = flow.shape[0], flow.shape[1] flow[:, :, 0] *= float(img_w) / float(flow_w) flow[:, :, 1] *= float(img_h) / float(flow_h) flow = cv2.resize(flow, (img_w, img_h), cv2.INTER_LINEAR) return flow def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--model", default="raft-things.pth", help="restore checkpoint") parser.add_argument("--small", action="store_true", help="use small model") parser.add_argument("--dir_path", help="dataset for evaluation") parser.add_argument( "--num_heads", default=1, type=int, help="number of heads in attention and aggregation", ) parser.add_argument( "--position_only", default=False, action="store_true", help="only use position-wise attention", ) parser.add_argument( "--position_and_content", default=False, action="store_true", help="use position and content-wise attention", ) parser.add_argument( "--mixed_precision", action="store_true", help="use mixed precision" ) return parser.parse_args() if __name__ == "__main__": args = parse_args() model = torch.nn.DataParallel(RAFT(args)) model.load_state_dict(torch.load(args.model)) flow_model = model.module device = torch.device("cuda" if torch.cuda.is_available() else "cpu") flow_model.to(device).eval() img_path = os.path.join(args.dir_path, "img") image_list = sorted(glob.glob(os.path.join(img_path, "*.png"))) # [::stride] image_list += sorted(glob.glob(os.path.join(img_path, "*.jpg"))) # [::stride] img_data = [] for t, (image_file) in tqdm.tqdm(enumerate(image_list)): image = cv2.imread(image_file)[..., ::-1] # rgb h0, w0, _ = image.shape h1 = int(h0 * np.sqrt((384 * 512) / (h0 * w0))) w1 = int(w0 * np.sqrt((384 * 512) / (h0 * w0))) image = cv2.resize(image, (w1, h1)) image = image[: h1 - h1 % 8, : w1 - w1 % 8].transpose(2, 0, 1) img_data.append(image) img_data = np.array(img_data) flows_low = [] flows_high = [] flow_masks_high = [] flow_init = None flows_arr_low_bwd = {} flows_arr_low_fwd = {} ii = [] jj = [] flows_arr_up = [] masks_arr_up = [] for step in [1, 2, 4, 8, 15]: flows_arr_low = [] for i in tqdm.tqdm(range(max(0, -step), img_data.shape[0] - max(0, step))): image1 = ( torch.as_tensor(np.ascontiguousarray(img_data[i : i + 1])) .float() .cuda() ) image2 = ( torch.as_tensor(np.ascontiguousarray(img_data[i + step : i + step + 1])) .float() .cuda() ) ii.append(i) jj.append(i + step) with torch.no_grad(): padder = InputPadder(image1.shape) image1, image2 = padder.pad(image1, image2) if np.abs(step) > 1: flow_init = np.stack( [flows_arr_low_fwd[i], flows_arr_low_bwd[i + step]], axis=0 ) flow_init = ( torch.as_tensor(np.ascontiguousarray(flow_init)) .float() .cuda() .permute(0, 3, 1, 2) ) else: flow_init = None flow_low, flow_up, _ = flow_model( torch.cat([image1, image2], dim=0), torch.cat([image2, image1], dim=0), iters=22, test_mode=True, flow_init=flow_init, ) flow_low_fwd = flow_low[0].cpu().numpy().transpose(1, 2, 0) flow_low_bwd = flow_low[1].cpu().numpy().transpose(1, 2, 0) flow_up_fwd = resize_flow( flow_up[0].cpu().numpy().transpose(1, 2, 0), flow_up.shape[-2] // 2, flow_up.shape[-1] // 2, ) flow_up_bwd = resize_flow( flow_up[1].cpu().numpy().transpose(1, 2, 0), flow_up.shape[-2] // 2, flow_up.shape[-1] // 2, ) bwd2fwd_flow = warp_flow(flow_up_bwd, flow_up_fwd) fwd_lr_error = np.linalg.norm(flow_up_fwd + bwd2fwd_flow, axis=-1) fwd_mask_up = fwd_lr_error < 1.0 # flows_arr_low.append(flow_low_fwd) flows_arr_low_bwd[i + step] = flow_low_bwd flows_arr_low_fwd[i] = flow_low_fwd # masks_arr_low.append(fwd_mask_low) flows_arr_up.append(flow_up_fwd) masks_arr_up.append(fwd_mask_up) iijj = np.stack((ii, jj), axis=0) flows_high = np.array(flows_arr_up).transpose(0, 3, 1, 2) flow_masks_high = np.array(masks_arr_up)[:, None, ...] output_path = os.path.join(args.dir_path, "cache-flow") if not os.path.exists(output_path): os.makedirs(output_path) np.save(os.path.join(output_path, "flows.npy"), np.float16(flows_high)) np.save(os.path.join(output_path, "flows_masks.npy"), flow_masks_high) np.save(os.path.join(output_path, "ii-jj.npy"), iijj) ================================================ FILE: camera_pose_annotation/depth_estimation/Depth-Anything/__init__.py ================================================ ================================================ FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/dinov2.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py from functools import partial import math import logging from typing import Sequence, Tuple, Union, Callable import torch import torch.nn as nn import torch.utils.checkpoint from torch.nn.init import trunc_normal_ from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block logger = logging.getLogger("dinov2") def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: if not depth_first and include_root: fn(module=module, name=name) for child_name, child_module in module.named_children(): child_name = ".".join((name, child_name)) if name else child_name named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) if depth_first and include_root: fn(module=module, name=name) return module class BlockChunk(nn.ModuleList): def forward(self, x): for b in self: x = b(x) return x class DinoVisionTransformer(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, ffn_bias=True, proj_bias=True, drop_path_rate=0.0, drop_path_uniform=False, init_values=None, # for layerscale: None or 0 => no layerscale embed_layer=PatchEmbed, act_layer=nn.GELU, block_fn=Block, ffn_layer="mlp", block_chunks=1, num_register_tokens=0, interpolate_antialias=False, interpolate_offset=0.1, ): """ Args: img_size (int, tuple): input image size patch_size (int, tuple): patch size in_chans (int): number of input channels embed_dim (int): embedding dimension depth (int): depth of transformer num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True proj_bias (bool): enable bias for proj in attn if True ffn_bias (bool): enable bias for ffn if True drop_path_rate (float): stochastic depth rate drop_path_uniform (bool): apply uniform drop rate across blocks weight_init (str): weight init scheme init_values (float): layer-scale init values embed_layer (nn.Module): patch embedding layer act_layer (nn.Module): MLP activation layer block_fn (nn.Module): transformer block class ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" block_chunks: (int) split block sequence into block_chunks units for FSDP wrap num_register_tokens: (int) number of extra cls tokens (so-called "registers") interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings """ super().__init__() norm_layer = partial(nn.LayerNorm, eps=1e-6) self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_tokens = 1 self.n_blocks = depth self.num_heads = num_heads self.patch_size = patch_size self.num_register_tokens = num_register_tokens self.interpolate_antialias = interpolate_antialias self.interpolate_offset = interpolate_offset self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) assert num_register_tokens >= 0 self.register_tokens = ( nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None ) if drop_path_uniform is True: dpr = [drop_path_rate] * depth else: dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule if ffn_layer == "mlp": logger.info("using MLP layer as FFN") ffn_layer = Mlp elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": logger.info("using SwiGLU layer as FFN") ffn_layer = SwiGLUFFNFused elif ffn_layer == "identity": logger.info("using Identity layer as FFN") def f(*args, **kwargs): return nn.Identity() ffn_layer = f else: raise NotImplementedError blocks_list = [ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, ffn_layer=ffn_layer, init_values=init_values, ) for i in range(depth) ] if block_chunks > 0: self.chunked_blocks = True chunked_blocks = [] chunksize = depth // block_chunks for i in range(0, depth, chunksize): # this is to keep the block index consistent if we chunk the block list chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) else: self.chunked_blocks = False self.blocks = nn.ModuleList(blocks_list) self.norm = norm_layer(embed_dim) self.head = nn.Identity() self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) self.init_weights() def init_weights(self): trunc_normal_(self.pos_embed, std=0.02) nn.init.normal_(self.cls_token, std=1e-6) if self.register_tokens is not None: nn.init.normal_(self.register_tokens, std=1e-6) named_apply(init_weights_vit_timm, self) def interpolate_pos_encoding(self, x, w, h): previous_dtype = x.dtype npatch = x.shape[1] - 1 N = self.pos_embed.shape[1] - 1 if npatch == N and w == h: return self.pos_embed pos_embed = self.pos_embed.float() class_pos_embed = pos_embed[:, 0] patch_pos_embed = pos_embed[:, 1:] dim = x.shape[-1] w0 = w // self.patch_size h0 = h // self.patch_size # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0 w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset # w0, h0 = w0 + 0.1, h0 + 0.1 sqrt_N = math.sqrt(N) sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), scale_factor=(sx, sy), # (int(w0), int(h0)), # to solve the upsampling shape issue mode="bicubic", antialias=self.interpolate_antialias ) assert int(w0) == patch_pos_embed.shape[-2] assert int(h0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) def prepare_tokens_with_masks(self, x, masks=None): B, nc, w, h = x.shape x = self.patch_embed(x) if masks is not None: x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = x + self.interpolate_pos_encoding(x, w, h) if self.register_tokens is not None: x = torch.cat( ( x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:], ), dim=1, ) return x def forward_features_list(self, x_list, masks_list): x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] for blk in self.blocks: x = blk(x) all_x = x output = [] for x, masks in zip(all_x, masks_list): x_norm = self.norm(x) output.append( { "x_norm_clstoken": x_norm[:, 0], "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], "x_prenorm": x, "masks": masks, } ) return output def forward_features(self, x, masks=None): if isinstance(x, list): return self.forward_features_list(x, masks) x = self.prepare_tokens_with_masks(x, masks) for blk in self.blocks: x = blk(x) x_norm = self.norm(x) return { "x_norm_clstoken": x_norm[:, 0], "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], "x_prenorm": x, "masks": masks, } def _get_intermediate_layers_not_chunked(self, x, n=1): x = self.prepare_tokens_with_masks(x) # If n is an int, take the n last blocks. If it's a list, take them output, total_block_len = [], len(self.blocks) blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n for i, blk in enumerate(self.blocks): x = blk(x) if i in blocks_to_take: output.append(x) assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" return output def _get_intermediate_layers_chunked(self, x, n=1): x = self.prepare_tokens_with_masks(x) output, i, total_block_len = [], 0, len(self.blocks[-1]) # If n is an int, take the n last blocks. If it's a list, take them blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n for block_chunk in self.blocks: for blk in block_chunk[i:]: # Passing the nn.Identity() x = blk(x) if i in blocks_to_take: output.append(x) i += 1 assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" return output def get_intermediate_layers( self, x: torch.Tensor, n: Union[int, Sequence] = 1, # Layers or n last layers to take reshape: bool = False, return_class_token: bool = False, norm=True ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: if self.chunked_blocks: outputs = self._get_intermediate_layers_chunked(x, n) else: outputs = self._get_intermediate_layers_not_chunked(x, n) if norm: outputs = [self.norm(out) for out in outputs] class_tokens = [out[:, 0] for out in outputs] outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] if reshape: B, _, w, h = x.shape outputs = [ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() for out in outputs ] if return_class_token: return tuple(zip(outputs, class_tokens)) return tuple(outputs) def forward(self, *args, is_training=False, **kwargs): ret = self.forward_features(*args, **kwargs) if is_training: return ret else: return self.head(ret["x_norm_clstoken"]) def init_weights_vit_timm(module: nn.Module, name: str = ""): """ViT weight initialization, original timm impl (for reproducibility)""" if isinstance(module, nn.Linear): trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) def vit_small(patch_size=16, num_register_tokens=0, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, **kwargs, ) return model def vit_base(patch_size=16, num_register_tokens=0, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, **kwargs, ) return model def vit_large(patch_size=16, num_register_tokens=0, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, **kwargs, ) return model def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): """ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 """ model = DinoVisionTransformer( patch_size=patch_size, embed_dim=1536, depth=40, num_heads=24, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, **kwargs, ) return model def DINOv2(model_name): model_zoo = { "vits": vit_small, "vitb": vit_base, "vitl": vit_large, "vitg": vit_giant2 } return model_zoo[model_name]( img_size=518, patch_size=14, init_values=1.0, ffn_layer="mlp" if model_name != "vitg" else "swiglufused", block_chunks=0, num_register_tokens=0, interpolate_antialias=False, interpolate_offset=0.1 ) ================================================ FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/dinov2_layers/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from .mlp import Mlp from .patch_embed import PatchEmbed from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused from .block import NestedTensorBlock from .attention import MemEffAttention ================================================ FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/dinov2_layers/attention.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py import logging from torch import Tensor from torch import nn logger = logging.getLogger("dinov2") try: from xformers.ops import memory_efficient_attention, unbind, fmha XFORMERS_AVAILABLE = True except ImportError: logger.warning("xFormers not available") XFORMERS_AVAILABLE = False class Attention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0, ) -> None: super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: Tensor) -> Tensor: B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] attn = q @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class MemEffAttention(Attention): def forward(self, x: Tensor, attn_bias=None) -> Tensor: if not XFORMERS_AVAILABLE: assert attn_bias is None, "xFormers is required for nested tensors usage" return super().forward(x) B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) q, k, v = unbind(qkv, 2) x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) x = x.reshape([B, N, C]) x = self.proj(x) x = self.proj_drop(x) return x ================================================ FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/dinov2_layers/block.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py import logging from typing import Callable, List, Any, Tuple, Dict import torch from torch import nn, Tensor from .attention import Attention, MemEffAttention from .drop_path import DropPath from .layer_scale import LayerScale from .mlp import Mlp logger = logging.getLogger("dinov2") try: from xformers.ops import fmha from xformers.ops import scaled_index_add, index_select_cat XFORMERS_AVAILABLE = True except ImportError: logger.warning("xFormers not available") XFORMERS_AVAILABLE = False class Block(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = False, proj_bias: bool = True, ffn_bias: bool = True, drop: float = 0.0, attn_drop: float = 0.0, init_values=None, drop_path: float = 0.0, act_layer: Callable[..., nn.Module] = nn.GELU, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, attn_class: Callable[..., nn.Module] = Attention, ffn_layer: Callable[..., nn.Module] = Mlp, ) -> None: super().__init__() # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") self.norm1 = norm_layer(dim) self.attn = attn_class( dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, ) self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = ffn_layer( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias, ) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.sample_drop_ratio = drop_path def forward(self, x: Tensor) -> Tensor: def attn_residual_func(x: Tensor) -> Tensor: return self.ls1(self.attn(self.norm1(x))) def ffn_residual_func(x: Tensor) -> Tensor: return self.ls2(self.mlp(self.norm2(x))) if self.training and self.sample_drop_ratio > 0.1: # the overhead is compensated only for a drop path rate larger than 0.1 x = drop_add_residual_stochastic_depth( x, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio, ) x = drop_add_residual_stochastic_depth( x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio, ) elif self.training and self.sample_drop_ratio > 0.0: x = x + self.drop_path1(attn_residual_func(x)) x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 else: x = x + attn_residual_func(x) x = x + ffn_residual_func(x) return x def drop_add_residual_stochastic_depth( x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, ) -> Tensor: # 1) extract subset using permutation b, n, d = x.shape sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) brange = (torch.randperm(b, device=x.device))[:sample_subset_size] x_subset = x[brange] # 2) apply residual_func to get residual residual = residual_func(x_subset) x_flat = x.flatten(1) residual = residual.flatten(1) residual_scale_factor = b / sample_subset_size # 3) add the residual x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) return x_plus_residual.view_as(x) def get_branges_scales(x, sample_drop_ratio=0.0): b, n, d = x.shape sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) brange = (torch.randperm(b, device=x.device))[:sample_subset_size] residual_scale_factor = b / sample_subset_size return brange, residual_scale_factor def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): if scaling_vector is None: x_flat = x.flatten(1) residual = residual.flatten(1) x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) else: x_plus_residual = scaled_index_add( x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor ) return x_plus_residual attn_bias_cache: Dict[Tuple, Any] = {} def get_attn_bias_and_cat(x_list, branges=None): """ this will perform the index select, cat the tensors, and provide the attn_bias from cache """ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) if all_shapes not in attn_bias_cache.keys(): seqlens = [] for b, x in zip(batch_sizes, x_list): for _ in range(b): seqlens.append(x.shape[1]) attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) attn_bias._batch_sizes = batch_sizes attn_bias_cache[all_shapes] = attn_bias if branges is not None: cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) else: tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) cat_tensors = torch.cat(tensors_bs1, dim=1) return attn_bias_cache[all_shapes], cat_tensors def drop_add_residual_stochastic_depth_list( x_list: List[Tensor], residual_func: Callable[[Tensor, Any], Tensor], sample_drop_ratio: float = 0.0, scaling_vector=None, ) -> Tensor: # 1) generate random set of indices for dropping samples in the batch branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] branges = [s[0] for s in branges_scales] residual_scale_factors = [s[1] for s in branges_scales] # 2) get attention bias and index+concat the tensors attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) # 3) apply residual_func to get residual, and split the result residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore outputs = [] for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) return outputs class NestedTensorBlock(Block): def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: """ x_list contains a list of tensors to nest together and run """ assert isinstance(self.attn, MemEffAttention) if self.training and self.sample_drop_ratio > 0.0: def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: return self.attn(self.norm1(x), attn_bias=attn_bias) def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: return self.mlp(self.norm2(x)) x_list = drop_add_residual_stochastic_depth_list( x_list, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio, scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, ) x_list = drop_add_residual_stochastic_depth_list( x_list, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio, scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, ) return x_list else: def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: return self.ls2(self.mlp(self.norm2(x))) attn_bias, x = get_attn_bias_and_cat(x_list) x = x + attn_residual_func(x, attn_bias=attn_bias) x = x + ffn_residual_func(x) return attn_bias.split(x) def forward(self, x_or_x_list): if isinstance(x_or_x_list, Tensor): return super().forward(x_or_x_list) elif isinstance(x_or_x_list, list): assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" return self.forward_nested(x_or_x_list) else: raise AssertionError ================================================ FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/dinov2_layers/drop_path.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py from torch import nn def drop_path(x, drop_prob: float = 0.0, training: bool = False): if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0: random_tensor.div_(keep_prob) output = x * random_tensor return output class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) ================================================ FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/dinov2_layers/layer_scale.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 from typing import Union import torch from torch import Tensor from torch import nn class LayerScale(nn.Module): def __init__( self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False, ) -> None: super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x: Tensor) -> Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma ================================================ FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/dinov2_layers/mlp.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py from typing import Callable, Optional from torch import Tensor, nn class Mlp(nn.Module): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = nn.GELU, drop: float = 0.0, bias: bool = True, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) self.drop = nn.Dropout(drop) def forward(self, x: Tensor) -> Tensor: x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x ================================================ FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/dinov2_layers/patch_embed.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py from typing import Callable, Optional, Tuple, Union from torch import Tensor import torch.nn as nn def make_2tuple(x): if isinstance(x, tuple): assert len(x) == 2 return x assert isinstance(x, int) return (x, x) class PatchEmbed(nn.Module): """ 2D image to patch embedding: (B,C,H,W) -> (B,N,D) Args: img_size: Image size. patch_size: Patch token size. in_chans: Number of input image channels. embed_dim: Number of linear projection output channels. norm_layer: Normalization layer. """ def __init__( self, img_size: Union[int, Tuple[int, int]] = 224, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, embed_dim: int = 768, norm_layer: Optional[Callable] = None, flatten_embedding: bool = True, ) -> None: super().__init__() image_HW = make_2tuple(img_size) patch_HW = make_2tuple(patch_size) patch_grid_size = ( image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1], ) self.img_size = image_HW self.patch_size = patch_HW self.patches_resolution = patch_grid_size self.num_patches = patch_grid_size[0] * patch_grid_size[1] self.in_chans = in_chans self.embed_dim = embed_dim self.flatten_embedding = flatten_embedding self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x: Tensor) -> Tensor: _, _, H, W = x.shape patch_H, patch_W = self.patch_size assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" x = self.proj(x) # B C H W H, W = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) # B HW C x = self.norm(x) if not self.flatten_embedding: x = x.reshape(-1, H, W, self.embed_dim) # B H W C return x def flops(self) -> float: Ho, Wo = self.patches_resolution flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops ================================================ FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/dinov2_layers/swiglu_ffn.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Callable, Optional from torch import Tensor, nn import torch.nn.functional as F class SwiGLUFFN(nn.Module): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = None, drop: float = 0.0, bias: bool = True, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) self.w3 = nn.Linear(hidden_features, out_features, bias=bias) def forward(self, x: Tensor) -> Tensor: x12 = self.w12(x) x1, x2 = x12.chunk(2, dim=-1) hidden = F.silu(x1) * x2 return self.w3(hidden) try: from xformers.ops import SwiGLU XFORMERS_AVAILABLE = True except ImportError: SwiGLU = SwiGLUFFN XFORMERS_AVAILABLE = False class SwiGLUFFNFused(SwiGLU): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = None, drop: float = 0.0, bias: bool = True, ) -> None: out_features = out_features or in_features hidden_features = hidden_features or in_features hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 super().__init__( in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias, ) ================================================ FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/dpt.py ================================================ import cv2 import torch import torch.nn as nn import torch.nn.functional as F from torchvision.transforms import Compose from .dinov2 import DINOv2 from .util.blocks import FeatureFusionBlock, _make_scratch from .util.transform import Resize, NormalizeImage, PrepareForNet def _make_fusion_block(features, use_bn, size=None): return FeatureFusionBlock( features, nn.ReLU(False), deconv=False, bn=use_bn, expand=False, align_corners=True, size=size, ) class ConvBlock(nn.Module): def __init__(self, in_feature, out_feature): super().__init__() self.conv_block = nn.Sequential( nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_feature), nn.ReLU(True) ) def forward(self, x): return self.conv_block(x) class DPTHead(nn.Module): def __init__( self, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False ): super(DPTHead, self).__init__() self.use_clstoken = use_clstoken self.projects = nn.ModuleList([ nn.Conv2d( in_channels=in_channels, out_channels=out_channel, kernel_size=1, stride=1, padding=0, ) for out_channel in out_channels ]) self.resize_layers = nn.ModuleList([ nn.ConvTranspose2d( in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0), nn.ConvTranspose2d( in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0), nn.Identity(), nn.Conv2d( in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1) ]) if use_clstoken: self.readout_projects = nn.ModuleList() for _ in range(len(self.projects)): self.readout_projects.append( nn.Sequential( nn.Linear(2 * in_channels, in_channels), nn.GELU())) self.scratch = _make_scratch( out_channels, features, groups=1, expand=False, ) self.scratch.stem_transpose = None self.scratch.refinenet1 = _make_fusion_block(features, use_bn) self.scratch.refinenet2 = _make_fusion_block(features, use_bn) self.scratch.refinenet3 = _make_fusion_block(features, use_bn) self.scratch.refinenet4 = _make_fusion_block(features, use_bn) head_features_1 = features head_features_2 = 32 self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) self.scratch.output_conv2 = nn.Sequential( nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), nn.ReLU(True), nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), nn.ReLU(True), nn.Identity(), ) def forward(self, out_features, patch_h, patch_w): out = [] for i, x in enumerate(out_features): if self.use_clstoken: x, cls_token = x[0], x[1] readout = cls_token.unsqueeze(1).expand_as(x) x = self.readout_projects[i](torch.cat((x, readout), -1)) else: x = x[0] x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) x = self.projects[i](x) x = self.resize_layers[i](x) out.append(x) layer_1, layer_2, layer_3, layer_4 = out layer_1_rn = self.scratch.layer1_rn(layer_1) layer_2_rn = self.scratch.layer2_rn(layer_2) layer_3_rn = self.scratch.layer3_rn(layer_3) layer_4_rn = self.scratch.layer4_rn(layer_4) path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) path_1 = self.scratch.refinenet1(path_2, layer_1_rn) out = self.scratch.output_conv1(path_1) out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) out = self.scratch.output_conv2(out) return out class DepthAnythingV2(nn.Module): def __init__( self, encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], use_bn=False, use_clstoken=False ): super(DepthAnythingV2, self).__init__() self.intermediate_layer_idx = { 'vits': [2, 5, 8, 11], 'vitb': [2, 5, 8, 11], 'vitl': [4, 11, 17, 23], 'vitg': [9, 19, 29, 39] } self.encoder = encoder self.pretrained = DINOv2(model_name=encoder) self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) def forward(self, x): patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14 features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True) depth = self.depth_head(features, patch_h, patch_w) depth = F.relu(depth) return depth.squeeze(1) @torch.no_grad() def infer_image(self, raw_image, input_size=518): image, (h, w) = self.image2tensor(raw_image, input_size) depth = self.forward(image) depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0] return depth.cpu().numpy() def image2tensor(self, raw_image, input_size=518): transform = Compose([ Resize( width=input_size, height=input_size, resize_target=False, keep_aspect_ratio=True, ensure_multiple_of=14, resize_method='lower_bound', image_interpolation_method=cv2.INTER_CUBIC, ), NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), PrepareForNet(), ]) h, w = raw_image.shape[:2] image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0 image = transform({'image': image})['image'] image = torch.from_numpy(image).unsqueeze(0) DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' image = image.to(DEVICE) return image, (h, w) ================================================ FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/util/blocks.py ================================================ import torch.nn as nn def _make_scratch(in_shape, out_shape, groups=1, expand=False): scratch = nn.Module() out_shape1 = out_shape out_shape2 = out_shape out_shape3 = out_shape if len(in_shape) >= 4: out_shape4 = out_shape if expand: out_shape1 = out_shape out_shape2 = out_shape * 2 out_shape3 = out_shape * 4 if len(in_shape) >= 4: out_shape4 = out_shape * 8 scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) if len(in_shape) >= 4: scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) return scratch class ResidualConvUnit(nn.Module): """Residual convolution module. """ def __init__(self, features, activation, bn): """Init. Args: features (int): number of features """ super().__init__() self.bn = bn self.groups=1 self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) if self.bn == True: self.bn1 = nn.BatchNorm2d(features) self.bn2 = nn.BatchNorm2d(features) self.activation = activation self.skip_add = nn.quantized.FloatFunctional() def forward(self, x): """Forward pass. Args: x (tensor): input Returns: tensor: output """ out = self.activation(x) out = self.conv1(out) if self.bn == True: out = self.bn1(out) out = self.activation(out) out = self.conv2(out) if self.bn == True: out = self.bn2(out) if self.groups > 1: out = self.conv_merge(out) return self.skip_add.add(out, x) class FeatureFusionBlock(nn.Module): """Feature fusion block. """ def __init__( self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None ): """Init. Args: features (int): number of features """ super(FeatureFusionBlock, self).__init__() self.deconv = deconv self.align_corners = align_corners self.groups=1 self.expand = expand out_features = features if self.expand == True: out_features = features // 2 self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) self.resConfUnit1 = ResidualConvUnit(features, activation, bn) self.resConfUnit2 = ResidualConvUnit(features, activation, bn) self.skip_add = nn.quantized.FloatFunctional() self.size=size def forward(self, *xs, size=None): """Forward pass. Returns: tensor: output """ output = xs[0] if len(xs) == 2: res = self.resConfUnit1(xs[1]) output = self.skip_add.add(output, res) output = self.resConfUnit2(output) if (size is None) and (self.size is None): modifier = {"scale_factor": 2} elif size is None: modifier = {"size": self.size} else: modifier = {"size": size} output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) output = self.out_conv(output) return output ================================================ FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/util/transform.py ================================================ import numpy as np import cv2 class Resize(object): """Resize sample to given size (width, height). """ def __init__( self, width, height, resize_target=True, keep_aspect_ratio=False, ensure_multiple_of=1, resize_method="lower_bound", image_interpolation_method=cv2.INTER_AREA, ): """Init. Args: width (int): desired output width height (int): desired output height resize_target (bool, optional): True: Resize the full sample (image, mask, target). False: Resize image only. Defaults to True. keep_aspect_ratio (bool, optional): True: Keep the aspect ratio of the input sample. Output sample might not have the given width and height, and resize behaviour depends on the parameter 'resize_method'. Defaults to False. ensure_multiple_of (int, optional): Output width and height is constrained to be multiple of this parameter. Defaults to 1. resize_method (str, optional): "lower_bound": Output will be at least as large as the given size. "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) "minimal": Scale as least as possible. (Output size might be smaller than given size.) Defaults to "lower_bound". """ self.__width = width self.__height = height self.__resize_target = resize_target self.__keep_aspect_ratio = keep_aspect_ratio self.__multiple_of = ensure_multiple_of self.__resize_method = resize_method self.__image_interpolation_method = image_interpolation_method def constrain_to_multiple_of(self, x, min_val=0, max_val=None): y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) if max_val is not None and y > max_val: y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) if y < min_val: y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) return y def get_size(self, width, height): # determine new height and width scale_height = self.__height / height scale_width = self.__width / width if self.__keep_aspect_ratio: if self.__resize_method == "lower_bound": # scale such that output size is lower bound if scale_width > scale_height: # fit width scale_height = scale_width else: # fit height scale_width = scale_height elif self.__resize_method == "upper_bound": # scale such that output size is upper bound if scale_width < scale_height: # fit width scale_height = scale_width else: # fit height scale_width = scale_height elif self.__resize_method == "minimal": # scale as least as possbile if abs(1 - scale_width) < abs(1 - scale_height): # fit width scale_height = scale_width else: # fit height scale_width = scale_height else: raise ValueError(f"resize_method {self.__resize_method} not implemented") if self.__resize_method == "lower_bound": new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height) new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width) elif self.__resize_method == "upper_bound": new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height) new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width) elif self.__resize_method == "minimal": new_height = self.constrain_to_multiple_of(scale_height * height) new_width = self.constrain_to_multiple_of(scale_width * width) else: raise ValueError(f"resize_method {self.__resize_method} not implemented") return (new_width, new_height) def __call__(self, sample): width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0]) # resize sample sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method) if self.__resize_target: if "depth" in sample: sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST) if "mask" in sample: sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST) return sample class NormalizeImage(object): """Normlize image by given mean and std. """ def __init__(self, mean, std): self.__mean = mean self.__std = std def __call__(self, sample): sample["image"] = (sample["image"] - self.__mean) / self.__std return sample class PrepareForNet(object): """Prepare sample for usage as network input. """ def __init__(self): pass def __call__(self, sample): image = np.transpose(sample["image"], (2, 0, 1)) sample["image"] = np.ascontiguousarray(image).astype(np.float32) if "depth" in sample: depth = sample["depth"].astype(np.float32) sample["depth"] = np.ascontiguousarray(depth) if "mask" in sample: sample["mask"] = sample["mask"].astype(np.float32) sample["mask"] = np.ascontiguousarray(sample["mask"]) return sample ================================================ FILE: camera_pose_annotation/depth_estimation/Depth-Anything/inference.py ================================================ """ Single-threaded inference script for Depth-Anything V2 model. Processes images in a directory to generate depth maps sequentially. """ import argparse import cv2 import glob import numpy as np import os import torch from depth_anything_v2.dpt import DepthAnythingV2 # Model configuration for different encoder variants model_configs = { "vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]}, "vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768]}, "vitl": { "encoder": "vitl", "features": 256, "out_channels": [256, 512, 1024, 1024], }, "vitg": { "encoder": "vitg", "features": 384, "out_channels": [1536, 1536, 1536, 1536], }, } def parse_args(): """Parse command line arguments for depth estimation.""" parser = argparse.ArgumentParser(description="Depth Anything V2") parser.add_argument("--input-size", type=int, default=518) parser.add_argument("--dir_path", type=str, default="./vis_depth") parser.add_argument( "--encoder", type=str, default="vitl", choices=["vits", "vitb", "vitl", "vitg"] ) parser.add_argument( "--load-from", type=str, default="checkpoints/Depth-Anything/depth_anything_v2_vitl.pth", ) return parser.parse_args() if __name__ == "__main__": args = parse_args() # Auto-detect best available device DEVICE = ( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) print(f"Using device: {DEVICE}") # Initialize Depth-Anything V2 model depth_anything = DepthAnythingV2(**model_configs[args.encoder]) depth_anything.load_state_dict(torch.load(args.load_from, map_location="cpu")) depth_anything = depth_anything.to(DEVICE).eval() # Setup input and output paths img_path = os.path.join(args.dir_path, "img") out_path = os.path.join(args.dir_path, "depth-anything") if not os.path.exists(out_path): os.makedirs(out_path) # Collect all image files img_list = sorted(glob.glob(os.path.join(img_path, "*.jpg"))) img_list += sorted(glob.glob(os.path.join(img_path, "*.png"))) # Process each image sequentially for k, img in enumerate(img_list): print(f"Progress {k+1}/{len(img_list)}: {img}") # Load and process image raw_image = cv2.imread(img) # Generate depth map depth = depth_anything.infer_image(raw_image, args.input_size) # Save depth map as numpy array output_path = os.path.join( out_path, os.path.splitext(os.path.basename(img))[0] + ".npy" ) np.save(output_path, depth) ================================================ FILE: camera_pose_annotation/depth_estimation/Depth-Anything/inference_batch.py ================================================ """ Distributed batch inference script for Depth-Anything V2 model. Processes video frames to generate depth maps using distributed computing. """ import argparse from datetime import timedelta import cv2 import glob import numpy as np import pandas as pd import os import torch import torch.distributed as dist from torch.utils.data import Dataset, DataLoader, DistributedSampler from torchvision.transforms import Compose import torch.nn.functional as F from torchvision.transforms import ToTensor from tqdm import tqdm from depth_anything_v2.util.transform import Resize, NormalizeImage, PrepareForNet from depth_anything_v2.dpt import DepthAnythingV2 # Model configuration for different encoder variants model_configs = { "vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]}, "vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768]}, "vitl": { "encoder": "vitl", "features": 256, "out_channels": [256, 512, 1024, 1024], }, "vitg": { "encoder": "vitg", "features": 384, "out_channels": [1536, 1536, 1536, 1536], }, } class ImageDataset(Dataset): """Dataset for loading and preprocessing images for depth estimation.""" def __init__(self, img_list, input_size): self.img_list = img_list self.input_size = input_size self.transform = Compose( [ Resize( width=input_size, height=input_size, resize_target=False, keep_aspect_ratio=True, ensure_multiple_of=14, resize_method="lower_bound", image_interpolation_method=cv2.INTER_CUBIC, ), NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), PrepareForNet(), ] ) def __len__(self): return len(self.img_list) def image2tensor(self, raw_image): """Convert raw image to tensor format for model input.""" h, w = raw_image.shape[:2] image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0 image = self.transform({"image": image})["image"] image = torch.from_numpy(image) return image, (h, w) def __getitem__(self, idx): """Load and preprocess a single image with error handling.""" def inner_func(idx): img_path = self.img_list[idx] raw_image = cv2.imread(img_path) image, (original_h, original_w) = self.image2tensor(raw_image) data = { "image": image, "path": img_path, "original_size": (original_h, original_w), } return data while True: try: return inner_func(idx) except Exception as e: print(f"e: [{e}], path: {self.img_list[idx]}, try to get next idx") idx += 1 if idx >= len(self.img_list): raise StopIteration def parse_args(): """Parse command line arguments for depth estimation.""" parser = argparse.ArgumentParser( description="Depth Anything V2 Distributed Inference" ) parser.add_argument("--csv_path", type=str, help="Path to the csv file") parser.add_argument("--input-size", type=int, default=518) parser.add_argument("--output_dir", type=str, default="./output") parser.add_argument( "--encoder", type=str, default="vitl", choices=["vits", "vitb", "vitl", "vitg"] ) parser.add_argument("--checkpoints_path", type=str, default="./checkpoints") parser.add_argument("--bs", type=int, default=8, help="Batch size for inference") parser.add_argument( "--num_workers", type=int, default=4, help="Number of data loading workers" ) return parser.parse_args() def collate_fn(batch): """Custom collate function for batching data.""" return_batch = {} for key in batch[0].keys(): if key == "image": return_batch[key] = torch.stack([item[key] for item in batch], dim=0) else: return_batch[key] = [item[key] for item in batch] return return_batch def main(): args = parse_args() # Initialize distributed environment dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) local_rank = dist.get_rank() torch.cuda.set_device(local_rank) DEVICE = f"cuda:{local_rank}" # Load data list from CSV df = pd.read_csv(args.csv_path) img_list = [] for index, row in tqdm( df.iterrows(), total=len(df), desc="Loading images", disable=(local_rank != 0) ): img_dir = os.path.join(args.output_dir, row["id"], "img") if not os.path.exists(img_dir): print(f"Image directory not found: {img_dir}") continue img_list += sorted(glob.glob(os.path.join(img_dir, "*.jpg"))) img_list += sorted(glob.glob(os.path.join(img_dir, "*.png"))) # Create dataset and distributed sampler dataset = ImageDataset(img_list, args.input_size) sampler = DistributedSampler( dataset, num_replicas=dist.get_world_size(), rank=local_rank, shuffle=False, drop_last=False, ) dataloader = DataLoader( dataset, batch_size=args.bs, sampler=sampler, num_workers=args.num_workers, pin_memory=True, collate_fn=collate_fn, ) # Initialize Depth-Anything V2 model depth_anything = DepthAnythingV2(**model_configs[args.encoder]) load_from = os.path.join( args.checkpoints_path, f"Depth-Anything/depth_anything_v2_{args.encoder}.pth" ) depth_anything.load_state_dict(torch.load(load_from, map_location="cpu")) depth_anything = depth_anything.to(DEVICE).eval() # Run inference and save depth maps with torch.no_grad(): for batch in tqdm( dataloader, desc="Depth inference", disable=(local_rank != 0) ): images = batch["image"].to(DEVICE) original_sizes = batch["original_size"] paths = batch["path"] # Forward pass through depth model depth = depth_anything(images) # Upsample to original image size original_h, original_w = original_sizes[0] depth = F.interpolate( depth[:, None], size=(original_h, original_w), mode="bilinear", align_corners=False, ) # Save depth maps as numpy arrays for i in range(depth.shape[0]): depth_i = depth[i, 0].cpu().numpy() img_path = paths[i] output_filename = ( os.path.splitext(os.path.basename(img_path))[0] + ".npy" ) output_dir = os.path.join( os.path.dirname(os.path.dirname(img_path)), "depth-anything" ) os.makedirs(output_dir, exist_ok=True) output_path = os.path.join(output_dir, output_filename) np.save(output_path, depth_i) dist.destroy_process_group() if __name__ == "__main__": main() ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/__init__.py ================================================ ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/inference.py ================================================ """ Single-threaded inference script for UniDepth V2 model. Processes images in a directory to generate depth maps and camera parameters sequentially. """ import argparse import glob import os import cv2 import numpy as np from PIL import Image import torch from unidepth.models import UniDepthV2 # Maximum dimension for image resizing LONG_DIM = 640 def parse_args(): """Parse command line arguments for UniDepth inference.""" parser = argparse.ArgumentParser() parser.add_argument("--dir_path", type=str, default="./vis_depth") parser.add_argument("--load-from", type=str, default="checkpoints/UniDepth") return parser.parse_args() def main(): args = parse_args() # Initialize UniDepth V2 model model = UniDepthV2.from_pretrained(args.load_from) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # Setup input and output paths img_path = os.path.join(args.dir_path, "img") out_path = os.path.join(args.dir_path, "unidepth") if not os.path.exists(out_path): os.makedirs(out_path) # Collect all image files img_list = sorted(glob.glob(os.path.join(img_path, "*.jpg"))) img_list += sorted(glob.glob(os.path.join(img_path, "*.png"))) fovs = [] # Process each image sequentially for img_path in img_list: # Load and preprocess image rgb = np.array(Image.open(img_path))[..., :3] # Calculate target size maintaining aspect ratio if rgb.shape[1] > rgb.shape[0]: final_w, final_h = LONG_DIM, int( round(LONG_DIM * rgb.shape[0] / rgb.shape[1]) ) else: final_w, final_h = ( int(round(LONG_DIM * rgb.shape[1] / rgb.shape[0])), LONG_DIM, ) rgb = cv2.resize(rgb, (final_w, final_h), cv2.INTER_AREA) # Convert to tensor format rgb_torch = torch.from_numpy(rgb).permute(2, 0, 1) # Predict depth and intrinsics predictions = model.infer(rgb_torch) # Calculate FOV (horizontal field of view) from predicted intrinsics fov_ = np.rad2deg( 2 * np.arctan( predictions["depth"].shape[-1] / (2 * predictions["intrinsics"][0, 0, 0].cpu().numpy()) ) ) depth = predictions["depth"][0, 0].cpu().numpy() print(fov_) fovs.append(fov_) # Save depth map and FOV np.savez( os.path.join(out_path, img_path.split("/")[-1][:-4] + ".npz"), depth=np.float32(depth), fov=fov_, ) if __name__ == "__main__": main() ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/inference_batch.py ================================================ """ Distributed batch inference script for UniDepth V2 model. Processes video frames to generate depth maps and camera intrinsics using distributed computing. """ import argparse from datetime import timedelta import glob import os import cv2 import pandas as pd import numpy as np from PIL import Image import torch import torch.distributed as dist from torch.utils.data import Dataset, DataLoader, DistributedSampler from tqdm import tqdm from unidepth.models import UniDepthV2 class ImageDataset(Dataset): """Dataset for loading and preprocessing images for UniDepth inference.""" def __init__(self, img_list, input_size): self.img_list = img_list self.input_size = input_size def __len__(self): return len(self.img_list) def __getitem__(self, idx): """Load and preprocess a single image with error handling.""" def inner_func(idx): img_path = self.img_list[idx] rgb = np.array(Image.open(img_path))[..., :3] h, w = rgb.shape[:2] # Calculate target size maintaining aspect ratio if w > h: final_w, final_h = self.input_size, int(round(self.input_size * h / w)) else: final_w, final_h = int(round(self.input_size * w / h)), self.input_size rgb_resized = cv2.resize(rgb, (final_w, final_h), cv2.INTER_AREA) rgb_torch = ( torch.from_numpy(rgb_resized).permute(2, 0, 1).float() ) # Convert to CHW format return { "image": rgb_torch, "path": img_path, } while True: try: return inner_func(idx) except Exception as e: print(f"e: [{e}], path: {self.img_list[idx]}, try to get next idx") idx = (idx + 1) % len(self.img_list) if idx >= len(self.img_list): raise StopIteration def collate_fn(batch): """Custom collate function for batching data.""" return_batch = {} for key in batch[0].keys(): if key == "image": return_batch[key] = torch.stack([item[key] for item in batch], dim=0) else: return_batch[key] = [item[key] for item in batch] return return_batch def parse_args(): """Parse command line arguments for UniDepth inference.""" parser = argparse.ArgumentParser() parser.add_argument("--csv_path", type=str, help="Path to the csv file") parser.add_argument("--output_dir", type=str, default="./output") parser.add_argument("--checkpoints_path", type=str, default="./checkpoints") parser.add_argument( "--input_size", type=int, default=640, help="Input size for the model" ) parser.add_argument("--bs", type=int, default=8, help="Inference batch size") parser.add_argument( "--num_workers", type=int, default=4, help="Data loading workers" ) return parser.parse_args() def main(): args = parse_args() # Initialize distributed environment dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) local_rank = dist.get_rank() torch.cuda.set_device(local_rank) DEVICE = f"cuda:{local_rank}" # Load data list from CSV df = pd.read_csv(args.csv_path) img_list = [] for index, row in tqdm( df.iterrows(), total=len(df), desc="Loading images", disable=local_rank != 0 ): img_dir = os.path.join(args.output_dir, row["id"], "img") if not os.path.exists(img_dir): print(f"Image directory not found: {img_dir}") continue img_list += sorted(glob.glob(os.path.join(img_dir, "*.jpg"))) img_list += sorted(glob.glob(os.path.join(img_dir, "*.png"))) # Create dataset and distributed sampler dataset = ImageDataset(img_list, args.input_size) sampler = DistributedSampler( dataset, num_replicas=dist.get_world_size(), rank=local_rank, shuffle=False, drop_last=False, ) dataloader = DataLoader( dataset, batch_size=args.bs, sampler=sampler, num_workers=args.num_workers, pin_memory=True, collate_fn=collate_fn, ) # Initialize UniDepth V2 model load_from = os.path.join(args.checkpoints_path, "UniDepth") model = UniDepthV2.from_pretrained(load_from) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device).eval() # Run inference and save results with torch.no_grad(): for batch in tqdm( dataloader, desc="Processing batches", disable=(local_rank != 0) ): images = batch["image"].to(device) paths = batch["path"] # Model inference predictions = model.infer(images) # Process results for each sample for i in range(len(paths)): depth = predictions["depth"][i, 0].cpu().numpy() # [H, W] intrinsics = predictions["intrinsics"][i].cpu().numpy() focal_length = intrinsics[ 0, 0 ] # Assume principal point at center, take fx w = depth.shape[-1] # Width # Calculate FOV (horizontal field of view) fov = np.rad2deg(2 * np.arctan(w / (2 * focal_length))) # Save results img_path = paths[i] output_filename = ( os.path.splitext(os.path.basename(img_path))[0] + ".npz" ) output_dir = os.path.join( os.path.dirname(os.path.dirname(img_path)), "unidepth" ) os.makedirs(output_dir, exist_ok=True) output_path = os.path.join(output_dir, output_filename) np.savez(output_path, depth=np.float32(depth), fov=fov) if __name__ == "__main__": main() ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/_2d3ds.py ================================================ from typing import Any import torch from unidepth.datasets.pipelines import Compose, PanoCrop, PanoRoll from unidepth.datasets.sequence_dataset import SequenceDataset class _2D3DS(SequenceDataset): min_depth = 0.01 max_depth = 10.0 depth_scale = 512.0 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = [f"2D3DS.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["cam2w", "camera_params"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) self.resizer = Compose( [PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer] ) def preprocess(self, results): self.resizer.ctx = None if self.test_mode: for i, seq in enumerate(results["sequence_fields"]): results[seq]["points"] = results[seq]["camera"].reconstruct( results[seq]["depth"] ) results[seq]["depth"] = results[seq]["points"][:, -1:] results[seq]["gt_fields"].add("points") return super().preprocess(results) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [False] * self.num_frames * self.num_copies results["quality"] = [1] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/_4dor.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class _4DOR(SequenceDataset): min_depth = 0.01 max_depth = 10.0 depth_scale = 1000.0 default_fps = 10 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = ["4DOR.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["camera_params", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [False] * self.num_frames * self.num_copies results["si"] = [False] * self.num_frames * self.num_copies results["quality"] = [2] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/__init__.py ================================================ from ._2d3ds import _2D3DS from ._4dor import _4DOR from .a2d2 import A2D2 from .adt import ADT from .aimotive import aiMotive from .argoverse import Argoverse from .argoverse2 import Argoverse2 from .arkit import ARKit from .ase import ASE from .base_dataset import BaseDataset from .bdd import BDD from .bedlam import BEDLAM from .behave import Behave from .blendedmvg import BlendedMVG from .cityscape import Cityscape from .ddad import DDAD from .deep360 import Deep360 from .dense import DENSE from .diode import DiodeIndoor, DiodeIndoor_F from .dl3dv import DL3DV from .driving_stereo import DrivingStereo from .dtu_rmvd import DTURMVD from .dummy import Dummy from .dynamic_replica import DynReplica from .eden import EDEN from .eth3d import ETH3D, ETH3D_F from .eth3d_rmvd import ETH3DRMVD from .facedepth import FaceDepth from .flsea import FLSea from .futurehouse import FutureHouse from .gibson import Gibson from .hammer import HAMMER from .hm3d import HM3D from .hoi4d import HOI4D from .hypersim import HyperSim from .ibims import IBims, IBims_F from .image_dataset import ImageDataset from .ken_burns import KenBurns from .kitti import KITTI, KITTIBenchmark from .kitti360 import KITTI360 from .kitti_multi import KITTIMulti from .kitti_rmvd import KITTIRMVD from .lyft import Lyft from .mapillary import Mapillary from .matrix_city import MatrixCity from .matterport3d import Matterport3D from .megadepth import MegaDepth from .megadepth_s import MegaDepthS from .midair import MidAir from .mip import MIP from .ms2 import MS2 from .mvimgnet import MVImgNet from .mvsynth import MVSynth from .nerds360 import NeRDS360 from .niantic_mapfree import NianticMapFree from .nuscenes import Nuscenes from .nyuv2 import NYUv2Depth from .point_odyssey import PointOdyssey from .proteus import Proteus from .samplers import DistributedSamplerNoDuplicate from .scannet import ScanNet from .scannetpp import ScanNetpp, ScanNetpp_F from .sequence_dataset import SequenceDataset from .sintel import Sintel from .sunrgbd import SUNRGBD from .synscapes import Synscapes from .tartanair import TartanAir from .taskonomy import Taskonomy from .tat_rmvd import TATRMVD from .theo import Theo from .unrealstereo4k import UnrealStereo4K from .urbansyn import UrbanSyn from .utils import ConcatDataset, collate_fn, get_weights from .vkitti import VKITTI from .void import VOID from .waymo import Waymo from .wildrgbd import WildRGBD ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/a2d2.py ================================================ import json import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.utils import DatasetFromList class A2D2(ImageDataset): min_depth = 0.01 max_depth = 120.0 depth_scale = 256.0 train_split = "train_clean.txt" intrisics_file = "intrinsics.json" hdf5_paths = ["a2d2.hdf5"] def __init__( self, image_shape, split_file, test_mode, crop=None, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") intrinsics = json.loads(intrinsics) h5file.close() dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename = line.strip().split(" ") intrinsics_val = torch.tensor( intrinsics[os.path.join(*image_filename.split("/")[:2])] ).squeeze()[:, :3] sample = [image_filename, depth_filename, intrinsics_val] dataset.append(sample) # if not self.test_mode: # dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [False] * self.num_copies results["quality"] = [1] * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/adt.py ================================================ from typing import Any import torch from unidepth.datasets.sequence_dataset import SequenceDataset class ADT(SequenceDataset): min_depth = 0.01 max_depth = 20.0 depth_scale = 1000.0 test_split = "val.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = [f"ADT.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["camera_params", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, # if not test_mode else [*decode_fields, "points"], inplace_fields=inplace_fields, **kwargs, ) def preprocess(self, results): self.resizer.ctx = None for i, seq in enumerate(results["sequence_fields"]): # Create a mask where the distance from the center is less than H/2 H, W = results[seq]["image"].shape[-2:] x = torch.linspace(-W / 2 - 0.5, W / 2 + 0.5, W) y = torch.linspace(-H / 2 - 0.5, H / 2 + 0.5, H) xv, yv = torch.meshgrid(x, y, indexing="xy") distance_from_center = torch.sqrt(xv**2 + yv**2).reshape(1, 1, H, W) results[seq]["validity_mask"] = distance_from_center < (H / 2) + 20 results[seq]["depth_mask"] = results[seq]["validity_mask"].clone() results[seq]["mask_fields"].add("depth_mask") results[seq]["mask_fields"].add("validity_mask") return super().preprocess(results) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["quality"] = [0] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/aimotive.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class aiMotive(SequenceDataset): min_depth = 0.01 max_depth = 100.0 depth_scale = 256.0 default_fps = 10 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = ["aiMotive.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["camera_params", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [False] * self.num_frames * self.num_copies results["synthetic"] = [False] * self.num_frames * self.num_copies results["quality"] = [2] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/argoverse.py ================================================ import json import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.utils import DatasetFromList class Argoverse(ImageDataset): min_depth = 0.05 max_depth = 120.0 depth_scale = 256.0 test_split = "argo_val.txt" train_split = "argo_train.txt" intrisics_file = "argo_intrinsics.json" hdf5_paths = ["argoverse11.hdf5"] def __init__( self, image_shape, split_file, test_mode, crop=None, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.crop = crop self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") intrinsics = json.loads(intrinsics) h5file.close() dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename = line.strip().split(" ") intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] sample = [image_filename, depth_filename, intrinsics_val] dataset.append(sample) if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) self.dataset = DatasetFromList(dataset) self.log_load_dataset() ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/argoverse2.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class Argoverse2(SequenceDataset): min_depth = 0.05 max_depth = 120.0 depth_scale = 256.0 test_split = "val.txt" train_split = "train.txt" sequences_file = "sequences_clean.json" hdf5_paths = [f"AV2_viz.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [False] * self.num_frames * self.num_copies results["quality"] = [1] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/arkit.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class ARKit(SequenceDataset): min_depth = 0.01 max_depth = 10.0 depth_scale = 1000.0 test_split = "Training.txt" train_split = "Training.txt" sequences_file = "sequences.json" hdf5_paths = ["ARKitS.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["quality"] = [2] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/ase.py ================================================ from typing import Any import torch from unidepth.datasets.sequence_dataset import SequenceDataset class ASE(SequenceDataset): min_depth = 0.01 max_depth = 20.0 depth_scale = 1000.0 test_split = "val.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = [f"ASE.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["camera_params", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def preprocess(self, results): self.resizer.ctx = None for i, seq in enumerate(results["sequence_fields"]): # Create a mask where the distance from the center is less than H/2 H, W = results[seq]["image"].shape[-2:] x = torch.linspace(-W / 2 - 0.5, W / 2 + 0.5, W) y = torch.linspace(-H / 2 - 0.5, H / 2 + 0.5, H) xv, yv = torch.meshgrid(x, y, indexing="xy") distance_from_center = torch.sqrt(xv**2 + yv**2).reshape(1, 1, H, W) results[seq]["validity_mask"] = distance_from_center < (H / 2) + 20 results[seq]["mask_fields"].add("validity_mask") return super().preprocess(results) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["quality"] = [0] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/base_dataset.py ================================================ import os from abc import abstractmethod from copy import deepcopy from math import ceil, log from typing import Any, Dict, Tuple import numpy as np import torch from torch.utils.data import Dataset import unidepth.datasets.pipelines as pipelines from unidepth.utils import (eval_3d, eval_depth, identity, is_main_process, recursive_index, sync_tensor_across_gpus) from unidepth.utils.constants import (IMAGENET_DATASET_MEAN, IMAGENET_DATASET_STD, OPENAI_DATASET_MEAN, OPENAI_DATASET_STD) class BaseDataset(Dataset): min_depth = 0.01 max_depth = 1000.0 def __init__( self, image_shape: Tuple[int, int], split_file: str, test_mode: bool, benchmark: bool, normalize: bool, augmentations_db: Dict[str, Any], resize_method: str, mini: float, num_copies: int = 1, **kwargs, ) -> None: super().__init__() assert normalize in [None, "imagenet", "openai"] self.split_file = split_file self.test_mode = test_mode self.data_root = os.environ["DATAROOT"] self.image_shape = image_shape self.resize_method = resize_method self.mini = mini self.num_frames = 1 self.num_copies = num_copies self.metrics_store = {} self.metrics_count = {} if normalize == "imagenet": self.normalization_stats = { "mean": torch.tensor(IMAGENET_DATASET_MEAN), "std": torch.tensor(IMAGENET_DATASET_STD), } elif normalize == "openai": self.normalization_stats = { "mean": torch.tensor(OPENAI_DATASET_MEAN), "std": torch.tensor(OPENAI_DATASET_STD), } else: self.normalization_stats = { "mean": torch.tensor([0.0, 0.0, 0.0]), "std": torch.tensor([1.0, 1.0, 1.0]), } for k, v in augmentations_db.items(): setattr(self, k, v) if not self.test_mode: self._augmentation_space() self.masker = pipelines.AnnotationMask( min_value=0.0, max_value=self.max_depth if test_mode else None, custom_fn=identity, ) self.filler = pipelines.RandomFiller(noise_pad=True) shape_mult = self.shape_constraints["shape_mult"] self.image_shape = [ ceil(self.image_shape[0] / shape_mult) * shape_mult, ceil(self.image_shape[1] / shape_mult) * shape_mult, ] self.resizer = pipelines.ContextCrop( image_shape=self.image_shape, train_ctx_range=(1.0 / self.random_scale, 1.0 * self.random_scale), test_min_ctx=self.test_context, keep_original=test_mode, shape_constraints=self.shape_constraints, ) self.collecter = pipelines.Collect( keys=["image_fields", "mask_fields", "gt_fields", "camera_fields"] ) def __len__(self): return len(self.dataset) def pack_batch(self, results): results["paddings"] = [ results[x]["paddings"][0] for x in results["sequence_fields"] ] for fields_name in [ "image_fields", "gt_fields", "mask_fields", "camera_fields", ]: fields = results.get(fields_name) packed = { field: torch.cat( [results[seq][field] for seq in results["sequence_fields"]] ) for field in fields } results.update(packed) return results def unpack_batch(self, results): for fields_name in [ "image_fields", "gt_fields", "mask_fields", "camera_fields", ]: fields = results.get(fields_name) unpacked = { field: { seq: results[field][idx : idx + 1] for idx, seq in enumerate(results["sequence_fields"]) } for field in fields } results.update(unpacked) return results def _augmentation_space(self): self.augmentations_dict = { "Flip": pipelines.RandomFlip(prob=self.flip_p), "Jitter": pipelines.RandomColorJitter( (-self.random_jitter, self.random_jitter), prob=self.jitter_p ), "Gamma": pipelines.RandomGamma( (-self.random_gamma, self.random_gamma), prob=self.gamma_p ), "Blur": pipelines.GaussianBlur( kernel_size=13, sigma=(0.1, self.random_blur), prob=self.blur_p ), "Grayscale": pipelines.RandomGrayscale(prob=self.grayscale_p), } def augment(self, results): for name, aug in self.augmentations_dict.items(): results = aug(results) return results def prepare_depth_eval(self, inputs, preds): new_preds = {} keyframe_idx = getattr(self, "keyframe_idx", None) slice_idx = slice( keyframe_idx, keyframe_idx + 1 if keyframe_idx is not None else None ) new_gts = inputs["depth"][slice_idx] new_masks = inputs["depth_mask"][slice_idx].bool() for key, val in preds.items(): if "depth" in key: new_preds[key] = val[slice_idx] return new_gts, new_preds, new_masks def prepare_points_eval(self, inputs, preds): new_preds = {} new_gts = inputs["points"] new_masks = inputs["depth_mask"].bool() if "points_mask" in inputs: new_masks = inputs["points_mask"].bool() for key, val in preds.items(): if "points" in key: new_preds[key] = val return new_gts, new_preds, new_masks def add_points(self, inputs): inputs["points"] = inputs.get("camera_original", inputs["camera"]).reconstruct( inputs["depth"] ) return inputs @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def accumulate_metrics( self, inputs, preds, keyframe_idx=None, metrics=["depth", "points", "flow_fwd", "pairwise"], ): if "depth" in inputs and "points" not in inputs: inputs = self.add_points(inputs) available_metrics = [] for metric in metrics: metric_in_gt = any((metric in k for k in inputs.keys())) metric_in_pred = any((metric in k for k in preds.keys())) if metric_in_gt and metric_in_pred: available_metrics.append(metric) if keyframe_idx is not None: inputs = recursive_index(inputs, slice(keyframe_idx, keyframe_idx + 1)) preds = recursive_index(preds, slice(keyframe_idx, keyframe_idx + 1)) if "depth" in available_metrics: depth_gt, depth_pred, depth_masks = self.prepare_depth_eval(inputs, preds) self.accumulate_metrics_depth(depth_gt, depth_pred, depth_masks) if "points" in available_metrics: points_gt, points_pred, points_masks = self.prepare_points_eval( inputs, preds ) self.accumulate_metrics_3d(points_gt, points_pred, points_masks) @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def accumulate_metrics_depth(self, gts, preds, masks): for eval_type, pred in preds.items(): log_name = eval_type.replace("depth", "").strip("-").strip("_") if log_name not in self.metrics_store: self.metrics_store[log_name] = {} current_count = self.metrics_count.get( log_name, torch.tensor([], device=gts.device) ) new_count = masks.view(gts.shape[0], -1).sum(dim=-1) self.metrics_count[log_name] = torch.cat([current_count, new_count]) for k, v in eval_depth(gts, pred, masks, max_depth=self.max_depth).items(): current_metric = self.metrics_store[log_name].get( k, torch.tensor([], device=gts.device) ) self.metrics_store[log_name][k] = torch.cat([current_metric, v]) @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def accumulate_metrics_3d(self, gts, preds, masks): thresholds = torch.linspace( log(self.min_depth), log(self.max_depth / 20), steps=100, device=gts.device, ).exp() for eval_type, pred in preds.items(): log_name = eval_type.replace("points", "").strip("-").strip("_") if log_name not in self.metrics_store: self.metrics_store[log_name] = {} current_count = self.metrics_count.get( log_name, torch.tensor([], device=gts.device) ) new_count = masks.view(gts.shape[0], -1).sum(dim=-1) self.metrics_count[log_name] = torch.cat([current_count, new_count]) for k, v in eval_3d(gts, pred, masks, thresholds=thresholds).items(): current_metric = self.metrics_store[log_name].get( k, torch.tensor([], device=gts.device) ) self.metrics_store[log_name][k] = torch.cat([current_metric, v]) def get_evaluation(self, metrics=None): metric_vals = {} for eval_type in metrics if metrics is not None else self.metrics_store.keys(): assert self.metrics_store[eval_type] cnts = sync_tensor_across_gpus(self.metrics_count[eval_type]) for name, val in self.metrics_store[eval_type].items(): # vals_r = (sync_tensor_across_gpus(val) * cnts / cnts.sum()).sum() vals_r = sync_tensor_across_gpus(val).mean() metric_vals[f"{eval_type}_{name}".strip("_")] = np.round( vals_r.cpu().item(), 5 ) self.metrics_store[eval_type] = {} self.metrics_count = {} return metric_vals def replicate(self, results): for i in range(1, self.num_copies): results[(0, i)] = {k: deepcopy(v) for k, v in results[(0, 0)].items()} results["sequence_fields"].append((0, i)) return results def log_load_dataset(self): if is_main_process(): info = f"Loaded {self.__class__.__name__} with {len(self)} images." print(info) def pre_pipeline(self, results): results["image_fields"] = results.get("image_fields", set()) results["gt_fields"] = results.get("gt_fields", set()) results["mask_fields"] = results.get("mask_fields", set()) results["sequence_fields"] = results.get("sequence_fields", set()) results["camera_fields"] = results.get("camera_fields", set()) results["dataset_name"] = ( [self.__class__.__name__] * self.num_frames * self.num_copies ) results["depth_scale"] = [self.depth_scale] * self.num_frames * self.num_copies results["si"] = [False] * self.num_frames * self.num_copies results["dense"] = [False] * self.num_frames * self.num_copies results["synthetic"] = [False] * self.num_frames * self.num_copies results["quality"] = [0] * self.num_frames * self.num_copies results["valid_camera"] = [True] * self.num_frames * self.num_copies results["valid_pose"] = [True] * self.num_frames * self.num_copies return results def eval_mask(self, valid_mask): return valid_mask def chunk(self, dataset, chunk_dim=1, pct=1.0): subsampled_datasets = [ x for i in range(0, len(dataset), int(1 / pct * chunk_dim)) for x in dataset[i : i + chunk_dim] ] return subsampled_datasets @abstractmethod def preprocess(self, results): raise NotImplementedError @abstractmethod def postprocess(self, results): raise NotImplementedError @abstractmethod def get_mapper(self): raise NotImplementedError @abstractmethod def get_intrinsics(self, idx, image_name): raise NotImplementedError @abstractmethod def get_extrinsics(self, idx, image_name): raise NotImplementedError @abstractmethod def load_dataset(self): raise NotImplementedError @abstractmethod def get_single_item(self, idx, sample=None, mapper=None): raise NotImplementedError @abstractmethod def __getitem__(self, idx): raise NotImplementedError ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/bdd.py ================================================ import json import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.utils import DatasetFromList class BDD(ImageDataset): min_depth = 0.01 max_depth = 70.0 depth_scale = 256.0 test_split = "val.txt" train_split = "train_clean.txt" intrisics_file = "intrinsics.json" hdf5_paths = ["BDD.hdf5"] def __init__( self, image_shape, split_file, test_mode, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1 intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") intrinsics = json.loads(intrinsics) dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename = line.strip().split(" ") intrinsics_val = torch.tensor( intrinsics[os.path.join(*image_filename.split("/")[:2])] ).squeeze()[:, :3] sample = [image_filename, depth_filename, intrinsics_val] dataset.append(sample) h5file.close() if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) if self.test_mode and not self.benchmark: dataset = self.chunk(dataset, chunk_dim=1, pct=0.1) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def pre_pipeline(self, results): results = super().pre_pipeline(results) results["si"] = [True] * self.num_copies results["valid_camera"] = [False] * self.num_copies results["dense"] = [False] * self.num_copies results["quality"] = [2] * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/bedlam.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class BEDLAM(SequenceDataset): min_depth = 0.01 max_depth = 256.0 depth_scale = 1000.0 test_split = "train.txt" train_split = "val.txt" sequences_file = "sequences.json" hdf5_paths = ["BEDLAM.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["quality"] = [0] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/behave.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class Behave(SequenceDataset): min_depth = 0.01 max_depth = 10.0 depth_scale = 1000.0 default_fps = 10 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = ["Behave.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["camera_params", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [False] * self.num_frames * self.num_copies results["si"] = [False] * self.num_frames * self.num_copies results["quality"] = [1] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/blendedmvg.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class BlendedMVG(SequenceDataset): min_depth = 0.01 max_depth = 5000.0 depth_scale = 1000.0 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences_clean.json" hdf5_paths = ["BlendedMVG_.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["si"] = [False] * self.num_frames * self.num_copies results["quality"] = [2] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/cityscape.py ================================================ import json import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.utils import DatasetFromList class Cityscape(ImageDataset): min_depth = 0.05 max_depth = 80.0 depth_scale = 256.0 test_split = "val.txt" train_split = "train.txt" intrisics_file = "intrinsics.json" hdf5_paths = ["cityscape.hdf5"] def __init__( self, image_shape, split_file, test_mode, crop=None, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.crop = crop self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") intrinsics = json.loads(intrinsics) h5file.close() dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename = line.strip().split(" ") intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] sample = [image_filename, depth_filename, intrinsics_val] dataset.append(sample) if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def pre_pipeline(self, results): results = super().pre_pipeline(results) results["quality"] = [2] * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/ddad.py ================================================ import json import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.utils import DatasetFromList class DDAD(ImageDataset): min_depth = 0.05 max_depth = 120.0 depth_scale = 256.0 test_split = "val.txt" train_split = "train.txt" intrisics_file = "intrinsics.json" hdf5_paths = [f"ddad/ddad_{i}.hdf5" for i in range(8)] def __init__( self, image_shape, split_file, test_mode, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii").strip("\n") intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") intrinsics = json.loads(intrinsics) h5file.close() dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename, chunk_idx = line.strip().split(" ") intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] sample = [image_filename, depth_filename, intrinsics_val, chunk_idx] dataset.append(sample) if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def get_mapper(self): return { "image_filename": 0, "depth_filename": 1, "K": 2, "chunk_idx": 3, } def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [False] * self.num_copies results["quality"] = [1] * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/deep360.py ================================================ from typing import Any import torch from unidepth.datasets.pipelines import Compose, PanoCrop, PanoRoll from unidepth.datasets.sequence_dataset import SequenceDataset class Deep360(SequenceDataset): min_depth = 0.1 max_depth = 1000.0 depth_scale = 1000.0 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = [f"Deep360.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["cam2w", "camera_params"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) self.resizer = Compose( [PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer] ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["quality"] = [0] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/dense.py ================================================ import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.utils import DatasetFromList class DENSE(ImageDataset): CAM_INTRINSIC = { "ALL": torch.tensor( [ [1177.8614, 0.0, 474.319027], [0.0, 1177.8614, 224.275919], [0.0, 0.0, 1.0], ] ) } min_depth = 0.05 max_depth = 80.0 depth_scale = 255.0 test_split = "train.txt" train_split = "train.txt" hdf5_paths = ["DENSE.hdf5"] def __init__( self, image_shape, split_file, test_mode, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.intrisics = {} self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 h5file.close() dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename = line.strip().split(" ") sample = [image_filename, depth_filename] dataset.append(sample) if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def get_intrinsics(self, idx, image_name): return self.CAM_INTRINSIC["ALL"].clone() def get_mapper(self): return { "image_filename": 0, "depth_filename": 1, } def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [False] * self.num_copies results["quality"] = [1] * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/diml.py ================================================ import json import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.utils import DatasetFromList class DIML(ImageDataset): min_depth = 0.01 max_depth = 100.0 depth_scale = 256.0 test_split = "test.txt" train_split = "train.txt" intrisics_file = "intrinsics.json" hdf5_paths = ["DIML.hdf5"] def __init__( self, image_shape, split_file, test_mode, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.intrisics = {} self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1 intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") intrinsics = json.loads(intrinsics) dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename = line.strip().split(" ") intrinsics_val = torch.tensor( intrinsics[image_filename.split("/")[0]] ).squeeze()[:, :3] sample = [image_filename, depth_filename, intrinsics_val] dataset.append(sample) h5file.close() if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_copies results["quality"] = [2] * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/diode.py ================================================ import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.sequence_dataset import SequenceDataset from unidepth.datasets.utils import DatasetFromList class DiodeIndoor(ImageDataset): CAM_INTRINSIC = { "ALL": torch.tensor([[886.81, 0, 512], [0, 927.06, 384], [0, 0, 1]]) } min_depth = 0.01 max_depth = 25.0 depth_scale = 256.0 test_split = "val.txt" train_split = "train.txt" hdf5_paths = ["DiodeIndoor.hdf5"] def __init__( self, image_shape, split_file, test_mode, crop=None, benchmark=False, augmentations_db={}, normalize=True, mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, mini=mini, **kwargs, ) self.test_mode = test_mode # load annotations self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 h5file.close() dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename = line.strip().split(" ") sample = [ image_filename, depth_filename, ] dataset.append(sample) if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def get_intrinsics(self, *args, **kwargs): return self.CAM_INTRINSIC["ALL"].clone() def get_mapper(self): return { "image_filename": 0, "depth_filename": 1, } def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_copies results["quality"] = [1] * self.num_copies return results class DiodeIndoor_F(SequenceDataset): min_depth = 0.01 max_depth = 25.0 depth_scale = 1000.0 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = ["DiodeIndoor-F.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, float], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["camera_params", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=( decode_fields if not test_mode else [*decode_fields, "points"] ), inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["quality"] = [1] * self.num_frames * self.num_copies return results class DiodeOutdoor(ImageDataset): CAM_INTRINSIC = { "ALL": torch.tensor([[886.81, 0, 512], [0, 927.06, 384], [0, 0, 1]]) } min_depth = 0.1 max_depth = 80.0 log_mean = 0 log_std = 1 test_split = "diode_outdoor_val.txt" train_split = "diode_outdoor_train.txt" hdf5_paths = ["diode.hdf5"] def __init__( self, image_shape, split_file, test_mode, depth_scale=256, crop=None, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.depth_scale = depth_scale self.masker = AnnotationMask( min_value=self.min_depth, max_value=self.max_depth if test_mode else None, custom_fn=self.eval_mask if test_mode else lambda x, *args, **kwargs: x, ) # load annotations self.load_dataset() def load_dataset(self): self.h5file = h5py.File( os.path.join(self.data_root, self.hdf5_path), "r", libver="latest", swmr=True, ) txt_file = np.array(self.h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii")[:-1] dataset = {"depth_filename": [], "image_filename": []} for line in txt_string.split("\n"): depth_filename = line.strip().split(" ")[1] img_name = line.strip().split(" ")[0] image_filename = img_name dataset["depth_filename"].append(depth_filename) dataset["image_filename"].append(image_filename) self.dataset = pl.from_dict(dataset) if not self.test_mode and self.mini: self.dataset = self.dataset[::2] class Diode(ImageDataset): CAM_INTRINSIC = { "ALL": torch.tensor([[886.81, 0, 512], [0, 927.06, 384], [0, 0, 1]]) } log_mean = 0 log_std = 1 min_depth = 0.6 max_depth = 80.0 test_split = "diode_val.txt" train_split = "diode_train.txt" hdf5_paths = ["diode.hdf5"] def __init__( self, image_shape, split_file, test_mode, depth_scale=256, crop=None, benchmark=False, augmentations_db={}, normalize=True, mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, mini=mini, **kwargs, ) self.test_mode = test_mode self.depth_scale = depth_scale self.masker = AnnotationMask( min_value=self.min_depth, max_value=self.max_depth if test_mode else None, custom_fn=self.eval_mask if test_mode else lambda x, *args, **kwargs: x, ) # load annotations self.load_dataset() def load_dataset(self): self.h5file = h5py.File( os.path.join(self.data_root, self.hdf5_path), "r", libver="latest", swmr=True, ) txt_file = np.array(self.h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii")[:-1] dataset = {"depth_filename": [], "image_filename": []} for line in txt_string.split("\n"): depth_filename = line.strip().split(" ")[1] image_filename = line.strip().split(" ")[0] dataset["depth_filename"].append(depth_filename) dataset["image_filename"].append(image_filename) self.dataset = pl.from_dict(dataset) if not self.test_mode and self.mini: self.dataset = self.dataset[::2] def get_intrinsics(self, *args, **kwargs): return self.CAM_INTRINSIC["ALL"].clone() ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/dl3dv.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class DL3DV(SequenceDataset): min_depth = 0.001 max_depth = 250.0 depth_scale = 512.0 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = [f"DL3DVcv.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["camera_params", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["si"] = [True] * self.num_frames * self.num_copies results["quality"] = [2] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/driving_stereo.py ================================================ import json import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.utils import DatasetFromList class DrivingStereo(ImageDataset): min_depth = 0.05 max_depth = 80.0 depth_scale = 256.0 test_split = "drivingstereo_val.txt" train_split = "drivingstereo_train.txt" intrisics_file = "drivingstereo_intrinsics.json" hdf5_paths = ["DrivingStereo.hdf5"] def __init__( self, image_shape, split_file, test_mode, crop=None, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.crop = crop self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1 intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") intrinsics = json.loads(intrinsics) dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename = line.strip().split(" ") intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] sample = [image_filename, depth_filename, intrinsics_val] dataset.append(sample) h5file.close() if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) if self.test_mode and not self.benchmark: dataset = self.chunk(dataset, chunk_dim=1, pct=1.0) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [False] * self.num_copies results["quality"] = [1] * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/dtu_rmvd.py ================================================ import json import os from typing import Any import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.pipelines import AnnotationMask, KittiCrop from unidepth.datasets.sequence_dataset import SequenceDataset from unidepth.datasets.utils import DatasetFromList from unidepth.utils import identity class DTURMVD(SequenceDataset): min_depth = 0.05 max_depth = 3.0 depth_scale = 1000.0 default_fps = 6 test_split = "test.txt" train_split = "test.txt" sequences_file = "sequences.json" hdf5_paths = ["dtu_rmvd.hdf5"] def __init__( self, image_shape, split_file, test_mode, crop=None, augmentations_db={}, normalize=True, resize_method="hard", mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["si"] = [True] * self.num_frames * self.num_copies results["quality"] = [1] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/dummy.py ================================================ import numpy as np import torch from torch.utils.data import Dataset class Dummy(Dataset): train_split = None test_split = None def __init__(self, *args, **kwargs): super().__init__() self.dataset = np.arange(1_000_000) def get_single_item(self, idx): # results = {} # results["cam2w"] = torch.eye(4).unsqueeze(0) # results["K"] = torch.eye(3).unsqueeze(0) # results["image"] = torch.zeros(1, 3, 1024, 1024).to(torch.uint8) # results["depth"] = torch.zeros(1, 1, 1024, 1024).to(torch.float32) return { "x": {(0, 0): torch.rand(1, 3, 1024, 1024, dtype=torch.float32)}, "img_metas": {"val": torch.rand(1, 1024, dtype=torch.float32)}, } def __getitem__(self, idx): if isinstance(idx, (list, tuple)): results = [self.get_single_item(i) for i in idx] else: results = self.get_single_item(idx) return results def __len__(self): return len(self.dataset) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/dynamic_replica.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class DynReplica(SequenceDataset): min_depth = 0.01 max_depth = 20.0 default_fps = 30.0 depth_scale = 512.0 test_split = "val.txt" train_split = "train.txt" sequences_file = "sequences_clean.json" hdf5_paths = ["DynReplica.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["quality"] = [0] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/eden.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class EDEN(SequenceDataset): min_depth = 0.1 max_depth = 100.0 depth_scale = 256.0 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = [f"EDEN.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["quality"] = [0] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/eth3d.py ================================================ import json import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.sequence_dataset import SequenceDataset from unidepth.datasets.utils import DatasetFromList class ETH3D(ImageDataset): min_depth = 0.01 max_depth = 50.0 depth_scale = 1000.0 test_split = "train.txt" train_split = "train.txt" intrisics_file = "intrinsics.json" hdf5_paths = ["ETH3D.hdf5"] def __init__( self, image_shape, split_file, test_mode, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") intrinsics = json.loads(intrinsics) h5file.close() dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename = line.strip().split(" ") intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] sample = [image_filename, depth_filename, intrinsics_val] dataset.append(sample) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["quality"] = [1] * self.num_frames * self.num_copies return results class ETH3D_F(SequenceDataset): min_depth = 0.05 max_depth = 60.0 depth_scale = 1000.0 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = ["ETH3D-F.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, float], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["camera_params", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=( decode_fields if not test_mode else [*decode_fields, "points"] ), inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["quality"] = [1] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/eth3d_rmvd.py ================================================ import json import os from typing import Any import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.pipelines import AnnotationMask, KittiCrop from unidepth.datasets.sequence_dataset import SequenceDataset from unidepth.datasets.utils import DatasetFromList from unidepth.utils import identity class ETH3DRMVD(SequenceDataset): min_depth = 0.01 max_depth = 50.0 depth_scale = 1000.0 default_fps = 6 test_split = "test.txt" train_split = "test.txt" sequences_file = "sequences.json" hdf5_paths = ["eth3d_rmvd.hdf5"] def __init__( self, image_shape, split_file, test_mode, crop=None, augmentations_db={}, normalize=True, resize_method="hard", mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/facedepth.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class FaceDepth(SequenceDataset): min_depth = 0.01 max_depth = 10.0 depth_scale = 1000.0 default_fps = 10 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = ["FaceDepth.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["quality"] = [0] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/flsea.py ================================================ import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.utils import DatasetFromList class FLSea(ImageDataset): CAM_INTRINSIC = { "canyons": torch.tensor( [ [1175.3913431656817, 0.0, 466.2595428966926], [0.0, 1174.2805075232263, 271.2116633091501], [0.0, 0.0, 1.0], ] ), "red_sea": torch.tensor( [ [1296.666758476217, 0.0, 501.50386149846], [0.0, 1300.831316354508, 276.161712082695], [0.0, 0.0, 1.0], ] ), } min_depth = 0.05 max_depth = 20.0 depth_scale = 1000.0 train_split = "train.txt" hdf5_paths = ["FLSea.hdf5"] def __init__( self, image_shape, split_file, test_mode, crop=None, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=False, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.crop = crop self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 h5file.close() dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename = line.strip().split(" ") sample = [image_filename, depth_filename] dataset.append(sample) if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) if self.test_mode and not self.benchmark: dataset = self.chunk(dataset, chunk_dim=1, pct=0.33) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def get_intrinsics(self, idx, image_name): return self.CAM_INTRINSIC[image_name.split("/")[0]][:, :3].clone() def get_mapper(self): return { "image_filename": 0, "depth_filename": 1, } def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_copies results["quality"] = [2] * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/futurehouse.py ================================================ from typing import Any import torch from unidepth.datasets.pipelines import Compose, PanoCrop, PanoRoll from unidepth.datasets.sequence_dataset import SequenceDataset class FutureHouse(SequenceDataset): min_depth = 0.01 max_depth = 10.0 depth_scale = 1000.0 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = [f"FutureHouse.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["cam2w", "camera_params"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) self.resizer = Compose( [PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer] ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["quality"] = [0] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/gibson.py ================================================ from typing import Any import torch from unidepth.datasets.pipelines import Compose, PanoCrop, PanoRoll from unidepth.datasets.sequence_dataset import SequenceDataset class Gibson(SequenceDataset): min_depth = 0.01 max_depth = 10.0 depth_scale = 1000.0 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = [f"Gibson.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["cam2w", "camera_params"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) self.resizer = Compose( [PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer] ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["quality"] = [1] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/hammer.py ================================================ import json import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.utils import DatasetFromList class HAMMER(ImageDataset): min_depth = 0.005 max_depth = 10.0 depth_scale = 1000.0 train_split = "test.txt" test_split = "test.txt" intrisics_file = "intrinsics.json" hdf5_paths = ["hammer.hdf5"] def __init__( self, image_shape, split_file, test_mode, crop=None, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.crop = crop self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") intrinsics = json.loads(intrinsics) h5file.close() dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename = line.strip().split(" ") intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] sample = [image_filename, depth_filename, intrinsics_val] dataset.append(sample) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["quality"] = [1] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/hm3d.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class HM3D(SequenceDataset): min_depth = 0.01 max_depth = 10.0 depth_scale = 1000.0 test_split = "val.txt" train_split = "full.txt" sequences_file = "sequences.json" hdf5_paths = [f"HM3D.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["quality"] = [2] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/hoi4d.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class HOI4D(SequenceDataset): min_depth = 0.01 max_depth = 10.0 depth_scale = 1000.0 default_fps = 5 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = ["HOI4D.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [False] * self.num_frames * self.num_copies results["quality"] = [1] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/hrwsi.py ================================================ import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.utils import DatasetFromList class HRWSI(ImageDataset): min_depth = 0.01 max_depth = 1000.0 depth_scale = 50.0 test_split = "val.txt" train_split = "train.txt" hdf5_paths = ["HRWSI.hdf5"] def __init__( self, image_shape, split_file, test_mode, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1 # with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f: # f.write(txt_string) dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename = line.strip().split(" ") sample = [ image_filename, depth_filename, ] dataset.append(sample) h5file.close() if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def pre_pipeline(self, results): results = super().pre_pipeline(results) results["ssi"] = [True] results["valid_camera"] = [False] return results def get_mapper(self): return { "image_filename": 0, "depth_filename": 1, } ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/hypersim.py ================================================ import json import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.utils import DatasetFromList class HyperSim(ImageDataset): min_depth = 0.01 max_depth = 50.0 depth_scale = 1000.0 test_split = "val.txt" train_split = "train.txt" intrisics_file = "intrinsics.json" hdf5_paths = [f"hypersim/hypersim_{i}.hdf5" for i in range(8)] def __init__( self, image_shape, split_file, test_mode, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii").strip("\n") intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") intrinsics = json.loads(intrinsics) # with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f: # f.write(txt_string) # with open(os.path.join(os.environ["TMPDIR"], self.intrisics_file), "w") as f: # json.dump(intrinsics, f) dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename, chunk_idx = line.strip().split(" ") intrinsics_val = torch.tensor( intrinsics[os.path.join(*image_filename.split("/")[:2])] ).squeeze()[:, :3] sample = [image_filename, depth_filename, intrinsics_val, chunk_idx] dataset.append(sample) h5file.close() if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) if self.test_mode and not self.benchmark: # corresponds to 712 images dataset = self.chunk(dataset, chunk_dim=1, pct=0.1) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def get_mapper(self): return { "image_filename": 0, "depth_filename": 1, "K": 2, "chunk_idx": 3, } def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_copies results["synthetic"] = [True] * self.num_copies results["quality"] = [0] * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/ibims.py ================================================ import json import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.sequence_dataset import SequenceDataset from unidepth.datasets.utils import DatasetFromList class IBims(ImageDataset): min_depth = 0.005 max_depth = 25.0 depth_scale = 1000.0 train_split = "ibims_val.txt" test_split = "ibims_val.txt" intrisics_file = "ibims_intrinsics.json" hdf5_paths = ["ibims.hdf5"] def __init__( self, image_shape, split_file, test_mode, crop=None, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.crop = crop self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") intrinsics = json.loads(intrinsics) h5file.close() dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename = line.strip().split(" ") intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] sample = [image_filename, depth_filename, intrinsics_val] dataset.append(sample) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_copies results["quality"] = [1] * self.num_copies return results class IBims_F(SequenceDataset): min_depth = 0.01 max_depth = 25.0 depth_scale = 1000.0 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = ["IBims-F.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, float], resize_method: str, mini: float, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["camera_params", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=( decode_fields if not test_mode else [*decode_fields, "points"] ), inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["quality"] = [1] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/image_dataset.py ================================================ import io import os from time import time from typing import Any, Dict, List, Tuple import numpy as np import tables import torch import torchvision import torchvision.transforms.v2.functional as TF from PIL import Image from unidepth.datasets.base_dataset import BaseDataset from unidepth.utils import is_main_process from unidepth.utils.camera import BatchCamera, Pinhole """ Awful class for legacy reasons, we assume only pinhole cameras And we "fake" sequences by setting sequence_fields to [(0, 0)] and cam2w as eye(4) """ class ImageDataset(BaseDataset): def __init__( self, image_shape: Tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: Dict[str, Any], resize_method: str, mini: float, benchmark: bool = False, **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.mapper = self.get_mapper() def get_single_item(self, idx, sample=None, mapper=None): sample = self.dataset[idx] if sample is None else sample mapper = self.mapper if mapper is None else mapper results = { (0, 0): dict( gt_fields=set(), image_fields=set(), mask_fields=set(), camera_fields=set(), ) } results = self.pre_pipeline(results) results["sequence_fields"] = [(0, 0)] chunk_idx = ( int(sample[self.mapper["chunk_idx"]]) if "chunk_idx" in self.mapper else 0 ) h5_path = os.path.join(self.data_root, self.hdf5_paths[chunk_idx]) with tables.File( h5_path, mode="r", libver="latest", swmr=True, ) as h5file_chunk: for key_mapper, idx_mapper in mapper.items(): if "image" not in key_mapper and "depth" not in key_mapper: continue value = sample[idx_mapper] results[(0, 0)][key_mapper] = value name = key_mapper.replace("_filename", "") value_root = "/" + value if "image" in key_mapper: results[(0, 0)]["filename"] = value file = h5file_chunk.get_node(value_root).read() image = ( torchvision.io.decode_image(torch.from_numpy(file)) .to(torch.uint8) .squeeze() ) results[(0, 0)]["image_fields"].add(name) results[(0, 0)][f"image_ori_shape"] = image.shape[-2:] results[(0, 0)][name] = image[None, ...] # collect camera information for the given image name = name.replace("image_", "") results[(0, 0)]["camera_fields"].update({"camera", "cam2w"}) K = self.get_intrinsics(idx, value) if K is None: K = torch.eye(3) K[0, 0] = K[1, 1] = 0.7 * self.image_shape[1] K[0, 2] = 0.5 * self.image_shape[1] K[1, 2] = 0.5 * self.image_shape[0] camera = Pinhole(K=K[None, ...].clone()) results[(0, 0)]["camera"] = BatchCamera.from_camera(camera) results[(0, 0)]["cam2w"] = self.get_extrinsics(idx, value)[ None, ... ] elif "depth" in key_mapper: # start = time() file = h5file_chunk.get_node(value_root).read() depth = Image.open(io.BytesIO(file)) depth = TF.pil_to_tensor(depth).squeeze().to(torch.float32) if depth.ndim == 3: depth = depth[2] + depth[1] * 255 + depth[0] * 255 * 255 results[(0, 0)]["gt_fields"].add(name) results[(0, 0)][f"depth_ori_shape"] = depth.shape depth = ( depth.view(1, 1, *depth.shape).contiguous() / self.depth_scale ) results[(0, 0)][name] = depth results = self.preprocess(results) if not self.test_mode: results = self.augment(results) results = self.postprocess(results) return results def preprocess(self, results): results = self.replicate(results) for i, seq in enumerate(results["sequence_fields"]): self.resizer.ctx = None results[seq] = self.resizer(results[seq]) num_pts = torch.count_nonzero(results[seq]["depth"] > 0) if num_pts < 50: raise IndexError(f"Too few points in depth map ({num_pts})") for key in results[seq].get("image_fields", ["image"]): results[seq][key] = results[seq][key].to(torch.float32) / 255 # update fields common in sequence for key in ["image_fields", "gt_fields", "mask_fields", "camera_fields"]: if key in results[(0, 0)]: results[key] = results[(0, 0)][key] results = self.pack_batch(results) return results def postprocess(self, results): # normalize after because color aug requires [0,255]? for key in results.get("image_fields", ["image"]): results[key] = TF.normalize(results[key], **self.normalization_stats) results = self.filler(results) results = self.unpack_batch(results) results = self.masker(results) results = self.collecter(results) return results def __getitem__(self, idx): try: if isinstance(idx, (list, tuple)): results = [self.get_single_item(i) for i in idx] else: results = self.get_single_item(idx) except Exception as e: print(f"Error loading sequence {idx} for {self.__class__.__name__}: {e}") idx = np.random.randint(0, len(self.dataset)) results = self[idx] return results def get_intrinsics(self, idx, image_name): idx_sample = self.mapper.get("K", 1000) sample = self.dataset[idx] if idx_sample >= len(sample): return None return sample[idx_sample] def get_extrinsics(self, idx, image_name): idx_sample = self.mapper.get("cam2w", 1000) sample = self.dataset[idx] if idx_sample >= len(sample): return torch.eye(4) return sample[idx_sample] def get_mapper(self): return { "image_filename": 0, "depth_filename": 1, "K": 2, } ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/ken_burns.py ================================================ import json import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.utils import DatasetFromList class KenBurns(ImageDataset): min_depth = 0.05 max_depth = 50.0 depth_scale = 256.0 test_split = "val.txt" train_split = "train.txt" intrisics_file = "intrinsics.json" hdf5_paths = [f"3dkenburns/3DKenBurns_{i}.hdf5" for i in range(8)] def __init__( self, image_shape, split_file, test_mode, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii").strip("\n") intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") intrinsics = json.loads(intrinsics) # with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f: # f.write(txt_string) # with open(os.path.join(os.environ["TMPDIR"], self.intrisics_file), "w") as f: # json.dump(intrinsics, f) dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename, chunk_idx = line.strip().split(" ") intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] sample = [image_filename, depth_filename, intrinsics_val, chunk_idx] dataset.append(sample) h5file.close() if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) if self.test_mode and not self.benchmark: # corresponds to 500 images dataset = self.chunk(dataset, chunk_dim=1, pct=0.25) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def get_mapper(self): return { "image_filename": 0, "depth_filename": 1, "K": 2, "chunk_idx": 3, } def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_copies results["synthetic"] = [True] * self.num_copies results["quality"] = [0] * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/kitti.py ================================================ import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.pipelines import AnnotationMask, KittiCrop from unidepth.datasets.utils import DatasetFromList from unidepth.utils import identity class KITTI(ImageDataset): CAM_INTRINSIC = { "2011_09_26": torch.tensor( [ [7.215377e02, 0.000000e00, 6.095593e02, 4.485728e01], [0.000000e00, 7.215377e02, 1.728540e02, 2.163791e-01], [0.000000e00, 0.000000e00, 1.000000e00, 2.745884e-03], ] ), "2011_09_28": torch.tensor( [ [7.070493e02, 0.000000e00, 6.040814e02, 4.575831e01], [0.000000e00, 7.070493e02, 1.805066e02, -3.454157e-01], [0.000000e00, 0.000000e00, 1.000000e00, 4.981016e-03], ] ), "2011_09_29": torch.tensor( [ [7.183351e02, 0.000000e00, 6.003891e02, 4.450382e01], [0.000000e00, 7.183351e02, 1.815122e02, -5.951107e-01], [0.000000e00, 0.000000e00, 1.000000e00, 2.616315e-03], ] ), "2011_09_30": torch.tensor( [ [7.070912e02, 0.000000e00, 6.018873e02, 4.688783e01], [0.000000e00, 7.070912e02, 1.831104e02, 1.178601e-01], [0.000000e00, 0.000000e00, 1.000000e00, 6.203223e-03], ] ), "2011_10_03": torch.tensor( [ [7.188560e02, 0.000000e00, 6.071928e02, 4.538225e01], [0.000000e00, 7.188560e02, 1.852157e02, -1.130887e-01], [0.000000e00, 0.000000e00, 1.000000e00, 3.779761e-03], ] ), } min_depth = 0.05 max_depth = 80.0 depth_scale = 256.0 log_mean = 2.5462 log_std = 0.5871 test_split = "kitti_eigen_test.txt" train_split = "kitti_eigen_train.txt" test_split_benchmark = "kitti_test.txt" hdf5_paths = ["kitti.hdf5"] def __init__( self, image_shape, split_file, test_mode, crop=None, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.masker = AnnotationMask( min_value=0.0, max_value=self.max_depth if test_mode else None, custom_fn=self.eval_mask if test_mode else lambda x, *args, **kwargs: x, ) self.test_mode = test_mode self.crop = crop self.cropper_base = KittiCrop(crop_size=(352, 1216)) self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 h5file.close() dataset = [] for line in txt_string.split("\n"): image_filename = line.strip().split(" ")[0] depth_filename = line.strip().split(" ")[1] if depth_filename == "None": self.invalid_depth_num += 1 continue sample = [ image_filename, depth_filename, ] dataset.append(sample) if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def get_intrinsics(self, idx, image_name): return self.CAM_INTRINSIC[image_name.split("/")[0]][:, :3].clone() def preprocess(self, results): results = self.replicate(results) for i, seq in enumerate(results["sequence_fields"]): self.resizer.ctx = None results[seq] = self.cropper_base(results[seq]) results[seq] = self.resizer(results[seq]) num_pts = torch.count_nonzero(results[seq]["depth"] > 0) if num_pts < 50: raise IndexError(f"Too few points in depth map ({num_pts})") for key in results[seq].get("image_fields", ["image"]): results[seq][key] = results[seq][key].to(torch.float32) / 255 # update fields common in sequence for key in ["image_fields", "gt_fields", "mask_fields", "camera_fields"]: if key in results[(0, 0)]: results[key] = results[(0, 0)][key] results = self.pack_batch(results) return results def eval_mask(self, valid_mask, info={}): """Do grag_crop or eigen_crop for testing""" mask_height, mask_width = valid_mask.shape[-2:] eval_mask = torch.zeros_like(valid_mask) if "garg" in self.crop: eval_mask[ ..., int(0.40810811 * mask_height) : int(0.99189189 * mask_height), int(0.03594771 * mask_width) : int(0.96405229 * mask_width), ] = 1 elif "eigen" in self.crop: eval_mask[ ..., int(0.3324324 * mask_height) : int(0.91351351 * mask_height), int(0.03594771 * mask_width) : int(0.96405229 * mask_width), ] = 1 return torch.logical_and(valid_mask, eval_mask) def get_mapper(self): return { "image_filename": 0, "depth_filename": 1, } def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [False] * self.num_copies results["quality"] = [1] * self.num_copies return results import json class KITTIBenchmark(ImageDataset): min_depth = 0.05 max_depth = 80.0 depth_scale = 256.0 test_split = "test_split.txt" train_split = "val_split.txt" intrinsics_file = "intrinsics.json" hdf5_paths = ["kitti_benchmark.hdf5"] def __init__( self, image_shape, split_file, test_mode, crop=None, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=True, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.crop = crop self.masker = AnnotationMask( min_value=self.min_depth, max_value=self.max_depth if test_mode else None, custom_fn=lambda x, *args, **kwargs: x, ) self.collecter = Collect(keys=["image_fields", "mask_fields", "gt_fields"]) self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_path), "r", libver="latest", swmr=True, ) txt_file = np.array(self.h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 intrinsics = np.array(h5file[self.intrinsics_file]).tostring().decode("ascii") intrinsics = json.loads(intrinsics) h5file.close() dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename = line.strip().split(" ") intrinsics = torch.tensor( intrinsics[os.path.join(*image_filename.split("/")[:2])] ).squeeze()[:, :3] sample = { "image_filename": image_filename, "depth_filename": depth_filename, "K": intrinsics, } dataset.append(sample) self.dataset = DatasetFromList(dataset) self.log_load_dataset() ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/kitti360.py ================================================ from typing import Any import torch from unidepth.datasets.sequence_dataset import SequenceDataset class KITTI360(SequenceDataset): min_depth = 0.01 max_depth = 80.0 depth_scale = 256.0 train_split = "train.txt" test_split = "val_split.txt" sequences_file = "sequences_split.json" hdf5_paths = [f"KITTI360.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["camera_params", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=( decode_fields if not test_mode else [*decode_fields, "points"] ), inplace_fields=inplace_fields, **kwargs, ) def preprocess(self, results): self.resizer.ctx = None for i, seq in enumerate(results["sequence_fields"]): # Create a mask where the distance from the center is less than H/2 H, W = results[seq]["image"].shape[-2:] x = torch.linspace(-W / 2, W / 2, W) y = torch.linspace(-H / 2, H / 2, H) xv, yv = torch.meshgrid(x, y, indexing="xy") distance_from_center = torch.sqrt(xv**2 + yv**2).reshape(1, 1, H, W) results[seq]["validity_mask"] = distance_from_center < (H / 2) return super().preprocess(results) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [False] * self.num_frames * self.num_copies results["quality"] = [1] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/kitti_multi.py ================================================ import json import os from typing import Any import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.pipelines import AnnotationMask, KittiCrop from unidepth.datasets.sequence_dataset import SequenceDataset from unidepth.datasets.utils import DatasetFromList from unidepth.utils import identity class KITTIMulti(SequenceDataset): min_depth = 0.05 max_depth = 80.0 depth_scale = 256.0 default_fps = 10.0 test_split = "val.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = ["KITTI_sequence.hdf5"] def __init__( self, image_shape, split_file, test_mode, crop=None, augmentations_db={}, normalize=True, resize_method="hard", mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) self.test_mode = test_mode self.crop = crop self.cropper_base = KittiCrop(crop_size=(352, 1216)) self.masker = AnnotationMask( min_value=0.0, max_value=self.max_depth if test_mode else None, custom_fn=self.eval_mask if test_mode else identity, ) self.eval_last = True def __len__(self): if self.test_mode: return 64 # FIXME: Hardcoded for now return len(self.dataset) def preprocess(self, results): self.resizer.ctx = None for i, seq in enumerate(results["sequence_fields"]): results[seq] = self.cropper_base(results[seq]) results[seq] = self.resizer(results[seq]) for key in results[seq].get("image_fields", ["image"]): results[seq][key] = results[seq][key].to(torch.float32) / 255 results.update({k: v for k, v in results[(0, 0)].items() if "fields" in k}) results = self.pack_batch(results) return results def eval_mask(self, valid_mask, info={}): """Do grag_crop or eigen_crop for testing""" mask_height, mask_width = valid_mask.shape[-2:] eval_mask = torch.zeros_like(valid_mask) if "garg" in self.crop: eval_mask[ ..., int(0.40810811 * mask_height) : int(0.99189189 * mask_height), int(0.03594771 * mask_width) : int(0.96405229 * mask_width), ] = 1 elif "eigen" in self.crop: eval_mask[ ..., int(0.3324324 * mask_height) : int(0.91351351 * mask_height), int(0.03594771 * mask_width) : int(0.96405229 * mask_width), ] = 1 else: return valid_mask return torch.logical_and(valid_mask, eval_mask) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/kitti_rmvd.py ================================================ import json import os from typing import Any import h5py import numpy as np import torch from unidepth.datasets.pipelines import AnnotationMask, Compose, KittiCrop from unidepth.datasets.sequence_dataset import SequenceDataset from unidepth.utils import identity class KITTIRMVD(SequenceDataset): min_depth = 0.05 max_depth = 80.0 depth_scale = 256.0 default_fps = 10 test_split = "test.txt" train_split = "test.txt" sequences_file = "sequences.json" hdf5_paths = ["kitti_rmvd.hdf5"] def __init__( self, image_shape, split_file, test_mode, crop=None, augmentations_db={}, normalize=True, resize_method="hard", mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) self.crop = crop self.resizer = Compose([KittiCrop(crop_size=(352, 1216)), self.resizer]) def eval_mask(self, valid_mask, info={}): """Do grag_crop or eigen_crop for testing""" mask_height, mask_width = valid_mask.shape[-2:] eval_mask = torch.zeros_like(valid_mask) if "garg" in self.crop: eval_mask[ ..., int(0.40810811 * mask_height) : int(0.99189189 * mask_height), int(0.03594771 * mask_width) : int(0.96405229 * mask_width), ] = 1 elif "eigen" in self.crop: eval_mask[ ..., int(0.3324324 * mask_height) : int(0.91351351 * mask_height), int(0.03594771 * mask_width) : int(0.96405229 * mask_width), ] = 1 else: return valid_mask return torch.logical_and(valid_mask, eval_mask) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/lyft.py ================================================ import json import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.utils import DatasetFromList class Lyft(ImageDataset): min_depth = 0.05 max_depth = 80.0 depth_scale = 256.0 test_split = "test.txt" train_split = "train.txt" intrisics_file = "intrinsics.json" hdf5_paths = ["Lyft2.hdf5"] def __init__( self, image_shape, split_file, test_mode, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1 intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") intrinsics = json.loads(intrinsics) # with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f: # f.write(txt_string) # with open(os.path.join(os.environ["TMPDIR"], self.intrisics_file), "w") as f: # json.dump(intrinsics, f) dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename = line.strip().split(" ") intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] sample = [ image_filename, depth_filename, intrinsics_val, ] dataset.append(sample) h5file.close() if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [False] return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/mapillary.py ================================================ import json import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.utils import DatasetFromList class Mapillary(ImageDataset): min_depth = 0.01 max_depth = 70.0 depth_scale = 256.0 test_split = "mapillary_val.txt" train_split = "mapillary_train_clean.txt" intrisics_file = "intrinsics.json" hdf5_paths = ["Mapillary.hdf5"] def __init__( self, image_shape, split_file, test_mode, crop=None, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.crop = crop self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1 intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") intrinsics = json.loads(intrinsics) dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename = line.strip().split(" ") intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] sample = [image_filename, depth_filename, intrinsics_val] dataset.append(sample) h5file.close() if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) if self.test_mode and not self.benchmark: dataset = self.chunk(dataset, chunk_dim=1, pct=0.05) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def pre_pipeline(self, results): results = super().pre_pipeline(results) results["si"] = [True] * self.num_copies results["valid_camera"] = [False] * self.num_copies results["dense"] = [False] * self.num_copies results["quality"] = [2] * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/matrix_city.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class MatrixCity(SequenceDataset): min_depth = 0.01 max_depth = 200.0 depth_scale = 1000.0 test_split = "test.txt" train_split = "train_full.txt" sequences_file = "sequences.json" hdf5_paths = [f"MatrixCity.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["quality"] = [0] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/matterport3d.py ================================================ from typing import Any import torch from unidepth.datasets.pipelines import Compose, PanoCrop, PanoRoll from unidepth.datasets.sequence_dataset import SequenceDataset class Matterport3D(SequenceDataset): min_depth = 0.01 max_depth = 10.0 depth_scale = 1000.0 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = [f"Matterport3D.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["cam2w", "camera_params"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) self.resizer = Compose( [PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer] ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["quality"] = [1] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/megadepth.py ================================================ import os import h5py import numpy as np from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.utils import DatasetFromList class MegaDepth(ImageDataset): min_depth = 0.01 max_depth = 1000.0 depth_scale = 50.0 test_split = "test.txt" train_split = "train.txt" hdf5_paths = ["MegaDepth.hdf5"] def __init__( self, image_shape, split_file, test_mode, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1 # with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f: # f.write(txt_string) dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename = line.strip().split(" ") sample = [ image_filename, depth_filename, ] dataset.append(sample) h5file.close() if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) else: dataset = self.chunk(dataset, chunk_dim=1, pct=0.5) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def pre_pipeline(self, results): results = super().pre_pipeline(results) results["ssi"] = [True] results["valid_camera"] = [False] results["dense"] = [False] return results def get_mapper(self): return { "image_filename": 0, "depth_filename": 1, } ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/megadepth_s.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class MegaDepthS(SequenceDataset): min_depth = 0.001 max_depth = 10000.0 depth_scale = 512.0 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences_filter_clean.json" hdf5_paths = ["MegaDepthS.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["intrinsics", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["si"] = [True] * self.num_frames * self.num_copies results["dense"] = [False] * self.num_frames * self.num_copies results["quality"] = [2] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/midair.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class MidAir(SequenceDataset): min_depth = 0.1 max_depth = 1000.0 depth_scale = 1000.0 default_fps = 6 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = ["MidAir.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["quality"] = [0] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/mip.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class MIP(SequenceDataset): min_depth = 0.01 max_depth = 100.0 depth_scale = 1000.0 default_fps = 10 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = ["MIP.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [False] * self.num_frames * self.num_copies results["si"] = [True] * self.num_frames * self.num_copies results["quality"] = [2] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/ms2.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class MS2(SequenceDataset): min_depth = 0.01 max_depth = 100.0 depth_scale = 256.0 default_fps = 5 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = ["MS2.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [False] * self.num_frames * self.num_copies results["synthetic"] = [False] * self.num_frames * self.num_copies results["quality"] = [1] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/mvimgnet.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset INVALID_SEQUENCES = [ "1/000121f2-0", "15/1600ae56-0", "26/000000f3-0", "33/1d00e677-0", "43/22008925-0", "49/000147db-0", "51/23002a43-0", "51/23000916-0", "108/000133ae-0", "129/000037f2-0", "141/17012545-0", "141/1700f3de-0", "152/1b00e061-0", "154/1d00decb-0", "154/1d017c1c-0", "154/1d0019a5-0", "154/1d00334d-0", "154/1d012ed6-0", "154/1d016b8a-0", "154/1d016cc1-0", "154/1d008d5f-0", "159/000157f9-0", "159/00000b96-0", "159/000075c0-0", "159/0000445c-0", "159/000056a0-0", "159/00010c68-0", "159/0000573b-0", "159/00002698-0", "159/00008fca-0", "159/00009ef8-0", "159/00015f05-0", "159/0000c6df-0", "159/0000ee59-0", "163/290159d2-0", "163/29016c7c-0", "163/2900239c-0", "163/29002f7b-0", "163/29014b05-0", "163/29000196-0", "163/2901750f-0", "164/1b0145cf-0", "164/1b00eb1d-0", "164/1b00c28b-0", "164/1b0110d0-0", "164/1b00dd20-0", "165/2600e15a-0", "165/26008444-0", "165/260145c5-0", "165/26003a0c-0", "165/260106ba-0", "165/26001548-0", "167/2a0092b0-0", "167/2a014dbe-0", "167/2a003ce6-0", "169/1800c645-0", "171/2500014d-0", "176/1d0021c2-0", "176/1d014abf-0", "176/1d00e714-0", "176/1d0159cb-0", "176/1e016629-0", "178/000102b8-0", "191/23008fdb-0", "191/2300187f-0", "191/2300ae68-0", "191/230076dd-0", "191/24007d7e-0", "192/000107b5-0", "195/1f012359-0", "195/1f00f751-0", "195/1f011331-0", "195/1e00d999-0", "196/1c01304e-0", "198/1a00e02f-0", "198/050084ac-0", "198/1a0075fa-0", "199/1e001742-0", "199/1e00116a-0", "199/1e011d00-0", "199/1e018040-0", "199/1e001107-0", ] class MVImgNet(SequenceDataset): min_depth = 0.005 max_depth = 10.0 # weird scale issue, should be 1000, but avg depth is ~10meters... depth_scale = 1000.0 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = ["MVImgNet.hdf5"] invalid_sequences = INVALID_SEQUENCES def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["intrinsics", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["si"] = [True] * self.num_frames * self.num_copies results["dense"] = [False] * self.num_frames * self.num_copies results["quality"] = [2] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/mvsynth.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class MVSynth(SequenceDataset): min_depth = 0.1 max_depth = 1000.0 depth_scale = 256.0 test_split = "val.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = [f"MVSynth.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["si"] = [True] * self.num_frames * self.num_copies results["quality"] = [0] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/nerds360.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class NeRDS360(SequenceDataset): min_depth = 0.01 max_depth = 1000.0 depth_scale = 1000.0 test_split = "val.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = ["NeRDS360.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["quality"] = [1] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/niantic_mapfree.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class NianticMapFree(SequenceDataset): min_depth = 0.1 max_depth = 250.0 depth_scale = 512.0 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = [f"NianticMapFree.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["si"] = [True] * self.num_frames * self.num_copies results["dense"] = [False] * self.num_frames * self.num_copies results["quality"] = [2] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/nuscenes.py ================================================ import json import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.utils import DatasetFromList class Nuscenes(ImageDataset): min_depth = 0.05 max_depth = 80.0 depth_scale = 256.0 test_split = "val.txt" train_split = "train.txt" intrisics_file = "intrinsics.json" # hdf5_paths = ["Nuscenes2.hdf5"] hdf5_paths = [f"nuscenes/nuscenes_{i}.hdf5" for i in range(8)] def __init__( self, image_shape, split_file, test_mode, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii").strip("\n") intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") intrinsics = json.loads(intrinsics) dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename, chunk_idx = line.strip().split(" ") intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] sample = [image_filename, depth_filename, intrinsics_val, chunk_idx] dataset.append(sample) h5file.close() if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=6, pct=self.mini) if self.test_mode and not self.benchmark: dataset = self.chunk(dataset, chunk_dim=6, pct=0.1) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def get_mapper(self): return { "image_filename": 0, "depth_filename": 1, "K": 2, "chunk_idx": 3, } def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [False] * self.num_copies results["quality"] = [1] * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/nyuv2.py ================================================ import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.pipelines import AnnotationMask from unidepth.datasets.utils import DatasetFromList from unidepth.utils import identity class NYUv2Depth(ImageDataset): CAM_INTRINSIC = { "ALL": torch.tensor( [ [5.1885790117450188e02, 0, 3.2558244941119034e02], [0, 5.1946961112127485e02, 2.5373616633400465e02], [0, 0, 1], ] ) } min_depth = 0.005 max_depth = 10.0 depth_scale = 1000.0 log_mean = 0.9140 log_std = 0.4825 test_split = "nyu_test.txt" train_split = "nyu_train.txt" hdf5_paths = ["nyuv2.hdf5"] def __init__( self, image_shape, split_file, test_mode, crop=None, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.masker = AnnotationMask( min_value=0.0, max_value=self.max_depth if test_mode else None, custom_fn=self.eval_mask if test_mode else lambda x, *args, **kwargs: x, ) self.test_mode = test_mode self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 h5file.close() dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename, _ = line.strip().split(" ") sample = [ image_filename, depth_filename, ] dataset.append(sample) if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_copies return results def get_intrinsics(self, idx, image_name): return self.CAM_INTRINSIC["ALL"].clone() def eval_mask(self, valid_mask, info={}): border_mask = torch.zeros_like(valid_mask) border_mask[..., 45:-9, 41:-39] = 1 return torch.logical_and(valid_mask, border_mask) def get_mapper(self): return { "image_filename": 0, "depth_filename": 1, } def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_copies results["quality"] = [2] * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/oasis.py ================================================ import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.utils import DatasetFromList class OASISv2(ImageDataset): min_depth = 0.01 max_depth = 400.0 depth_scale = 1000.0 test_split = "val.txt" train_split = "train.txt" hdf5_paths = ["Oasis2.hdf5"] def __init__( self, image_shape, split_file, test_mode, crop=None, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1 dataset = [] # with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f: # f.write(txt_string) for line in txt_string.split("\n"): image_filename, depth_filename = line.strip().split(" ") sample = [image_filename, depth_filename] dataset.append(sample) h5file.close() if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def pre_pipeline(self, results): results = super().pre_pipeline(results) results["ssi"] = [True] results["valid_camera"] = [False] return results def get_mapper(self): return { "image_filename": 0, "depth_filename": 1, } ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/pipelines/__init__.py ================================================ from .formating import AnnotationMask, Collect from .transforms import (Compose, ContextCrop, Crop, GaussianBlur, KittiCrop, PanoCrop, PanoRoll, RandomAutoContrast, RandomBrightness, RandomColor, RandomColorJitter, RandomContrast, RandomEqualize, RandomFiller, RandomFlip, RandomGamma, RandomGrayscale, RandomInvert, RandomMasking, RandomPosterize, RandomSaturation, RandomSharpness, RandomShear, RandomSolarize, RandomTranslate, Rotate) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/pipelines/formating.py ================================================ from collections.abc import Sequence import numpy as np import torch class Collect(object): def __init__( self, keys, meta_keys=( "filename", "keyframe_idx", "sequence_name", "image_filename", "depth_filename", "image_ori_shape", "camera", "original_camera", "sfm", "image_shape", "resized_shape", "scale_factor", "rotation", "resize_factor", "flip", "flip_direction", "dataset_name", "paddings", "max_value", "log_mean", "log_std", "image_rescale", "focal_rescale", "depth_rescale", ), ): self.keys = keys self.meta_keys = meta_keys def __call__(self, results): data_keys = [key for field in self.keys for key in results.get(field, [])] data = { key: { sequence_key: results[key][sequence_key] for sequence_key in results["sequence_fields"] } for key in data_keys } data["img_metas"] = { key: value for key, value in results.items() if key not in data_keys } return data def __repr__(self): return ( self.__class__.__name__ + f"(keys={self.keys}, meta_keys={self.meta_keys})" ) class AnnotationMask(object): def __init__(self, min_value, max_value, custom_fn=lambda x: x): self.min_value = min_value self.max_value = max_value self.custom_fn = custom_fn def __call__(self, results): for key in results.get("gt_fields", []): if key + "_mask" in results["mask_fields"]: if "flow" in key: for sequence_idx in results.get("sequence_fields", []): boundaries = (results[key][sequence_idx] >= -1) & ( results[key][sequence_idx] <= 1 ) boundaries = boundaries[:, :1] & boundaries[:, 1:] results[key + "_mask"][sequence_idx] = ( results[key + "_mask"][sequence_idx] & boundaries ) continue for sequence_idx in results.get("sequence_fields", []): mask = results[key][sequence_idx] > self.min_value if self.max_value is not None: mask = mask & (results[key][sequence_idx] < self.max_value) mask = self.custom_fn(mask, info=results) if key + "_mask" not in results: results[key + "_mask"] = {} results[key + "_mask"][sequence_idx] = mask.bool() results["mask_fields"].add(key + "_mask") return results def __repr__(self): return ( self.__class__.__name__ + f"(min_value={self.min_value}, max_value={ self.max_value})" ) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/pipelines/transforms.py ================================================ import os import random from copy import deepcopy from math import ceil, exp, log, log2, log10, tanh from typing import Dict, List, Tuple import numpy as np import torch import torch.nn.functional as F import torchvision.transforms.v2.functional as TF from unidepth.utils.geometric import downsample class PanoCrop: def __init__(self, crop_v=0.1): self.crop_v = crop_v def _crop_data(self, results, crop_size): """Function to randomly crop images, bounding boxes, masks, semantic segmentation maps. Args: results (dict): Result dict from loading pipeline. crop_size (tuple): Expected absolute size after cropping, (h, w). allow_negative_crop (bool): Whether to allow a crop that does not contain any bbox area. Default to False. Returns: dict: Randomly cropped results, 'image_shape' key in result dict is updated according to crop size. """ offset_w, offset_h = crop_size left, top, right, bottom = offset_w[0], offset_h[0], offset_w[1], offset_h[1] H, W = results["image"].shape[-2:] for key in results.get("image_fields", ["image"]): img = results[key][..., top : H - bottom, left : W - right] results[key] = img results["image_shape"] = tuple(img.shape) for key in results.get("gt_fields", []): results[key] = results[key][..., top : H - bottom, left : W - right] for key in results.get("mask_fields", []): results[key] = results[key][..., top : H - bottom, left : W - right] results["camera"].crop(left, top, right, bottom) return results def __call__(self, results): H, W = results["image"].shape[-2:] crop_w = (0, 0) crop_h = (int(H * self.crop_v), int(H * self.crop_v)) results = self._crop_data(results, (crop_w, crop_h)) return results class PanoRoll: def __init__(self, roll=[-0.5, 0.5]): self.roll = roll def __call__(self, results): W = results["image"].shape[-1] roll = random.randint(int(W * self.roll[0]), int(W * self.roll[1])) for key in results.get("image_fields", ["image"]): img = results[key] img = torch.roll(img, roll, dims=-1) results[key] = img for key in results.get("gt_fields", []): results[key] = torch.roll(results[key], roll, dims=-1) for key in results.get("mask_fields", []): results[key] = torch.roll(results[key], roll, dims=-1) return results class RandomFlip: """Flip the points & bbox. If the input dict contains the key "flip", then the flag will be used, otherwise it will be randomly decided by a ratio specified in the init method. Args: flip_ratio_bev_horizontal (float, optional): The flipping probability in horizontal direction. Defaults to 0.0. flip_ratio_bev_vertical (float, optional): The flipping probability in vertical direction. Defaults to 0.0. """ def __init__(self, direction="horizontal", prob=0.5, **kwargs): self.flip_ratio = prob valid_directions = ["horizontal", "vertical", "diagonal"] if isinstance(direction, str): assert direction in valid_directions elif isinstance(direction, list): assert set(direction).issubset(set(valid_directions)) else: raise ValueError("direction must be either str or list of str") self.direction = direction def __call__(self, results): """Call function to flip points, values in the ``bbox3d_fields`` and also flip 2D image and its annotations. Args: results (dict): Result dict from loading pipeline. Returns: dict: Flipped results, 'flip', 'flip_direction', """ if "flip" not in results: if isinstance(self.direction, list): # None means non-flip direction_list = self.direction + [None] else: # None means non-flip direction_list = [self.direction, None] if isinstance(self.flip_ratio, list): non_flip_ratio = 1 - sum(self.flip_ratio) flip_ratio_list = self.flip_ratio + [non_flip_ratio] else: non_flip_ratio = 1 - self.flip_ratio # exclude non-flip single_ratio = self.flip_ratio / (len(direction_list) - 1) flip_ratio_list = [single_ratio] * (len(direction_list) - 1) + [ non_flip_ratio ] cur_dir = np.random.choice(direction_list, p=flip_ratio_list) results["flip"] = cur_dir is not None if "flip_direction" not in results: results["flip_direction"] = cur_dir if results["flip"]: # flip image if results["flip_direction"] != "vertical": for key in results.get("image_fields", ["image"]): results[key] = TF.hflip(results[key]) for key in results.get("mask_fields", []): results[key] = TF.hflip(results[key]) for key in results.get("gt_fields", []): results[key] = TF.hflip(results[key]) if "flow" in key: # flip u direction results[key][:, 0] = -results[key][:, 0] H, W = results["image"].shape[-2:] results["camera"] = results["camera"].flip( H=H, W=W, direction="horizontal" ) # results["K"][..., 0, 2] = results["image"].shape[-1] - results["K"][..., 0, 2] # flip: - t_x rotate around y by: pi - angle_y * 2 flip_transform = torch.tensor( [[-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=torch.float32, ).unsqueeze(0) repeats = (results["cam2w"].shape[0],) + (1,) * ( results["cam2w"].ndim - 1 ) results["cam2w"] = flip_transform.repeat(*repeats) @ results["cam2w"] if results["flip_direction"] != "horizontal": for key in results.get("image_fields", ["image"]): results[key] = TF.vflip(results[key]) for key in results.get("mask_fields", []): results[key] = TF.vflip(results[key]) for key in results.get("gt_fields", []): results[key] = TF.vflip(results[key]) results["K"][..., 1, 2] = ( results["image"].shape[-2] - results["K"][..., 1, 2] ) results["flip"] = [results["flip"]] * len(results["image"]) return results def __repr__(self): """str: Return a string that describes the module.""" repr_str = self.__class__.__name__ repr_str += f" flip_ratio={self.flip_ratio})" return repr_str class Crop: def __init__( self, crop_size, crop_type="absolute", crop_offset=(0, 0), ): if crop_type not in [ "relative_range", "relative", "absolute", "absolute_range", ]: raise ValueError(f"Invalid crop_type {crop_type}.") if crop_type in ["absolute", "absolute_range"]: assert crop_size[0] > 0 and crop_size[1] > 0 assert isinstance(crop_size[0], int) and isinstance(crop_size[1], int) else: assert 0 < crop_size[0] <= 1 and 0 < crop_size[1] <= 1 self.crop_size = crop_size self.crop_type = crop_type self.offset_h, self.offset_w = ( crop_offset[: len(crop_offset) // 2], crop_offset[len(crop_offset) // 2 :], ) def _get_crop_size(self, image_shape): h, w = image_shape if self.crop_type == "absolute": return (min(self.crop_size[0], h), min(self.crop_size[1], w)) elif self.crop_type == "absolute_range": assert self.crop_size[0] <= self.crop_size[1] crop_h = np.random.randint( min(h, self.crop_size[0]), min(h, self.crop_size[1]) + 1 ) crop_w = np.random.randint( min(w, self.crop_size[0]), min(w, self.crop_size[1]) + 1 ) return crop_h, crop_w elif self.crop_type == "relative": crop_h, crop_w = self.crop_size return int(h * crop_h + 0.5), int(w * crop_w + 0.5) elif self.crop_type == "relative_range": crop_size = np.asarray(self.crop_size, dtype=np.float32) crop_h, crop_w = crop_size + np.random.rand(2) * (1 - crop_size) return int(h * crop_h + 0.5), int(w * crop_w + 0.5) def _crop_data(self, results, crop_size): assert crop_size[0] > 0 and crop_size[1] > 0 for key in results.get("image_fields", ["image"]): img = results[key] img = TF.crop( img, self.offset_h[0], self.offset_w[0], crop_size[0], crop_size[1] ) results[key] = img results["image_shape"] = tuple(img.shape) for key in results.get("gt_fields", []): gt = results[key] results[key] = TF.crop( gt, self.offset_h[0], self.offset_w[0], crop_size[0], crop_size[1] ) # crop semantic seg for key in results.get("mask_fields", []): mask = results[key] results[key] = TF.crop( mask, self.offset_h[0], self.offset_w[0], crop_size[0], crop_size[1] ) results["K"][..., 0, 2] = results["K"][..., 0, 2] - self.offset_w[0] results["K"][..., 1, 2] = results["K"][..., 1, 2] - self.offset_h[0] return results def __call__(self, results): image_shape = results["image"].shape[-2:] crop_size = self._get_crop_size(image_shape) results = self._crop_data(results, crop_size) return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += f"(crop_size={self.crop_size}, " repr_str += f"crop_type={self.crop_type}, " return repr_str class KittiCrop: def __init__(self, crop_size): self.crop_size = crop_size def _crop_data(self, results, crop_size): """Function to randomly crop images, bounding boxes, masks, semantic segmentation maps. Args: results (dict): Result dict from loading pipeline. crop_size (tuple): Expected absolute size after cropping, (h, w). allow_negative_crop (bool): Whether to allow a crop that does not contain any bbox area. Default to False. Returns: dict: Randomly cropped results, 'image_shape' key in result dict is updated according to crop size. """ assert crop_size[0] > 0 and crop_size[1] > 0 for key in results.get("image_fields", ["image"]): img = results[key] h, w = img.shape[-2:] offset_h, offset_w = int(h - self.crop_size[0]), int( (w - self.crop_size[1]) / 2 ) # crop the image img = TF.crop(img, offset_h, offset_w, crop_size[0], crop_size[1]) results[key] = img results["image_shape"] = tuple(img.shape) for key in results.get("gt_fields", []): gt = results[key] results[key] = TF.crop(gt, offset_h, offset_w, crop_size[0], crop_size[1]) # crop semantic seg for key in results.get("mask_fields", []): mask = results[key] results[key] = TF.crop(mask, offset_h, offset_w, crop_size[0], crop_size[1]) results["camera"].crop(offset_w, offset_h) return results def __call__(self, results): """Call function to randomly crop images, bounding boxes, masks, semantic segmentation maps. Args: results (dict): Result dict from loading pipeline. Returns: dict: Randomly cropped results, 'image_shape' key in result dict is updated according to crop size. """ results = self._crop_data(results, self.crop_size) return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += f"(crop_size={self.crop_size}, " return repr_str class RandomMasking: def __init__( self, mask_ratio, mask_patch=16, prob=0.5, warmup_steps=50000, sampling="random", curriculum=False, ): self.mask_patch = mask_patch self.prob = prob self.mask_ratio = mask_ratio self.warmup_steps = max(1, warmup_steps) self.hard_bound = 1 self.idx = 0 self.curriculum = curriculum self.sampling = sampling self.low_bound = 0.0 self.up_bound = 0.0 def __call__(self, results): B, _, H, W = results["image"].shape device = results["image"].device down_size = H // self.mask_patch, W // self.mask_patch if np.random.random() > self.prob: # fill with dummy return self._nop(results, down_size, device) validity_mask = results["validity_mask"].float().reshape(B, -1, H, W) validity_mask = F.interpolate(validity_mask, size=down_size).bool() validity_mask = validity_mask.reshape(B, 1, *down_size) is_random = self.is_warmup or results.get("guidance") is None if not is_random: guidance = F.interpolate(results["guidance"], size=(H, W), mode="bilinear") results["guidance"] = -F.max_pool2d( -guidance, kernel_size=self.mask_patch, stride=self.mask_patch ) if is_random and self.sampling == "inverse": sampling = self.inverse_sampling elif is_random and self.sampling == "random": sampling = self.random_sampling else: sampling = self.guided_sampling mask_ratio = np.random.uniform(self.low_bound, self.up_bound) for key in results.get("image_fields", ["image"]): mask = sampling(results, mask_ratio, down_size, validity_mask, device) results[key + "_mask"] = mask return results def _nop(self, results, down_size, device): B = results["image"].shape[0] for key in results.get("image_fields", ["image"]): mask_blocks = torch.zeros(size=(B, 1, *down_size), device=device) results[key + "_mask"] = mask_blocks return results def random_sampling(self, results, mask_ratio, down_size, validity_mask, device): B = results["image"].shape[0] prob_blocks = torch.rand(size=(B, 1, *down_size), device=device) mask_blocks = torch.logical_and(prob_blocks < mask_ratio, validity_mask) return mask_blocks def inverse_sampling(self, results, mask_ratio, down_size, validity_mask, device): # from PIL import Image # from unidepth.utils import colorize def area_sample(depth, fx, fy): dtype = depth.dtype B = depth.shape[0] H, W = down_size depth = downsample(depth, depth.shape[-2] // H) depth[depth > 200] = 50 # set sky as if depth 50 meters pixel_area3d = depth / torch.sqrt(fx * fy) # Set invalid as -1 (no div problem) -> then clip to 0.0 pixel_area3d[depth == 0.0] = -1 prob_density = (1 / pixel_area3d).clamp(min=0.0).square() prob_density = prob_density / prob_density.sum( dim=(-1, -2), keepdim=True ).clamp(min=1e-5) # Image.fromarray((prob_density[0] * 255 * 100).clamp(min=0.0, max=255.0).squeeze().cpu().byte().numpy()).save("prob_density.png") # Sample locations based on prob_density prob_density_flat = prob_density.view(B, -1) # Get the avgerage valid locations, of those we mask self.mask_ratio valid_locations = (prob_density_flat > 0).to(dtype).sum(dim=1) masks = [] for i in range(B): num_samples = int(valid_locations[i] * mask_ratio) mask = torch.zeros_like(prob_density_flat[i]) # Sample indices if num_samples > 0: sampled_indices_flat = torch.multinomial( prob_density_flat[i], num_samples, replacement=False ) mask.scatter_(0, sampled_indices_flat, 1) masks.append(mask) return torch.stack(masks).bool().view(B, 1, H, W) def random_sample(validity_mask): prob_blocks = torch.rand( size=(validity_mask.shape[0], 1, *down_size), device=device ) mask = torch.logical_and(prob_blocks < mask_ratio, validity_mask) return mask fx = results["K"][..., 0, 0].view(-1, 1, 1, 1) / self.mask_patch fy = results["K"][..., 1, 1].view(-1, 1, 1, 1) / self.mask_patch valid = ~results["ssi"] & ~results["si"] & results["valid_camera"] mask_blocks = torch.zeros_like(validity_mask) if valid.any(): out = area_sample(results["depth"][valid], fx[valid], fy[valid]) mask_blocks[valid] = out if (~valid).any(): mask_blocks[~valid] = random_sample(validity_mask[~valid]) # mask_blocks_ = (mask_blocks.float() * 255).squeeze(1).byte().cpu().numpy() # Image.fromarray(mask_blocks_[0]).save("mask1.png") # Image.fromarray(mask_blocks_[-1]).save("mask2.png") # dd = results["depth"] # Image.fromarray(colorize(dd[0].squeeze().cpu().numpy())).save("depth1_p.png") # Image.fromarray(colorize(dd[-1].squeeze().cpu().numpy())).save("depth2_p.png") # dd = downsample(dd, dd.shape[-2] // down_size[0]) # Image.fromarray(colorize(dd[0].squeeze().cpu().numpy())).save("depth1.png") # Image.fromarray(colorize(dd[-1].squeeze().cpu().numpy())).save("depth2.png") # raise ValueError return mask_blocks def guided_sampling(self, results, mask_ratio, down_size, validity_mask, device): # get the lowest (based on guidance) "mask_ratio" quantile of the patches that are in validity mask B = results["image"].shape[0] guidance = results["guidance"] mask_blocks = torch.zeros(size=(B, 1, *down_size), device=device) for b in range(B): low_bound = torch.quantile( guidance[b][validity_mask[b]], max(0.0, self.hard_bound - mask_ratio) ) up_bound = torch.quantile( guidance[b][validity_mask[b]], min(1.0, self.hard_bound) ) mask_blocks[b] = torch.logical_and( guidance[b] < up_bound, guidance[b] > low_bound ) mask_blocks = torch.logical_and(mask_blocks, validity_mask) return mask_blocks def step(self): self.idx += 1 # schedule hard from 1.0 to self.mask_ratio if self.curriculum: step = max(0, self.idx / self.warmup_steps / 2 - 0.5) self.hard_bound = 1 - (1 - self.mask_ratio) * tanh(step) self.up_bound = self.mask_ratio * tanh(step) self.low_bound = 0.2 * tanh(step) @property def is_warmup(self): return self.idx < self.warmup_steps class Rotate: def __init__( self, angle, center=None, img_fill_val=(123.68, 116.28, 103.53), prob=0.5 ): if isinstance(img_fill_val, (float, int)): img_fill_val = tuple([float(img_fill_val)] * 3) elif isinstance(img_fill_val, tuple): assert len(img_fill_val) == 3, ( "image_fill_val as tuple must " f"have 3 elements. got {len(img_fill_val)}." ) img_fill_val = tuple([float(val) for val in img_fill_val]) else: raise ValueError("image_fill_val must be float or tuple with 3 elements.") assert np.all( [0 <= val <= 255 for val in img_fill_val] ), f"all elements of img_fill_val should between range [0,255] got {img_fill_val}." assert 0 <= prob <= 1.0, f"The probability should be in range [0,1]bgot {prob}." self.center = center self.img_fill_val = img_fill_val self.prob = prob self.random = not isinstance(angle, (float, int)) self.angle = angle def _rotate(self, results, angle, center=None, fill_val=0.0): for key in results.get("image_fields", ["image"]): img = results[key] img_rotated = TF.rotate( img, angle, center=center, interpolation=TF.InterpolationMode.NEAREST_EXACT, fill=self.img_fill_val, ) results[key] = img_rotated.to(img.dtype) results["image_shape"] = results[key].shape for key in results.get("mask_fields", []): results[key] = TF.rotate( results[key], angle, center=center, interpolation=TF.InterpolationMode.NEAREST_EXACT, fill=fill_val, ) for key in results.get("gt_fields", []): results[key] = TF.rotate( results[key], angle, center=center, interpolation=TF.InterpolationMode.NEAREST_EXACT, fill=fill_val, ) def __call__(self, results): if np.random.random() > self.prob: return results angle = ( (self.angle[1] - self.angle[0]) * np.random.rand() + self.angle[0] if self.random else np.random.choice([-1, 1], size=1) * self.angle ) self._rotate(results, angle, None, fill_val=0.0) results["rotation"] = angle return results class RandomColor: def __init__(self, level, prob=0.5): self.random = not isinstance(level, (float, int)) self.level = level self.prob = prob def _adjust_color_img(self, results, factor=1.0): for key in results.get("image_fields", ["image"]): results[key] = TF.adjust_hue(results[key], factor) # .to(img.dtype) def __call__(self, results): if np.random.random() > self.prob: return results factor = ( ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) if self.random else self.level ) self._adjust_color_img(results, factor) return results class RandomSaturation: def __init__(self, level, prob=0.5): self.random = not isinstance(level, (float, int)) self.level = level self.prob = prob def _adjust_saturation_img(self, results, factor=1.0): for key in results.get("image_fields", ["image"]): # NOTE defaultly the image should be BGR format results[key] = TF.adjust_saturation(results[key], factor) # .to(img.dtype) def __call__(self, results): if np.random.random() > self.prob: return results factor = ( 2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) if self.random else 2**self.level ) self._adjust_saturation_img(results, factor) return results class RandomSharpness: def __init__(self, level, prob=0.5): self.random = not isinstance(level, (float, int)) self.level = level self.prob = prob def _adjust_sharpeness_img(self, results, factor=1.0): for key in results.get("image_fields", ["image"]): # NOTE defaultly the image should be BGR format results[key] = TF.adjust_sharpness(results[key], factor) # .to(img.dtype) def __call__(self, results): if np.random.random() > self.prob: return results factor = ( 2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) if self.random else 2**self.level ) self._adjust_sharpeness_img(results, factor) return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += f"(level={self.level}, " repr_str += f"prob={self.prob})" return repr_str class RandomSolarize: def __init__(self, level, prob=0.5): self.random = not isinstance(level, (float, int)) self.level = level self.prob = prob def _adjust_solarize_img(self, results, factor=255.0): for key in results.get("image_fields", ["image"]): results[key] = TF.solarize(results[key], factor) # .to(img.dtype) def __call__(self, results): if np.random.random() > self.prob: return results factor = ( ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) if self.random else self.level ) self._adjust_solarize_img(results, factor) return results class RandomPosterize: def __init__(self, level, prob=0.5): self.random = not isinstance(level, (float, int)) self.level = level self.prob = prob def _posterize_img(self, results, factor=1.0): for key in results.get("image_fields", ["image"]): results[key] = TF.posterize(results[key], int(factor)) # .to(img.dtype) def __call__(self, results): if np.random.random() > self.prob: return results factor = ( ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) if self.random else self.level ) self._posterize_img(results, factor) return results class RandomEqualize: def __init__(self, prob=0.5): assert 0 <= prob <= 1.0, "The probability should be in range [0,1]." self.prob = prob def _imequalize(self, results): for key in results.get("image_fields", ["image"]): results[key] = TF.equalize(results[key]) # .to(img.dtype) def __call__(self, results): if np.random.random() > self.prob: return results self._imequalize(results) return results class RandomBrightness: def __init__(self, level, prob=0.5): self.random = not isinstance(level, (float, int)) self.level = level self.prob = prob def _adjust_brightness_img(self, results, factor=1.0): for key in results.get("image_fields", ["image"]): results[key] = TF.adjust_brightness(results[key], factor) # .to(img.dtype) def __call__(self, results, level=None): if np.random.random() > self.prob: return results factor = ( 2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) if self.random else 2**self.level ) self._adjust_brightness_img(results, factor) return results class RandomContrast: def __init__(self, level, prob=0.5): self.random = not isinstance(level, (float, int)) self.level = level self.prob = prob def _adjust_contrast_img(self, results, factor=1.0): for key in results.get("image_fields", ["image"]): results[key] = TF.adjust_contrast(results[key], factor) # .to(img.dtype) def __call__(self, results, level=None): if np.random.random() > self.prob: return results factor = ( 2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) if self.random else 2**self.level ) self._adjust_contrast_img(results, factor) return results class RandomGamma: def __init__(self, level, prob=0.5): self.random = not isinstance(level, (float, int)) self.level = level self.prob = prob def __call__(self, results, level=None): if np.random.random() > self.prob: return results factor = (self.level[1] - self.level[0]) * np.random.rand() + self.level[0] for key in results.get("image_fields", ["image"]): if "original" not in key: results[key] = TF.adjust_gamma(results[key], 1 + factor) return results class RandomInvert: def __init__(self, prob=0.5): self.prob = prob def __call__(self, results): if np.random.random() > self.prob: return results for key in results.get("image_fields", ["image"]): if "original" not in key: results[key] = TF.invert(results[key]) # .to(img.dtype) return results class RandomAutoContrast: def __init__(self, prob=0.5): self.prob = prob def _autocontrast_img(self, results): for key in results.get("image_fields", ["image"]): img = results[key] results[key] = TF.autocontrast(img) # .to(img.dtype) def __call__(self, results): if np.random.random() > self.prob: return results self._autocontrast_img(results) return results class RandomShear(object): def __init__( self, level, prob=0.5, direction="horizontal", ): self.random = not isinstance(level, (float, int)) self.level = level self.prob = prob self.direction = direction def _shear_img(self, results, magnitude): for key in results.get("image_fields", ["image"]): img_sheared = TF.affine( results[key], angle=0.0, translate=[0.0, 0.0], scale=1.0, shear=magnitude, interpolation=TF.InterpolationMode.BILINEAR, fill=0.0, ) results[key] = img_sheared def _shear_masks(self, results, magnitude): for key in results.get("mask_fields", []): mask_sheared = TF.affine( results[key], angle=0.0, translate=[0.0, 0.0], scale=1.0, shear=magnitude, interpolation=TF.InterpolationMode.NEAREST_EXACT, fill=0.0, ) results[key] = mask_sheared def _shear_gt( self, results, magnitude, ): for key in results.get("gt_fields", []): mask_sheared = TF.affine( results[key], angle=0.0, translate=[0.0, 0.0], scale=1.0, shear=magnitude, interpolation=TF.InterpolationMode.NEAREST_EXACT, fill=0.0, ) results[key] = mask_sheared def __call__(self, results): if np.random.random() > self.prob: return results magnitude = ( ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) if self.random else np.random.choice([-1, 1], size=1) * self.level ) if self.direction == "horizontal": magnitude = [magnitude, 0.0] else: magnitude = [0.0, magnitude] self._shear_img(results, magnitude) self._shear_masks(results, magnitude) self._shear_gt(results, magnitude) return results class RandomTranslate(object): def __init__( self, range, prob=0.5, direction="horizontal", ): self.range = range self.prob = prob self.direction = direction def _translate_img(self, results, magnitude): """Shear the image. Args: results (dict): Result dict from loading pipeline. magnitude (int | float): The magnitude used for shear. direction (str): The direction for shear, either "horizontal" or "vertical". interpolation (str): Same as in :func:`mmcv.imshear`. """ for key in results.get("image_fields", ["image"]): img_sheared = TF.affine( results[key], angle=0.0, translate=magnitude, scale=1.0, shear=[0.0, 0.0], interpolation=TF.InterpolationMode.BILINEAR, fill=(123.68, 116.28, 103.53), ) results[key] = img_sheared def _translate_mask(self, results, magnitude): """Shear the masks.""" for key in results.get("mask_fields", []): mask_sheared = TF.affine( results[key], angle=0.0, translate=magnitude, scale=1.0, shear=[0.0, 0.0], interpolation=TF.InterpolationMode.NEAREST_EXACT, fill=0.0, ) results[key] = mask_sheared def _translate_gt( self, results, magnitude, ): """Shear the segmentation maps.""" for key in results.get("gt_fields", []): mask_sheared = TF.affine( results[key], angle=0.0, translate=magnitude, scale=1.0, shear=[0.0, 0.0], interpolation=TF.InterpolationMode.NEAREST_EXACT, fill=0.0, ) results[key] = mask_sheared def __call__(self, results): """Call function to shear images, bounding boxes, masks and semantic segmentation maps. Args: results (dict): Result dict from loading pipeline. Returns: dict: Sheared results. """ if np.random.random() > self.prob: return results magnitude = (self.range[1] - self.range[0]) * np.random.rand() + self.range[0] if self.direction == "horizontal": magnitude = [magnitude * results["image"].shape[1], 0] else: magnitude = [0, magnitude * results["image"].shape[0]] self._translate_img(results, magnitude) self._translate_mask(results, magnitude) self._translate_gt(results, magnitude) results["K"][..., 0, 2] = results["K"][..., 0, 2] + magnitude[0] results["K"][..., 1, 2] = results["K"][..., 1, 2] + magnitude[1] return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += f"(range={self.range}, " repr_str += f"prob={self.prob}, " repr_str += f"direction={self.direction}, " return repr_str class RandomColorJitter: def __init__(self, level, prob=0.9): self.level = level self.prob = prob self.list_transform = [ self._adjust_brightness_img, # self._adjust_sharpness_img, self._adjust_contrast_img, self._adjust_saturation_img, self._adjust_color_img, ] def _adjust_contrast_img(self, results, factor=1.0): """Adjust the image contrast.""" for key in results.get("image_fields", ["image"]): if "original" not in key: img = results[key] results[key] = TF.adjust_contrast(img, factor) def _adjust_sharpness_img(self, results, factor=1.0): """Adjust the image contrast.""" for key in results.get("image_fields", ["image"]): if "original" not in key: img = results[key] results[key] = TF.adjust_sharpness(img, factor) def _adjust_brightness_img(self, results, factor=1.0): """Adjust the brightness of image.""" for key in results.get("image_fields", ["image"]): if "original" not in key: img = results[key] results[key] = TF.adjust_brightness(img, factor) def _adjust_saturation_img(self, results, factor=1.0): """Apply Color transformation to image.""" for key in results.get("image_fields", ["image"]): if "original" not in key: img = results[key] results[key] = TF.adjust_saturation(img, factor / 2.0) def _adjust_color_img(self, results, factor=1.0): """Apply Color transformation to image.""" for key in results.get("image_fields", ["image"]): if "original" not in key: img = results[key] results[key] = TF.adjust_hue(img, (factor - 1.0) / 4.0) def __call__(self, results): """Call function for color transformation. Args: results (dict): Results dict from loading pipeline. Returns: dict: Results after the transformation. """ random.shuffle(self.list_transform) for op in self.list_transform: if np.random.random() < self.prob: factor = 1.0 + ( (self.level[1] - self.level[0]) * np.random.random() + self.level[0] ) op(results, factor) return results class RandomGrayscale: def __init__(self, prob=0.1, num_output_channels=3): super().__init__() self.prob = prob self.num_output_channels = num_output_channels def __call__(self, results): if np.random.random() > self.prob: return results for key in results.get("image_fields", ["image"]): if "original" not in key: results[key] = TF.rgb_to_grayscale( results[key], num_output_channels=self.num_output_channels ) return results def masked_nearest_interpolation(input, mask, target_size): """ Resize the depth map using bilinear interpolation, considering only valid pixels within NxN neighbors. Args: depth (torch.Tensor): The depth map tensor of shape (H, W). mask (torch.Tensor): The mask tensor of shape (H, W) where 1 indicates valid depth and 0 indicates missing depth. target_size (tuple): The desired output size (target_H, target_W). Returns: torch.Tensor: The resized depth map. """ B, C, H, W = input.shape target_H, target_W = target_size mask = mask.float() # Generate a grid of coordinates in the target space grid_y, grid_x = torch.meshgrid( torch.linspace(0, H - 1, target_H), torch.linspace(0, W - 1, target_W), indexing="ij", ) grid_y = grid_y.to(input.device) grid_x = grid_x.to(input.device) # Calculate the floor and ceil of the grid coordinates to get the bounding box x0 = torch.floor(grid_x).long().clamp(0, W - 1) x1 = (x0 + 1).clamp(0, W - 1) y0 = torch.floor(grid_y).long().clamp(0, H - 1) y1 = (y0 + 1).clamp(0, H - 1) # Gather depth values at the four corners Ia = input[..., y0, x0] Ib = input[..., y1, x0] Ic = input[..., y0, x1] Id = input[..., y1, x1] # Gather corresponding mask values ma = mask[..., y0, x0] mb = mask[..., y1, x0] mc = mask[..., y0, x1] md = mask[..., y1, x1] # Calculate distances to each neighbor # The distances are calculated from the center (grid_x, grid_y) to each corner dist_a = (grid_x - x0.float()) ** 2 + (grid_y - y0.float()) ** 2 # Top-left dist_b = (grid_x - x0.float()) ** 2 + (grid_y - y1.float()) ** 2 # Bottom-left dist_c = (grid_x - x1.float()) ** 2 + (grid_y - y0.float()) ** 2 # Top-right dist_d = (grid_x - x1.float()) ** 2 + (grid_y - y1.float()) ** 2 # Bottom-right # Stack the neighbors, their masks, and distances stacked_values = torch.stack( [Ia, Ib, Ic, Id], dim=-1 ) # Shape: (B, C, target_H, target_W, 4) stacked_masks = torch.stack( [ma, mb, mc, md], dim=-1 ) # Shape: (B, 1, target_H, target_W, 4) stacked_distances = torch.stack( [dist_a, dist_b, dist_c, dist_d], dim=-1 ) # Shape: (target_H, target_W, 4) stacked_distances = ( stacked_distances.unsqueeze(0).unsqueeze(1).repeat(B, 1, 1, 1, 1) ) # Shape: (B, 1, target_H, target_W, 4) # Set distances to infinity for invalid neighbors (so that invalid neighbors are never chosen) stacked_distances[stacked_masks == 0] = float("inf") # Find the index of the nearest valid neighbor (the one with the smallest distance) nearest_indices = stacked_distances.argmin(dim=-1, keepdim=True)[ ..., :1 ] # Shape: (B, 1, target_H, target_W, 1) # Select the corresponding depth value using the nearest valid neighbor index interpolated_depth = torch.gather( stacked_values, dim=-1, index=nearest_indices.repeat(1, C, 1, 1, 1) ).squeeze(-1) # Set depth to zero where no valid neighbors were found interpolated_depth = interpolated_depth * stacked_masks.sum(dim=-1).clip( min=0.0, max=1.0 ) return interpolated_depth class ContextCrop: def __init__( self, image_shape, keep_original=False, test_min_ctx=1.0, train_ctx_range=[0.5, 1.5], shape_constraints={}, ): self.image_shape = image_shape self.keep_original = keep_original self.test_min_ctx = test_min_ctx self.train_ctx_range = train_ctx_range self.shape_mult = shape_constraints["shape_mult"] self.sample = shape_constraints["sample"] self.ratio_bounds = shape_constraints["ratio_bounds"] pixels_min = shape_constraints["pixels_min"] / ( self.shape_mult * self.shape_mult ) pixels_max = shape_constraints["pixels_max"] / ( self.shape_mult * self.shape_mult ) self.pixels_bounds = (pixels_min, pixels_max) self.ctx = None def _transform_img(self, results, shapes): for key in results.get("image_fields", ["image"]): img = self.crop(results[key], **shapes) img = TF.resize( img, results["resized_shape"], interpolation=TF.InterpolationMode.BICUBIC, antialias=True, ) results[key] = img def _transform_masks(self, results, shapes): for key in results.get("mask_fields", []): mask = self.crop(results[key].float(), **shapes).byte() mask = masked_nearest_interpolation( mask, mask > 0, results["resized_shape"] ) results[key] = mask def _transform_gt(self, results, shapes): for key in results.get("gt_fields", []): gt = self.crop(results[key], **shapes) gt = masked_nearest_interpolation(gt, gt > 0, results["resized_shape"]) results[key] = gt @staticmethod def crop(img, height, width, top, left) -> torch.Tensor: h, w = img.shape[-2:] right = left + width bottom = top + height padding_ltrb = [ max(-left + min(0, right), 0), max(-top + min(0, bottom), 0), max(right - max(w, left), 0), max(bottom - max(h, top), 0), ] image_cropped = img[..., max(top, 0) : bottom, max(left, 0) : right] return TF.pad(image_cropped, padding_ltrb) def test_closest_shape(self, image_shape): h, w = image_shape input_ratio = w / h if self.sample: input_pixels = int(ceil(h / self.shape_mult * w / self.shape_mult)) pixels = max( min(input_pixels, self.pixels_bounds[1]), self.pixels_bounds[0] ) ratio = min(max(input_ratio, self.ratio_bounds[0]), self.ratio_bounds[1]) h = round((pixels / ratio) ** 0.5) w = h * ratio self.image_shape[0] = int(h) * self.shape_mult self.image_shape[1] = int(w) * self.shape_mult def _get_crop_shapes(self, image_shape, ctx=None): h, w = image_shape input_ratio = w / h if self.keep_original: self.test_closest_shape(image_shape) ctx = 1.0 elif ctx is None: ctx = float( torch.empty(1) .uniform_(self.train_ctx_range[0], self.train_ctx_range[1]) .item() ) output_ratio = self.image_shape[1] / self.image_shape[0] if output_ratio <= input_ratio: # out like 4:3 in like kitti if ( ctx >= 1 ): # fully in -> use just max_length with sqrt(ctx), here max is width new_w = w * ctx**0.5 # sporge un po in una sola dim # we know that in_width will stick out before in_height, partial overshoot (sporge) # new_h > old_h via area -> new_h ** 2 * ratio_new = old_h ** 2 * ratio_old * ctx elif output_ratio / input_ratio * ctx > 1: new_w = w * ctx else: # fully contained -> use area new_w = w * (ctx * output_ratio / input_ratio) ** 0.5 new_h = new_w / output_ratio else: if ctx >= 1: new_h = h * ctx**0.5 elif input_ratio / output_ratio * ctx > 1: new_h = h * ctx else: new_h = h * (ctx * input_ratio / output_ratio) ** 0.5 new_w = new_h * output_ratio return (int(ceil(new_h - 0.5)), int(ceil(new_w - 0.5))), ctx def __call__(self, results): h, w = results["image"].shape[-2:] results["image_ori_shape"] = (h, w) results.get("mask_fields", set()).add("validity_mask") if "validity_mask" not in results: results["validity_mask"] = torch.ones( (results["image"].shape[0], 1, h, w), dtype=torch.uint8, device=results["image"].device, ) n_iter = 1 if self.keep_original or not self.sample else 100 min_valid_area = 0.5 results["camera_fields"].add("camera_original") results["camera_original"] = results["camera"].clone() max_hfov, max_vfov = results["camera"].max_fov[0] # it is a 1-dim list ctx = None for ii in range(n_iter): (height, width), ctx = self._get_crop_shapes((h, w), ctx=self.ctx or ctx) margin_h = h - height margin_w = w - width # keep it centered in y direction top = margin_h // 2 left = margin_w // 2 if not self.keep_original: left = left + np.random.randint( -self.shape_mult // 2, self.shape_mult // 2 + 1 ) top = top + np.random.randint( -self.shape_mult // 2, self.shape_mult // 2 + 1 ) right = left + width bottom = top + height x_zoom = self.image_shape[0] / height paddings = [ max(-left + min(0, right), 0), max(bottom - max(h, top), 0), max(right - max(w, left), 0), max(-top + min(0, bottom), 0), ] valid_area = ( h * w / (h + paddings[1] + paddings[3]) / (w + paddings[0] + paddings[2]) ) new_hfov, new_vfov = results["camera_original"].get_new_fov( new_shape=(height, width), original_shape=(h, w) )[0] if ( valid_area >= min_valid_area and new_hfov < max_hfov and new_vfov < max_vfov ): results["camera"] = results["camera"].crop( left, top, right=w - right, bottom=h - bottom ) results["camera"] = results["camera"].resize(x_zoom) break ctx = ( ctx * 0.96 ) # if not enough valid area, try again with less ctx (more zoom) # save ctx for next iteration of sequences? self.ctx = ctx results["resized_shape"] = self.image_shape results["paddings"] = paddings # left ,top ,right, bottom results["image_rescale"] = x_zoom results["scale_factor"] = results.get("scale_factor", 1.0) * x_zoom shapes = dict(height=height, width=width, top=top, left=left) self._transform_img(results, shapes) if not self.keep_original: self._transform_gt(results, shapes) self._transform_masks(results, shapes) else: # only validity_mask (rgb's masks follows rgb transform) #FIXME mask = results["validity_mask"].float() mask = self.crop(mask, **shapes).byte() mask = TF.resize( mask, results["resized_shape"], interpolation=TF.InterpolationMode.NEAREST, ) results["validity_mask"] = mask # keep original images before photo-augment results["image_original"] = results["image"].clone() results["image_fields"].add( *[ field.replace("image", "image_original") for field in results["image_fields"] ] ) # repeat for batch resized shape and paddings results["paddings"] = [results["paddings"]] * results["image"].shape[0] results["resized_shape"] = [results["resized_shape"]] * results["image"].shape[ 0 ] return results class RandomFiller: def __init__(self, *args, **kwargs): super().__init__() def _transform(self, results): def fill_noise(size, device): return torch.normal(0, 2.0, size=size, device=device) def fill_black(size, device): return -4 * torch.ones(size, device=device, dtype=torch.float32) def fill_white(size, device): return 4 * torch.ones(size, device=device, dtype=torch.float32) def fill_zero(size, device): return torch.zeros(size, device=device, dtype=torch.float32) B, C = results["image"].shape[:2] mismatch = B // results["validity_mask"].shape[0] if mismatch: results["validity_mask"] = results["validity_mask"].repeat( mismatch, 1, 1, 1 ) validity_mask = results["validity_mask"].repeat(1, C, 1, 1).bool() filler_fn = np.random.choice([fill_noise, fill_black, fill_white, fill_zero]) for key in results.get("image_fields", ["image"]): results[key][~validity_mask] = filler_fn( size=results[key][~validity_mask].shape, device=results[key].device ) def __call__(self, results): # generate mask for filler if "validity_mask" not in results: paddings = results.get("padding_size", [0] * 4) height, width = results["image"].shape[-2:] results.get("mask_fields", []).add("validity_mask") results["validity_mask"] = torch.zeros_like(results["image"][:, :1]) results["validity_mask"][ ..., paddings[1] : height - paddings[3], paddings[0] : width - paddings[2], ] = 1.0 self._transform(results) return results class GaussianBlur: def __init__(self, kernel_size, sigma=(0.1, 2.0), prob=0.9): super().__init__() self.kernel_size = kernel_size self.sigma = sigma self.prob = prob self.padding = kernel_size // 2 def apply(self, x, kernel): # Pad the input tensor x = F.pad( x, (self.padding, self.padding, self.padding, self.padding), mode="reflect" ) # Apply the convolution with the Gaussian kernel return F.conv2d(x, kernel, stride=1, padding=0, groups=x.size(1)) def _create_kernel(self, sigma): # Create a 1D Gaussian kernel kernel_1d = torch.exp( -torch.arange(-self.padding, self.padding + 1) ** 2 / (2 * sigma**2) ) kernel_1d = kernel_1d / kernel_1d.sum() # Expand the kernel to 2D and match size of the input kernel_2d = kernel_1d.unsqueeze(0) * kernel_1d.unsqueeze(1) kernel_2d = kernel_2d.view(1, 1, self.kernel_size, self.kernel_size).expand( 3, 1, -1, -1 ) return kernel_2d def __call__(self, results): if np.random.random() > self.prob: return results sigma = (self.sigma[1] - self.sigma[0]) * np.random.rand() + self.sigma[0] kernel = self._create_kernel(sigma) for key in results.get("image_fields", ["image"]): if "original" not in key: results[key] = self.apply(results[key], kernel) return results class Compose: def __init__(self, transforms): self.transforms = deepcopy(transforms) def __call__(self, results): for t in self.transforms: results = t(results) return results def __setattr__(self, name: str, value) -> None: super().__setattr__(name, value) for t in self.transforms: setattr(t, name, value) def __repr__(self): format_string = self.__class__.__name__ + "(" for t in self.transforms: format_string += f"\n {t}" format_string += "\n)" return format_string ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/point_odyssey.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class PointOdyssey(SequenceDataset): min_depth = 0.01 max_depth = 250.0 depth_scale = 1000.0 test_split = "test.txt" train_split = "train.txt" sequences_file = "sequences_clean.json" hdf5_paths = [f"PointOdyssey.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["quality"] = [0] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/proteus.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class Proteus(SequenceDataset): min_depth = 0.01 max_depth = 10.0 depth_scale = 1000.0 default_fps = 5 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = ["Proteus.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["quality"] = [0] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/samplers copy.py ================================================ import itertools import warnings from operator import itemgetter from typing import Any, Optional import numpy as np import torch from torch.utils.data import Sampler from unidepth.utils import get_dist_info def _get_numpy_dtype(size: int) -> Any: return np.int32 if size <= 2**31 else np.int64 def _get_torch_dtype(size: int) -> Any: return torch.int32 if size <= 2**31 else torch.int64 def _generate_randperm_indices(*, size: int, generator: torch.Generator): """Generate the indices of a random permutation.""" dtype = _get_torch_dtype(size) # This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921 perm = torch.arange(size, dtype=dtype) for i in range(size): j = torch.randint(i, size, size=(1,), generator=generator).item() # Always swap even if no-op value = perm[j].item() perm[j] = perm[i].item() perm[i] = value yield value # The following function is somewhat equivalent to _new_shuffle_tensor_slice below, # but avoids a full in-place random permutation generation. def _shuffle_tensor_slice( *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator ) -> np.ndarray: stop = len(tensor) count = stop // step drop_count = stop - step * count if drop_count: warnings.warn(f"# of dropped samples: {drop_count}") dtype = _get_numpy_dtype(stop) result = np.empty(count, dtype=dtype) for i in range(count): j = ( torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0 ) result[i] = result[j] result[j] = tensor[start + i * step].item() return result def _new_shuffle_tensor_slice( *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator ) -> np.ndarray: stop = len(tensor) count = stop // step dtype = torch.int64 # Needed for using randperm result as indices count = stop // step drop_count = stop - step * count if drop_count: warnings.warn(f"# of dropped samples: {drop_count}") indices = torch.randperm(count, dtype=dtype, generator=generator) return tensor[start::step][indices].numpy() def _make_seed(seed: int, start: int, iter_count: int) -> int: # NOTE: Tried a few variants (including iter_count << 32), this one worked best. return seed + start + (iter_count << 24) class ShardedInfiniteSampler(Sampler): def __init__( self, *, sample_count: int, shuffle: bool = False, seed: int = 0, start: Optional[int] = None, step: Optional[int] = None, advance: int = 0, use_new_shuffle_tensor_slice: bool = False, ): self._sample_count = sample_count self._seed = seed self._shuffle = shuffle rank, world_size = get_dist_info() self._start = rank if start is None else start self._step = world_size if step is None else step self._advance = advance self._iter_count = 0 self._shuffle_tensor_slice_fn = ( _new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice ) def __iter__(self): iter_count = self._advance // self._sample_count if iter_count > 0: self._advance -= iter_count * self._sample_count self._iter_count += iter_count if self._shuffle: iterator = self._shuffled_iterator() else: iterator = self._iterator() yield from itertools.islice(iterator, self._advance, None) def _iterator(self): assert not self._shuffle while True: iterable = range(self._sample_count) yield from itertools.islice(iterable, self._start, None, self._step) def _shuffled_iterator(self): assert self._shuffle # Instantiate a generator here (rather than in the ctor) to be keep the class # picklable (requirement of mp.spawn) generator = torch.Generator() # Always shuffle everything first generator.manual_seed(self._seed) dtype = _get_torch_dtype(self._sample_count) perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator) while True: # Re-seed on each iteration to allow skipping whole permutations seed = _make_seed(self._seed, self._start, self._iter_count) generator.manual_seed(seed) iterable = self._shuffle_tensor_slice_fn( tensor=perm, start=self._start, step=self._step, generator=generator ) yield from iterable self._iter_count += 1 class DistributedSamplerNoDuplicate(torch.utils.data.DistributedSampler): """A distributed sampler that doesn't add duplicates. Arguments are the same as DistributedSampler""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not self.drop_last and len(self.dataset) % self.num_replicas != 0: # some ranks may have less samples, that's fine if self.rank >= len(self.dataset) % self.num_replicas: self.num_samples -= 1 self.total_size = len(self.dataset) class DatasetFromSampler(torch.utils.data.Dataset): """Dataset to create indexes from `Sampler`. Args: sampler: PyTorch sampler """ def __init__(self, sampler: Sampler): """Initialisation for DatasetFromSampler.""" self.sampler = sampler self.sampler_list = None def __getitem__(self, index: int): """Gets element of the dataset. Args: index: index of the element in the dataset Returns: Single element by index """ if self.sampler_list is None: self.sampler_list = list(self.sampler) return self.sampler_list[index] def __len__(self) -> int: """ Returns: int: length of the dataset """ return len(self.sampler) class DistributedSamplerWrapper(torch.utils.data.DistributedSampler): """ Wrapper over `Sampler` for distributed training Allows you to use any sampler in distributed mode. It is especially useful in conjunction with `torch.nn.parallel.DistributedDataParallel`. In such case, each process can pass a DistributedSamplerWrapper instance as a DataLoader sampler, and load a subset of subsampled data of the original dataset that is exclusive to it. .. note:: Sampler is assumed to be of constant size. """ def __init__( self, sampler, num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True, ): """ Args: sampler: Sampler used for subsampling num_replicas (int, optional): Number of processes participating in distributed training rank (int, optional): Rank of the current process within ``num_replicas`` shuffle (bool, optional): If true (default), sampler will shuffle the indices """ super(DistributedSamplerWrapper, self).__init__( DatasetFromSampler(sampler), num_replicas=num_replicas, rank=rank, shuffle=shuffle, ) self.sampler = sampler def __iter__(self): self.dataset = DatasetFromSampler(self.sampler) indexes_of_indexes = super().__iter__() subsampler_indexes = self.dataset return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/samplers.py ================================================ import torch class DistributedSamplerNoDuplicate(torch.utils.data.DistributedSampler): """A distributed sampler that doesn't add duplicates. Arguments are the same as DistributedSampler""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not self.drop_last and len(self.dataset) % self.num_replicas != 0: # some ranks may have less samples, that's fine if self.rank >= len(self.dataset) % self.num_replicas: self.num_samples -= 1 self.total_size = len(self.dataset) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/scannet.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class ScanNet(SequenceDataset): min_depth = 0.005 max_depth = 10.0 depth_scale = 1000.0 test_split = "test.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = ["ScanNetS.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["quality"] = [1] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/scannetpp.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class ScanNetpp(SequenceDataset): min_depth = 0.001 max_depth = 10.0 depth_scale = 1000.0 test_split = "val_iphone.txt" train_split = "train_iphone.txt" sequences_file = "sequences_iphone_clean.json" hdf5_paths = [f"ScanNetpp_viz.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["quality"] = [1] * self.num_frames * self.num_copies return results class ScanNetpp_F(SequenceDataset): min_depth = 0.001 max_depth = 10.0 depth_scale = 1000.0 train_split = "train.txt" test_split = "val_split.txt" sequences_file = "sequences_split.json" hdf5_paths = [f"ScanNetpp_F.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["camera_params", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=( decode_fields if not test_mode else [*decode_fields, "points"] ), inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["quality"] = [1] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/sequence_dataset.py ================================================ import json import os from functools import partial from typing import Any, Dict, Tuple import h5py import numpy as np import tables import torch import torchvision.transforms.v2.functional as TF from unidepth.datasets.base_dataset import BaseDataset from unidepth.datasets.utils import DatasetFromList from unidepth.datasets.utils_decode import (decode_camera, decode_depth, decode_flow, decode_K, decode_mask, decode_numpy, decode_rgb, decode_tensor) from unidepth.utils.distributed import is_main_process class SequenceDataset(BaseDataset): DECODE_FNS = { "image": partial(decode_rgb, name="image"), "points": partial(decode_numpy, name="points"), "K": partial(decode_K, name="camera"), "camera_params": partial(decode_camera, name="camera"), "cam2w": partial(decode_tensor, name="cam2w"), "depth": partial(decode_depth, name="depth"), "flow_fwd": partial(decode_flow, name="flow_fwd"), "flow_bwd": partial(decode_flow, name="flow_bwd"), "flow_fwd_mask": partial(decode_mask, name="flow_fwd_mask"), "flow_bwd_mask": partial(decode_mask, name="flow_bwd_mask"), } default_fps = 5 def __init__( self, image_shape: Tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: Dict[str, Any], resize_method: str, mini: float, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth", "flow_fwd", "flow_fwd_mask"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.num_frames = num_frames self.original_num_frames = num_frames self.decode_fields = decode_fields self.inplace_fields = inplace_fields self.fps = self.default_fps self.fps_range = kwargs.get("fps_range", None) if self.fps_range is not None: self.fps_range[1] = min(self.default_fps, self.fps_range[1]) self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii").strip() sequences = np.array(h5file[self.sequences_file]).tostring().decode("ascii") sequences = json.loads(sequences) h5file.close() dataset = [] for line in txt_string.split("\n"): if len(line.strip().split(" ")) == 1: print(line) continue sequence_name, num_samples = line.strip().split(" ") dataset.append( { "sequence_name": sequence_name, "num_samples": int(num_samples), "chunk_idx": 0, } ) # filter dataset based on attr "invalid_sequences" invalid_sequences = getattr(self, "invalid_sequences", []) dataset = [ sample for sample in dataset if sample["sequence_name"] not in invalid_sequences ] self.dataset = DatasetFromList(dataset) self.sequences = DatasetFromList( [sequences[sample["sequence_name"]] for sample in dataset] ) self.log_load_dataset() def get_random_idxs(self, num_samples_sequence): if self.num_frames == 1: return [np.random.randint(0, num_samples_sequence)], 0 # Check if we can satisfy the required number of frames if self.num_frames > num_samples_sequence: raise ValueError( "Cannot sample more frames than available in the sequence." ) # Restrict FPS range to be within default FPS min_fps, max_fps = self.fps_range max_fps = min(max_fps, self.default_fps) if min_fps > self.default_fps: sampled_fps = self.default_fps else: # Compute minimal viable FPS min_required_fps = ( self.num_frames / num_samples_sequence ) * self.default_fps min_fps = max(min_fps, min_required_fps) # Sample an FPS from the viable range sampled_fps = np.random.uniform(min_fps, max_fps) # Compute the stride based on the sampled FPS stride = self.default_fps / sampled_fps max_start_index = num_samples_sequence - int(stride * (self.num_frames - 1)) # Ensure a valid starting position if max_start_index <= 0: raise ValueError( "No valid start position allows sampling num_frames with the chosen FPS." ) start_index = np.random.randint(0, max_start_index + 1) # Compute indices based on the sampled FPS indices = [int(start_index + i * stride) for i in range(self.num_frames)] return indices, np.random.randint(0, len(indices)) def get_test_idxs(self, num_samples_sequence, keyframe_idx): if self.num_frames == 1: return [ keyframe_idx if keyframe_idx is not None else num_samples_sequence // 2 ], 0 if self.num_frames == -1: cap_idxs = min(32, num_samples_sequence) # CAP 32 images idxs = list( range(max(0, num_samples_sequence - cap_idxs), num_samples_sequence, 1) ) return idxs, keyframe_idx # pick closest keyframe_idx st they are around it or capped by the 0 and max num_samples_sequence keyframe_idx = ( keyframe_idx if keyframe_idx is not None else num_samples_sequence - 1 ) excess_tail = 0 - min(0, keyframe_idx - self.num_frames // 2) excess_head = ( max(num_samples_sequence, keyframe_idx + (self.num_frames - 1) // 2) - num_samples_sequence ) start = keyframe_idx - self.num_frames // 2 + excess_tail - excess_head end = keyframe_idx + (self.num_frames - 1) // 2 + excess_head - excess_tail idxs = list(range(start, 1 + end)) return idxs, idxs.index(keyframe_idx) def get_single_sequence(self, idx): self.num_frames = self.original_num_frames # sequence_name = self.dataset[idx]["sequence_name"] sample = self.sequences[idx] chunk_idx = int(sample.get("chunk_idx", 0)) h5_path = os.path.join(self.data_root, self.hdf5_paths[chunk_idx]) num_samples_sequence = len(sample["image"]) if self.num_frames > 0 and num_samples_sequence < self.num_frames: raise IndexError(f"Sequence {idx} has less than {self.num_frames} frames") keyframe_idx = None if not self.test_mode: idxs, keyframe_idx = self.get_random_idxs(num_samples_sequence) else: idxs, keyframe_idx = self.get_test_idxs( num_samples_sequence, sample.get("keyframe_idx", None) ) self.num_frames = len(idxs) results = {} results = self.pre_pipeline(results) results["sequence_fields"] = [(i, 0) for i in range(self.num_frames)] results["keyframe_idx"] = keyframe_idx with tables.File( h5_path, mode="r", libver="latest", swmr=True, ) as h5file_chunk: for i, j in enumerate(idxs): results[(i, 0)] = { k: v.copy() for k, v in results.items() if "fields" in k } for inplace_field in self.inplace_fields: inplace_field_ = inplace_field.replace("intrinsics", "K").replace( "extrinsics", "cam2w" ) results = self.DECODE_FNS[inplace_field_]( results, sample[inplace_field][j], idx=i, sample=sample, j=j ) for i, j in enumerate(idxs): for decode_field in self.decode_fields: results = self.DECODE_FNS[decode_field]( results, h5file_chunk, sample[decode_field][j], idx=i, depth_scale=self.depth_scale, ) results["filename"] = sample["image"][j] results = self.preprocess(results) if not self.test_mode: results = self.augment(results) results = self.postprocess(results) return results def preprocess(self, results): results = self.replicate(results) for i, seq in enumerate(results["sequence_fields"]): results[seq] = self.resizer(results[seq]) self.resizer.ctx = None if self.num_copies > 1 else self.resizer.ctx num_pts = torch.count_nonzero(results[seq]["depth"] > 0) if num_pts < 50: raise IndexError(f"Too few points in depth map ({num_pts})") for key in results[seq].get("image_fields", ["image"]): results[seq][key] = results[seq][key].to(torch.float32) / 255 # update fields common in sequence for key in [ "image_fields", "gt_fields", "mask_fields", "camera_fields", ]: if key in results[(0, 0)]: results[key] = results[(0, 0)][key] results = self.pack_batch(results) return results def postprocess(self, results): # # normalize after because color aug requires [0,255]? for key in results.get("image_fields", ["image"]): results[key] = TF.normalize(results[key], **self.normalization_stats) results = self.filler(results) results = self.unpack_batch(results) results = self.masker(results) results = self.collecter(results) return results def __getitem__(self, idx): try: if isinstance(idx, (list, tuple)): results = [self.get_single_sequence(i) for i in idx] else: results = self.get_single_sequence(idx) except Exception as e: print(f"Error loading sequence {idx} for {self.__class__.__name__}: {e}") idx = np.random.randint(0, len(self.dataset)) results = self[idx] return results def log_load_dataset(self): if is_main_process(): info = f"Loaded {self.__class__.__name__} with {sum([len(x['image']) for x in self.sequences])} images in {len(self)} sequences." print(info) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/sintel copy.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class Sintel(SequenceDataset): min_depth = 0.001 max_depth = 1000.0 depth_scale = 1000.0 test_split = "training.txt" train_split = "training.txt" sequences_file = "sequences.json" hdf5_paths = ["Sintel.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth", "flow_fwd", "flow_fwd_mask"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["quality"] = [0] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/sintel.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class Sintel(SequenceDataset): min_depth = 0.001 max_depth = 1000.0 depth_scale = 1000.0 test_split = "training.txt" train_split = "training.txt" sequences_file = "sequences.json" hdf5_paths = ["Sintel.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames results["synthetic"] = [True] * self.num_frames return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/sunrgbd.py ================================================ import json import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.utils import DatasetFromList class SUNRGBD(ImageDataset): min_depth = 0.005 max_depth = 8.0 depth_scale = 1000.0 test_split = "alltest.txt" train_split = "alltrain.txt" intrisics_file = "intrinsics.json" hdf5_paths = ["SUNRGB.hdf5"] def __init__( self, image_shape, split_file, test_mode, crop=None, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.crop = crop self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") intrinsics = json.loads(intrinsics) h5file.close() dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename = line.strip().split(" ") intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] sample = [image_filename, depth_filename, intrinsics_val] dataset.append(sample) if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) self.dataset = DatasetFromList(dataset) self.log_load_dataset() ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/synscapes.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class Synscapes(SequenceDataset): min_depth = 0.1 max_depth = 1000.0 depth_scale = 256.0 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = [f"Synscapes.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["quality"] = [0] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/tartanair.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class TartanAir(SequenceDataset): min_depth = 0.01 max_depth = 512.0 depth_scale = 1000.0 default_fps = 15 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = ["TartanAir.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["quality"] = [0] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/taskonomy.py ================================================ import json import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.utils import DatasetFromList class Taskonomy(ImageDataset): min_depth = 0.005 max_depth = 15.0 depth_scale = 512.0 test_split = "val.txt" train_split = "train_clean.txt" intrisics_file = "intrinsics.json" hdf5_paths = ["Taskonomy.hdf5"] def __init__( self, image_shape, split_file, test_mode, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1 intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") intrinsics = json.loads(intrinsics) # with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f: # f.write(txt_string) # with open(os.path.join(os.environ["TMPDIR"], self.intrisics_file), "w") as f: # json.dump(intrinsics, f) dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename, chunk_idx = line.strip().split(" ") intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] sample = [image_filename, depth_filename, intrinsics_val, chunk_idx] dataset.append(sample) h5file.close() if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) if self.test_mode and not self.benchmark: dataset = self.chunk(dataset, chunk_dim=1, pct=0.01) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def get_mapper(self): return { "image_filename": 0, "depth_filename": 1, "K": 2, } def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["quality"] = [2] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/tat_rmvd.py ================================================ import json import os from copy import deepcopy from typing import Any import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.pipelines import AnnotationMask, KittiCrop from unidepth.datasets.sequence_dataset import SequenceDataset from unidepth.datasets.utils import DatasetFromList from unidepth.utils import identity class TATRMVD(SequenceDataset): min_depth = 0.001 max_depth = 50.0 depth_scale = 1000.0 default_fps = 6 test_split = "test.txt" train_split = "test.txt" sequences_file = "sequences.json" hdf5_paths = ["tanks_and_temples_rmvd.hdf5"] def __init__( self, image_shape, split_file, test_mode, crop=None, augmentations_db={}, normalize=True, resize_method="hard", mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [False] * self.num_frames * self.num_copies results["si"] = [True] * self.num_frames * self.num_copies results["quality"] = [2] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/theo.py ================================================ from typing import Any import torch from unidepth.datasets.sequence_dataset import SequenceDataset class Theo(SequenceDataset): min_depth = 0.01 max_depth = 10.0 depth_scale = 1000.0 default_fps = 5 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = ["THEO.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["camera_params", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def preprocess(self, results): self.resizer.ctx = None for i, seq in enumerate(results["sequence_fields"]): # Create a mask where the distance from the center is less than H/2 H, W = results[seq]["image"].shape[-2:] x = torch.linspace(-(W - 1) / 2, (W - 1) / 2, W) y = torch.linspace(-(H - 1) / 2, (H - 1) / 2, H) xv, yv = torch.meshgrid(x, y, indexing="xy") distance_from_center = torch.sqrt(xv**2 + yv**2).reshape(1, 1, H, W) results[seq]["validity_mask"] = distance_from_center < (H - 1) / 2 return super().preprocess(results) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["quality"] = [0] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/unrealstereo4k.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class UnrealStereo4K(SequenceDataset): min_depth = 0.01 max_depth = 200.0 depth_scale = 1000.0 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = [f"UnrealStereo4K.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["quality"] = [0] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/urbansyn.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class UrbanSyn(SequenceDataset): min_depth = 0.1 max_depth = 1000.0 depth_scale = 256.0 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = [f"UrbanSyn.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["quality"] = [0] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/utils.py ================================================ import copy import multiprocessing as mp import pickle from collections import defaultdict from typing import Any, Dict, List import numpy as np import torch import torch.utils.data from unidepth.utils.distributed import (all_gather, get_local_rank, get_local_size, get_rank, get_world_size) class ConcatDataset(torch.utils.data.ConcatDataset): def __init__(self, datasets, shape_constraints: dict[str, list[int]] = {}): super().__init__(datasets) self.sample = shape_constraints["sample"] self.shape_mult = shape_constraints["shape_mult"] self.ratio_bounds = shape_constraints["ratio_bounds"] self.pixels_max = shape_constraints["pixels_max"] self.pixels_min = shape_constraints["pixels_min"] self.height_min = shape_constraints["height_min"] self.width_min = shape_constraints["width_min"] def sample_shape(self): if not self.sample: return # 1: sample image ratio ratio = np.random.uniform(*self.ratio_bounds) pixels_min = self.pixels_min // (self.shape_mult * self.shape_mult) pixels_max = self.pixels_max // (self.shape_mult * self.shape_mult) # 2: sample image height or width, if ratio > 1 or < 1 if ratio > 1: height_min = max(self.height_min, np.sqrt(pixels_min / ratio)) height = np.random.uniform(height_min, np.sqrt(pixels_max / ratio)) width = height * ratio else: width_min = max(self.width_min, np.sqrt(pixels_min * ratio)) width = np.random.uniform(width_min, np.sqrt(pixels_max * ratio)) height = width / ratio # 3: get final shape based on the shape_mult shape = [int(height) * self.shape_mult, int(width) * self.shape_mult] for dataset in self.datasets: setattr(dataset, "image_shape", shape) setattr(dataset.resizer, "image_shape", shape) def __getitem__(self, idxs): self.sample_shape() return [super(ConcatDataset, self).__getitem__(idx) for idx in idxs] def _paddings(image_shape, network_shape): cur_h, cur_w = image_shape h, w = network_shape pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2 pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2 return pad_left, pad_right, pad_top, pad_bottom def collate_fn(in_data: List[List[Dict[str, Any]]], is_batched: bool = True): out_data = defaultdict(list) img_metas = [] in_data = in_data[0] if is_batched else in_data # get max_shape and paddings shapes = [tensor.shape[-2:] for x in in_data for tensor in x["depth"].values()] max_shape_tuple = tuple(max(elements) for elements in zip(*shapes)) paddings = [ [ _paddings(tensor.shape[-2:], max_shape_tuple) for tensor in x["depth"].values() ] for x in in_data ] for x in in_data: # here iter over batches padding = paddings.pop(0) for k, v in x.items(): if "img_metas" not in k: values = list(v.values()) v = torch.cat(values) out_data[k].append(v) else: v["depth_paddings"] = padding img_metas.append(v) output_dict = { "data": {k: torch.stack(v, dim=0) for k, v in out_data.items()}, "img_metas": img_metas, } # camera are always flattened and the stack/cat so if list of B times (T, 3, 3) cameras # it goes to (B * T, 3, 3), to be consistent with the image shape -> reshape if "camera" in output_dict["data"]: output_dict["data"]["camera"] = output_dict["data"]["camera"].reshape( *output_dict["data"]["image"].shape[:2] ) return output_dict def local_scatter(array: list[Any]): """ Scatter an array from local leader to all local workers. The i-th local worker gets array[i]. Args: array: Array with same size of #local workers. """ if get_world_size() == 1: return array[0] if get_local_rank() == 0: assert len(array) == get_local_size() all_gather(array) else: all_data = all_gather(None) array = all_data[get_rank() - get_local_rank()] return array[get_local_rank()] class DatasetFromList(torch.utils.data.Dataset): # type: ignore """Wrap a list to a torch Dataset. We serialize and wrap big python objects in a torch.Dataset due to a memory leak when dealing with large python objects using multiple workers. See: https://github.com/pytorch/pytorch/issues/13246 """ def __init__(self, lst: List[Any], deepcopy: bool = False, serialize: bool = True): """Creates an instance of the class. Args: lst: a list which contains elements to produce. deepcopy: whether to deepcopy the element when producing it, s.t. the result can be modified in place without affecting the source in the list. serialize: whether to hold memory using serialized objects. When enabled, data loader workers can use shared RAM from master process instead of making a copy. """ self._copy = deepcopy self._serialize = serialize def _serialize(data: Any): buffer = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL) return torch.frombuffer(buffer, dtype=torch.uint8) if self._serialize: # load only on 0th rank if get_local_rank() == 0: _lst = [_serialize(x) for x in lst] self._addr = torch.cumsum( torch.tensor([len(x) for x in _lst], dtype=torch.int64), dim=0 ) self._lst = torch.concatenate(_lst) # Move data to shared memory, obtain a handle to send to each local worker. handles = [None] + [ bytes(mp.reduction.ForkingPickler.dumps((self._addr, self._lst))) for _ in range(get_local_size() - 1) ] else: handles = None # Each worker receives the handle from local leader (rank 0) # then materialize the tensor from shared memory handle = local_scatter(handles) if get_local_rank() > 0: self._addr, self._lst = mp.reduction.ForkingPickler.loads(handle) else: self._lst = lst def __len__(self) -> int: """Return len of list.""" if self._serialize: return len(self._addr) return len(self._lst) def __getitem__(self, idx: int) -> Any: """Return item of list at idx.""" if self._serialize: start_addr = 0 if idx == 0 else self._addr[idx - 1] end_addr = self._addr[idx] bytes_ = memoryview(self._lst[start_addr:end_addr].numpy()) return pickle.loads(bytes_) if self._copy: return copy.deepcopy(self._lst[idx]) return self._lst[idx] def get_weights( train_datasets: dict[str, torch.utils.data.Dataset], sampling: dict[str, float] ) -> torch.Tensor: from .image_dataset import ImageDataset from .sequence_dataset import SequenceDataset weights = [] num_samples = 0 info_weights = {} for dataset_name, dataset in train_datasets.items(): assert ( dataset_name in sampling ), f"Dataset {dataset_name} not found in {sampling.keys()}" if isinstance(dataset, ImageDataset): # sum of all samples has weight as in sampling s.t. sampling dataset in general is as in sampling # inside is uniform weight = sampling[dataset_name] / len(dataset) weights.append(torch.full((len(dataset),), weight).double()) num_samples += len(dataset) elif isinstance(dataset, SequenceDataset): # local weight is num_samples, but global must be as in sampling # hence is num_samples / (sum num_samples / sampling[dataset_name]) # s.t. sampling anything from the dataset is # sum(num_samples / (sum num_samples / sampling[dataset_name])) # -> sampling[dataset_name] numerator = [int(data["num_samples"]) for data in dataset.dataset] weights.append( sampling[dataset_name] * torch.tensor(numerator).double() / sum(numerator) ) num_samples += sum(numerator) else: weight = sampling[dataset_name] / len(dataset) weights.append(torch.full((len(dataset),), weight).double()) info_weights[dataset_name] = weights[-1][-1] return torch.cat(weights), num_samples ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/utils_decode.py ================================================ import io import cv2 import numpy as np import torch import torchvision import torchvision.transforms.v2.functional as TF from PIL import Image from unidepth.utils.camera import (EUCM, MEI, BatchCamera, Fisheye624, Pinhole, Spherical) def decode_depth(results, h5file, value, idx, depth_scale, name="depth", **kwargs): file = h5file.get_node("/" + value).read() decoded_data = Image.open(io.BytesIO(file)) decoded_data = TF.pil_to_tensor(decoded_data).squeeze() if decoded_data.ndim == 3: # 24 channel loading decoded_channels = [ (decoded_data[0] & 0xFF).to(torch.int32), (decoded_data[1] & 0xFF).to(torch.int32), (decoded_data[2] & 0xFF).to(torch.int32), ] # Reshape and extract the original depth map decoded_data = ( decoded_channels[0] | (decoded_channels[1] << 8) | (decoded_channels[2] << 16) ) decoded_data = decoded_data.to(torch.float32) results.get("gt_fields", set()).add(name) results[(idx, 0)].get("gt_fields", set()).add(name) results[f"{name}_ori_shape"] = decoded_data.shape results[(idx, 0)][name] = ( decoded_data.view(1, 1, *decoded_data.shape).contiguous() / depth_scale ) return results def decode_numpy(results, h5file, value, idx, name="points", **kwargs): file = h5file.get_node("/" + value).read() decoded_data = np.load(io.BytesIO(file), allow_pickle=False) decoded_data = torch.from_numpy(decoded_data).to(torch.float32) if decoded_data.ndim > 2: decoded_data = decoded_data.permute(2, 0, 1) results.get("gt_fields", set()).add(name) results[(idx, 0)].get("gt_fields", set()).add(name) results[(idx, 0)][name] = decoded_data.unsqueeze(0) return results def decode_tensor(results, value, idx, name, **kwargs): results.get("camera_fields", set()).add(name) results[(idx, 0)].get("camera_fields", set()).add(name) results[(idx, 0)][name] = torch.tensor(value).unsqueeze(0) return results def decode_camera(results, value, idx, name, sample, j, **kwargs): results.get("camera_fields", set()).add(name) results[(idx, 0)].get("camera_fields", set()).add(name) camera = eval(sample["camera_model"][j])(params=torch.tensor(value).unsqueeze(0)) results[(idx, 0)][name] = BatchCamera.from_camera(camera) return results def decode_K(results, value, idx, name, **kwargs): results.get("camera_fields", set()).add(name) results[(idx, 0)].get("camera_fields", set()).add(name) camera = Pinhole(K=torch.tensor(value).unsqueeze(0)) results[(idx, 0)][name] = BatchCamera.from_camera(camera) return results def decode_mask(results, h5file, value, idx, name, **kwargs): file = h5file.get_node("/" + value).read() mask = torchvision.io.decode_image(torch.from_numpy(file)).bool().squeeze() results.get("mask_fields", set()).add(name) results[(idx, 0)].get("mask_fields", set()).add(name) results[f"{name}_ori_shape"] = mask.shape[-2:] results[(idx, 0)][name] = mask.view(1, 1, *mask.shape).contiguous() return results def decode_rgb(results, h5file, value, idx, name="image", **kwargs): file = h5file.get_node("/" + value).read() image = ( torchvision.io.decode_image(torch.from_numpy(file)).to(torch.uint8).squeeze() ) results.get("image_fields", set()).add(name) results[(idx, 0)].get("image_fields", set()).add(name) results[f"{name}_ori_shape"] = image.shape[-2:] if image.ndim == 2: image = image.unsqueeze(0).repeat(3, 1, 1) results[(idx, 0)][name] = image.unsqueeze(0) return results def decode_flow(results, h5file, value, idx, name, **kwargs): file = h5file.get_node("/" + value).read() image = ( torchvision.io.decode_image(torch.from_numpy(file)).to(torch.uint8).squeeze() ) decoded_channels = [ (image[0] & 0xFF).to(torch.int16), (image[1] & 0xFF).to(torch.int16), (image[2] & 0xFF).to(torch.int16), ] # Reshape and extract the original 2-channel flow map flow = torch.zeros((2, image.shape[1], image.shape[2]), dtype=torch.int16) flow[0] = (decoded_channels[0] | decoded_channels[1] << 8) & 0xFFF flow[1] = (decoded_channels[1] >> 4 | decoded_channels[2] << 4) & 0xFFF results.get("gt_fields", set()).add(name) results[(idx, 0)].get("gt_fields", set()).add(name) results[f"{name}_ori_shape"] = flow.shape[-2:] flow = flow.unsqueeze(0).contiguous().float() results[(idx, 0)][name] = (0.5 + flow) / 4095.0 * 2 - 1 return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/vkitti.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class VKITTI(SequenceDataset): min_depth = 0.01 max_depth = 255.0 depth_scale = 256.0 test_split = "training.txt" train_split = "training.txt" sequences_file = "sequences.json" hdf5_paths = ["VKITTI2.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth", "flow_fwd", "flow_fwd_mask"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["synthetic"] = [True] * self.num_frames * self.num_copies results["quality"] = [0] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/void.py ================================================ import json import os import h5py import numpy as np import torch from unidepth.datasets.image_dataset import ImageDataset from unidepth.datasets.utils import DatasetFromList class VOID(ImageDataset): min_depth = 0.01 max_depth = 10.0 depth_scale = 256.0 test_split = "void_val.txt" train_split = "void_train.txt" intrisics_file = "void_intrinsics.json" hdf5_paths = ["void.hdf5"] def __init__( self, image_shape, split_file, test_mode, crop=None, benchmark=False, augmentations_db={}, normalize=True, resize_method="hard", mini=1.0, **kwargs, ): super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, **kwargs, ) self.test_mode = test_mode self.crop = crop self.load_dataset() def load_dataset(self): h5file = h5py.File( os.path.join(self.data_root, self.hdf5_paths[0]), "r", libver="latest", swmr=True, ) txt_file = np.array(h5file[self.split_file]) txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") intrinsics = json.loads(intrinsics) h5file.close() dataset = [] for line in txt_string.split("\n"): image_filename, depth_filename = line.strip().split(" ") intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] sample = [image_filename, depth_filename, intrinsics_val] dataset.append(sample) if not self.test_mode: dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) self.dataset = DatasetFromList(dataset) self.log_load_dataset() def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_copies results["quality"] = [2] * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/waymo.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class Waymo(SequenceDataset): min_depth = 0.05 max_depth = 70.0 depth_scale = 256.0 test_split = "validation.txt" train_split = "training.txt" sequences_file = "sequences.json" hdf5_paths = [f"Waymo_viz.hdf5"] def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [False] * self.num_frames * self.num_copies results["synthetic"] = [False] * self.num_frames * self.num_copies results["quality"] = [1] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/wildrgbd.py ================================================ from typing import Any from unidepth.datasets.sequence_dataset import SequenceDataset class WildRGBD(SequenceDataset): min_depth = 0.01 max_depth = 10.0 depth_scale = 1000.0 test_split = "train.txt" train_split = "train.txt" sequences_file = "sequences.json" hdf5_paths = ["WildRGBD.hdf5"] default_fps = 30 def __init__( self, image_shape: tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: dict[str, Any], resize_method: str, mini: float = 1.0, num_frames: int = 1, benchmark: bool = False, decode_fields: list[str] = ["image", "depth"], inplace_fields: list[str] = ["K", "cam2w"], **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, resize_method=resize_method, mini=mini, num_frames=num_frames, decode_fields=decode_fields, inplace_fields=inplace_fields, **kwargs, ) def pre_pipeline(self, results): results = super().pre_pipeline(results) results["dense"] = [True] * self.num_frames * self.num_copies results["quality"] = [1] * self.num_frames * self.num_copies return results ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/__init__.py ================================================ from .activation import GEGLU, SwiGLU from .attention import AttentionBlock, AttentionDecoderBlock, AttentionLayer from .convnext import CvnxtBlock from .mlp import MLP from .nystrom_attention import NystromBlock from .positional_encoding import PositionEmbeddingSine from .upsample import (ConvUpsample, ConvUpsampleShuffle, ConvUpsampleShuffleResidual, ResUpsampleBil) __all__ = [ "SwiGLU", "GEGLU", "CvnxtBlock", "AttentionBlock", "NystromBlock", "PositionEmbeddingSine", "ConvUpsample", "MLP", "ConvUpsampleShuffle", "AttentionDecoderBlock", "ConvUpsampleShuffleResidual", ] ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/activation.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class SwiGLU(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: x, gates = x.chunk(2, dim=-1) return x * F.silu(gates) class GEGLU(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: x, gates = x.chunk(2, dim=-1) return x * F.gelu(gates) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/attention.py ================================================ """ Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ from functools import partial import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from .layer_scale import LayerScale from .mlp import MLP class SimpleAttention(nn.Module): def __init__( self, dim: int, num_heads: int = 4, dropout: float = 0.0, cosine: bool = False, context_dim: int | None = None, ): super().__init__() self.dropout = dropout self.num_heads = num_heads self.hidden_dim = dim context_dim = context_dim or dim self.kv = nn.Linear(context_dim, dim * 2, bias=False) self.q = nn.Linear(dim, dim, bias=False) self.norm_attnx = nn.LayerNorm(dim) self.norm_attnctx = nn.LayerNorm(context_dim) self.cosine = cosine self.out = nn.Linear(dim, dim) def forward( self, x: torch.Tensor, attn_bias: torch.Tensor | None = None, context: torch.Tensor | None = None, pos_embed: torch.Tensor | None = None, pos_embed_context: torch.Tensor | None = None, rope: nn.Module | None = None, ) -> torch.Tensor: context = x if context is None else context x = self.norm_attnx(x) context = self.norm_attnctx(context) k, v = rearrange( self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2 ).unbind(dim=-1) q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads) if rope is not None: q = rope(q) k = rope(k) else: if pos_embed is not None: pos_embed = rearrange( pos_embed, "b n (h d) -> b h n d", h=self.num_heads ) q = q + pos_embed if pos_embed_context is not None: pos_embed_context = rearrange( pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads ) k = k + pos_embed_context if self.cosine: q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim x = F.scaled_dot_product_attention( q, k, v, dropout_p=self.dropout, attn_mask=attn_bias ) x = rearrange(x, "b h n d -> b n (h d)") x = self.out(x) return x class AttentionBlock(nn.Module): def __init__( self, dim: int, num_heads: int = 4, expansion: int = 4, dropout: float = 0.0, cosine: bool = False, gated: bool = False, layer_scale: float = 1.0, context_dim: int | None = None, use_bias: bool = True, ): super().__init__() self.dropout = dropout self.num_heads = num_heads self.hidden_dim = dim context_dim = context_dim or dim self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated) self.kv = nn.Linear(context_dim, dim * 2, bias=use_bias) self.q = nn.Linear(dim, dim, bias=use_bias) self.norm_attnx = nn.LayerNorm(dim) self.norm_attnctx = nn.LayerNorm(context_dim) self.cosine = cosine self.out = nn.Linear(dim, dim, bias=use_bias) self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity() self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity() def attn( self, x: torch.Tensor, attn_bias: torch.Tensor | None = None, context: torch.Tensor | None = None, pos_embed: torch.Tensor | None = None, pos_embed_context: torch.Tensor | None = None, ) -> torch.Tensor: x = self.norm_attnx(x) context = self.norm_attnctx(context) k, v = rearrange( self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2 ).unbind(dim=-1) q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads) if pos_embed is not None: pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=self.num_heads) q = q + pos_embed if pos_embed_context is not None: pos_embed_context = rearrange( pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads ) k = k + pos_embed_context if self.cosine: q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim x = F.scaled_dot_product_attention( q, k, v, dropout_p=self.dropout, attn_mask=attn_bias ) x = rearrange(x, "b h n d -> b n (h d)") x = self.out(x) return x def forward( self, x: torch.Tensor, attn_bias: torch.Tensor | None = None, context: torch.Tensor | None = None, pos_embed: torch.Tensor | None = None, pos_embed_context: torch.Tensor | None = None, ) -> torch.Tensor: context = x if context is None else context x = ( self.ls1( self.attn( x, attn_bias=attn_bias, context=context, pos_embed=pos_embed, pos_embed_context=pos_embed_context, ) ) + x ) x = self.ls2(self.mlp(x)) + x return x class AttentionLayer(nn.Module): def __init__( self, num_blocks: int, dim: int, num_heads: int = 4, expansion: int = 4, dropout: float = 0.0, cosine: bool = False, gated: bool = False, layer_scale: float = 1.0, context_dim: int | None = None, use_bias: bool = True, ): super().__init__() self.layers = nn.ModuleList( [ AttentionBlock( dim=dim, num_heads=num_heads, expansion=expansion, dropout=dropout, cosine=cosine, gated=gated, layer_scale=layer_scale, context_dim=context_dim, use_bias=use_bias, ) for _ in range(num_blocks) ] ) def forward( self, x: torch.Tensor, context: torch.Tensor | None = None, pos_embed: torch.Tensor | None = None, pos_embed_context: torch.Tensor | None = None, attn_bias: torch.Tensor | None = None, ) -> torch.Tensor: for layer in self.layers: x = layer( x, context=context, pos_embed=pos_embed, pos_embed_context=pos_embed_context, attn_bias=attn_bias, ) return x class AttentionDecoderBlock(nn.Module): def __init__( self, dim: int, num_heads: int = 4, expansion: int = 4, dropout: float = 0.0, cosine: bool = False, gated: bool = False, layer_scale: float = 1.0, context_dim: int | None = None, single_head_ca: bool = True, ): super().__init__() self.dropout = dropout self.num_heads = num_heads self.hidden_dim = dim self.single_head_ca = single_head_ca context_dim = context_dim or dim self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated) self.kv_ca = nn.Linear(context_dim, dim * 2) self.q_ca = nn.Linear(dim, dim) self.kv_sa = nn.Linear(dim, dim * 2) self.q_sa = nn.Linear(dim, dim) self.norm_x_sa = nn.LayerNorm(dim) self.norm_x_ca = nn.LayerNorm(dim) self.norm_ctx_ca = nn.LayerNorm(context_dim) self.cosine = cosine self.out_ca = nn.Linear(dim, dim) self.out_sa = nn.Linear(dim, dim) self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity() self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity() self.ls3 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity() def cross_attn( self, x: torch.Tensor, attn_bias: torch.Tensor | None = None, context: torch.Tensor | None = None, pos_embed: torch.Tensor | None = None, pos_embed_context: torch.Tensor | None = None, rope: nn.Module | None = None, ) -> torch.Tensor: num_heads = 1 if self.single_head_ca else self.num_heads x = self.norm_x_ca(x) context = self.norm_ctx_ca(context) k, v = rearrange( self.kv_ca(context), "b n (kv h d) -> b h n d kv", h=num_heads, kv=2 ).unbind(dim=-1) q = rearrange(self.q_ca(x), "b n (h d) -> b h n d", h=num_heads) if rope is not None: q = rope(q) k = rope(k) else: if pos_embed is not None: pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=num_heads) q = q + pos_embed if pos_embed_context is not None: pos_embed_context = rearrange( pos_embed_context, "b n (h d) -> b h n d", h=num_heads ) k = k + pos_embed_context if self.cosine: q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim x = F.scaled_dot_product_attention( q, k, v, dropout_p=self.dropout, attn_mask=attn_bias ) x = rearrange(x, "b h n d -> b n (h d)") x = self.out_ca(x) return x def self_attn( self, x: torch.Tensor, attn_bias: torch.Tensor | None = None, pos_embed: torch.Tensor | None = None, rope: nn.Module | None = None, ) -> torch.Tensor: x = self.norm_x_sa(x) k, v = rearrange( self.kv_sa(x), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2 ).unbind(dim=-1) q = rearrange(self.q_sa(x), "b n (h d) -> b h n d", h=self.num_heads) if rope is not None: q = rope(q) k = rope(k) elif pos_embed is not None: pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=self.num_heads) q = q + pos_embed if self.cosine: q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim x = F.scaled_dot_product_attention( q, k, v, dropout_p=self.dropout, attn_mask=attn_bias ) x = rearrange(x, "b h n d -> b n (h d)") x = self.out_sa(x) return x def forward( self, x: torch.Tensor, attn_bias: torch.Tensor | None = None, context: torch.Tensor | None = None, pos_embed: torch.Tensor | None = None, pos_embed_context: torch.Tensor | None = None, rope: nn.Module | None = None, ) -> torch.Tensor: context = x if context is None else context x = ( self.ls1( self.cross_attn( x, rope=rope, attn_bias=attn_bias, context=context, pos_embed=pos_embed, pos_embed_context=pos_embed_context, ) ) + x ) x = ( self.ls2( self.self_attn(x, rope=rope, attn_bias=attn_bias, pos_embed=pos_embed) ) + x ) x = self.ls3(self.mlp(x)) + x return x ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/convnext.py ================================================ import torch import torch.nn as nn class CvnxtBlock(nn.Module): def __init__( self, dim, kernel_size=7, layer_scale=1.0, expansion=4, dilation=1, padding_mode: str = "zeros", ): super().__init__() self.dwconv = nn.Conv2d( dim, dim, kernel_size=kernel_size, padding=dilation * (kernel_size - 1) // 2, groups=dim, dilation=dilation, padding_mode=padding_mode, ) # depthwise conv self.norm = nn.LayerNorm(dim) self.pwconv1 = nn.Linear(dim, expansion * dim) self.act = nn.GELU() self.pwconv2 = nn.Linear(expansion * dim, dim) self.gamma = ( nn.Parameter(layer_scale * torch.ones((dim))) if layer_scale > 0.0 else 1.0 ) def forward(self, x): input = x x = self.dwconv(x) x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.pwconv2(x) x = self.gamma * x x = input + x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) return x ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/drop_path.py ================================================ import torch import torch.nn as nn def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False): if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * ( x.ndim - 1 ) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0: random_tensor.div_(keep_prob) output = x * random_tensor return output class DropPath(nn.Module): def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/layer_scale.py ================================================ import torch import torch.nn as nn class LayerScale(nn.Module): def __init__( self, dim: int, init_values: float | torch.Tensor = 1e-5, inplace: bool = False, ) -> None: super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/mlp.py ================================================ import torch import torch.nn as nn from unidepth.utils.misc import default from .activation import SwiGLU class MLP(nn.Module): def __init__( self, input_dim: int, expansion: int = 4, dropout: float = 0.0, gated: bool = False, output_dim: int | None = None, ): super().__init__() if gated: expansion = int(expansion * 2 / 3) hidden_dim = int(input_dim * expansion) output_dim = default(output_dim, input_dim) self.norm = nn.LayerNorm(input_dim) self.proj1 = nn.Linear(input_dim, hidden_dim) self.proj2 = nn.Linear(hidden_dim, output_dim) self.act = nn.GELU() if not gated else SwiGLU() self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.norm(x) x = self.proj1(x) x = self.act(x) x = self.proj2(x) x = self.dropout(x) return x ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/nystrom.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import logging from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn from xformers.components.attention import Attention, AttentionConfig, register_attention from xformers.components.attention.core import ( scaled_dot_product_attention, scaled_query_key_softmax, ) from xformers.components.attention.utils import ( bool_mask_to_additive, iterative_pinv, reshape_key_padding_mask, ) logger = logging.getLogger("xformers") @dataclass class NystromSelfAttentionConfig(AttentionConfig): """ num_heads Number of heads. num_landmarks Number of landmarks to use for softmax approximation. 64 often sufficient for a good approximation according to https://arxiv.org/pdf/2102.03902.pdf. causal Apply a causal mask, in that the attention cannot be applied to the future. use_razavi_pinverse If true, use iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose inverse, otherwise use standard torch inverse. pinverse_original_init True if using original initialization when calculating Moore-Penrose pseudo inverse using method from (Razavi et al. 2014). False if using exact coefficient computation (leads to faster convergence). inv_iterations Number of iterations for calculating the Moore-Penrose pseudo inverse. v_skip_connection A module that will take V as input and will be added as a skip connection to the softmax approximation. A skip connection is added in the paper to help with training. conv_kernel_size Kernel size for convolution optionally added to help in training. If v_skip_connection is not specified, this will be used to define the default depth wise convolution used as a skip connection. If both conv_kernel_size and v_skip_connection are None, no skip connection will be added. landmark_pooling Which module to use when computing landmarks. Default is AdaptiveAvgPool2d. """ num_heads: int num_landmarks: Optional[int] landmark_pooling: Optional[nn.Module] causal: Optional[bool] pinverse_original_init: Optional[bool] inv_iterations: Optional[int] v_skip_connection: Optional[nn.Module] conv_kernel_size: Optional[int] use_razavi_pinverse: Optional[bool] class AvgPool(nn.Module): def __init__(self, n: int): super().__init__() self.n = n def forward(self, x: torch.Tensor): # Average independently for every segment in the sequence dimension seq_len = x.shape[1] head_dim = x.shape[2] segments = seq_len // self.n assert segments > 0, "num_landmarks should be smaller than the sequence length" # Dimensions are a match if seq_len % self.n == 0: return x.reshape( -1, self.n, segments, head_dim, ).mean(dim=-2) # Handle the last segment boundary being off n_round = self.n - seq_len % self.n x_avg_round = ( x[:, : n_round * segments, :] .reshape(-1, n_round, segments, head_dim) .mean(dim=-2) ) x_avg_off = ( x[:, n_round * segments :, :] .reshape(-1, self.n - n_round, segments + 1, head_dim) .mean(dim=-2) ) return torch.cat((x_avg_round, x_avg_off), dim=-2) @register_attention("nystrom", NystromSelfAttentionConfig) class NystromAttention(Attention): # TODO: update defaults for use_razavi_pinverse and inv_iterations def __init__( self, dropout: float, num_heads: int, num_landmarks: int = 64, landmark_pooling: Optional[nn.Module] = None, causal: bool = False, use_razavi_pinverse: bool = True, pinverse_original_init: bool = False, inv_iterations: int = 6, # recommended default in paper was 6. v_skip_connection: Optional[nn.Module] = None, conv_kernel_size: Optional[int] = None, *args, **kwargs, ): """ Nystrom attention mechanism, from Nystromformer_. :: "A Nystrom-based Algorithm for Approximating Self-Attention." Xiong, Y., Zeng, Z., Chakraborty, R., Tan, M., Fung, G., Li, Y., Singh, V. (2021) Reference codebase: https://github.com/mlpen/Nystromformer .. _Nystromformer: https://arxiv.org/pdf/2102.03902.pdf """ super().__init__() # merged key padding mask and attention mask is not accepted self.requires_separate_masks = True self.num_landmarks = num_landmarks # TODO: should be able to not have to pass in num_heads self.num_heads = num_heads self.use_razavi_pinverse = use_razavi_pinverse self.pinverse_original_init = pinverse_original_init self.inv_iterations = inv_iterations self.attn_drop = nn.Dropout(dropout) self.skip_connection = v_skip_connection self.causal = causal if self.skip_connection is None and conv_kernel_size is not None: self.skip_connection = nn.Conv2d( in_channels=self.num_heads, out_channels=self.num_heads, kernel_size=(conv_kernel_size, 1), padding=(conv_kernel_size // 2, 0), bias=False, groups=self.num_heads, ) if landmark_pooling is not None: self.landmark_pooling = landmark_pooling else: self.landmark_pooling = AvgPool(n=self.num_landmarks) # Optional lower triangular masks for causal attention self.causal_mask_1: Optional[torch.Tensor] = None self.causal_mask_2: Optional[torch.Tensor] = None self.causal_mask_3: Optional[torch.Tensor] = None # This attention does not support attention masks self.supports_attention_mask = False self.supports_key_padding_mask = True def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None, *args, **kwargs, ): r""" key_padding_mask Only a key padding mask is accepted here. The size must be (batch size, sequence length) or (batch size * num_heads, 1, sequence length). If dimensions are not correct, the mask will be ignored. An additive mask is expected, meaning float values using "-inf" to mask values """ batched_dim = k.size(0) seq_len = k.size(-2) tt = {"dtype": q.dtype, "device": q.device} if key_padding_mask is not None: if key_padding_mask.dtype == torch.bool: logger.warning( "Bool mask found, but an additive mask is expected. Converting but this is slow" ) key_padding_mask = bool_mask_to_additive(key_padding_mask) if key_padding_mask.ndim == 2: key_padding_mask = reshape_key_padding_mask( key_padding_mask, batched_dim ) zeros = torch.zeros_like(key_padding_mask) ones = torch.ones_like(key_padding_mask) is_masked = torch.isinf(-key_padding_mask) # _mask takes 1 if the token is not padded, otherwise 0. _mask = torch.where(is_masked, zeros, ones) _mask = _mask.transpose(2, 1) assert _mask.shape == (batched_dim, q.shape[1], 1) # Mask q and k before pooling # https://github.com/mlpen/Nystromformer/blob/main/code/attention_nystrom.py#L31 q = q * _mask k = k * _mask assert key_padding_mask.size() == (batched_dim, 1, seq_len), ( f"key_padding_mask has invalid dimensions {key_padding_mask.size()}." f" Must have dimensions {batched_dim, 1, seq_len} or (batch_size, {seq_len})." ) if self.num_landmarks >= seq_len: mask: Optional[torch.Tensor] = None if self.causal: mask = self._triu_mask(batched_dim, seq_len, seq_len, **tt) if key_padding_mask is not None: mask = key_padding_mask if mask is None else mask + key_padding_mask x = scaled_dot_product_attention(q=q, k=k, v=v, att_mask=mask) else: q_landmarks = self.landmark_pooling(q) k_landmarks = self.landmark_pooling(k) if self.causal and ( self.causal_mask_1 is None or (batched_dim, seq_len, self.num_landmarks) != self.causal_mask_1.size() ): self.causal_mask_1 = self._triu_mask( batched_dim, seq_len, self.num_landmarks, **tt ) self.causal_mask_2 = self._triu_mask( batched_dim, self.num_landmarks, self.num_landmarks, **tt ) self.causal_mask_3 = self._triu_mask( batched_dim, self.num_landmarks, seq_len, **tt ) mask_3: Optional[torch.Tensor] = self.causal_mask_3 if key_padding_mask is not None: mask_3 = ( key_padding_mask if mask_3 is None else mask_3 + key_padding_mask ) kernel_1 = scaled_query_key_softmax(q=q, k=k_landmarks, att_mask=None) kernel_2 = scaled_query_key_softmax( q=q_landmarks, k=k_landmarks, att_mask=None ) kernel_3 = scaled_dot_product_attention( q=q_landmarks, k=k, v=v, att_mask=mask_3 ) kernel_2_inv = ( iterative_pinv( kernel_2, self.inv_iterations, self.pinverse_original_init ) if self.use_razavi_pinverse else torch.linalg.pinv(kernel_2) ) x = torch.matmul( torch.matmul( kernel_1, kernel_2_inv, ), kernel_3, ) if self.skip_connection: # Assumption here is that v is 3D. v_conv = self.skip_connection( v.reshape(-1, self.num_heads, v.size(-2), v.size(-1)) ) x += v_conv.reshape(-1, v_conv.size(-2), v_conv.size(-1)) x = self.attn_drop(x) return x def _triu_mask(self, dim_1: int, dim_2: int, dim_3: int, **kwargs) -> torch.Tensor: device = kwargs["device"] dtype = kwargs["dtype"] return torch.triu( torch.ones(dim_2, dim_3, dtype=dtype, device=device) * float("-inf"), diagonal=1, ).expand( dim_1, -1, -1 ) # micro optim, save memory on the batch dimension ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/nystrom_attention.py ================================================ from functools import partial import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from .nystrom import NystromAttention from .attention import AttentionBlock class NystromBlock(AttentionBlock): def __init__( self, dim: int, num_heads: int = 4, expansion: int = 4, dropout: float = 0.0, cosine: bool = False, gated: bool = False, layer_scale: float = 1.0, context_dim: int | None = None, ): super().__init__( dim=dim, num_heads=num_heads, expansion=expansion, dropout=dropout, cosine=cosine, gated=gated, layer_scale=layer_scale, context_dim=context_dim, ) self.attention_fn = NystromAttention( num_landmarks=128, num_heads=num_heads, dropout=dropout ) def attn( self, x: torch.Tensor, attn_bias: torch.Tensor | None = None, context: torch.Tensor | None = None, pos_embed: torch.Tensor | None = None, pos_embed_context: torch.Tensor | None = None, rope: nn.Module | None = None, ) -> torch.Tensor: x = self.norm_attnx(x) context = self.norm_attnctx(context) k, v = rearrange( self.kv(context), "b n (kv h d) -> b n h d kv", h=self.num_heads, kv=2 ).unbind(dim=-1) q = rearrange(self.q(x), "b n (h d) -> b n h d", h=self.num_heads) if rope is not None: q = rope(q) k = rope(k) else: if pos_embed is not None: pos_embed = rearrange( pos_embed, "b n (h d) -> b n h d", h=self.num_heads ) q = q + pos_embed if pos_embed_context is not None: pos_embed_context = rearrange( pos_embed_context, "b n (h d) -> b n h d", h=self.num_heads ) k = k + pos_embed_context if self.cosine: q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim x = self.attention_fn(q, k, v, key_padding_mask=attn_bias) x = rearrange(x, "b n h d -> b n (h d)") x = self.out(x) return x ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/positional_encoding.py ================================================ """ Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ from math import pi from typing import Optional import torch import torch.nn as nn from einops import rearrange, repeat class PositionEmbeddingSine(nn.Module): def __init__( self, num_pos_feats=64, temperature=10000, normalize=False, scale=None ): super().__init__() self.num_pos_feats = num_pos_feats self.temperature = temperature self.normalize = normalize if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") if scale is None: scale = 2 * pi self.scale = scale def forward( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: if mask is None: mask = torch.zeros( (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool ) not_mask = ~mask y_embed = not_mask.cumsum(1, dtype=torch.float32) x_embed = not_mask.cumsum(2, dtype=torch.float32) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = self.temperature ** ( 2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats ) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack( (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos_y = torch.stack( (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) return pos def __repr__(self, _repr_indent=4): head = "Positional encoding " + self.__class__.__name__ body = [ "num_pos_feats: {}".format(self.num_pos_feats), "temperature: {}".format(self.temperature), "normalize: {}".format(self.normalize), "scale: {}".format(self.scale), ] # _repr_indent = 4 lines = [head] + [" " * _repr_indent + line for line in body] return "\n".join(lines) class LearnedSinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() assert (dim % 2) == 0 half_dim = dim // 2 self.weights = nn.Parameter(torch.randn(half_dim)) def forward(self, x): x = rearrange(x, "b -> b 1") freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) fouriered = torch.cat((x, fouriered), dim=-1) return fouriered def generate_fourier_features(x, max_freq=64, num_bands=16): x = x.unsqueeze(-1) device, dtype, orig_x = x.device, x.dtype, x scales = torch.linspace( -max_freq / 2, max_freq / 2, num_bands, device=device, dtype=dtype ) scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] x = x * scales * pi x = torch.cat([x.sin(), x.cos()], dim=-1) x = torch.cat((x, orig_x), dim=-1) return x.flatten(-2) def broadcat(tensors, dim=-1): num_tensors = len(tensors) shape_lens = set(list(map(lambda t: len(t.shape), tensors))) assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" shape_len = list(shape_lens)[0] dim = (dim + shape_len) if dim < 0 else dim dims = list(zip(*map(lambda t: list(t.shape), tensors))) expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] assert all( [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] ), "invalid dimensions for broadcastable concatentation" max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) expanded_dims.insert(dim, (dim, dims[dim])) expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) return torch.cat(tensors, dim=dim) def rotate_half(x): x = rearrange(x, "... (d r) -> ... d r", r=2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return rearrange(x, "... d r -> ... (d r)") class VisionRotaryEmbedding(nn.Module): def __init__( self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, ): super().__init__() if custom_freqs: freqs = custom_freqs elif freqs_for == "lang": freqs = 1.0 / ( theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) ) elif freqs_for == "pixel": freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi elif freqs_for == "constant": freqs = torch.ones(num_freqs).float() else: raise ValueError(f"unknown modality {freqs_for}") if ft_seq_len is None: ft_seq_len = pt_seq_len t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len freqs_h = torch.einsum("..., f -> ... f", t, freqs) freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) freqs_w = torch.einsum("..., f -> ... f", t, freqs) freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1) self.register_buffer("freqs_cos", freqs.cos()) self.register_buffer("freqs_sin", freqs.sin()) print("======== shape of rope freq", self.freqs_cos.shape, "========") def forward(self, t, start_index=0): rot_dim = self.freqs_cos.shape[-1] end_index = start_index + rot_dim assert ( rot_dim <= t.shape[-1] ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" t_left, t, t_right = ( t[..., :start_index], t[..., start_index:end_index], t[..., end_index:], ) t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) return torch.cat((t_left, t, t_right), dim=-1) class VisionRotaryEmbeddingFast(nn.Module): def __init__( self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, ): super().__init__() if custom_freqs: freqs = custom_freqs elif freqs_for == "lang": freqs = 1.0 / ( theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) ) elif freqs_for == "pixel": freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi elif freqs_for == "constant": freqs = torch.ones(num_freqs).float() else: raise ValueError(f"unknown modality {freqs_for}") if ft_seq_len is None: ft_seq_len = pt_seq_len t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len freqs = torch.einsum("..., f -> ... f", t, freqs) freqs = repeat(freqs, "... n -> ... (n r)", r=2) freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) self.register_buffer("freqs_cos", freqs_cos) self.register_buffer("freqs_sin", freqs_sin) def forward(self, t): return t * self.freqs_cos + rotate_half(t) * self.freqs_sin ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/upsample.py ================================================ """ Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ import torch import torch.nn as nn from einops import rearrange from .convnext import CvnxtBlock class ConvUpsample(nn.Module): def __init__( self, hidden_dim, num_layers: int = 2, expansion: int = 4, layer_scale: float = 1.0, kernel_size: int = 7, **kwargs, ): super().__init__() self.convs = nn.ModuleList([]) for _ in range(num_layers): self.convs.append( CvnxtBlock( hidden_dim, kernel_size=kernel_size, expansion=expansion, layer_scale=layer_scale, ) ) self.up = nn.Sequential( nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=1, padding=0), nn.UpsamplingBilinear2d(scale_factor=2), nn.Conv2d(hidden_dim // 2, hidden_dim // 2, kernel_size=3, padding=1), ) def forward(self, x: torch.Tensor): for conv in self.convs: x = conv(x) x = self.up(x) x = rearrange(x, "b c h w -> b (h w) c") return x class ConvUpsampleShuffle(nn.Module): def __init__( self, hidden_dim, num_layers: int = 2, expansion: int = 4, layer_scale: float = 1.0, kernel_size: int = 7, **kwargs, ): super().__init__() self.convs = nn.ModuleList([]) for _ in range(num_layers): self.convs.append( CvnxtBlock( hidden_dim, kernel_size=kernel_size, expansion=expansion, layer_scale=layer_scale, ) ) self.up = nn.Sequential( nn.PixelShuffle(2), nn.Conv2d(hidden_dim // 4, hidden_dim // 2, kernel_size=3, padding=1), ) def forward(self, x: torch.Tensor): for conv in self.convs: x = conv(x) x = self.up(x) x = rearrange(x, "b c h w -> b (h w) c") return x class ConvUpsampleShuffleResidual(nn.Module): def __init__( self, hidden_dim, num_layers: int = 2, expansion: int = 4, layer_scale: float = 1.0, kernel_size: int = 7, padding_mode: str = "zeros", **kwargs, ): super().__init__() self.convs = nn.ModuleList([]) for _ in range(num_layers): self.convs.append( CvnxtBlock( hidden_dim, kernel_size=kernel_size, expansion=expansion, layer_scale=layer_scale, padding_mode=padding_mode, ) ) self.up = nn.Sequential( nn.PixelShuffle(2), nn.Conv2d( hidden_dim // 4, hidden_dim // 4, kernel_size=7, padding=3, padding_mode=padding_mode, groups=hidden_dim // 4, ), nn.ReLU(), nn.Conv2d( hidden_dim // 4, hidden_dim // 2, kernel_size=3, padding=1, padding_mode=padding_mode, ), ) self.residual = nn.Sequential( nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=1, padding=0), nn.UpsamplingBilinear2d(scale_factor=2), ) def forward(self, x: torch.Tensor): for conv in self.convs: x = conv(x) x = self.up(x) + self.residual(x) x = rearrange(x, "b c h w -> b (h w) c") return x class ResidualConvUnit(nn.Module): def __init__( self, dim, kernel_size: int = 3, padding_mode: str = "zeros", dilation: int = 1, layer_scale: float = 1.0, use_norm: bool = False, ): super().__init__() self.conv1 = nn.Conv2d( dim, dim, kernel_size=kernel_size, padding=dilation * (kernel_size - 1) // 2, dilation=dilation, padding_mode=padding_mode, ) self.conv2 = nn.Conv2d( dim, dim, kernel_size=kernel_size, padding=dilation * (kernel_size - 1) // 2, dilation=dilation, padding_mode=padding_mode, ) self.activation = nn.LeakyReLU() self.gamma = ( nn.Parameter(layer_scale * torch.ones(1, dim, 1, 1)) if layer_scale > 0.0 else 1.0 ) self.norm1 = nn.GroupNorm(dim // 16, dim) if use_norm else nn.Identity() self.norm2 = nn.GroupNorm(dim // 16, dim) if use_norm else nn.Identity() def forward(self, x): out = self.activation(x) out = self.conv1(out) out = self.norm1(out) out = self.activation(out) out = self.conv2(out) out = self.norm2(out) return self.gamma * out + x class ResUpsampleBil(nn.Module): def __init__( self, hidden_dim, output_dim: int = None, num_layers: int = 2, kernel_size: int = 3, layer_scale: float = 1.0, padding_mode: str = "zeros", use_norm: bool = False, **kwargs, ): super().__init__() output_dim = output_dim if output_dim is not None else hidden_dim // 2 self.convs = nn.ModuleList([]) for _ in range(num_layers): self.convs.append( ResidualConvUnit( hidden_dim, kernel_size=kernel_size, layer_scale=layer_scale, padding_mode=padding_mode, use_norm=use_norm, ) ) self.up = nn.Sequential( nn.Conv2d( hidden_dim, output_dim, kernel_size=1, padding=0, padding_mode=padding_mode, ), nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), ) def forward(self, x: torch.Tensor): for conv in self.convs: x = conv(x) x = self.up(x) return x ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/__init__.py ================================================ from .unidepthv1 import UniDepthV1 from .unidepthv2 import UniDepthV2, UniDepthV2old __all__ = [ "UniDepthV1", "UniDepthV2old", "UniDepthV2", ] ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/__init__.py ================================================ from .convnext import ConvNeXt from .convnext2 import ConvNeXtV2 from .dinov2 import _make_dinov2_model __all__ = [ "ConvNeXt", "ConvNeXtV2", "_make_dinov2_model", ] ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/convnext.py ================================================ from collections import OrderedDict from functools import partial from typing import Callable, Optional, Sequence, Tuple, Union import torch import torch.nn as nn from timm.layers import (AvgPool2dSame, DropPath, GlobalResponseNormMlp, LayerNorm, LayerNorm2d, Mlp, create_conv2d, get_act_layer, make_divisible, to_ntuple, trunc_normal_) from torch.utils.checkpoint import checkpoint def get_num_layer_for_convnext(var_name): """ Divide [3, 3, 27, 3] layers into 12 groups; each group is three consecutive blocks, including possible neighboring downsample layers; adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py """ if var_name.startswith("downsample_layers"): stage_id = int(var_name.split(".")[1]) if stage_id == 0: layer_id = 0 elif stage_id == 1 or stage_id == 2: layer_id = stage_id + 1 elif stage_id == 3: layer_id = 12 elif var_name.startswith("stages"): stage_id = int(var_name.split(".")[1]) block_id = int(var_name.split(".")[3]) if stage_id == 0 or stage_id == 1: layer_id = stage_id + 1 elif stage_id == 2: layer_id = 3 + block_id // 3 elif stage_id == 3: layer_id = 12 elif var_name.startswith("stem"): return 0 else: layer_id = 12 return layer_id + 1 def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=None): parameter_group_names = {} parameter_group_vars = {} skip = set() if skip_list is not None: skip = skip_list if hasattr(model, "no_weight_decay"): skip.update(model.no_weight_decay()) num_layers = 12 layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2)) for name, param in model.named_parameters(): if not param.requires_grad: continue # frozen weights if len(param.shape) == 1 or name.endswith(".bias") or name in skip: group_name = "no_decay" this_wd = 0.0 else: group_name = "decay" this_wd = wd layer_id = get_num_layer_for_convnext(name) group_name = "layer_%d_%s" % (layer_id, group_name) if group_name not in parameter_group_names: scale = layer_scale[layer_id] cur_lr = lr * scale parameter_group_names[group_name] = { "weight_decay": this_wd, "weight_decay_init": this_wd, "weight_decay_base": this_wd, "params": [], "lr_init": cur_lr, "lr_base": lr, "lr": cur_lr, } parameter_group_vars[group_name] = { "weight_decay": this_wd, "weight_decay_init": this_wd, "weight_decay_base": this_wd, "params": [], "lr_init": cur_lr, "lr_base": lr, "lr": cur_lr, } if this_wd == 0.0: parameter_group_names[group_name]["weight_decay_final"] = 0.0 parameter_group_vars[group_name]["weight_decay_final"] = 0.0 parameter_group_vars[group_name]["params"].append(param) parameter_group_names[group_name]["params"].append(name) # from unidepth.utils import is_main_process # import json # if is_main_process(): # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) return list(parameter_group_vars.values()), [ v["lr"] for k, v in parameter_group_vars.items() ] class Downsample(nn.Module): def __init__(self, in_chs, out_chs, stride=1, dilation=1): super().__init__() avg_stride = stride if dilation == 1 else 1 if stride > 1 or dilation > 1: avg_pool_fn = ( AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d ) self.pool = avg_pool_fn( 2, avg_stride, ceil_mode=True, count_include_pad=False ) else: self.pool = nn.Identity() if in_chs != out_chs: self.conv = create_conv2d(in_chs, out_chs, 1, stride=1) else: self.conv = nn.Identity() def forward(self, x): x = self.pool(x) x = self.conv(x) return x class ConvNeXtBlock(nn.Module): """ConvNeXt Block There are two equivalent implementations: (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW. """ def __init__( self, in_chs: int, out_chs: Optional[int] = None, kernel_size: int = 7, stride: int = 1, dilation: Union[int, Tuple[int, int]] = (1, 1), mlp_ratio: float = 4, conv_mlp: bool = False, conv_bias: bool = True, use_grn: bool = False, ls_init_value: Optional[float] = 1e-6, act_layer: Union[str, Callable] = "gelu", norm_layer: Optional[Callable] = None, drop_path: float = 0.0, ): """ Args: in_chs: Block input channels. out_chs: Block output channels (same as in_chs if None). kernel_size: Depthwise convolution kernel size. stride: Stride of depthwise convolution. dilation: Tuple specifying input and output dilation of block. mlp_ratio: MLP expansion ratio. conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True. conv_bias: Apply bias for all convolution (linear) layers. use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2) ls_init_value: Layer-scale init values, layer-scale applied if not None. act_layer: Activation layer. norm_layer: Normalization layer (defaults to LN if not specified). drop_path: Stochastic depth probability. """ super().__init__() out_chs = out_chs or in_chs dilation = to_ntuple(2)(dilation) act_layer = get_act_layer(act_layer) if not norm_layer: norm_layer = LayerNorm2d if conv_mlp else LayerNorm mlp_layer = partial( GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp ) self.use_conv_mlp = conv_mlp self.conv_dw = create_conv2d( in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation[0], depthwise=True, bias=conv_bias, ) self.norm = norm_layer(out_chs) self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer) self.gamma = ( nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None ) if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: self.shortcut = Downsample( in_chs, out_chs, stride=stride, dilation=dilation[0] ) else: self.shortcut = nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, x): shortcut = x x = self.conv_dw(x.contiguous()) if self.use_conv_mlp: x = self.norm(x) x = self.mlp(x) else: x = x.permute(0, 2, 3, 1).contiguous() x = self.norm(x) x = self.mlp(x) x = x.permute(0, 3, 1, 2).contiguous() if self.gamma is not None: x = x.mul(self.gamma.reshape(1, -1, 1, 1)) x = self.drop_path(x) + self.shortcut(shortcut) return x.contiguous() class ConvNeXtStage(nn.Module): def __init__( self, in_chs, out_chs, kernel_size=7, stride=2, depth=2, dilation=(1, 1), drop_path_rates=None, ls_init_value=1.0, conv_mlp=False, conv_bias=True, use_grn=False, act_layer="gelu", norm_layer=None, norm_layer_cl=None, ): super().__init__() self.grad_checkpointing = False if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]: ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1 pad = ( "same" if dilation[1] > 1 else 0 ) # same padding needed if dilation used self.downsample = nn.Sequential( norm_layer(in_chs), create_conv2d( in_chs, out_chs, kernel_size=ds_ks, stride=stride, dilation=dilation[0], padding=pad, bias=conv_bias, ), ) in_chs = out_chs else: self.downsample = nn.Identity() drop_path_rates = drop_path_rates or [0.0] * depth stage_blocks = [] for i in range(depth): stage_blocks.append( ConvNeXtBlock( in_chs=in_chs, out_chs=out_chs, kernel_size=kernel_size, dilation=dilation[1], drop_path=drop_path_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp, conv_bias=conv_bias, use_grn=use_grn, act_layer=act_layer, norm_layer=norm_layer if conv_mlp else norm_layer_cl, ) ) in_chs = out_chs self.blocks = nn.ModuleList(stage_blocks) def forward(self, x): xs = [] x = self.downsample(x) for block in self.blocks: if self.grad_checkpointing: x = checkpoint(block, x) else: x = block(x) xs.append(x) return xs class ConvNeXt(nn.Module): def __init__( self, in_chans: int = 3, output_stride: int = 32, depths: Tuple[int, ...] = (3, 3, 9, 3), dims: Tuple[int, ...] = (96, 192, 384, 768), kernel_sizes: Union[int, Tuple[int, ...]] = 7, ls_init_value: Optional[float] = 1e-6, stem_type: str = "patch", patch_size: int = 4, conv_mlp: bool = False, conv_bias: bool = True, use_grn: bool = False, act_layer: Union[str, Callable] = "gelu", norm_layer: Optional[Union[str, Callable]] = None, norm_eps: Optional[float] = None, drop_path_rate: float = 0.0, output_idx=[], use_checkpoint=False, ): """ Args: in_chans: Number of input image channels. num_classes: Number of classes for classification head. global_pool: Global pooling type. output_stride: Output stride of network, one of (8, 16, 32). depths: Number of blocks at each stage. dims: Feature dimension at each stage. kernel_sizes: Depthwise convolution kernel-sizes for each stage. ls_init_value: Init value for Layer Scale, disabled if None. stem_type: Type of stem. patch_size: Stem patch size for patch stem. head_init_scale: Init scaling value for classifier weights and biases. head_norm_first: Apply normalization before global pool + head. head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False. conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last. conv_bias: Use bias layers w/ all convolutions. use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP. act_layer: Activation layer type. norm_layer: Normalization layer type. drop_rate: Head pre-classifier dropout rate. drop_path_rate: Stochastic depth drop rate. """ super().__init__() self.num_layers = len(depths) self.depths = output_idx self.embed_dims = [ int(dim) for i, dim in enumerate(dims) for _ in range(depths[i]) ] self.embed_dim = dims[0] assert output_stride in (8, 16, 32) kernel_sizes = to_ntuple(4)(kernel_sizes) if norm_layer is None: norm_layer = LayerNorm2d norm_layer_cl = norm_layer if conv_mlp else LayerNorm if norm_eps is not None: norm_layer = partial(norm_layer, eps=norm_eps) norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) else: assert ( conv_mlp ), "If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input" norm_layer_cl = norm_layer if norm_eps is not None: norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) self.feature_info = [] assert stem_type in ("patch", "overlap", "overlap_tiered") if stem_type == "patch": # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 self.stem = nn.Sequential( nn.Conv2d( in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias, ), norm_layer(dims[0]), ) stem_stride = patch_size else: mid_chs = make_divisible(dims[0] // 2) if "tiered" in stem_type else dims[0] self.stem = nn.Sequential( nn.Conv2d( in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias, ), nn.Conv2d( mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias ), norm_layer(dims[0]), ) stem_stride = 4 self.stages = nn.Sequential() dp_rates = [ x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths) ] stages = [] prev_chs = dims[0] curr_stride = stem_stride dilation = 1 # 4 feature resolution stages, each consisting of multiple residual blocks for i in range(4): stride = 2 if curr_stride == 2 or i > 0 else 1 if curr_stride >= output_stride and stride > 1: dilation *= stride stride = 1 curr_stride *= stride first_dilation = 1 if dilation in (1, 2) else 2 out_chs = dims[i] stages.append( ConvNeXtStage( prev_chs, out_chs, kernel_size=kernel_sizes[i], stride=stride, dilation=(first_dilation, dilation), depth=depths[i], drop_path_rates=dp_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp, conv_bias=conv_bias, use_grn=use_grn, act_layer=act_layer, norm_layer=norm_layer, norm_layer_cl=norm_layer_cl, ) ) prev_chs = out_chs # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 self.feature_info += [ dict(num_chs=prev_chs, reduction=curr_stride, module=f"stages.{i}") ] self.stages = nn.ModuleList(stages) self.mask_token = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1)) self.num_features = prev_chs self.apply(self._init_weights) self.set_grad_checkpointing(use_checkpoint) def _init_weights(self, module): if isinstance(module, nn.Conv2d): trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Linear): trunc_normal_(module.weight, std=0.02) nn.init.zeros_(module.bias) def forward(self, x, masks=None): outs = [] x = self.stem(x) if masks is not None: masks = torch.nn.functional.interpolate( masks.float(), size=x.shape[-2:], mode="nearest" ) x = torch.where(masks.bool(), self.mask_token.to(x.dtype), x).contiguous() for stage in self.stages: xs = stage(x) outs.extend([x.permute(0, 2, 3, 1).contiguous() for x in xs]) x = xs[-1] return outs, [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs] @torch.jit.ignore def group_matcher(self, coarse=False): return dict( stem=r"^stem", blocks=( r"^stages\.(\d+)" if coarse else [ (r"^stages\.(\d+)\.downsample", (0,)), # blocks (r"^stages\.(\d+)\.blocks\.(\d+)", None), (r"^norm_pre", (99999,)), ] ), ) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): for s in self.stages: s.grad_checkpointing = enable def freeze(self) -> None: for module in self.modules(): module.eval() for parameters in self.parameters(): parameters.requires_grad = False def get_params(self, lr, wd, ld, *args, **kwargs): encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld) return encoder_p, encoder_lr def no_weight_decay(self): return {"mask_token"} @classmethod def build(cls, config): obj = globals()[config["model"]["encoder"]["name"]](config) return obj def checkpoint_filter_fn(state_dict, model): """Remap FB checkpoints -> timm""" if "head.norm.weight" in state_dict or "norm_pre.weight" in state_dict: return state_dict # non-FB checkpoint if "model" in state_dict: state_dict = state_dict["model"] out_dict = {} if "visual.trunk.stem.0.weight" in state_dict: out_dict = { k.replace("visual.trunk.", ""): v for k, v in state_dict.items() if k.startswith("visual.trunk.") } if "visual.head.proj.weight" in state_dict: out_dict["head.fc.weight"] = state_dict["visual.head.proj.weight"] out_dict["head.fc.bias"] = torch.zeros( state_dict["visual.head.proj.weight"].shape[0] ) elif "visual.head.mlp.fc1.weight" in state_dict: out_dict["head.pre_logits.fc.weight"] = state_dict[ "visual.head.mlp.fc1.weight" ] out_dict["head.pre_logits.fc.bias"] = state_dict["visual.head.mlp.fc1.bias"] out_dict["head.fc.weight"] = state_dict["visual.head.mlp.fc2.weight"] out_dict["head.fc.bias"] = torch.zeros( state_dict["visual.head.mlp.fc2.weight"].shape[0] ) return out_dict import re for k, v in state_dict.items(): k = k.replace("downsample_layers.0.", "stem.") k = re.sub(r"stages.([0-9]+).([0-9]+)", r"stages.\1.blocks.\2", k) k = re.sub( r"downsample_layers.([0-9]+).([0-9]+)", r"stages.\1.downsample.\2", k ) k = k.replace("dwconv", "conv_dw") k = k.replace("pwconv", "mlp.fc") if "grn" in k: k = k.replace("grn.beta", "mlp.grn.bias") k = k.replace("grn.gamma", "mlp.grn.weight") v = v.reshape(v.shape[-1]) k = k.replace("head.", "head.fc.") if k.startswith("norm."): k = k.replace("norm", "head.norm") if v.ndim == 2 and "head" not in k: model_shape = model.state_dict()[k].shape v = v.reshape(model_shape) out_dict[k] = v return out_dict HF_URL = { "convnext_xxlarge_pt": ( "laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup", "open_clip_pytorch_model.bin", ), "convnext_large_pt": ( "laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup", "open_clip_pytorch_model.bin", ), "convnext_large": ( "timm/convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384", "pytorch_model.bin", ), } ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/convnext2.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from timm.models.layers import DropPath, trunc_normal_ def get_num_layer_for_convnext_single(var_name, depths): """ Each layer is assigned distinctive layer ids """ if var_name.startswith("downsample_layers"): stage_id = int(var_name.split(".")[1]) layer_id = sum(depths[:stage_id]) + 1 return layer_id elif var_name.startswith("stages"): stage_id = int(var_name.split(".")[1]) block_id = int(var_name.split(".")[2]) layer_id = sum(depths[:stage_id]) + block_id + 1 return layer_id else: return sum(depths) + 1 def get_num_layer_for_convnext(var_name): """ Divide [3, 3, 27, 3] layers into 12 groups; each group is three consecutive blocks, including possible neighboring downsample layers; adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py """ num_max_layer = 12 if var_name.startswith("downsample_layers"): stage_id = int(var_name.split(".")[1]) if stage_id == 0: layer_id = 0 elif stage_id == 1 or stage_id == 2: layer_id = stage_id + 1 elif stage_id == 3: layer_id = 12 return layer_id elif var_name.startswith("stages"): stage_id = int(var_name.split(".")[1]) block_id = int(var_name.split(".")[2]) if stage_id == 0 or stage_id == 1: layer_id = stage_id + 1 elif stage_id == 2: layer_id = 3 + block_id // 3 elif stage_id == 3: layer_id = 12 return layer_id else: return num_max_layer + 1 def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()): parameter_group_names = {} parameter_group_vars = {} skip = {} if skip_list is not None: skip = skip_list elif hasattr(model, "no_weight_decay"): skip = model.no_weight_decay() num_layers = 12 # sum(model.depths) layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2)) for name, param in model.named_parameters(): if not param.requires_grad: continue # frozen weights if ( len(param.shape) == 1 or name.endswith(".bias") or name in skip or name.endswith(".gamma") or name.endswith(".beta") ): group_name = "no_decay" this_weight_decay = 0.0 else: group_name = "decay" this_weight_decay = wd # layer_id = get_num_layer_for_convnext_single(name, model.depths) layer_id = get_num_layer_for_convnext(name) group_name = "layer_%d_%s" % (layer_id, group_name) if group_name not in parameter_group_names: scale = layer_scale[layer_id] cur_lr = lr * scale parameter_group_names[group_name] = { "weight_decay": this_weight_decay, "params": [], "lr_scale": scale, "lr": cur_lr, } parameter_group_vars[group_name] = { "weight_decay": this_weight_decay, "params": [], "lr_scale": scale, "lr": cur_lr, } parameter_group_vars[group_name]["params"].append(param) parameter_group_names[group_name]["params"].append(name) # if is_main_process(): # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) return list(parameter_group_vars.values()), [ v["lr"] for k, v in parameter_group_vars.items() ] class LayerNorm(nn.Module): """LayerNorm that supports two data formats: channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). """ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): super().__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.eps = eps self.data_format = data_format if self.data_format not in ["channels_last", "channels_first"]: raise NotImplementedError self.normalized_shape = (normalized_shape,) def forward(self, x): if self.data_format == "channels_last": return F.layer_norm( x, self.normalized_shape, self.weight, self.bias, self.eps ) elif self.data_format == "channels_first": u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x class GRN(nn.Module): """GRN (Global Response Normalization) layer""" def __init__(self, dim): super().__init__() self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) def forward(self, x): Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) return self.gamma * (x * Nx) + self.beta + x class Block(nn.Module): """ConvNeXtV2 Block. Args: dim (int): Number of input channels. drop_path (float): Stochastic depth rate. Default: 0.0 """ def __init__(self, dim, drop_path=0.0, mult=4, use_checkpoint=False): super().__init__() self.dwconv = nn.Conv2d( dim, dim, kernel_size=7, padding=3, groups=dim ) # depthwise conv self.norm = LayerNorm(dim, eps=1e-6) self.pwconv1 = nn.Linear( dim, mult * dim ) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.grn = GRN(mult * dim) self.pwconv2 = nn.Linear(mult * dim, dim) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.use_checkpoint = use_checkpoint def forward(self, x): input = x x = self.dwconv(x) x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.grn(x) x = self.pwconv2(x) x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) x = input + self.drop_path(x) return x class ConvNeXtV2(nn.Module): """ConvNeXt V2 Args: in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] drop_path_rate (float): Stochastic depth rate. Default: 0. head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. """ def __init__( self, in_chans=3, depths=[3, 3, 9, 3], dims=96, drop_path_rate=0.0, output_idx=[], use_checkpoint=False, ): super().__init__() self.num_layers = len(depths) self.depths = output_idx self.embed_dims = [ int(dim) for i, dim in enumerate(dims) for _ in range(depths[i]) ] self.embed_dim = dims[0] self.downsample_layers = ( nn.ModuleList() ) # stem and 3 intermediate downsampling conv layers stem = nn.Sequential( nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), ) self.downsample_layers.append(stem) for i in range(3): downsample_layer = nn.Sequential( LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), ) self.downsample_layers.append(downsample_layer) self.stages = ( nn.ModuleList() ) # 4 feature resolution stages, each consisting of multiple residual blocks self.out_norms = nn.ModuleList() dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] cur = 0 for i in range(4): stage = nn.ModuleList( [ Block( dim=dims[i], drop_path=dp_rates[cur + j], use_checkpoint=use_checkpoint, ) for j in range(depths[i]) ] ) self.stages.append(stage) cur += depths[i] self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) def forward(self, x): outs = [] for i in range(4): x = self.downsample_layers[i](x) for stage in self.stages[i]: x = stage(x) outs.append(x.permute(0, 2, 3, 1)) cls_tokens = [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs] return outs, cls_tokens def get_params(self, lr, wd, ld, *args, **kwargs): encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld) return encoder_p, encoder_lr def freeze(self) -> None: for module in self.modules(): module.eval() for parameters in self.parameters(): parameters.requires_grad = False @classmethod def build(cls, config): obj = globals()[config["model"]["encoder"]["name"]](config) return obj ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/dinov2.py ================================================ import contextlib import logging import math from functools import partial from typing import Callable, Sequence import torch import torch.nn as nn from torch.nn.init import trunc_normal_ from .metadinov2 import Attention, MemEffAttention, Mlp from .metadinov2 import NestedTensorBlock as Block from .metadinov2 import PatchEmbed, SwiGLUFFNFused _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" logger = logging.getLogger("dinov2") def named_apply( fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False ) -> nn.Module: if not depth_first and include_root: fn(module=module, name=name) for child_name, child_module in module.named_children(): child_name = ".".join((name, child_name)) if name else child_name named_apply( fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True, ) if depth_first and include_root: fn(module=module, name=name) return module def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()): parameter_group_names = {} parameter_group_vars = {} skip = {} if skip_list is not None: skip = skip_list elif hasattr(model, "no_weight_decay"): skip = model.no_weight_decay() num_layers = model.n_blocks layer_scale = list(ld ** (num_layers - i) for i in range(num_layers)) for name, param in model.named_parameters(): if not param.requires_grad: continue if len(param.shape) == 1: # norm group_name = "no_decay" this_wd = 0.0 # layer scale, bias beta? elif ( name in skip or name.endswith(".gamma") or name.endswith(".beta") or name.endswith(".bias") ): group_name = "no_decay" this_wd = 0.0 elif "cls_token" in name or "pos_embed" in name or "mask_token" in name: group_name = "no_decay" this_wd = 0.0 else: group_name = "decay" this_wd = wd if name.startswith("blocks"): layer_id = int(name.split(".")[1]) elif name.startswith("patch_embed"): layer_id = 0 else: layer_id = 0 group_name = f"layer_{layer_id}_{group_name}" if group_name not in parameter_group_names: scale = layer_scale[layer_id] cur_lr = lr * scale parameter_group_names[group_name] = { "weight_decay": this_wd, "params": [], "lr_init": cur_lr, "lr_base": lr, "lr": cur_lr, } parameter_group_vars[group_name] = { "weight_decay": this_wd, "params": [], "lr_init": cur_lr, "lr_base": lr, "lr": cur_lr, } parameter_group_vars[group_name]["params"].append(param) parameter_group_names[group_name]["params"].append(name) return list(parameter_group_vars.values()), [ v["lr"] for k, v in parameter_group_vars.items() ] class BlockChunk(nn.ModuleList): def forward(self, x): for b in self: x = b(x) return x class DinoVisionTransformer(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, ffn_bias=True, proj_bias=True, drop_path_rate=0.0, drop_path_uniform=False, init_values=None, # for layerscale: None or 0 => no layerscale embed_layer=PatchEmbed, act_layer=nn.GELU, block_fn=Block, ffn_layer="mlp", block_chunks=1, output_idx=[5, 12, 18, 24], checkpoint: bool = False, num_register_tokens=0, interpolate_antialias=False, interpolate_offset=0.0, use_norm=False, frozen_stages=0, ): """ Args: img_size (int, tuple): input image size patch_size (int, tuple): patch size in_chans (int): number of input channels embed_dim (int): embedding dimension depth (int): depth of transformer num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True proj_bias (bool): enable bias for proj in attn if True ffn_bias (bool): enable bias for ffn if True drop_path_rate (float): stochastic depth rate drop_path_uniform (bool): apply uniform drop rate across blocks weight_init (str): weight init scheme init_values (float): layer-scale init values embed_layer (nn.Module): patch embedding layer act_layer (nn.Module): MLP activation layer block_fn (nn.Module): transformer block class ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" block_chunks: (int) split block sequence into block_chunks units for FSDP wrap """ super().__init__() norm_layer = partial(nn.LayerNorm, eps=1e-6) self.num_features = self.embed_dim = ( embed_dim # num_features for consistency with other models ) self.frozen_stages = frozen_stages self.embed_dims = [embed_dim] * output_idx[-1] self.num_tokens = 1 self.n_blocks = depth self.num_heads = num_heads self.patch_size = patch_size self.depths = output_idx self.checkpoint = checkpoint self.num_register_tokens = num_register_tokens self.interpolate_antialias = interpolate_antialias self.interpolate_offset = interpolate_offset self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter( torch.zeros(1, num_patches + self.num_tokens, embed_dim) ) assert num_register_tokens >= 0 self.register_tokens = nn.Parameter( torch.zeros(1, max(1, num_register_tokens), embed_dim) ) if drop_path_uniform is True: dpr = [drop_path_rate] * depth else: dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule if ffn_layer == "mlp": logger.info("using MLP layer as FFN") ffn_layer = Mlp elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": logger.info("using SwiGLU layer as FFN") ffn_layer = SwiGLUFFNFused elif ffn_layer == "identity": logger.info("using Identity layer as FFN") def f(*args, **kwargs): return nn.Identity() ffn_layer = f else: raise NotImplementedError blocks_list = [ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, ffn_layer=ffn_layer, init_values=init_values, ) for i in range(depth) ] if block_chunks > 0: self.chunked_blocks = True chunked_blocks = [] chunksize = depth // block_chunks for i in range(0, depth, chunksize): # this is to keep the block index consistent if we chunk the block list chunked_blocks.append( [nn.Identity()] * i + blocks_list[i : i + chunksize] ) self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) else: self.chunked_blocks = False self.blocks = nn.ModuleList(blocks_list) self.norm = nn.LayerNorm(embed_dim) self.use_norm = use_norm self.head = nn.Identity() self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) self.init_weights() def init_weights(self): trunc_normal_(self.pos_embed, std=0.02) nn.init.normal_(self.cls_token, std=1e-6) if self.num_register_tokens: nn.init.normal_(self.register_tokens, std=1e-6) named_apply(init_weights_vit_timm, self) def interpolate_pos_encoding(self, x, w, h): previous_dtype = x.dtype npatch = x.shape[1] - 1 N = self.pos_embed.shape[1] - 1 if npatch == N and w == h: return self.pos_embed pos_embed = self.pos_embed.float() class_pos_embed = pos_embed[:, 0] patch_pos_embed = pos_embed[:, 1:] dim = x.shape[-1] w0 = w // self.patch_size h0 = h // self.patch_size M = int(math.sqrt(N)) # Recover the number of patches in each dimension assert N == M * M kwargs = {} if self.interpolate_offset: # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors sx = float(w0 + self.interpolate_offset) / M sy = float(h0 + self.interpolate_offset) / M kwargs["scale_factor"] = (sx, sy) else: # Simply specify an output size instead of a scale factor kwargs["size"] = (w0, h0) patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), mode="bicubic", antialias=self.interpolate_antialias, **kwargs, ) assert (w0, h0) == patch_pos_embed.shape[-2:] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( previous_dtype ) def prepare_tokens_with_masks(self, x, masks=None): B, nc, w, h = x.shape with torch.no_grad() if self.frozen_stages > -1 else contextlib.nullcontext(): x = self.patch_embed(x) if masks is not None: masks = masks.bool().view(B, -1, 1) x = torch.where(masks, self.mask_token.to(x.dtype).unsqueeze(0), x) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = x + self.interpolate_pos_encoding(x, w, h) if self.num_register_tokens: x = torch.cat( (x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), dim=1, ) return x def forward(self, x, masks=None): shapes = [val // self.patch_size for val in x.shape[-2:]] batch_size = x.shape[0] x = self.prepare_tokens_with_masks(x, masks) outputs = [] for i, blk in enumerate(self.blocks): with ( torch.no_grad() if i < self.frozen_stages else contextlib.nullcontext() ): x = blk(x) outputs.append(x) if self.use_norm: with ( torch.no_grad() if self.frozen_stages >= len(self.blocks) else contextlib.nullcontext() ): outputs = [self.norm(out) for out in outputs] class_tokens = [out[:, :1] for out in outputs] outputs = [out[:, self.num_register_tokens + 1 :] for out in outputs] outputs = [out.reshape(batch_size, *shapes, -1) for out in outputs] return (outputs, class_tokens) def get_params(self, lr, wd, ld, *args, **kwargs): encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld) return encoder_p, encoder_lr def freeze(self) -> None: for module in self.modules(): module.eval() for parameters in self.parameters(): parameters.requires_grad = False def train(self, mode=True): super().train(mode) if self.frozen_stages > -1: for p in self.patch_embed.parameters(): p.requires_grad = False for i, blk in enumerate(self.blocks): if i < self.frozen_stages: blk.eval() for p in blk.parameters(): p.requires_grad = False for p in self.norm.parameters(): p.requires_grad = self.frozen_stages <= len(self.blocks) and self.use_norm self.cls_token.requires_grad = self.frozen_stages < 1 self.pos_embed.requires_grad = self.frozen_stages < 1 self.mask_token.requires_grad = False self.register_tokens.requires_grad = False def init_weights_vit_timm(module: nn.Module, name: str = ""): """ViT weight initialization, original timm impl (for reproducibility)""" if isinstance(module, nn.Linear): trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) def vit_small(patch_size=16, num_register_tokens=0, export=False, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, num_register_tokens=num_register_tokens, block_fn=partial(Block, attn_class=Attention if export else MemEffAttention), **kwargs, ) return model def vit_base(patch_size=16, num_register_tokens=0, export=False, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, num_register_tokens=num_register_tokens, block_fn=partial(Block, attn_class=Attention if export else MemEffAttention), **kwargs, ) return model def vit_large(patch_size=16, num_register_tokens=0, export=False, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, num_register_tokens=num_register_tokens, block_fn=partial(Block, attn_class=Attention if export else MemEffAttention), **kwargs, ) return model def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str: compact_arch_name = arch_name.replace("_", "")[:4] return f"dinov2_{compact_arch_name}{patch_size}" def _make_dinov2_model( *, arch_name: str = "vit_large", img_size: int = 518, patch_size: int = 14, init_values: float = 1.0, ffn_layer: str = "mlp", block_chunks: int = 0, pretrained: str = "", output_idx: Sequence[int] = [], num_register_tokens: int = 0, drop_path_rate: float = 0.0, use_norm: bool = False, export: bool = False, interpolate_offset: float = 0.0, frozen_stages: int = 0, **kwargs, ): model_name = _make_dinov2_model_name(arch_name, patch_size) vit_kwargs = dict( img_size=img_size, patch_size=patch_size, init_values=init_values, ffn_layer=ffn_layer, block_chunks=block_chunks, output_idx=output_idx, drop_path_rate=drop_path_rate, num_register_tokens=num_register_tokens, use_norm=use_norm, export=export, interpolate_offset=interpolate_offset, frozen_stages=frozen_stages, ) vit_kwargs.update(**kwargs) model = eval(arch_name)(**vit_kwargs) if pretrained == "": url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}" if num_register_tokens > 0: url += "_reg4" url += "_pretrain.pth" state_dict = torch.hub.load_state_dict_from_url( url, map_location="cpu", progress=False ) info = model.load_state_dict(state_dict, strict=False) print(info) elif pretrained is not None: state_dict = torch.load(pretrained, map_location="cpu") info = model.load_state_dict(state_dict, strict=False) print(f"loading from {pretrained} with:", info) else: print("Not loading pretrained weights for backbone") return model ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/metadinov2/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from .attention import Attention, MemEffAttention from .block import NestedTensorBlock from .dino_head import DINOHead from .mlp import Mlp from .patch_embed import PatchEmbed from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/metadinov2/attention.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py import logging import torch import torch.nn as nn import torch.nn.functional as F logger = logging.getLogger("dinov2") try: from xformers.ops import fmha, memory_efficient_attention, unbind XFORMERS_AVAILABLE = True except ImportError: logger.warning("xFormers not available") XFORMERS_AVAILABLE = False XFORMERS_AVAILABLE = XFORMERS_AVAILABLE and torch.cuda.is_available() class Attention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0, ) -> None: super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, C = x.shape qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) x = F.scaled_dot_product_attention(qkv[0], qkv[1], qkv[2]) x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class MemEffAttention(Attention): def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor: # new pytorch have good attn efficient, no need for xformers if not XFORMERS_AVAILABLE or x.device.type == "cpu": assert attn_bias is None, "xFormers is required for nested tensors usage" return super().forward(x) B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) q, k, v = unbind(qkv, 2) x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) x = x.reshape([B, N, C]) x = self.proj(x) x = self.proj_drop(x) return x ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/metadinov2/block.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py import logging from typing import Any, Callable, Dict, List, Tuple import torch import torch.nn as nn from .attention import Attention, MemEffAttention from .drop_path import DropPath from .layer_scale import LayerScale from .mlp import Mlp logger = logging.getLogger("dinov2") try: from xformers.ops import fmha, index_select_cat, scaled_index_add XFORMERS_AVAILABLE = True except ImportError: logger.warning("xFormers not available") XFORMERS_AVAILABLE = False class Block(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = False, proj_bias: bool = True, ffn_bias: bool = True, drop: float = 0.0, attn_drop: float = 0.0, init_values=None, drop_path: float = 0.0, act_layer: Callable[..., nn.Module] = nn.GELU, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, attn_class: Callable[..., nn.Module] = Attention, ffn_layer: Callable[..., nn.Module] = Mlp, ) -> None: super().__init__() # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") self.norm1 = norm_layer(dim) self.attn = attn_class( dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, ) self.ls1 = ( LayerScale(dim, init_values=init_values) if init_values else nn.Identity() ) self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = ffn_layer( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias, ) self.ls2 = ( LayerScale(dim, init_values=init_values) if init_values else nn.Identity() ) self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.sample_drop_ratio = drop_path def forward(self, x: torch.Tensor) -> torch.Tensor: def attn_residual_func(x: torch.Tensor) -> torch.Tensor: return self.ls1(self.attn(self.norm1(x))) def ffn_residual_func(x: torch.Tensor) -> torch.Tensor: return self.ls2(self.mlp(self.norm2(x))) if self.training and self.sample_drop_ratio > 0.1: # the overhead is compensated only for a drop path rate larger than 0.1 x = drop_add_residual_stochastic_depth( x, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio, ) x = drop_add_residual_stochastic_depth( x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio, ) elif self.training and self.sample_drop_ratio > 0.0: x = x + self.drop_path1(attn_residual_func(x)) x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 else: x = x + attn_residual_func(x) x = x + ffn_residual_func(x) return x def drop_add_residual_stochastic_depth( x: torch.Tensor, residual_func: Callable[[torch.Tensor], torch.Tensor], sample_drop_ratio: float = 0.0, ) -> torch.Tensor: # 1) extract subset using permutation b, n, d = x.shape sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) brange = (torch.randperm(b, device=x.device))[:sample_subset_size] x_subset = x[brange] # 2) apply residual_func to get residual residual = residual_func(x_subset) x_flat = x.flatten(1) residual = residual.flatten(1) residual_scale_factor = b / sample_subset_size # 3) add the residual x_plus_residual = torch.index_add( x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor ) return x_plus_residual.view_as(x) def get_branges_scales(x, sample_drop_ratio=0.0): b, n, d = x.shape sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) brange = (torch.randperm(b, device=x.device))[:sample_subset_size] residual_scale_factor = b / sample_subset_size return brange, residual_scale_factor def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): if scaling_vector is None: x_flat = x.flatten(1) residual = residual.flatten(1) x_plus_residual = torch.index_add( x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor ) else: x_plus_residual = scaled_index_add( x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor, ) return x_plus_residual attn_bias_cache: Dict[Tuple, Any] = {} def get_attn_bias_and_cat(x_list, branges=None): """ this will perform the index select, cat the tensors, and provide the attn_bias from cache """ batch_sizes = ( [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] ) all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) if all_shapes not in attn_bias_cache.keys(): seqlens = [] for b, x in zip(batch_sizes, x_list): for _ in range(b): seqlens.append(x.shape[1]) attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) attn_bias._batch_sizes = batch_sizes attn_bias_cache[all_shapes] = attn_bias if branges is not None: cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view( 1, -1, x_list[0].shape[-1] ) else: tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) cat_tensors = torch.cat(tensors_bs1, dim=1) return attn_bias_cache[all_shapes], cat_tensors def drop_add_residual_stochastic_depth_list( x_list: List[torch.Tensor], residual_func: Callable[[torch.Tensor, Any], torch.Tensor], sample_drop_ratio: float = 0.0, scaling_vector=None, ) -> torch.Tensor: # 1) generate random set of indices for dropping samples in the batch branges_scales = [ get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list ] branges = [s[0] for s in branges_scales] residual_scale_factors = [s[1] for s in branges_scales] # 2) get attention bias and index+concat the tensors attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) # 3) apply residual_func to get residual, and split the result residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore outputs = [] for x, brange, residual, residual_scale_factor in zip( x_list, branges, residual_list, residual_scale_factors ): outputs.append( add_residual( x, brange, residual, residual_scale_factor, scaling_vector ).view_as(x) ) return outputs class NestedTensorBlock(Block): def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]: """ x_list contains a list of tensors to nest together and run """ assert isinstance(self.attn, MemEffAttention) if self.training and self.sample_drop_ratio > 0.0: def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: return self.attn(self.norm1(x), attn_bias=attn_bias) def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: return self.mlp(self.norm2(x)) x_list = drop_add_residual_stochastic_depth_list( x_list, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio, scaling_vector=( self.ls1.gamma if isinstance(self.ls1, LayerScale) else None ), ) x_list = drop_add_residual_stochastic_depth_list( x_list, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio, scaling_vector=( self.ls2.gamma if isinstance(self.ls1, LayerScale) else None ), ) return x_list else: def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: return self.ls2(self.mlp(self.norm2(x))) attn_bias, x = get_attn_bias_and_cat(x_list) x = x + attn_residual_func(x, attn_bias=attn_bias) x = x + ffn_residual_func(x) return attn_bias.split(x) def forward(self, x_or_x_list): if isinstance(x_or_x_list, torch.Tensor): return super(NestedTensorBlock, self).forward(x_or_x_list) elif isinstance(x_or_x_list, list): assert ( XFORMERS_AVAILABLE ), "Please install xFormers for nested tensors usage" return self.forward_nested(x_or_x_list) else: raise AssertionError ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/metadinov2/dino_head.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn from torch.nn.init import trunc_normal_ from torch.nn.utils import weight_norm class DINOHead(nn.Module): def __init__( self, in_dim, out_dim, use_bn=False, nlayers=3, hidden_dim=2048, bottleneck_dim=256, mlp_bias=True, ): super().__init__() nlayers = max(nlayers, 1) self.mlp = _build_mlp( nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias, ) self.apply(self._init_weights) self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) self.last_layer.weight_g.data.fill_(1) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x = self.mlp(x) eps = 1e-6 if x.dtype == torch.float16 else 1e-12 x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) x = self.last_layer(x) return x def _build_mlp( nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True ): if nlayers == 1: return nn.Linear(in_dim, bottleneck_dim, bias=bias) else: layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] if use_bn: layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(nn.GELU()) for _ in range(nlayers - 2): layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) if use_bn: layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(nn.GELU()) layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) return nn.Sequential(*layers) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/metadinov2/drop_path.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py import torch.nn as nn def drop_path(x, drop_prob: float = 0.0, training: bool = False): if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * ( x.ndim - 1 ) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0: random_tensor.div_(keep_prob) output = x * random_tensor return output class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/metadinov2/layer_scale.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 from typing import Union import torch import torch.nn as nn from torch import Tensor class LayerScale(nn.Module): def __init__( self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False, ) -> None: super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x: Tensor) -> Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/metadinov2/mlp.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py from typing import Callable, Optional from torch import Tensor, nn class Mlp(nn.Module): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = nn.GELU, drop: float = 0.0, bias: bool = True, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) self.drop = nn.Dropout(drop) def forward(self, x: Tensor) -> Tensor: x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/metadinov2/patch_embed.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py from typing import Callable, Optional, Tuple, Union import torch.nn as nn from torch import Tensor def make_2tuple(x): if isinstance(x, tuple): assert len(x) == 2 return x assert isinstance(x, int) return (x, x) class PatchEmbed(nn.Module): """ 2D image to patch embedding: (B,C,H,W) -> (B,N,D) Args: img_size: Image size. patch_size: Patch token size. in_chans: Number of input image channels. embed_dim: Number of linear projection output channels. norm_layer: Normalization layer. """ def __init__( self, img_size: Union[int, Tuple[int, int]] = 224, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, embed_dim: int = 768, norm_layer: Optional[Callable] = None, flatten_embedding: bool = True, ) -> None: super().__init__() image_HW = make_2tuple(img_size) patch_HW = make_2tuple(patch_size) patch_grid_size = ( image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1], ) self.img_size = image_HW self.patch_size = patch_HW self.patches_resolution = patch_grid_size self.num_patches = patch_grid_size[0] * patch_grid_size[1] self.in_chans = in_chans self.embed_dim = embed_dim self.flatten_embedding = flatten_embedding self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW ) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x: Tensor) -> Tensor: _, _, H, W = x.shape patch_H, patch_W = self.patch_size assert ( H % patch_H == 0 ), f"Input image height {H} is not a multiple of patch height {patch_H}" assert ( W % patch_W == 0 ), f"Input image width {W} is not a multiple of patch width: {patch_W}" x = self.proj(x) # B C H W H, W = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) # B HW C x = self.norm(x) if not self.flatten_embedding: x = x.reshape(-1, H, W, self.embed_dim) # B H W C return x def flops(self) -> float: Ho, Wo = self.patches_resolution flops = ( Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) ) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/metadinov2/swiglu_ffn.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Callable, Optional import torch.nn.functional as F from torch import Tensor, nn class SwiGLUFFN(nn.Module): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = None, drop: float = 0.0, bias: bool = True, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) self.w3 = nn.Linear(hidden_features, out_features, bias=bias) def forward(self, x: Tensor) -> Tensor: x12 = self.w12(x) x1, x2 = x12.chunk(2, dim=-1) hidden = F.silu(x1) * x2 return self.w3(hidden) try: from xformers.ops import SwiGLU XFORMERS_AVAILABLE = True except ImportError: SwiGLU = SwiGLUFFN XFORMERS_AVAILABLE = False class SwiGLUFFNFused(SwiGLU): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = None, drop: float = 0.0, bias: bool = True, ) -> None: out_features = out_features or in_features hidden_features = hidden_features or in_features hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 super().__init__( in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias, ) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/encoder.py ================================================ import torch import torch.nn as nn from unidepth.models.backbones import ConvNeXt, ConvNeXtV2, _make_dinov2_model class ModelWrap(nn.Module): def __init__(self, model) -> None: super().__init__() self.backbone = model def forward(self, x, *args, **kwargs): features = [] for layer in self.backbone.features: x = layer(x) features.append(x) return features def convnextv2_base(config, **kwargs): model = ConvNeXtV2( depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], output_idx=config.get("output_idx", [3, 6, 33, 36]), use_checkpoint=config.get("use_checkpoint", False), **kwargs, ) url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt" state_dict = torch.hub.load_state_dict_from_url( url, map_location="cpu", progress=False )["model"] info = model.load_state_dict(state_dict, strict=False) print(info) return model def convnextv2_large(config, **kwargs): model = ConvNeXtV2( depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], output_idx=config.get("output_idx", [3, 6, 33, 36]), use_checkpoint=config.get("use_checkpoint", False), **kwargs, ) url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt" state_dict = torch.hub.load_state_dict_from_url( url, map_location="cpu", progress=False )["model"] info = model.load_state_dict(state_dict, strict=False) print(info) return model def convnextv2_large_mae(config, **kwargs): model = ConvNeXtV2( depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], output_idx=config.get("output_idx", [3, 6, 33, 36]), use_checkpoint=config.get("use_checkpoint", False), **kwargs, ) url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_large_1k_224_fcmae.pt" state_dict = torch.hub.load_state_dict_from_url( url, map_location="cpu", progress=False )["model"] info = model.load_state_dict(state_dict, strict=False) print(info) return model def convnextv2_huge(config, **kwargs): model = ConvNeXtV2( depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], output_idx=config.get("output_idx", [3, 6, 33, 36]), use_checkpoint=config.get("use_checkpoint", False), **kwargs, ) url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt" state_dict = torch.hub.load_state_dict_from_url( url, map_location="cpu", progress=False )["model"] info = model.load_state_dict(state_dict, strict=False) print(info) return model def convnextv2_huge_mae(config, **kwargs): model = ConvNeXtV2( depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], output_idx=config.get("output_idx", [3, 6, 33, 36]), use_checkpoint=config.get("use_checkpoint", False), **kwargs, ) url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_huge_1k_224_fcmae.pt" state_dict = torch.hub.load_state_dict_from_url( url, map_location="cpu", progress=False )["model"] info = model.load_state_dict(state_dict, strict=False) print(info) return model def convnext_large_pt(config, **kwargs): model = ConvNeXt( depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], output_idx=config.get("output_idx", [3, 6, 33, 36]), use_checkpoint=config.get("use_checkpoint", False), **kwargs, ) from huggingface_hub import hf_hub_download from huggingface_hub.utils import disable_progress_bars from unidepth.models.backbones.convnext import HF_URL, checkpoint_filter_fn disable_progress_bars() repo_id, filename = HF_URL["convnext_large_pt"] state_dict = torch.load(hf_hub_download(repo_id=repo_id, filename=filename)) state_dict = checkpoint_filter_fn(state_dict, model) info = model.load_state_dict(state_dict, strict=False) print(info) return model def convnext_large(config, **kwargs): model = ConvNeXt( depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], output_idx=config.get("output_idx", [3, 6, 33, 36]), use_checkpoint=config.get("use_checkpoint", False), drop_path_rate=config.get("drop_path", 0.0), **kwargs, ) return model def dinov2_vits14(config, pretrained: bool = True, **kwargs): """ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. """ vit = _make_dinov2_model( arch_name="vit_small", pretrained=config["pretrained"], output_idx=config.get("output_idx", [3, 6, 9, 12]), checkpoint=config.get("use_checkpoint", False), drop_path_rate=config.get("drop_path", 0.0), num_register_tokens=config.get("num_register_tokens", 0), use_norm=config.get("use_norm", False), export=config.get("export", False), interpolate_offset=config.get("interpolate_offset", 0.0), **kwargs, ) return vit def dinov2_vitb14(config, pretrained: bool = True, **kwargs): """ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. """ vit = _make_dinov2_model( arch_name="vit_base", pretrained=config["pretrained"], output_idx=config.get("output_idx", [3, 6, 9, 12]), checkpoint=config.get("use_checkpoint", False), drop_path_rate=config.get("drop_path", 0.0), num_register_tokens=config.get("num_register_tokens", 0), use_norm=config.get("use_norm", False), export=config.get("export", False), interpolate_offset=config.get("interpolate_offset", 0.0), **kwargs, ) return vit def dinov2_vitl14(config, pretrained: str = "", **kwargs): """ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. """ vit = _make_dinov2_model( arch_name="vit_large", pretrained=config["pretrained"], output_idx=config.get("output_idx", [5, 12, 18, 24]), checkpoint=config.get("use_checkpoint", False), drop_path_rate=config.get("drop_path", 0.0), num_register_tokens=config.get("num_register_tokens", 0), use_norm=config.get("use_norm", False), export=config.get("export", False), interpolate_offset=config.get("interpolate_offset", 0.0), **kwargs, ) return vit ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/unidepthv1/__init__.py ================================================ from .unidepthv1 import UniDepthV1 __all__ = [ "UniDepthV1", ] ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/unidepthv1/decoder.py ================================================ """ Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ from typing import List, Tuple import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from timm.models.layers import trunc_normal_ from unidepth.layers import (MLP, AttentionBlock, ConvUpsample, NystromBlock, PositionEmbeddingSine) from unidepth.utils.geometric import flat_interpolate, generate_rays from unidepth.utils.misc import max_stack from unidepth.utils.sht import rsh_cart_8 class ListAdapter(nn.Module): def __init__(self, input_dims: List[int], hidden_dim: int): super().__init__() self.input_adapters = nn.ModuleList([]) self.num_chunks = len(input_dims) for input_dim in input_dims: self.input_adapters.append( nn.Sequential( nn.LayerNorm(input_dim), nn.Linear(input_dim, hidden_dim), nn.GELU() ) ) def forward(self, x: torch.Tensor, splits: torch.Tensor) -> torch.Tensor: xs = torch.split(x, splits.int().tolist(), dim=-1) xs = [adapter(x) for x, adapter in zip(xs, self.input_adapters)] return torch.cat(xs, dim=-1) class CameraHead(nn.Module): def __init__( self, input_dim: int, hidden_dim: int, num_heads: int = 8, expansion: int = 4, depth: int = 4, dropout: float = 0.0, layer_scale: float = 1.0, **kwargs, ): super().__init__() self.aggregate = AttentionBlock( hidden_dim, num_heads=1, expansion=expansion, dropout=dropout, layer_scale=layer_scale, ) self.latents_pos = nn.Parameter( torch.randn(1, 4, hidden_dim), requires_grad=True ) self.layers = nn.ModuleList([]) self.in_features = MLP(hidden_dim, expansion=2, dropout=dropout) for _ in range(depth): blk = AttentionBlock( hidden_dim, num_heads=num_heads, expansion=expansion, dropout=dropout, layer_scale=layer_scale, ) self.layers.append(blk) self.out = MLP(hidden_dim, expansion=2, dropout=0.0, output_dim=1) self.cls_project = nn.Sequential( nn.LayerNorm(input_dim), nn.Linear(input_dim, hidden_dim // 2), nn.GELU(), nn.Linear(hidden_dim // 2, hidden_dim), ) def forward(self, features, cls_tokens, pos_embed) -> torch.Tensor: features = features.unbind(dim=-1) cls_tokens = self.cls_project(cls_tokens) features_stack = torch.cat(features, dim=1) features_stack = features_stack + pos_embed latents_pos = self.latents_pos.expand(cls_tokens.shape[0], -1, -1) features_stack = self.in_features(features_stack) features = torch.cat((features_stack, cls_tokens), dim=1) cls_tokens = self.aggregate(cls_tokens, context=features, pos_embed=latents_pos) for i, layer in enumerate(self.layers): cls_tokens = layer(cls_tokens, pos_embed=latents_pos) # project x = self.out(cls_tokens).squeeze(-1) camera_intrinsics = torch.zeros( x.shape[0], 3, 3, device=x.device, requires_grad=False ) camera_intrinsics[:, 0, 0] = x[:, 0].exp() camera_intrinsics[:, 1, 1] = x[:, 1].exp() camera_intrinsics[:, 0, 2] = x[:, 2].sigmoid() camera_intrinsics[:, 1, 2] = x[:, 3].sigmoid() camera_intrinsics[:, 2, 2] = 1.0 return camera_intrinsics def set_shapes(self, shapes: Tuple[int, int]): self.shapes = shapes class DepthHead(nn.Module): def __init__( self, hidden_dim: int, num_heads: int = 8, expansion: int = 4, depths: int | list[int] = 4, camera_dim: int = 256, num_resolutions: int = 4, dropout: float = 0.0, layer_scale: float = 1.0, **kwargs, ) -> None: super().__init__() if isinstance(depths, int): depths = [depths] * 3 assert len(depths) == 3 self.project_rays16 = MLP( camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim ) self.project_rays8 = MLP( camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 2 ) self.project_rays4 = MLP( camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 4 ) self.to_latents = MLP(hidden_dim, expansion=2, dropout=dropout) self.features_channel_cat = nn.Linear(hidden_dim * num_resolutions, hidden_dim) self.up8 = ConvUpsample( hidden_dim, expansion=expansion, layer_scale=layer_scale ) self.up4 = ConvUpsample( hidden_dim // 2, expansion=expansion, layer_scale=layer_scale ) self.up2 = ConvUpsample( hidden_dim // 4, expansion=expansion, layer_scale=layer_scale ) self.layers_16 = nn.ModuleList([]) self.layers_8 = nn.ModuleList([]) self.layers_4 = nn.ModuleList([]) self.aggregate_16 = AttentionBlock( hidden_dim, num_heads=1, expansion=expansion, dropout=dropout, layer_scale=layer_scale, context_dim=hidden_dim, ) self.prompt_camera = AttentionBlock( hidden_dim, num_heads=1, expansion=expansion, dropout=dropout, layer_scale=layer_scale, context_dim=hidden_dim, ) for i, (blk_lst, depth) in enumerate( zip([self.layers_16, self.layers_8, self.layers_4], depths) ): attn_cls = AttentionBlock if i == 0 else NystromBlock for _ in range(depth): blk_lst.append( attn_cls( hidden_dim // (2**i), num_heads=num_heads // (2**i), expansion=expansion, dropout=dropout, layer_scale=layer_scale, ) ) self.out2 = nn.Conv2d(hidden_dim // 8, 1, 3, padding=1) self.out4 = nn.Conv2d(hidden_dim // 4, 1, 3, padding=1) self.out8 = nn.Conv2d(hidden_dim // 2, 1, 3, padding=1) def set_original_shapes(self, shapes: Tuple[int, int]): self.original_shapes = shapes def set_shapes(self, shapes: Tuple[int, int]): self.shapes = shapes def forward( self, features: torch.Tensor, rays_hr: torch.Tensor, pos_embed, level_embed ) -> torch.Tensor: features = features.unbind(dim=-1) shapes = self.shapes rays_hr = rays_hr.detach() # camera_embedding rays_embedding_16 = F.normalize( flat_interpolate(rays_hr, old=self.original_shapes, new=shapes), dim=-1 ) rays_embedding_8 = F.normalize( flat_interpolate( rays_hr, old=self.original_shapes, new=[x * 2 for x in shapes] ), dim=-1, ) rays_embedding_4 = F.normalize( flat_interpolate( rays_hr, old=self.original_shapes, new=[x * 4 for x in shapes] ), dim=-1, ) rays_embedding_16 = self.project_rays16(rsh_cart_8(rays_embedding_16)) rays_embedding_8 = self.project_rays8(rsh_cart_8(rays_embedding_8)) rays_embedding_4 = self.project_rays4(rsh_cart_8(rays_embedding_4)) features_tokens = torch.cat(features, dim=1) features_tokens_pos = pos_embed + level_embed # Generate latents with init as pooled features features_channels = torch.cat(features, dim=-1) features_16 = self.features_channel_cat(features_channels) latents_16 = self.to_latents( flat_interpolate(features_16, old=self.shapes, new=shapes, antialias=False) ) # Aggregate features: F -> D latents_16 = self.aggregate_16( latents_16, context=features_tokens, pos_embed_context=features_tokens_pos ) # Aggregate camera: D- > D|E latents_16 = self.prompt_camera(latents_16, context=rays_embedding_16) # Block 16 - Out 8 for layer in self.layers_16: latents_16 = layer(latents_16, pos_embed=rays_embedding_16) latents_8 = self.up8( rearrange( latents_16 + rays_embedding_16, "b (h w) c -> b c h w", h=shapes[0], w=shapes[1], ).contiguous() ) out8 = self.out8( rearrange( latents_8, "b (h w) c -> b c h w", h=shapes[0] * 2, w=shapes[1] * 2 ) ) # Block 8 - Out 4 for layer in self.layers_8: latents_8 = layer(latents_8, pos_embed=rays_embedding_8) latents_4 = self.up4( rearrange( latents_8 + rays_embedding_8, "b (h w) c -> b c h w", h=shapes[0] * 2, w=shapes[1] * 2, ).contiguous() ) out4 = self.out4( rearrange( latents_4, "b (h w) c -> b c h w", h=shapes[0] * 4, w=shapes[1] * 4 ) ) # Block 4 - Out 2 for layer in self.layers_4: latents_4 = layer(latents_4, pos_embed=rays_embedding_4) latents_2 = self.up2( rearrange( latents_4 + rays_embedding_4, "b (h w) c -> b c h w", h=shapes[0] * 4, w=shapes[1] * 4, ).contiguous() ) out2 = self.out2( rearrange( latents_2, "b (h w) c -> b c h w", h=shapes[0] * 8, w=shapes[1] * 8 ) ) # Depth features proj_latents_16 = rearrange( latents_16, "b (h w) c -> b c h w", h=shapes[0], w=shapes[1] ).contiguous() # MS Outputs out2 = out2.clamp(-10.0, 10.0).exp() out4 = out4.clamp(-10.0, 10.0).exp() out8 = out8.clamp(-10.0, 10.0).exp() return out8, out4, out2, proj_latents_16 class Decoder(nn.Module): def __init__( self, config, *args, **kwargs, ): super().__init__() self.build(config) self.apply(self._init_weights) self.test_fixed_camera = False self.skip_camera = False def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def get_adapted_features(self, features_flat, splits): features_flat_cat = torch.cat(features_flat, dim=-1) features_projected = self.input_adapter( features_flat_cat, splits ) # list [b hw c] shapes features = torch.chunk(features_projected, len(splits), dim=-1) return features def run_camera(self, cls_tokens, features, pos_embed, original_shapes, rays): # get cls tokens projections cls_tokens_splits = torch.tensor( [x.shape[-1] for x in cls_tokens], device=features.device, requires_grad=False, dtype=features.dtype, ) cls_tokens = torch.cat(cls_tokens, dim=-1) cls_tokens = self.token_adapter(cls_tokens, cls_tokens_splits) cls_tokens = torch.cat( torch.chunk(cls_tokens, len(cls_tokens_splits), dim=-1), dim=1 ) # camera layer intrinsics = self.camera_layer( features=features, cls_tokens=cls_tokens, pos_embed=pos_embed ) intrinsics[:, 0, 0] = max(original_shapes) / 2 * intrinsics[:, 0, 0] intrinsics[:, 1, 1] = max(original_shapes) / 2 * intrinsics[:, 1, 1] intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * original_shapes[1] intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * original_shapes[0] if not self.test_fixed_camera: rays, _ = generate_rays(intrinsics, original_shapes, noisy=False) return intrinsics, rays def forward(self, inputs, image_metas) -> torch.Tensor: B, _, H, W = inputs["image"].shape device = inputs["image"].device # make stride happy? original_encoder_outputs = [x.contiguous() for x in inputs["encoder_outputs"]] cls_tokens = [x.contiguous() for x in inputs["cls_tokens"]] # collect features and tokens original_encoder_outputs = [ max_stack(original_encoder_outputs[i:j]) for i, j in self.slices_encoder_range ] # detach tokens for camera cls_tokens = [ cls_tokens[-i - 1].detach() for i in range(len(self.slices_encoder_range)) ] # get features in b n d format # level shapes, the shape per level, for swin like [[128, 128], [64, 64],...], for vit [[32,32]] -> mult times resolutions resolutions = [ tuple(sorted([x.shape[1], x.shape[2]])) for x in original_encoder_outputs ] level_shapes = sorted(list(set(resolutions)))[::-1] if len(level_shapes) == 1: level_shapes = level_shapes * self.num_resolutions input_shapes = [ level_shapes[i] for i, (start, end) in enumerate(self.slices_encoder) for _ in range(end - start) ] common_shape = level_shapes[-2] # input shapes repeat shapes for each level, times the amount of the layers: features_flat = [ flat_interpolate( rearrange(x, "b h w c -> b (h w) c"), old=input_shape, new=common_shape ) for x, input_shape in zip(original_encoder_outputs, input_shapes) ] features_splits = torch.tensor( [x.shape[-1] for x in features_flat], device=device, requires_grad=False, dtype=torch.float32, ) # input adapter, then do mean of features in same blocks features = self.get_adapted_features(features_flat, features_splits) features = torch.stack(features, dim=-1) # positional embeddings, spatial and level level_embed = torch.cat( [ self.level_embed_layer(self.level_embeds)[i : i + 1] .unsqueeze(0) .repeat(B, common_shape[0] * common_shape[1], 1) for i in range(self.num_resolutions) ], dim=1, ) pos_embed = self.pos_embed( torch.zeros( B, 1, common_shape[0], common_shape[1], device=device, requires_grad=False, ) ) pos_embed = rearrange(pos_embed, "b c h w -> b (h w) c").repeat( 1, self.num_resolutions, 1 ) self.camera_layer.set_shapes(common_shape) intrinsics, rays = ( self.run_camera( cls_tokens, features=features, pos_embed=pos_embed + level_embed, original_shapes=(H, W), rays=inputs.get("rays", None), ) if not self.skip_camera else (inputs["K"], inputs["rays"]) ) # run bulk of the model self.depth_layer.set_shapes(common_shape) self.depth_layer.set_original_shapes((H, W)) out8, out4, out2, depth_features = self.depth_layer( features=features, rays_hr=rays, pos_embed=pos_embed, level_embed=level_embed, ) return intrinsics, [out8, out4, out2], depth_features @torch.jit.ignore def no_weight_decay_keywords(self): return {"latents_pos", "level_embeds"} def build(self, config): depth = config["model"]["pixel_decoder"]["depths"] input_dims = config["model"]["pixel_encoder"]["embed_dims"] hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"] num_heads = config["model"]["num_heads"] expansion = config["model"]["expansion"] dropout = config["model"]["pixel_decoder"]["dropout"] depths_encoder = config["model"]["pixel_encoder"]["depths"] layer_scale = 1.0 self.depth = depth self.dim = hidden_dim self.downsample = 4 self.num_heads = num_heads self.num_resolutions = len(depths_encoder) self.depths_encoder = depths_encoder self.slices_encoder_single = list( zip([d - 1 for d in self.depths_encoder], self.depths_encoder) ) self.slices_encoder_range = list( zip([0, *self.depths_encoder[:-1]], self.depths_encoder) ) cls_token_input_dims = [input_dims[-i - 1] for i in range(len(depths_encoder))] input_dims = [input_dims[d - 1] for d in depths_encoder] self.slices_encoder = self.slices_encoder_single # adapt from encoder features, just project self.input_adapter = ListAdapter(input_dims, hidden_dim) self.token_adapter = ListAdapter(cls_token_input_dims, hidden_dim) # camera layer self.camera_layer = CameraHead( input_dim=hidden_dim, hidden_dim=hidden_dim, num_heads=num_heads, expansion=expansion, depth=2, dropout=dropout, layer_scale=layer_scale, ) self.depth_layer = DepthHead( hidden_dim=hidden_dim, num_heads=num_heads, expansion=expansion, depths=depth, dropout=dropout, camera_dim=81, num_resolutions=self.num_resolutions, layer_scale=layer_scale, ) # transformer part self.pos_embed = PositionEmbeddingSine(hidden_dim // 2, normalize=True) self.level_embeds = nn.Parameter( torch.randn(len(input_dims), hidden_dim), requires_grad=True ) self.level_embed_layer = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), ) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/unidepthv1/unidepthv1.py ================================================ """ Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ import importlib from copy import deepcopy from math import ceil import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms.functional as TF from einops import rearrange from huggingface_hub import PyTorchModelHubMixin from unidepth.models.unidepthv1.decoder import Decoder from unidepth.utils.constants import (IMAGENET_DATASET_MEAN, IMAGENET_DATASET_STD) from unidepth.utils.distributed import is_main_process from unidepth.utils.geometric import (generate_rays, spherical_zbuffer_to_euclidean) from unidepth.utils.misc import (get_params, match_gt, match_intrinsics, profile_method) VERBOSE = False # inference helpers def _paddings(image_shape, network_shape): cur_h, cur_w = image_shape h, w = network_shape pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2 pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2 return pad_left, pad_right, pad_top, pad_bottom def _shapes(image_shape, network_shape): h, w = image_shape input_ratio = w / h output_ratio = network_shape[1] / network_shape[0] if output_ratio > input_ratio: ratio = network_shape[0] / h elif output_ratio <= input_ratio: ratio = network_shape[1] / w return (ceil(h * ratio - 0.5), ceil(w * ratio - 0.5)), ratio def _preprocess(rgbs, intrinsics, shapes, pads, ratio, output_shapes): (pad_left, pad_right, pad_top, pad_bottom) = pads rgbs = F.interpolate( rgbs, size=shapes, mode="bilinear", align_corners=False, antialias=True ) rgbs = F.pad(rgbs, (pad_left, pad_right, pad_top, pad_bottom), mode="constant") if intrinsics is not None: intrinsics = intrinsics.clone() intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio + pad_left intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio + pad_top return rgbs, intrinsics return rgbs, None def _postprocess(predictions, intrinsics, shapes, pads, ratio, original_shapes): (pad_left, pad_right, pad_top, pad_bottom) = pads # pred mean, trim paddings, and upsample to input dim predictions = sum( [ F.interpolate( x.clone(), size=shapes, mode="bilinear", align_corners=False, antialias=True, ) for x in predictions ] ) / len(predictions) predictions = predictions[ ..., pad_top : shapes[0] - pad_bottom, pad_left : shapes[1] - pad_right ] predictions = F.interpolate( predictions, size=original_shapes, mode="bilinear", align_corners=False, antialias=True, ) intrinsics[:, 0, 0] = intrinsics[:, 0, 0] / ratio intrinsics[:, 1, 1] = intrinsics[:, 1, 1] / ratio intrinsics[:, 0, 2] = (intrinsics[:, 0, 2] - pad_left) / ratio intrinsics[:, 1, 2] = (intrinsics[:, 1, 2] - pad_top) / ratio return predictions, intrinsics class UniDepthV1( nn.Module, PyTorchModelHubMixin, library_name="UniDepth", repo_url="https://github.com/lpiccinelli-eth/UniDepth", tags=["monocular-metric-depth-estimation"], ): def __init__( self, config, eps: float = 1e-6, **kwargs, ): super().__init__() self.build(config) self.build_losses(config) self.eps = eps @profile_method(verbose=VERBOSE) def forward_train(self, inputs, image_metas): inputs, outputs = self.encode_decode(inputs, image_metas) losses = self.compute_losses(outputs, inputs, image_metas) return outputs, losses @profile_method(verbose=VERBOSE) def forward_test(self, inputs, image_metas): inputs, outputs = self.encode_decode(inputs, image_metas) depth_gt = inputs["depth"] test_outputs = {} test_outputs["depth"] = match_gt( outputs["depth"], depth_gt, padding1=inputs["paddings"], padding2=None ) test_outputs["points"] = match_gt( outputs["points"], depth_gt, padding1=inputs["paddings"], padding2=None ) test_outputs["confidence"] = match_gt( outputs["confidence"], depth_gt, padding1=inputs["paddings"], padding2=None ) test_outputs["rays"] = match_gt( outputs["rays"], depth_gt, padding1=inputs["paddings"], padding2=None ) test_outputs["rays"] = outputs["rays"] / torch.norm( outputs["rays"], dim=1, keepdim=True ).clip(min=1e-5) test_outputs["intrinsics"] = match_intrinsics( outputs["intrinsics"], inputs["image"], depth_gt, padding1=inputs["paddings"], padding2=None, ) return test_outputs def forward(self, inputs, image_metas): if self.training: return self.forward_train(inputs, image_metas) else: return self.forward_test(inputs, image_metas) def encode_decode(self, inputs, image_metas): rgbs = inputs["image"] B, _, H, W = rgbs.shape cameras = inputs["camera"] # shortcut eval should avoid errors if len(image_metas) and "paddings" in image_metas[0]: inputs["paddings"] = torch.tensor( [image_meta["paddings"] for image_meta in image_metas], device=self.device, )[ ..., [0, 2, 1, 3] ] # lrtb inputs["depth_paddings"] = torch.tensor( [image_meta["depth_paddings"] for image_meta in image_metas], device=self.device, ) if ( self.training ): # at inference we do not have image paddings on top of depth ones (we have not "crop" on gt in ContextCrop) inputs["depth_paddings"] = inputs["depth_paddings"] + inputs["paddings"] # Get camera rays for supervision, all in unit sphere if inputs.get("camera", None) is not None: inputs["rays"] = rearrange( inputs["camera"].get_rays(shapes=(B, H, W)), "b c h w -> b (h w) c" ) # Encode encoder_outputs, cls_tokens = self.pixel_encoder(rgbs) if "dino" in self.pixel_encoder.__class__.__name__.lower(): encoder_outputs = [ (x + y.unsqueeze(1)).contiguous() for x, y in zip(encoder_outputs, cls_tokens) ] inputs["encoder_outputs"] = encoder_outputs inputs["cls_tokens"] = cls_tokens # Decode pred_intrinsics, predictions, depth_features = self.pixel_decoder(inputs, {}) predictions = sum( [ F.interpolate( x.clone(), size=(H, W), mode="bilinear", align_corners=False, antialias=True, ) for x in predictions ] ) / len(predictions) # Final 3D points backprojection pred_rays, pred_angles = generate_rays(pred_intrinsics, (H, W), noisy=False) # You may want to use inputs["angles"] if available? pred_angles = rearrange(pred_angles, "b (h w) c -> b c h w", h=H, w=W) points_3d = torch.cat((pred_angles, predictions), dim=1) points_3d = spherical_zbuffer_to_euclidean( points_3d.permute(0, 2, 3, 1) ).permute(0, 3, 1, 2) # Output data, use for loss computation outputs = { "angles": pred_angles, "rays": pred_rays, "intrinsics": pred_intrinsics, "points": points_3d, "depth": predictions[:, -1:], "cond_features": depth_features, } self.pixel_decoder.test_fixed_camera = False outputs["rays"] = rearrange(outputs["rays"], "b (h w) c -> b c h w", h=H, w=W) if "rays" in inputs: inputs["rays"] = rearrange(inputs["rays"], "b (h w) c -> b c h w", h=H, w=W) return inputs, outputs def compute_losses(self, outputs, inputs, image_metas): B, _, H, W = inputs["image"].shape losses = {"opt": {}, "stat": {}} if ( not self.training ): # only compute losses during training, avoid issues for mismatch size of pred and GT return losses losses_to_be_computed = list(self.losses.keys()) # depth loss si = torch.tensor( [x.get("si", False) for x in image_metas], device=self.device ).reshape(B) loss = self.losses["depth"] depth_losses = loss( outputs["depth"], target=inputs["depth"], mask=inputs["depth_mask"].clone(), si=si, ) losses["opt"][loss.name] = loss.weight * depth_losses.mean() losses_to_be_computed.remove("depth") # camera loss, here we apply to rays for simplicity # in the original training was on angles # however, we saw no difference (see supplementary) loss = self.losses["camera"] camera_losses = loss(outputs["rays"], target=inputs["rays"]) losses["opt"][loss.name] = loss.weight * camera_losses.mean() losses_to_be_computed.remove("camera") # invariance loss flips = torch.tensor( [x.get("flip", False) for x in image_metas], device=self.device ).reshape(B) loss = self.losses["invariance"] invariance_losses = loss( outputs["cond_features"], intrinsics=inputs["camera"].K, mask=inputs["depth_mask"], flips=flips, ) losses["opt"][loss.name] = loss.weight * invariance_losses.mean() losses_to_be_computed.remove("invariance") # remaining losses, we expect no more losses to be computed assert ( not losses_to_be_computed ), f"Losses {losses_to_be_computed} not computed, revise `compute_loss` method" return losses @torch.no_grad() def infer(self, rgbs: torch.Tensor, intrinsics=None, skip_camera=False): if rgbs.ndim == 3: rgbs = rgbs.unsqueeze(0) if intrinsics is not None and intrinsics.ndim == 2: intrinsics = intrinsics.unsqueeze(0) B, _, H, W = rgbs.shape rgbs = rgbs.to(self.device) if intrinsics is not None: intrinsics = intrinsics.to(self.device) # process image and intrinsiscs (if any) to match network input (slow?) if rgbs.max() > 5 or rgbs.dtype == torch.uint8: rgbs = rgbs.to(torch.float32).div(255) if rgbs.min() >= 0.0 and rgbs.max() <= 1.0: rgbs = TF.normalize( rgbs, mean=IMAGENET_DATASET_MEAN, std=IMAGENET_DATASET_STD, ) (h, w), ratio = _shapes((H, W), self.image_shape) pad_left, pad_right, pad_top, pad_bottom = _paddings((h, w), self.image_shape) rgbs, gt_intrinsics = _preprocess( rgbs, intrinsics, (h, w), (pad_left, pad_right, pad_top, pad_bottom), ratio, self.image_shape, ) # run encoder encoder_outputs, cls_tokens = self.pixel_encoder(rgbs) if "dino" in self.pixel_encoder.__class__.__name__.lower(): encoder_outputs = [ (x + y.unsqueeze(1)).contiguous() for x, y in zip(encoder_outputs, cls_tokens) ] # get data for decoder and adapt to given camera inputs = {} inputs["encoder_outputs"] = encoder_outputs inputs["cls_tokens"] = cls_tokens inputs["image"] = rgbs if gt_intrinsics is not None: rays, angles = generate_rays( gt_intrinsics, self.image_shape, noisy=self.training ) inputs["rays"] = rays inputs["angles"] = angles inputs["K"] = gt_intrinsics self.pixel_decoder.test_fixed_camera = True self.pixel_decoder.skip_camera = skip_camera # decode all pred_intrinsics, predictions, _ = self.pixel_decoder(inputs, {}) # undo the reshaping and get original image size (slow) predictions, pred_intrinsics = _postprocess( predictions, pred_intrinsics, self.image_shape, (pad_left, pad_right, pad_top, pad_bottom), ratio, (H, W), ) # final 3D points backprojection intrinsics = gt_intrinsics if gt_intrinsics is not None else pred_intrinsics angles = generate_rays(intrinsics, (H, W), noisy=False)[-1] angles = rearrange(angles, "b (h w) c -> b c h w", h=H, w=W) points_3d = torch.cat((angles, predictions), dim=1) points_3d = spherical_zbuffer_to_euclidean( points_3d.permute(0, 2, 3, 1) ).permute(0, 3, 1, 2) # output data outputs = { "intrinsics": pred_intrinsics, "points": points_3d, "depth": predictions[:, -1:], } self.pixel_decoder.test_fixed_camera = False self.pixel_decoder.skip_camera = False return outputs def load_pretrained(self, model_file): device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) dict_model = torch.load(model_file, map_location=device) if "model" in dict_model: dict_model = dict_model["model"] new_state_dict = deepcopy( {k.replace("module.", ""): v for k, v in dict_model.items()} ) info = self.load_state_dict(new_state_dict, strict=False) if is_main_process(): print( f"Loaded from {model_file} for {self.__class__.__name__} results in:", info, ) def get_params(self, config): if hasattr(self.pixel_encoder, "get_params"): encoder_p, encoder_lr = self.pixel_encoder.get_params( config["model"]["pixel_encoder"]["lr"], config["training"]["wd"], config["training"]["ld"], ) else: encoder_p, encoder_lr = get_params( self.pixel_encoder, config["model"]["pixel_encoder"]["lr"], config["training"]["wd"], ) decoder_p, decoder_lr = get_params( self.pixel_decoder, config["training"]["lr"], config["training"]["wd"] ) return [*encoder_p, *decoder_p] @property def device(self): return next(self.parameters()).device def build(self, config): mod = importlib.import_module("unidepth.models.encoder") pixel_encoder_factory = getattr(mod, config["model"]["pixel_encoder"]["name"]) pixel_encoder_config = { **config["training"], **config["data"], **config["model"]["pixel_encoder"], "interpolate_offset": 0.1, } pixel_encoder = pixel_encoder_factory(pixel_encoder_config) config["model"]["pixel_encoder"]["patch_size"] = ( 14 if "dino" in config["model"]["pixel_encoder"]["name"] else 16 ) pixel_encoder_embed_dims = ( pixel_encoder.embed_dims if hasattr(pixel_encoder, "embed_dims") else [getattr(pixel_encoder, "embed_dim") * 2**i for i in range(4)] ) config["model"]["pixel_encoder"]["embed_dim"] = getattr( pixel_encoder, "embed_dim" ) config["model"]["pixel_encoder"]["embed_dims"] = pixel_encoder_embed_dims config["model"]["pixel_encoder"]["depths"] = pixel_encoder.depths self.pixel_encoder = pixel_encoder self.pixel_decoder = Decoder(config) self.image_shape = config["data"]["image_shape"] def build_losses(self, config): self.losses = {} for loss_name, loss_config in config["training"].get("losses", {}).items(): mod = importlib.import_module("unidepth.ops.losses") loss_factory = getattr(mod, loss_config["name"]) self.losses[loss_name] = loss_factory.build(loss_config) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/unidepthv2/__init__.py ================================================ from .unidepthv2 import UniDepthV2 from .unidepthv2_old import UniDepthV2old __all__ = [ "UniDepthV2", "UniDepthV2old", ] ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/unidepthv2/decoder.py ================================================ """ Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from timm.models.layers import trunc_normal_ from unidepth.layers import (MLP, AttentionBlock, AttentionLayer, PositionEmbeddingSine, ResUpsampleBil) from unidepth.utils.coordinate import coords_grid from unidepth.utils.geometric import flat_interpolate from unidepth.utils.positional_embedding import generate_fourier_features def orthonormal_init(num_tokens, dims): pe = torch.randn(num_tokens, dims) # Apply Gram-Schmidt process to make the matrix orthonormal for i in range(num_tokens): for j in range(i): # Subtract the projection of current row onto previous row pe[i] -= torch.dot(pe[i], pe[j]) * pe[j] # Normalize the current row pe[i] = F.normalize(pe[i], p=2, dim=0) return pe class ListAdapter(nn.Module): def __init__(self, input_dims: list[int], hidden_dim: int): super().__init__() self.input_adapters = nn.ModuleList([]) self.num_chunks = len(input_dims) for input_dim in input_dims: self.input_adapters.append(nn.Linear(input_dim, hidden_dim)) def forward(self, xs: torch.Tensor) -> list[torch.Tensor]: outs = [self.input_adapters[i](x) for i, x in enumerate(xs)] return outs class CameraHead(nn.Module): def __init__( self, hidden_dim: int, num_heads: int = 8, expansion: int = 4, dropout: float = 0.0, layer_scale: float = 1.0, **kwargs, ): super().__init__() self.num_params = 4 self.aggregate1 = AttentionBlock( hidden_dim, num_heads=num_heads, expansion=expansion, dropout=dropout, layer_scale=layer_scale, use_bias=False, ) self.aggregate2 = AttentionBlock( hidden_dim, num_heads=num_heads, expansion=expansion, dropout=dropout, layer_scale=layer_scale, use_bias=False, ) self.latents_pos = nn.Parameter( torch.randn(1, self.num_params, hidden_dim), requires_grad=True ) self.project = MLP( hidden_dim, expansion=1, dropout=dropout, output_dim=hidden_dim ) self.out_pinhole = MLP(hidden_dim, expansion=1, dropout=dropout, output_dim=1) def fill_intrinsics(self, x): fx, fy, cx, cy = x.unbind(dim=-1) fx = torch.exp(fx) fy = torch.exp(fy) cx = torch.sigmoid(cx) cy = torch.sigmoid(cy) diagonal = (self.shapes[0] ** 2 + self.shapes[1] ** 2) ** 0.5 correction_tensor = torch.tensor( [0.7 * diagonal, 0.7 * diagonal, self.shapes[1], self.shapes[0]], device=x.device, dtype=x.dtype, ) intrinsics = torch.stack([fx, fy, cx, cy], dim=1) intrinsics = correction_tensor.unsqueeze(0) * intrinsics return intrinsics def forward(self, features, cls_tokens, pos_embed) -> torch.Tensor: features = features.unbind(dim=-1) tokens = self.project(cls_tokens) latents_pos = self.latents_pos.expand(cls_tokens.shape[0], -1, -1) tokens = self.aggregate1(tokens, pos_embed=latents_pos) tokens = self.aggregate2(tokens, pos_embed=latents_pos) x = self.out_pinhole(tokens.clone()).squeeze(-1) camera_intrinsics = self.fill_intrinsics(x) return camera_intrinsics def set_shapes(self, shapes: tuple[int, int]): self.shapes = shapes class DepthHead(nn.Module): def __init__( self, hidden_dim: int, num_heads: int = 8, expansion: int = 4, depths: int | list[int] = 4, camera_dim: int = 256, dropout: float = 0.0, kernel_size: int = 7, layer_scale: float = 1.0, out_dim: int = 1, use_norm=False, num_prompt_blocks=1, **kwargs, ) -> None: super().__init__() self.camera_dim = camera_dim self.out_dim = out_dim self.hidden_dim = hidden_dim self.ups = nn.ModuleList([]) self.depth_mlp = nn.ModuleList([]) self.process_features = nn.ModuleList([]) self.project_features = nn.ModuleList([]) self.prompt_camera = nn.ModuleList([]) mult = 2 self.to_latents = nn.Linear(hidden_dim, hidden_dim) for _ in range(4): self.prompt_camera.append( AttentionLayer( num_blocks=num_prompt_blocks, dim=hidden_dim, num_heads=num_heads, expansion=expansion, dropout=dropout, layer_scale=-1.0, context_dim=hidden_dim, use_bias=False, ) ) for i, depth in enumerate(depths): current_dim = min(hidden_dim, mult * hidden_dim // int(2**i)) next_dim = mult * hidden_dim // int(2 ** (i + 1)) output_dim = max(next_dim, out_dim) self.process_features.append( nn.ConvTranspose2d( hidden_dim, current_dim, kernel_size=max(1, 2 * i), stride=max(1, 2 * i), padding=0, ) ) self.ups.append( ResUpsampleBil( current_dim, output_dim=output_dim, expansion=expansion, layer_scale=layer_scale, kernel_size=kernel_size, num_layers=depth, use_norm=use_norm, ) ) depth_mlp = nn.Identity() if i == len(depths) - 1: depth_mlp = nn.Sequential( nn.LayerNorm(next_dim), nn.Linear(next_dim, output_dim) ) self.depth_mlp.append(depth_mlp) self.confidence_mlp = nn.Sequential( nn.LayerNorm(next_dim), nn.Linear(next_dim, output_dim) ) self.to_depth_lr = nn.Conv2d( output_dim, output_dim // 2, kernel_size=3, padding=1, padding_mode="reflect", ) self.to_confidence_lr = nn.Conv2d( output_dim, output_dim // 2, kernel_size=3, padding=1, padding_mode="reflect", ) self.to_depth_hr = nn.Sequential( nn.Conv2d( output_dim // 2, 32, kernel_size=3, padding=1, padding_mode="reflect" ), nn.LeakyReLU(), nn.Conv2d(32, 1, kernel_size=1), ) self.to_confidence_hr = nn.Sequential( nn.Conv2d( output_dim // 2, 32, kernel_size=3, padding=1, padding_mode="reflect" ), nn.LeakyReLU(), nn.Conv2d(32, 1, kernel_size=1), ) def set_original_shapes(self, shapes: tuple[int, int]): self.original_shapes = shapes def set_shapes(self, shapes: tuple[int, int]): self.shapes = shapes def embed_rays(self, rays): rays_embedding = flat_interpolate( rays, old=self.original_shapes, new=self.shapes, antialias=True ) rays_embedding = rays_embedding / torch.norm( rays_embedding, dim=-1, keepdim=True ).clip(min=1e-4) x, y, z = rays_embedding[..., 0], rays_embedding[..., 1], rays_embedding[..., 2] polar = torch.acos(z) x_clipped = x.abs().clip(min=1e-3) * (2 * (x >= 0).int() - 1) azimuth = torch.atan2(y, x_clipped) rays_embedding = torch.stack([polar, azimuth], dim=-1) rays_embedding = generate_fourier_features( rays_embedding, dim=self.hidden_dim, max_freq=max(self.shapes) // 2, use_log=True, cat_orig=False, ) return rays_embedding def condition(self, feat, rays_embeddings): conditioned_features = [ prompter(rearrange(feature, "b h w c -> b (h w) c"), rays_embeddings) for prompter, feature in zip(self.prompt_camera, feat) ] return conditioned_features def process(self, features_list, rays_embeddings): conditioned_features = self.condition(features_list, rays_embeddings) init_latents = self.to_latents(conditioned_features[0]) init_latents = rearrange( init_latents, "b (h w) c -> b c h w", h=self.shapes[0], w=self.shapes[1] ).contiguous() conditioned_features = [ rearrange( x, "b (h w) c -> b c h w", h=self.shapes[0], w=self.shapes[1] ).contiguous() for x in conditioned_features ] latents = init_latents out_features = [] for i, up in enumerate(self.ups): latents = latents + self.process_features[i](conditioned_features[i + 1]) latents = up(latents) out_features.append(latents) return out_features, init_latents def depth_proj(self, out_features): h_out, w_out = out_features[-1].shape[-2:] # aggregate output and project to depth for i, (layer, features) in enumerate(zip(self.depth_mlp, out_features)): out_depth_features = layer(features.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) out_depth_features = F.interpolate( out_depth_features, size=(h_out, w_out), mode="bilinear", align_corners=True, ) if i == len(self.depth_mlp) - 1: logdepth = out_depth_features logdepth = self.to_depth_lr(logdepth) logdepth = F.interpolate( logdepth, size=self.original_shapes, mode="bilinear", align_corners=True ) logdepth = self.to_depth_hr(logdepth) return logdepth def confidence_proj(self, out_features): highres_features = out_features[-1].permute(0, 2, 3, 1) confidence = self.confidence_mlp(highres_features).permute(0, 3, 1, 2) confidence = self.to_confidence_lr(confidence) confidence = F.interpolate( confidence, size=self.original_shapes, mode="bilinear", align_corners=True ) confidence = self.to_confidence_hr(confidence) return confidence def decode(self, out_features): logdepth = self.depth_proj(out_features) confidence = self.confidence_proj(out_features) return logdepth, confidence def forward( self, features: list[torch.Tensor], rays_hr: torch.Tensor, pos_embed, level_embed, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: B = features[0].shape[0] rays_embeddings = self.embed_rays(rays_hr) features, proj_latents_16 = self.process(features, rays_embeddings) logdepth, logconf = self.decode(features) return logdepth, logconf, proj_latents_16 class Decoder(nn.Module): def __init__( self, config, ): super().__init__() self.build(config) self.apply(self._init_weights) self.test_gt_camera = False def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): if m.bias is not None: nn.init.constant_(m.bias, 0) if m.weight is not None: nn.init.constant_(m.weight, 1.0) def run_camera(self, cls_tokens, features, pos_embed, original_shapes, rays_gt): H, W = original_shapes # camera layer intrinsics = self.camera_layer( features=features, cls_tokens=cls_tokens, pos_embed=pos_embed ) B, N = intrinsics.shape device = intrinsics.device dtype = intrinsics.dtype id_coords = coords_grid(B, H, W, device=features.device, homogeneous=True) intrinsics_matrix_inverse = torch.eye(3, device=device, dtype=dtype).repeat( B, 1, 1 ) intrinsics_matrix_inverse[:, 0, 0] = 1.0 / intrinsics[:, 0] intrinsics_matrix_inverse[:, 1, 1] = 1.0 / intrinsics[:, 1] intrinsics_matrix_inverse[:, 0, 2] = -intrinsics[:, 2] / intrinsics[:, 0] intrinsics_matrix_inverse[:, 1, 2] = -intrinsics[:, 3] / intrinsics[:, 1] intrinsics_matrix = torch.eye(3, device=device, dtype=dtype).repeat(B, 1, 1) intrinsics_matrix[:, 0, 0] = intrinsics[:, 0] intrinsics_matrix[:, 1, 1] = intrinsics[:, 1] intrinsics_matrix[:, 0, 2] = intrinsics[:, 2] intrinsics_matrix[:, 1, 2] = intrinsics[:, 3] rays_pred = intrinsics_matrix_inverse @ id_coords.reshape(B, 3, -1) rays_pred = rays_pred.reshape(B, 3, H, W) rays_pred = rays_pred / torch.norm(rays_pred, dim=1, keepdim=True).clamp( min=1e-5 ) ### LEGACY CODE FOR TRAINING # if self.training and rays_gt is not None: # prob = -1.0 # 0.8 * (1 - tanh(self.steps / 100000)) + 0.2 # where_use_gt_rays = torch.rand(B, 1, 1, device=device, dtype=dtype) < prob # where_use_gt_rays = where_use_gt_rays.int() # rays = rays_gt * where_use_gt_rays + rays_pred * (1 - where_use_gt_rays) rays = rays_pred if rays_gt is None else rays_gt rays = rearrange(rays, "b c h w -> b (h w) c") return intrinsics_matrix, rays def forward( self, inputs: dict[str, torch.Tensor], image_metas: list[dict[str, torch.Tensor]], ) -> dict[str, torch.Tensor]: B, C, H, W = inputs["image"].shape device = inputs["image"].device dtype = inputs["features"][0].dtype # get features in b n d format common_shape = inputs["features"][0].shape[1:3] # input shapes repeat shapes for each level, times the amount of the layers: features = self.input_adapter(inputs["features"]) # positional embeddings, spatial and level level_embed = self.level_embeds.repeat( B, common_shape[0] * common_shape[1], 1, 1 ) level_embed = rearrange(level_embed, "b n l d -> b (n l) d") dummy_tensor = torch.zeros( B, 1, common_shape[0], common_shape[1], device=device, requires_grad=False ) pos_embed = self.pos_embed(dummy_tensor) pos_embed = rearrange(pos_embed, "b c h w -> b (h w) c").repeat( 1, self.num_resolutions, 1 ) # get cls tokens projections camera_tokens = inputs["tokens"] camera_tokens = self.camera_token_adapter(camera_tokens) self.camera_layer.set_shapes((H, W)) intrinsics, rays = self.run_camera( torch.cat(camera_tokens, dim=1), features=torch.stack(features, dim=-1).detach(), pos_embed=(pos_embed + level_embed).detach(), original_shapes=(H, W), rays_gt=inputs.get("rays", None), ) # run bulk of the model self.depth_layer.set_shapes(common_shape) self.depth_layer.set_original_shapes((H, W)) logdepth, logconfidence, depth_features = self.depth_layer( features=features, rays_hr=rays, pos_embed=pos_embed, level_embed=level_embed, ) return { "radius": torch.exp(logdepth.clip(min=-8.0, max=8.0) + 2.0), "depth_features": depth_features, "confidence": torch.exp(logconfidence.clip(min=-8.0, max=8.0)), "intrinsics": intrinsics, "rays": rays, } @torch.jit.ignore def no_weight_decay_keywords(self): return {"latents_pos", "level_embeds"} def build(self, config): input_dims = config["model"]["pixel_encoder"]["embed_dims"] hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"] expansion = config["model"]["expansion"] num_heads = config["model"]["num_heads"] dropout = config["model"]["pixel_decoder"]["dropout"] depths_encoder = config["model"]["pixel_encoder"]["depths"] layer_scale = config["model"]["layer_scale"] depth = config["model"]["pixel_decoder"]["depths"] self.downsample = 4 depths_encoder = config["model"]["pixel_encoder"]["depths"] self.num_resolutions = len(depths_encoder) self.test_fixed_camera = False out_dim = config["model"]["pixel_decoder"]["out_dim"] kernel_size = config["model"]["pixel_decoder"].get("kernel_size", 7) self.slices_encoder = list(zip([d - 1 for d in depths_encoder], depths_encoder)) input_dims = [input_dims[d - 1] for d in depths_encoder] # # adapt from encoder features, just project camera_dims = input_dims self.input_adapter = ListAdapter(input_dims, hidden_dim) self.camera_token_adapter = ListAdapter(camera_dims, hidden_dim) # # camera layer self.camera_layer = CameraHead( hidden_dim=hidden_dim, num_heads=num_heads, expansion=expansion, dropout=dropout, layer_scale=layer_scale, ) self.depth_layer = DepthHead( hidden_dim=hidden_dim, num_heads=num_heads, expansion=expansion, depths=depth, dropout=dropout, camera_dim=96, num_resolutions=self.num_resolutions, layer_scale=layer_scale, out_dim=out_dim, kernel_size=kernel_size, num_prompt_blocks=1, use_norm=False, ) self.pos_embed = PositionEmbeddingSine(hidden_dim // 2, normalize=True) self.level_embeds = nn.Parameter( orthonormal_init(len(input_dims), hidden_dim).reshape( 1, 1, len(input_dims), hidden_dim ), requires_grad=False, ) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/unidepthv2/decoder_old.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from timm.models.layers import trunc_normal_ from unidepth.layers import (MLP, AttentionBlock, ConvUpsampleShuffleResidual, NystromBlock, PositionEmbeddingSine) from unidepth.utils.geometric import flat_interpolate, generate_rays from unidepth.utils.positional_embedding import generate_fourier_features class ListAdapter(nn.Module): def __init__(self, input_dims: list[int], hidden_dim: int): super().__init__() self.input_adapters = nn.ModuleList([]) self.num_chunks = len(input_dims) self.checkpoint = True for input_dim in input_dims: self.input_adapters.append( nn.Sequential( nn.LayerNorm(input_dim), nn.Linear(input_dim, hidden_dim), nn.GELU() ) ) def forward(self, x: torch.Tensor, splits: torch.Tensor) -> torch.Tensor: xs = torch.split(x, splits.int().tolist(), dim=-1) xs = [adapter(x) for x, adapter in zip(xs, self.input_adapters)] return torch.cat(xs, dim=-1) class CameraHead(nn.Module): def __init__( self, hidden_dim: int, num_heads: int = 8, expansion: int = 4, dropout: float = 0.0, **kwargs, ): super().__init__() self.aggregate1 = AttentionBlock( hidden_dim, num_heads=1, expansion=expansion, dropout=dropout ) self.aggregate2 = AttentionBlock( hidden_dim, num_heads=1, expansion=expansion, dropout=dropout ) self.latents_pos = nn.Parameter( torch.randn(1, 4, hidden_dim), requires_grad=True ) self.in_features = MLP(hidden_dim, expansion=2, dropout=dropout) self.project_cls = MLP(hidden_dim, dropout=dropout) self.out = MLP(hidden_dim, expansion=2, dropout=0.0, output_dim=1) def fill_intrinsics(self, x): camera_intrinsics = torch.zeros( x.shape[0], 3, 3, device=x.device, requires_grad=False ) camera_intrinsics[:, 0, 0] = x[:, 0].exp() camera_intrinsics[:, 1, 1] = x[:, 1].exp() camera_intrinsics[:, 0, 2] = x[:, 2].sigmoid() camera_intrinsics[:, 1, 2] = x[:, 3].sigmoid() camera_intrinsics[:, 2, 2] = 1.0 return camera_intrinsics def forward(self, features, cls_tokens, pos_embed) -> torch.Tensor: features = features.unbind(dim=-1) cls_tokens = self.project_cls(cls_tokens) latents_pos = self.latents_pos.expand(cls_tokens.shape[0], -1, -1) features = self.in_features(torch.cat(features, dim=1) + pos_embed) features = torch.cat((features, cls_tokens), dim=1) cls_tokens = self.aggregate1( cls_tokens, context=features, pos_embed=latents_pos ) cls_tokens = self.aggregate2( cls_tokens, context=features, pos_embed=latents_pos ) # project to intrinsics x = self.out(cls_tokens).squeeze(-1) camera_intrinsics = self.fill_intrinsics(x) return camera_intrinsics def set_shapes(self, shapes: tuple[int, int]): self.shapes = shapes class GlobalHead(nn.Module): def __init__( self, hidden_dim: int, camera_dim: int, expansion: int = 4, dropout: float = 0.0, **kwargs, ): super().__init__() self.camera_dim = camera_dim self.in_features = nn.Linear(hidden_dim, hidden_dim) self.project_rays = nn.Linear(camera_dim + 3, hidden_dim) self.aggregate1 = AttentionBlock( hidden_dim, num_heads=1, expansion=expansion, dropout=dropout ) self.aggregate2 = AttentionBlock( hidden_dim, num_heads=1, expansion=expansion, dropout=dropout ) self.project_cls = MLP(hidden_dim, dropout=dropout) self.out = MLP(hidden_dim, expansion=2, dropout=0.0, output_dim=1) def embed_rays(self, rays, shapes): rays_embedding = flat_interpolate(rays, old=self.original_shapes, new=shapes) rays_embedding = F.normalize(rays_embedding, dim=-1) rays_embedding = generate_fourier_features( rays_embedding, dim=self.camera_dim, max_freq=max(shapes) // 2, use_log=True, cat_orig=True, ) return rays_embedding def set_original_shapes(self, shapes: tuple[int, int]): self.original_shapes = shapes def set_shapes(self, shapes: tuple[int, int]): self.shapes = shapes def get_scaleshift(self, x): scale, shift = torch.chunk(x, 2, dim=1) scale = scale.exp().reshape(-1, 1, 1, 1) shift = shift.reshape(-1, 1, 1, 1) return scale, shift def forward(self, features, cls_tokens, rays) -> torch.Tensor: features = features.unbind(dim=-1) cls_tokens = self.project_cls(cls_tokens) rays_embedding = self.project_rays(self.embed_rays(rays, self.shapes)) rays_embedding = rays_embedding.repeat(1, len(features), 1) features = self.in_features(torch.cat(features, dim=1) + rays_embedding) features = torch.cat((features, cls_tokens), dim=1) cls_tokens = self.aggregate1(cls_tokens, context=features) cls_tokens = self.aggregate2(cls_tokens, context=features) x = self.out(cls_tokens).squeeze(-1) scale, shift = self.get_scaleshift(x) return scale, shift class DepthHead(nn.Module): def __init__( self, hidden_dim: int, num_heads: int = 8, expansion: int = 4, depths: int | list[int] = 4, checkpoint: bool = True, camera_dim: int = 256, num_resolutions: int = 4, dropout: float = 0.0, **kwargs, ) -> None: super().__init__() self.checkpoint = checkpoint self.camera_dim = camera_dim self.skip_depth = False self.to_latents = MLP(hidden_dim, expansion=2, dropout=dropout) self.features_channel_cat = nn.Linear(hidden_dim * num_resolutions, hidden_dim) self.aggregate_16 = AttentionBlock( hidden_dim, num_heads=1, expansion=expansion, dropout=dropout, context_dim=hidden_dim, ) self.prompt_camera = AttentionBlock( hidden_dim, num_heads=1, expansion=expansion, dropout=dropout, context_dim=hidden_dim, ) self.rays_layers = nn.ModuleList([]) self.ups = nn.ModuleList([]) self.process_layers = nn.ModuleList([]) self.norms, self.out_layers = nn.ModuleList([]), nn.ModuleList([]) self.depth_mlp, self.confidence_mlp = nn.ModuleList([]), nn.ModuleList([]) for i, depth in enumerate(depths): blk_lst = nn.ModuleList([]) for _ in range(depth): blk_lst.append( NystromBlock( hidden_dim // int(2**i), num_heads=num_heads // int(2**i), expansion=expansion, dropout=dropout, ) ) self.process_layers.append(blk_lst) self.rays_layers.append(nn.Linear(camera_dim + 3, hidden_dim // int(2**i))) self.ups.append( ConvUpsampleShuffleResidual( hidden_dim // int(2**i), expansion=expansion, kernel_size=7, num_layers=2, ) ) self.depth_mlp.append( MLP( input_dim=hidden_dim // int(2 ** (i + 1)), output_dim=16, expansion=1, ) ) self.confidence_mlp.append( MLP( input_dim=hidden_dim // int(2 ** (i + 1)), output_dim=16, expansion=1, ) ) self.to_depth = nn.Conv2d( 16 * len(depths), 1, 7, padding=3, padding_mode="reflect" ) self.to_confidence = nn.Conv2d( 16 * len(depths), 1, 7, padding=3, padding_mode="reflect" ) def set_original_shapes(self, shapes: tuple[int, int]): self.original_shapes = shapes def set_shapes(self, shapes: tuple[int, int]): self.shapes = shapes def embed_rays(self, rays, shapes): rays_embedding = flat_interpolate(rays, old=self.original_shapes, new=shapes) rays_embedding = F.normalize(rays_embedding, dim=-1) rays_embedding = generate_fourier_features( rays_embedding, dim=self.camera_dim, max_freq=max(shapes) // 2, use_log=True, cat_orig=True, ) return rays_embedding def project_rays(self, rays, shapes): embedded_rays = [] for i, layer in enumerate(self.rays_layers): embedded_rays.append( layer(self.embed_rays(rays, [(2**i) * x for x in shapes])) ) return embedded_rays def decode_depth(self, latents_16, rays, shapes): latents = latents_16 out_features, depths, confidences = [], [], [] for i, (up, layers, rays_embedding) in enumerate( zip(self.ups, self.process_layers, rays) ): for layer in layers: latents = layer(latents, pos_embed=rays_embedding) latents = up( rearrange( latents + rays_embedding, "b (h w) c -> b c h w", h=shapes[0] * int(2**i), w=shapes[1] * int(2**i), ).contiguous() ) out = rearrange( latents, "b (h w) c -> b h w c", h=shapes[0] * int(2 ** (1 + i)), w=shapes[1] * int(2 ** (1 + i)), ) out_features.append(out) # aggregate output and project to depth for i, (layer, features) in enumerate( zip(self.depth_mlp[::-1], out_features[::-1]) ): out_depth_features = layer(features).permute(0, 3, 1, 2) out_depth_features = F.interpolate( out_depth_features, size=self.original_shapes, mode="bilinear" ) depths.append(out_depth_features) logdepth = self.to_depth(torch.cat(depths, dim=1)) # aggregate output and project to confidences for i, (layer, features) in enumerate( zip(self.confidence_mlp[::-1], out_features[::-1]) ): out_conf_features = layer(features).permute(0, 3, 1, 2) out_conf_features = F.interpolate( out_conf_features, size=self.original_shapes, mode="bilinear" ) confidences.append(out_conf_features) confidence = self.to_confidence(torch.cat(confidences, dim=1)) # apply sigmoid ot get conf in [0, 1] confidence = torch.sigmoid(confidence) return logdepth, confidence def init_latents(self, features, shapes): # Generate latents with init as pooled features features_channels = torch.cat(features, dim=-1) features_16 = self.features_channel_cat(features_channels) latents_16 = features_16 + self.to_latents( flat_interpolate(features_16, old=self.shapes, new=shapes, antialias=False) ) return latents_16 def forward( self, features: torch.Tensor, rays_hr: torch.Tensor, pos_embed, level_embed ) -> torch.Tensor: B = features.shape[0] features = features.unbind(dim=-1) shapes = self.shapes # camera_embedding rays_embeddings = self.project_rays(rays_hr, shapes) # Init latents init_latents_16 = self.init_latents(features, shapes) # Aggregate features: F -> D latents_16 = self.aggregate_16( init_latents_16, context=torch.cat(features, dim=1), pos_embed_context=pos_embed + level_embed, ) # Aggregate camera: D -> D|E latents_16 = self.prompt_camera(latents_16, context=rays_embeddings[0]) # Decode depth logdepth, confidence = self.decode_depth(latents_16, rays_embeddings, shapes) return logdepth, confidence, latents_16 class Decoder(nn.Module): def __init__( self, config, ): super().__init__() self.build(config) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): if m.bias is not None: nn.init.constant_(m.bias, 0) if m.weight is not None: nn.init.constant_(m.weight, 1.0) def get_adapted_features(self, features_flat, splits): features_flat_cat = torch.cat(features_flat, dim=-1) features_projected = self.input_adapter( features_flat_cat, splits ) # list [b hw c] shapes features = torch.chunk(features_projected, splits.shape[0], dim=-1) return features def run_camera(self, cls_tokens, features, pos_embed, original_shapes, rays_gt): # get cls tokens projections cls_tokens_splits = torch.tensor( [x.shape[-1] for x in cls_tokens], device=features.device, requires_grad=False, dtype=features.dtype, ) cls_tokens = torch.cat(cls_tokens, dim=-1) cls_tokens = self.camera_token_adapter(cls_tokens, cls_tokens_splits) cls_tokens = torch.cat( torch.chunk(cls_tokens, cls_tokens_splits.shape[0], dim=-1), dim=1 ) # camera layer intrinsics = self.camera_layer( features=features, cls_tokens=cls_tokens, pos_embed=pos_embed ) intrinsics[:, 0, 0] = max(original_shapes) / 2 * intrinsics[:, 0, 0] intrinsics[:, 1, 1] = max(original_shapes) / 2 * intrinsics[:, 1, 1] intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * original_shapes[1] intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * original_shapes[0] rays = ( rays_gt if rays_gt is not None else generate_rays(intrinsics, original_shapes)[0] ) return intrinsics, rays def run_global(self, cls_tokens, features, rays): # get cls tokens projections cls_tokens_splits = torch.tensor( [x.shape[-1] for x in cls_tokens], device=features.device, requires_grad=False, dtype=torch.float32, ) cls_tokens = torch.cat(cls_tokens, dim=-1) cls_tokens = self.global_token_adapter(cls_tokens, cls_tokens_splits) cls_tokens = torch.cat( torch.chunk(cls_tokens, cls_tokens_splits.shape[0], dim=-1), dim=1 ) scale, shift = self.global_layer( features=features, rays=rays, cls_tokens=cls_tokens ) return scale, shift def forward(self, inputs, image_metas) -> torch.Tensor: B, C, H, W = inputs["image"].shape device = inputs["image"].device dtype = inputs["image"].dtype # get features in b n d format # level shapes, the shape per level, for swin like [[128, 128], [64, 64],...], for vit [[32,32]] -> mult times resolutions level_shapes = sorted( list(set([tuple([x.shape[1], x.shape[2]]) for x in inputs["features"]])) )[::-1] if len(level_shapes) == 1: level_shapes = level_shapes * self.num_resolutions input_shapes = [ level_shapes[i] for i, (start, end) in enumerate(self.slices_encoder) for _ in range(end - start) ] common_shape = level_shapes[-2] # input shapes repeat shapes for each level, times the amount of the layers: features_flat = [ flat_interpolate( rearrange(x, "b h w c -> b (h w) c"), old=input_shape, new=common_shape ) for x, input_shape in zip(inputs["features"], input_shapes) ] features_splits = torch.tensor( [x.shape[-1] for x in features_flat], device=device, requires_grad=False, dtype=torch.float32, ) features = self.get_adapted_features(features_flat, features_splits) features = torch.stack(features, dim=-1) # positional embeddings, spatial and level level_embed = torch.cat( [ self.level_embed_layer(self.level_embeds)[i : i + 1] .unsqueeze(0) .repeat(B, common_shape[0] * common_shape[1], 1) for i in range(self.num_resolutions) ], dim=1, ) dummy_tensor = torch.zeros( B, 1, common_shape[0], common_shape[1], device=device, requires_grad=False ) pos_embed = self.pos_embed(dummy_tensor) pos_embed = rearrange(pos_embed, "b c h w -> b (h w) c").repeat( 1, self.num_resolutions, 1 ) self.camera_layer.set_shapes(common_shape) intrinsics, rays = self.run_camera( inputs["camera_tokens"], features=features, pos_embed=pos_embed + level_embed, original_shapes=(H, W), rays_gt=inputs.get("rays"), ) self.global_layer.set_shapes(common_shape) self.global_layer.set_original_shapes((H, W)) scale, shift = self.run_global( inputs["global_tokens"], features=features, rays=rays ) # run bulk of the model self.depth_layer.set_shapes(common_shape) self.depth_layer.set_original_shapes((H, W)) logdepth, confidence, depth_features = self.depth_layer( features=features, rays_hr=rays, pos_embed=pos_embed, level_embed=level_embed, ) logdepth = logdepth.to(torch.float32, non_blocking=True) # norm in log space, why performs better? shapes = [int(x) for x in logdepth.shape[-2:]] depth_normalized = F.layer_norm(logdepth, shapes).exp() depth = ( depth_normalized + shift ) * scale # shift is scale invariant if we do (x + mu) * sigma depth = F.softplus(depth, beta=10.0).to(dtype, non_blocking=True) outputs = { "depth": depth, "confidence": confidence, "depth_features": depth_features, "K": intrinsics, } return outputs @torch.jit.ignore def no_weight_decay_keywords(self): return {"latents_pos", "level_embeds"} def build(self, config): input_dims = config["model"]["pixel_encoder"]["embed_dims"] hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"] expansion = config["model"]["expansion"] num_heads = config["model"]["num_heads"] dropout = config["model"]["pixel_decoder"]["dropout"] depths_encoder = config["model"]["pixel_encoder"]["depths"] depth = config["model"]["pixel_decoder"]["depths"] depths_encoder = config["model"]["pixel_encoder"]["depths"] self.downsample = 4 self.num_resolutions = len(depths_encoder) self.slices_encoder = list(zip([d - 1 for d in depths_encoder], depths_encoder)) cls_token_input_dims = [input_dims[i] for i in [-1, -2, -3, -4]] input_dims = [input_dims[d - 1] for d in depths_encoder] # # camera layer self.camera_layer = CameraHead( hidden_dim=hidden_dim, num_heads=num_heads, expansion=expansion, dropout=dropout, ) # # scale shift layer self.global_layer = GlobalHead( hidden_dim=hidden_dim, camera_dim=96, num_heads=num_heads, expansion=expansion, dropout=dropout, ) # # adapt from encoder features, just project self.input_adapter = ListAdapter(input_dims, hidden_dim) self.camera_token_adapter = ListAdapter(cls_token_input_dims, hidden_dim) self.global_token_adapter = ListAdapter(cls_token_input_dims[:2], hidden_dim) self.depth_layer = DepthHead( hidden_dim=hidden_dim, num_heads=num_heads, expansion=expansion, depths=depth, dropout=dropout, camera_dim=96, num_resolutions=self.num_resolutions, ) self.pos_embed = PositionEmbeddingSine(hidden_dim // 2, normalize=True) self.level_embeds = nn.Parameter( torch.randn(len(input_dims), hidden_dim), requires_grad=True ) self.level_embed_layer = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), ) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/unidepthv2/export.py ================================================ """ Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ import argparse import json import os from math import ceil import huggingface_hub import torch.nn.functional as F import torch.onnx from unidepth.models.unidepthv2 import UniDepthV2 class UniDepthV2ONNX(UniDepthV2): def __init__( self, config, eps: float = 1e-6, **kwargs, ): super().__init__(config, eps) def forward(self, rgbs): B, _, H, W = rgbs.shape features, tokens = self.pixel_encoder(rgbs) inputs = {} inputs["image"] = rgbs inputs["features"] = [ self.stacking_fn(features[i:j]).contiguous() for i, j in self.slices_encoder_range ] inputs["tokens"] = [ self.stacking_fn(tokens[i:j]).contiguous() for i, j in self.slices_encoder_range ] outputs = self.pixel_decoder(inputs, []) outputs["rays"] = outputs["rays"].permute(0, 2, 1).reshape(B, 3, H, W) pts_3d = outputs["rays"] * outputs["radius"] return pts_3d, outputs["confidence"], outputs["intrinsics"] class UniDepthV2ONNXcam(UniDepthV2): def __init__( self, config, eps: float = 1e-6, **kwargs, ): super().__init__(config, eps) def forward(self, rgbs, rays): B, _, H, W = rgbs.shape features, tokens = self.pixel_encoder(rgbs) inputs = {} inputs["image"] = rgbs inputs["rays"] = rays inputs["features"] = [ self.stacking_fn(features[i:j]).contiguous() for i, j in self.slices_encoder_range ] inputs["tokens"] = [ self.stacking_fn(tokens[i:j]).contiguous() for i, j in self.slices_encoder_range ] outputs = self.pixel_decoder(inputs, []) outputs["rays"] = outputs["rays"].permute(0, 2, 1).reshape(B, 3, H, W) pts_3d = outputs["rays"] * outputs["radius"] return pts_3d, outputs["confidence"], outputs["intrinsics"] def export(model, path, shape=(462, 630), with_camera=False): model.eval() image = torch.rand(1, 3, *shape) dynamic_axes_in = {"rgbs": {0: "batch"}} inputs = [image] if with_camera: rays = torch.rand(1, 3, *shape) inputs.append(rays) dynamic_axes_in["rays"] = {0: "batch"} dynamic_axes_out = { "pts_3d": {0: "batch"}, "confidence": {0: "batch"}, "intrinsics": {0: "batch"}, } torch.onnx.export( model, tuple(inputs), path, input_names=list(dynamic_axes_in.keys()), output_names=list(dynamic_axes_out.keys()), opset_version=14, dynamic_axes={**dynamic_axes_in, **dynamic_axes_out}, ) print(f"Model exported to {path}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Export UniDepthV2 model to ONNX") parser.add_argument( "--version", type=str, default="v2", choices=["v2"], help="UniDepth version" ) parser.add_argument( "--backbone", type=str, default="vitl", choices=["vits", "vitb", "vitl"], help="Backbone model", ) parser.add_argument( "--shape", type=int, nargs=2, default=(462, 630), help="Input shape. No dyamic shape supported!", ) parser.add_argument( "--output-path", type=str, default="unidepthv2.onnx", help="Output ONNX file" ) parser.add_argument( "--with-camera", action="store_true", help="Export model that expects GT camera as unprojected rays at inference", ) args = parser.parse_args() version = args.version backbone = args.backbone shape = args.shape output_path = args.output_path with_camera = args.with_camera # force shape to be multiple of 14 shape_rounded = [14 * ceil(x // 14 - 0.5) for x in shape] if list(shape) != list(shape_rounded): print(f"Shape {shape} is not multiple of 14. Rounding to {shape_rounded}") shape = shape_rounded # assumes command is from root of repo with open(os.path.join("configs", f"config_{version}_{backbone}14.json")) as f: config = json.load(f) # tell DINO not to use efficient attention: not exportable config["training"]["export"] = True model = UniDepthV2ONNX(config) if not with_camera else UniDepthV2ONNXcam(config) path = huggingface_hub.hf_hub_download( repo_id=f"lpiccinelli/unidepth-{version}-{backbone}14", filename=f"pytorch_model.bin", repo_type="model", ) info = model.load_state_dict(torch.load(path), strict=False) print(f"UniDepth_{version}_{backbone} is loaded with:") print(f"\t missing keys: {info.missing_keys}") print(f"\t additional keys: {info.unexpected_keys}") export( model=model, path=os.path.join(os.environ.get("TMPDIR", "."), output_path), shape=shape, with_camera=with_camera, ) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/unidepthv2/unidepthv2.py ================================================ """ Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ import importlib from copy import deepcopy from math import ceil import warnings import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms.v2.functional as TF from einops import rearrange from huggingface_hub import PyTorchModelHubMixin from unidepth.models.unidepthv2.decoder import Decoder from unidepth.utils.camera import BatchCamera, Camera, Pinhole from unidepth.utils.constants import (IMAGENET_DATASET_MEAN, IMAGENET_DATASET_STD) from unidepth.utils.distributed import is_main_process from unidepth.utils.misc import (first_stack, get_params, last_stack, match_gt, match_intrinsics, max_stack, mean_stack, softmax_stack) STACKING_FNS = { "max": max_stack, "mean": mean_stack, "first": first_stack, "last": last_stack, "softmax": softmax_stack, } def get_paddings(original_shape, aspect_ratio_range): # Original dimensions H_ori, W_ori = original_shape orig_aspect_ratio = W_ori / H_ori # Determine the closest aspect ratio within the range min_ratio, max_ratio = aspect_ratio_range target_aspect_ratio = min(max_ratio, max(min_ratio, orig_aspect_ratio)) if orig_aspect_ratio > target_aspect_ratio: # Too wide W_new = W_ori H_new = int(W_ori / target_aspect_ratio) pad_top = (H_new - H_ori) // 2 pad_bottom = H_new - H_ori - pad_top pad_left, pad_right = 0, 0 else: # Too tall H_new = H_ori W_new = int(H_ori * target_aspect_ratio) pad_left = (W_new - W_ori) // 2 pad_right = W_new - W_ori - pad_left pad_top, pad_bottom = 0, 0 return (pad_left, pad_right, pad_top, pad_bottom), (H_new, W_new) def get_resize_factor(original_shape, pixels_range, shape_multiplier=14): # Original dimensions H_ori, W_ori = original_shape n_pixels_ori = W_ori * H_ori # Determine the closest number of pixels within the range min_pixels, max_pixels = pixels_range target_pixels = min(max_pixels, max(min_pixels, n_pixels_ori)) # Calculate the resize factor resize_factor = (target_pixels / n_pixels_ori) ** 0.5 new_width = int(W_ori * resize_factor) new_height = int(H_ori * resize_factor) new_height = ceil(new_height / shape_multiplier) * shape_multiplier new_width = ceil(new_width / shape_multiplier) * shape_multiplier return resize_factor, (new_height, new_width) def _postprocess(tensor, shapes, paddings, interpolation_mode="bilinear"): # interpolate to original size tensor = F.interpolate( tensor, size=shapes, mode=interpolation_mode, align_corners=False ) # remove paddings pad1_l, pad1_r, pad1_t, pad1_b = paddings tensor = tensor[..., pad1_t : shapes[0] - pad1_b, pad1_l : shapes[1] - pad1_r] return tensor def _postprocess_intrinsics(K, resize_factors, paddings): batch_size = K.shape[0] K_new = K.clone() for i in range(batch_size): scale = resize_factors[i] pad_l, _, pad_t, _ = paddings[i] K_new[i, 0, 0] /= scale # fx K_new[i, 1, 1] /= scale # fy K_new[i, 0, 2] /= scale # cx K_new[i, 1, 2] /= scale # cy K_new[i, 0, 2] -= pad_l # cx K_new[i, 1, 2] -= pad_t # cy return K_new class UniDepthV2( nn.Module, PyTorchModelHubMixin, library_name="UniDepth", repo_url="https://github.com/lpiccinelli-eth/UniDepth", tags=["monocular-metric-depth-estimation"], ): def __init__( self, config, eps: float = 1e-6, **kwargs, ): super().__init__() self.eps = eps self.build(config) self.build_losses(config) def forward_train(self, inputs, image_metas): inputs, outputs = self.encode_decode(inputs, image_metas) losses = self.compute_losses(outputs, inputs, image_metas) return outputs, losses def forward_test(self, inputs, image_metas): inputs, outputs = self.encode_decode(inputs, image_metas) depth_gt = inputs["depth"] test_outputs = {} test_outputs["depth"] = match_gt( outputs["depth"], depth_gt, padding1=inputs["paddings"], padding2=None ) test_outputs["points"] = match_gt( outputs["points"], depth_gt, padding1=inputs["paddings"], padding2=None ) test_outputs["confidence"] = match_gt( outputs["confidence"], depth_gt, padding1=inputs["paddings"], padding2=None ) test_outputs["rays"] = match_gt( outputs["rays"], depth_gt, padding1=inputs["paddings"], padding2=None ) test_outputs["rays"] = outputs["rays"] / torch.norm( outputs["rays"], dim=1, keepdim=True ).clip(min=1e-5) test_outputs["intrinsics"] = match_intrinsics( outputs["intrinsics"], inputs["image"], depth_gt, padding1=inputs["paddings"], padding2=None, ) return test_outputs def forward(self, inputs, image_metas): if self.training: return self.forward_train(inputs, image_metas) else: return self.forward_test(inputs, image_metas) def compute_losses(self, outputs, inputs, image_metas): B, _, H, W = inputs["image"].shape losses = {"opt": {}, "stat": {}} losses_to_be_computed = list(self.losses.keys()) # depth loss si = torch.tensor( [x.get("si", False) for x in image_metas], device=self.device ).reshape(B) loss = self.losses["depth"] depth_losses = loss( outputs["depth"], target=inputs["depth"], mask=inputs["depth_mask"].clone(), si=si, ) losses["opt"][loss.name] = loss.weight * depth_losses.mean() losses_to_be_computed.remove("depth") # camera loss, here we apply to rays for simplicity # in the original training was on angles # however, we saw no difference (see supplementary) loss = self.losses["camera"] camera_losses = loss(outputs["rays"], target=inputs["rays"]) losses["opt"][loss.name] = loss.weight * camera_losses.mean() losses_to_be_computed.remove("camera") # invariance loss on output depth flips = torch.tensor( [x.get("flip", False) for x in image_metas], device=self.device ).reshape(B) loss = self.losses["invariance"] invariance_losses = loss( outputs["depth"], intrinsics=inputs["camera"].K, mask=inputs["depth_mask"], flips=flips, downsample_ratio=1, ) losses["opt"][loss.name] = loss.weight * invariance_losses.mean() losses_to_be_computed.remove("invariance") # edge guided ssi loss = self.losses["ssi"] ssi_losses = loss( outputs["depth"], target=inputs["depth"], mask=inputs["depth_mask"].clone(), image=inputs["image"], validity_mask=inputs["validity_mask"], ) losses["opt"][loss.name] = loss.weight * ssi_losses.mean() losses_to_be_computed.remove("ssi") # remaining losses, we expect no more losses to be computed loss = self.losses["confidence"] conf_losses = loss( outputs["confidence"].log(), target_gt=inputs["depth"], target_pred=outputs["depth"], mask=inputs["depth_mask"].clone(), ) losses["opt"][loss.name + "_conf"] = loss.weight * conf_losses.mean() losses_to_be_computed.remove("confidence") assert ( not losses_to_be_computed ), f"Losses {losses_to_be_computed} not computed, revise `compute_loss` method" return losses @torch.no_grad() @torch.autocast(device_type="cuda", enabled=True, dtype=torch.float16) def infer( self, rgb: torch.Tensor, camera: torch.Tensor | Camera | None = None, normalize=True, ): ratio_bounds = self.shape_constraints["ratio_bounds"] pixels_bounds = [ self.shape_constraints["pixels_min"], self.shape_constraints["pixels_max"], ] if hasattr(self, "resolution_level"): assert ( self.resolution_level >= 0 and self.resolution_level < 10 ), "resolution_level should be in [0, 10)" pixels_range = pixels_bounds[1] - pixels_bounds[0] interval = pixels_range / 10 new_lowbound = self.resolution_level * interval + pixels_bounds[0] new_upbound = (self.resolution_level + 1) * interval + pixels_bounds[0] pixels_bounds = (new_lowbound, new_upbound) else: warnings.warn("!! self.resolution_level not set, using default bounds !!") # houskeeping on cpu/cuda and batchify if rgb.ndim == 3: rgb = rgb.unsqueeze(0) if camera is not None: if isinstance(camera, torch.Tensor): assert ( camera.shape[-1] == 3 and camera.shape[-2] == 3 ), "camera tensor should be of shape (..., 3, 3): assume pinhole" camera = Pinhole(K=camera) camera = BatchCamera.from_camera(camera) camera = camera.to(self.device) B, _, H, W = rgb.shape rgb = rgb.to(self.device) if camera is not None: camera = camera.to(self.device) # preprocess paddings, (padded_H, padded_W) = get_paddings((H, W), ratio_bounds) (pad_left, pad_right, pad_top, pad_bottom) = paddings resize_factor, (new_H, new_W) = get_resize_factor( (padded_H, padded_W), pixels_bounds ) # -> rgb preprocess (input std-ized and resized) if normalize: rgb = TF.normalize( rgb.float() / 255.0, mean=IMAGENET_DATASET_MEAN, std=IMAGENET_DATASET_STD, ) rgb = F.pad(rgb, (pad_left, pad_right, pad_top, pad_bottom), value=0.0) rgb = F.interpolate( rgb, size=(new_H, new_W), mode="bilinear", align_corners=False ) # -> camera preprocess if camera is not None: camera = camera.crop( left=-pad_left, top=-pad_top, right=-pad_right, bottom=-pad_bottom ) camera = camera.resize(resize_factor) # run model _, model_outputs = self.encode_decode( inputs={"image": rgb, "camera": camera}, image_metas=[] ) # collect outputs out = {} out["confidence"] = _postprocess( model_outputs["confidence"], (padded_H, padded_W), paddings=paddings, interpolation_mode=self.interpolation_mode, ) points = _postprocess( model_outputs["points"], (padded_H, padded_W), paddings=paddings, interpolation_mode=self.interpolation_mode, ) rays = _postprocess( model_outputs["rays"], (padded_H, padded_W), paddings=paddings, interpolation_mode=self.interpolation_mode, ) out["intrinsics"] = _postprocess_intrinsics( model_outputs["intrinsics"], [resize_factor] * B, [paddings] * B ) out["radius"] = points.norm(dim=1, keepdim=True) out["depth"] = points[:, -1:] out["points"] = points out["rays"] = rays / torch.norm(rays, dim=1, keepdim=True).clip(min=1e-5) out["depth_features"] = model_outputs["depth_features"] return out def encode_decode(self, inputs, image_metas=[]): B, _, H, W = inputs["image"].shape # shortcut eval should avoid errors if len(image_metas) and "paddings" in image_metas[0]: inputs["paddings"] = torch.tensor( [image_meta["paddings"] for image_meta in image_metas], device=self.device, )[ ..., [0, 2, 1, 3] ] # lrtb inputs["depth_paddings"] = torch.tensor( [image_meta["depth_paddings"] for image_meta in image_metas], device=self.device, ) if ( self.training ): # at inference we do not have image paddings on top of depth ones (we have not "crop" on gt in ContextCrop) inputs["depth_paddings"] = inputs["depth_paddings"] + inputs["paddings"] if inputs.get("camera", None) is not None: inputs["rays"] = inputs["camera"].get_rays(shapes=(B, H, W)) features, tokens = self.pixel_encoder(inputs["image"]) inputs["features"] = [ self.stacking_fn(features[i:j]).contiguous() for i, j in self.slices_encoder_range ] inputs["tokens"] = [ self.stacking_fn(tokens[i:j]).contiguous() for i, j in self.slices_encoder_range ] outputs = self.pixel_decoder(inputs, image_metas) outputs["rays"] = rearrange(outputs["rays"], "b (h w) c -> b c h w", h=H, w=W) pts_3d = outputs["rays"] * outputs["radius"] outputs.update({"points": pts_3d, "depth": pts_3d[:, -1:]}) return inputs, outputs def load_pretrained(self, model_file): device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) dict_model = torch.load(model_file, map_location=device, weights_only=False) if "model" in dict_model: dict_model = dict_model["model"] dict_model = {k.replace("module.", ""): v for k, v in dict_model.items()} info = self.load_state_dict(dict_model, strict=False) if is_main_process(): print( f"Loaded from {model_file} for {self.__class__.__name__} results in:", info, ) def get_params(self, config): if hasattr(self.pixel_encoder, "get_params"): encoder_p, encoder_lr = self.pixel_encoder.get_params( config["model"]["pixel_encoder"]["lr"], config["training"]["wd"], config["training"]["ld"], ) else: encoder_p, encoder_lr = get_params( self.pixel_encoder, config["model"]["pixel_encoder"]["lr"], config["training"]["wd"], ) decoder_p, decoder_lr = get_params( self.pixel_decoder, config["training"]["lr"], config["training"]["wd"] ) return [*encoder_p, *decoder_p] @property def device(self): return next(self.parameters()).device def build(self, config): mod = importlib.import_module("unidepth.models.encoder") pixel_encoder_factory = getattr(mod, config["model"]["pixel_encoder"]["name"]) pixel_encoder_config = { **config["training"], **config["model"]["pixel_encoder"], **config["data"], } pixel_encoder = pixel_encoder_factory(pixel_encoder_config) config["model"]["pixel_encoder"]["patch_size"] = ( 14 if "dino" in config["model"]["pixel_encoder"]["name"] else 16 ) pixel_encoder_embed_dims = ( pixel_encoder.embed_dims if hasattr(pixel_encoder, "embed_dims") else [getattr(pixel_encoder, "embed_dim") * 2**i for i in range(4)] ) config["model"]["pixel_encoder"]["embed_dim"] = getattr( pixel_encoder, "embed_dim" ) config["model"]["pixel_encoder"]["embed_dims"] = pixel_encoder_embed_dims config["model"]["pixel_encoder"]["depths"] = pixel_encoder.depths config["model"]["pixel_encoder"]["cls_token_embed_dims"] = getattr( pixel_encoder, "cls_token_embed_dims", pixel_encoder_embed_dims ) pixel_decoder = Decoder(config) self.pixel_encoder = pixel_encoder self.pixel_decoder = pixel_decoder self.slices_encoder_range = list( zip([0, *self.pixel_encoder.depths[:-1]], self.pixel_encoder.depths) ) stacking_fn = config["model"]["pixel_encoder"]["stacking_fn"] assert ( stacking_fn in STACKING_FNS ), f"Stacking function {stacking_fn} not found in {STACKING_FNS.keys()}" self.stacking_fn = STACKING_FNS[stacking_fn] self.shape_constraints = config["data"]["augmentations"]["shape_constraints"] self.interpolation_mode = "bilinear" def build_losses(self, config): self.losses = {} for loss_name, loss_config in config["training"]["losses"].items(): mod = importlib.import_module("unidepth.ops.losses") loss_factory = getattr(mod, loss_config["name"]) self.losses[loss_name] = loss_factory.build(loss_config) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/unidepthv2/unidepthv2_old.py ================================================ import importlib import warnings from copy import deepcopy from math import ceil import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms.functional as TF from einops import rearrange from huggingface_hub import PyTorchModelHubMixin from unidepth.models.unidepthv2.decoder_old import Decoder from unidepth.utils.constants import (IMAGENET_DATASET_MEAN, IMAGENET_DATASET_STD) from unidepth.utils.distributed import is_main_process from unidepth.utils.geometric import (generate_rays, spherical_zbuffer_to_euclidean) from unidepth.utils.misc import (first_stack, last_stack, max_stack, mean_stack, softmax_stack) STACKING_FNS = { "max": max_stack, "mean": mean_stack, "first": first_stack, "last": last_stack, "softmax": softmax_stack, } RESOLUTION_LEVELS = 10 # inference helpers def _check_ratio(image_ratio, ratio_bounds): ratio_bounds = sorted(ratio_bounds) if ratio_bounds is not None and ( image_ratio < ratio_bounds[0] or image_ratio > ratio_bounds[1] ): warnings.warn( f"Input image ratio ({image_ratio:.3f}) is out of training " f"distribution: {ratio_bounds}. This may lead to unexpected results. " f"Consider resizing/padding the image to match the training distribution." ) def _check_resolution(shape_constraints, resolution_level): if resolution_level is None: warnings.warn( "Resolution level is not set. Using max resolution. " "You can tradeoff resolution for speed by setting a number in [0,10]. " "This can be achieved by setting model's `resolution_level` attribute." ) resolution_level = RESOLUTION_LEVELS pixel_bounds = sorted(shape_constraints["pixels_bounds_ori"]) pixel_range = pixel_bounds[-1] - pixel_bounds[0] clipped_resolution_level = min(max(resolution_level, 0), RESOLUTION_LEVELS) if clipped_resolution_level != resolution_level: warnings.warn( f"Resolution level {resolution_level} is out of bounds ([0,{RESOLUTION_LEVELS}]). " f"Clipping to {clipped_resolution_level}." ) shape_constraints["pixels_bounds"] = [ pixel_bounds[0] + ceil(pixel_range * clipped_resolution_level / RESOLUTION_LEVELS), pixel_bounds[0] + ceil(pixel_range * clipped_resolution_level / RESOLUTION_LEVELS), ] return shape_constraints def _get_closes_num_pixels(image_shape, pixels_bounds): h, w = image_shape num_pixels = h * w pixels_bounds = sorted(pixels_bounds) num_pixels = max(min(num_pixels, pixels_bounds[1]), pixels_bounds[0]) return num_pixels def _shapes(image_shape, shape_constraints): h, w = image_shape image_ratio = w / h _check_ratio(image_ratio, shape_constraints["ratio_bounds"]) num_pixels = _get_closes_num_pixels( (h / shape_constraints["patch_size"], w / shape_constraints["patch_size"]), shape_constraints["pixels_bounds"], ) h = ceil((num_pixels / image_ratio) ** 0.5 - 0.5) w = ceil(h * image_ratio - 0.5) ratio = h / image_shape[0] * shape_constraints["patch_size"] return ( h * shape_constraints["patch_size"], w * shape_constraints["patch_size"], ), ratio def _preprocess(rgbs, intrinsics, shapes, ratio): rgbs = F.interpolate(rgbs, size=shapes, mode="bilinear", antialias=True) if intrinsics is not None: intrinsics = intrinsics.clone() intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio return rgbs, intrinsics return rgbs, None def _postprocess(outs, ratio, original_shapes, mode="nearest-exact"): outs["depth"] = F.interpolate(outs["depth"], size=original_shapes, mode=mode) outs["confidence"] = F.interpolate( outs["confidence"], size=original_shapes, mode="bilinear", antialias=True ) outs["K"][:, 0, 0] = outs["K"][:, 0, 0] / ratio outs["K"][:, 1, 1] = outs["K"][:, 1, 1] / ratio outs["K"][:, 0, 2] = outs["K"][:, 0, 2] / ratio outs["K"][:, 1, 2] = outs["K"][:, 1, 2] / ratio return outs class UniDepthV2old( nn.Module, PyTorchModelHubMixin, library_name="UniDepth", repo_url="https://github.com/lpiccinelli-eth/UniDepth", tags=["monocular-metric-depth-estimation"], ): def __init__( self, config, **kwargs, ): super().__init__() self.build(config) def forward(self, inputs, image_metas): H, W = inputs["depth"].shape[-2:] if "K" in inputs: rays, angles = generate_rays(inputs["K"], (H, W)) inputs["rays"] = rays inputs["angles"] = angles features, tokens = self.pixel_encoder(inputs[f"image"]) cls_tokens = [x.contiguous() for x in tokens] features = [ self.stacking_fn(features[i:j]).contiguous() for i, j in self.slices_encoder_range ] tokens = [ self.stacking_fn(tokens[i:j]).contiguous() for i, j in self.slices_encoder_range ] global_tokens = [cls_tokens[i] for i in [-2, -1]] camera_tokens = [cls_tokens[i] for i in [-3, -2, -1]] + [tokens[-2]] inputs["features"] = features inputs["tokens"] = tokens inputs["global_tokens"] = global_tokens inputs["camera_tokens"] = camera_tokens outs = self.pixel_decoder(inputs, image_metas) angles = rearrange( generate_rays(outs["K"], (H, W), noisy=False)[-1], "b (h w) c -> b c h w", h=H, w=W, ) predictions = F.interpolate( outs["depth"], size=(H, W), mode="bilinear", align_corners=False, antialias=True, ) confidence = F.interpolate( outs["confidence"], size=(H, W), mode="bilinear", align_corners=False, antialias=True, ) predictions_3d = torch.cat((angles, predictions), dim=1) predictions_3d = spherical_zbuffer_to_euclidean( predictions_3d.permute(0, 2, 3, 1) ).permute(0, 3, 1, 2) outputs = { "K": outs["K"], "depth": predictions, "confidence": confidence, "points": predictions_3d, "depth_features": outs["depth_features"], } return outputs @torch.no_grad() def infer(self, rgbs: torch.Tensor, intrinsics=None): shape_constraints = self.shape_constraints if rgbs.ndim == 3: rgbs = rgbs.unsqueeze(0) if intrinsics is not None and intrinsics.ndim == 2: intrinsics = intrinsics.unsqueeze(0) B, _, H, W = rgbs.shape rgbs = rgbs.to(self.device) if intrinsics is not None: intrinsics = intrinsics.to(self.device) # process image and intrinsiscs (if any) to match network input (slow?) if rgbs.max() > 5 or rgbs.dtype == torch.uint8: rgbs = rgbs.to(torch.float32).div(255) if rgbs.min() >= 0.0 and rgbs.max() <= 1.0: rgbs = TF.normalize( rgbs, mean=IMAGENET_DATASET_MEAN, std=IMAGENET_DATASET_STD, ) # check resolution constraints: tradeoff resolution and speed shape_constraints = _check_resolution(shape_constraints, self.resolution_level) # get image shape (h, w), ratio = _shapes((H, W), shape_constraints) rgbs, gt_intrinsics = _preprocess( rgbs, intrinsics, (h, w), ratio, ) # run encoder features, tokens = self.pixel_encoder(rgbs) cls_tokens = [x.contiguous() for x in tokens] features = [ self.stacking_fn(features[i:j]).contiguous() for i, j in self.slices_encoder_range ] tokens = [ self.stacking_fn(tokens[i:j]).contiguous() for i, j in self.slices_encoder_range ] global_tokens = [cls_tokens[i] for i in [-2, -1]] camera_tokens = [cls_tokens[i] for i in [-3, -2, -1]] + [tokens[-2]] # get data fro decoder and adapt to given camera inputs = {} inputs["features"] = features inputs["tokens"] = tokens inputs["global_tokens"] = global_tokens inputs["camera_tokens"] = camera_tokens inputs["image"] = rgbs if gt_intrinsics is not None: rays, angles = generate_rays(gt_intrinsics, (h, w)) inputs["rays"] = rays inputs["angles"] = angles inputs["K"] = gt_intrinsics outs = self.pixel_decoder(inputs, {}) # undo the reshaping and get original image size (slow) outs = _postprocess(outs, ratio, (H, W), mode=self.interpolation_mode) pred_intrinsics = outs["K"] depth = outs["depth"] confidence = outs["confidence"] # final 3D points backprojection intrinsics = intrinsics if intrinsics is not None else pred_intrinsics angles = generate_rays(intrinsics, (H, W))[-1] angles = rearrange(angles, "b (h w) c -> b c h w", h=H, w=W) points_3d = torch.cat((angles, depth), dim=1) points_3d = spherical_zbuffer_to_euclidean( points_3d.permute(0, 2, 3, 1) ).permute(0, 3, 1, 2) outputs = { "intrinsics": pred_intrinsics, "points": points_3d, "depth": depth, "confidence": confidence, } return outputs def load_pretrained(self, model_file): device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) dict_model = torch.load(model_file, map_location=device) if "model" in dict_model: dict_model = dict_model["model"] dict_model = deepcopy( {k.replace("module.", ""): v for k, v in dict_model.items()} ) info = self.load_state_dict(dict_model, strict=False) if is_main_process(): print( f"Loaded from {model_file} for {self.__class__.__name__} results in:", info, ) @property def device(self): return next(self.parameters()).device def build(self, config): mod = importlib.import_module("unidepth.models.encoder") pixel_encoder_factory = getattr(mod, config["model"]["pixel_encoder"]["name"]) pixel_encoder_config = { **config["training"], **config["model"]["pixel_encoder"], **config["data"], } pixel_encoder = pixel_encoder_factory(pixel_encoder_config) config["model"]["pixel_encoder"]["patch_size"] = ( 14 if "dino" in config["model"]["pixel_encoder"]["name"] else 16 ) pixel_encoder_embed_dims = ( pixel_encoder.embed_dims if hasattr(pixel_encoder, "embed_dims") else [getattr(pixel_encoder, "embed_dim") * 2**i for i in range(4)] ) config["model"]["pixel_encoder"]["embed_dim"] = getattr( pixel_encoder, "embed_dim" ) config["model"]["pixel_encoder"]["embed_dims"] = pixel_encoder_embed_dims config["model"]["pixel_encoder"]["depths"] = pixel_encoder.depths pixel_decoder = Decoder(config) self.pixel_encoder = pixel_encoder self.pixel_decoder = pixel_decoder stacking_fn = config["model"]["pixel_encoder"]["stacking_fn"] assert ( stacking_fn in STACKING_FNS ), f"Stacking function {stacking_fn} not found in {STACKING_FNS.keys()}" self.stacking_fn = STACKING_FNS[stacking_fn] self.slices_encoder_range = list( zip([0, *pixel_encoder.depths[:-1]], pixel_encoder.depths) ) self.shape_constraints = config["data"]["shape_constraints"] self.shape_constraints["pixels_bounds_ori"] = self.shape_constraints.get( "pixels_bounds", [1400, 2400] ) self.interpolation_mode = "bilinear" self.eps = 1e-6 self.resolution_level = None def build_losses(self, config): self.losses = {} for loss_name, loss_config in config["training"]["losses"].items(): mod = importlib.import_module("unidepth.ops.losses") loss_factory = getattr(mod, loss_config["name"]) self.losses[loss_name] = loss_factory.build(loss_config) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/__init__.py ================================================ from .losses import (ARel, Confidence, Dummy, EdgeGuidedLocalSSI, LocalSSI, Regression, SelfDistill, SILog, TeacherDistill) from .scheduler import CosineScheduler, PlainCosineScheduler ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/__init__.py ================================================ from .functions import ExtractPatchesFunction from .modules import RandomPatchExtractor __all__ = ["ExtractPatchesFunction", "RandomPatchExtractor"] ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/compile.sh ================================================ #!/usr/bin/env bash if [ -z "$TORCH_CUDA_ARCH_LIST" ]; then export TORCH_CUDA_ARCH_LIST="7.5 8.0 8.6+PTX" fi python setup.py build install ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/functions/__init__.py ================================================ from .extract_patches import ExtractPatchesFunction __all__ = ["ExtractPatchesFunction"] ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/functions/extract_patches.py ================================================ import RandomPatchExtraction import torch from torch.autograd import Function class ExtractPatchesFunction(Function): @staticmethod def forward(ctx, input, centers, h, w): # Save variables for backward pass. inputs for shapes ctx.save_for_backward(input, centers) return RandomPatchExtraction.extract_patches_forward(input, centers, h, w) @staticmethod def backward(ctx, grad_output): input, centers = ctx.saved_tensors (grad_input,) = RandomPatchExtraction.extract_patches_backward( grad_output, centers, input.shape[2], input.shape[3] ) # breakpoint() # Return gradients with respect to inputs only return grad_input, None, None, None # Test if __name__ == "__main__": B, C, H, W = 1, 1, 10, 10 N = 2 h, w = 3, 3 input = torch.arange( B * C * H * W, device="cuda", dtype=torch.float32, requires_grad=True ).view(B, C, H, W) centers = torch.tensor([[[4, 4], [6, 6]]], device="cuda", dtype=torch.int32) output = ExtractPatchesFunction.apply(input, centers, h, w) output.mean().backward() ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/modules/__init__.py ================================================ from .patch_extractor import RandomPatchExtractor __all__ = ["RandomPatchExtractor"] ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/modules/patch_extractor.py ================================================ from __future__ import absolute_import, division, print_function import torch import torch.nn.functional as F from torch import nn from ..functions import ExtractPatchesFunction class RandomPatchExtractor(nn.Module): def __init__( self, ): super().__init__() def forward( self, tensor: torch.Tensor, centers: torch.Tensor, patch_size: tuple[int, int] ): device = tensor.device dtype = tensor.dtype patch_width, patch_height = patch_size pad_width = patch_width // 2 pad_height = patch_height // 2 dtype = tensor.dtype # Pad input to avoid out-of-bounds tensor_padded = F.pad( tensor, (pad_width, pad_width, pad_height, pad_height), mode="constant", value=0.0, ) # Adjust edge coordinates to account for padding centers_padded = centers + torch.tensor( [pad_height, pad_width], dtype=dtype, device=device ).reshape(1, 1, 2) output = ExtractPatchesFunction.apply( tensor_padded.float(), centers_padded.int(), patch_height, patch_width ) return output.to(dtype) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/setup.py ================================================ import glob import os import torch from setuptools import find_packages, setup from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension requirements = ["torch", "torchvision"] def get_extensions(): this_dir = os.path.dirname(os.path.abspath(__file__)) extensions_dir = os.path.join(this_dir, "src") main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) sources = main_file + source_cpu extension = CppExtension extra_compile_args = {"cxx": ["-O2"]} define_macros = [] if torch.cuda.is_available() and CUDA_HOME is not None: extension = CUDAExtension sources += source_cuda define_macros += [("WITH_CUDA", None)] extra_compile_args["nvcc"] = [ "-O2", ] else: raise NotImplementedError("Cuda is not available") sources = list(set([os.path.join(extensions_dir, s) for s in sources])) include_dirs = [extensions_dir] ext_modules = [ extension( "RandomPatchExtraction", sources, include_dirs=include_dirs, define_macros=define_macros, extra_compile_args=extra_compile_args, ) ] return ext_modules setup( name="RandomPatchExtraction", version="0.1", author="Luigi Piccinelli", ext_modules=get_extensions(), packages=find_packages( exclude=( "configs", "tests", ) ), cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, ) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/src/cpu/extract_patches_cpu.cpp ================================================ #include #include #include torch::Tensor extract_patches_cpu_forward( const torch::Tensor &input, const torch::Tensor ¢ers, int h, int w ) { AT_ERROR("Not implement on cpu"); } std::vector extract_patches_cpu_backward( const torch::Tensor &grad_patches, const torch::Tensor &coords, int H, int W ) { AT_ERROR("Not implement on cpu"); } ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/src/cpu/extract_patches_cpu.h ================================================ #pragma once #include #include torch::Tensor extract_patches_cpu_forward( const torch::Tensor &input, const torch::Tensor ¢ers, int h, int w ); std::vector extract_patches_cpu_backward( const torch::Tensor &grad_patches, const torch::Tensor &coords, int H, int W ); ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/src/cuda/extract_patches_cuda.h ================================================ #ifndef EXTRACT_PATCHES_CUDA_H #define EXTRACT_PATCHES_CUDA_H #include #include #include #include // Function prototypes for the CUDA functions torch::Tensor extract_patches_cuda_forward( const torch::Tensor &input, const torch::Tensor ¢ers, int h, int w ); std::vector extract_patches_cuda_backward( const torch::Tensor &grad_output, const torch::Tensor ¢ers, int H, int W ); #endif // EXTRACT_PATCHES_CUDA_H ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/src/cuda/extract_patches_kernel.cu ================================================ #include #include #include "cuda/extract_patches_kernel.cuh" #include "cuda/extract_patches_cuda.h" // Need to templetize these two to get fp16 working, but problems in compilation... torch::Tensor extract_patches_cuda_forward( const torch::Tensor &input, const torch::Tensor ¢ers, int h, int w ) { int B = input.size(0); int C = input.size(1); int H = input.size(2); int W = input.size(3); int N = centers.size(1); auto output = torch::zeros({B, C, N, h, w}, input.options()); const int threads = C; const dim3 blocks(B, N); extract_patches_cuda_forward_kernel<<>>( input.data_ptr(), output.data_ptr(), centers.data_ptr(), B, C, H, W, N, h, w); return {output}; } std::vector extract_patches_cuda_backward( const torch::Tensor &grad_output, const torch::Tensor ¢ers, int H, int W ) { int B = grad_output.size(0); int C = grad_output.size(1); int N = centers.size(1); int h = grad_output.size(3); int w = grad_output.size(4); auto grad_input = torch::zeros({B, C, H, W}, grad_output.options()); const int threads = C; const dim3 blocks(B, N); extract_patches_cuda_backward_kernel<<>>( grad_output.data_ptr(), grad_input.data_ptr(), centers.data_ptr(), B, C, H, W, N, h, w); return {grad_input}; } template __global__ void extract_patches_cuda_forward_kernel( const T* __restrict__ input, T* __restrict__ output, const int* __restrict__ centers, int B, int C, int H, int W, int N, int h, int w) { // Calculate thread indices int batch_idx = blockIdx.x; int patch_idx = blockIdx.y; int channel_idx = threadIdx.x; // Extract center coordinates int center_y = centers[(batch_idx * N + patch_idx) * 2]; int center_x = centers[(batch_idx * N + patch_idx) * 2 + 1]; // Calculate half patch size int half_h = h / 2; int half_w = w / 2; // Extract patch for (int i = 0; i < h; ++i) { for (int j = 0; j < w; ++j) { int y = center_y - half_h + i; int x = center_x - half_w + j; output[batch_idx * C * N * h * w + patch_idx * C * h * w + channel_idx * h * w + i * w + j] = input[batch_idx * C * H * W + channel_idx * H * W + y * W + x]; } } } template __global__ void extract_patches_cuda_forward_kernel( const float* __restrict__ input, float* __restrict__ output, const int* __restrict__ centers, int B, int C, int H, int W, int N, int h, int w); template __global__ void extract_patches_cuda_forward_kernel<__half>( const __half* __restrict__ input, __half* __restrict__ output, const int* __restrict__ centers, int B, int C, int H, int W, int N, int h, int w); template __global__ void extract_patches_cuda_backward_kernel( const T* __restrict__ grad_output, T* __restrict__ grad_input, const int* __restrict__ centers, int B, int C, int H, int W, int N, int h, int w) { // Calculate thread indices int batch_idx = blockIdx.x; int patch_idx = blockIdx.y; int channel_idx = threadIdx.x; // Extract center coordinates int center_y = centers[(batch_idx * N + patch_idx) * 2]; int center_x = centers[(batch_idx * N + patch_idx) * 2 + 1]; // Calculate half patch size int half_h = h / 2; int half_w = w / 2; // Compute gradients with respect to input tensor using chain rule for (int i = 0; i < h; ++i) { for (int j = 0; j < w; ++j) { int y = center_y - half_h + i; int x = center_x - half_w + j; atomicAdd( &grad_input[batch_idx * C * H * W + channel_idx * H * W + y * W + x], grad_output[batch_idx * C * N * h * w + patch_idx * C * h * w + channel_idx * h * w + i * w + j] ); } } } template __global__ void extract_patches_cuda_backward_kernel( const float* __restrict__ grad_output, float* __restrict__ grad_input, const int* __restrict__ centers, int B, int C, int H, int W, int N, int h, int w); template __global__ void extract_patches_cuda_backward_kernel<__half>( const __half* __restrict__ grad_output, __half* __restrict__ grad_input, const int* __restrict__ centers, int B, int C, int H, int W, int N, int h, int w); ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/src/cuda/extract_patches_kernel.cuh ================================================ #ifndef EXTRACT_PATCHES_KERNEL_CUH #define EXTRACT_PATCHES_KERNEL_CUH #include #include #include #include #include // should contain __half // Declare the forward CUDA kernel function template __global__ void extract_patches_cuda_forward_kernel( const T* __restrict__ input, T* __restrict__ output, const int* __restrict__ centers, int B, int C, int H, int W, int N, int h, int w); // Declare the backward CUDA kernel function template __global__ void extract_patches_cuda_backward_kernel( const T* __restrict__ grad_output, T* __restrict__ grad_input, const int* __restrict__ centers, int B, int C, int H, int W, int N, int h, int w); #endif // EXTRACT_PATCHES_KERNEL_CUH ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/src/extract_patches.cpp ================================================ #include "extract_patches.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("extract_patches_forward", &extract_patches_forward, "Extract patches forward (CUDA)"); m.def("extract_patches_backward", &extract_patches_backward, "Extract patches backward (CUDA)"); } ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/src/extract_patches.h ================================================ #pragma once #include "cpu/extract_patches_cpu.h" #ifdef WITH_CUDA #include "cuda/extract_patches_cuda.h" #endif #include #include #include torch::Tensor extract_patches_forward( const torch::Tensor &images, const torch::Tensor &coords, int patch_height, int patch_width) { if (images.type().is_cuda()) { #ifdef WITH_CUDA return extract_patches_cuda_forward(images, coords, patch_height, patch_width); #else AT_ERROR("Not compiled with GPU support"); #endif } AT_ERROR("Not implemented on the CPU"); } std::vector extract_patches_backward( const torch::Tensor &grad_patches, const torch::Tensor &coords, int H, int W) { if (grad_patches.type().is_cuda()) { #ifdef WITH_CUDA return extract_patches_cuda_backward(grad_patches, coords, H, W); #else AT_ERROR("Not compiled with GPU support"); #endif } AT_ERROR("Not implemented on the CPU"); } ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/test.py ================================================ import RandomPatchExtraction import torch def extract_patches(input, centers, patch_size): h, w = patch_size output = RandomPatchExtraction.extract_patches_forward(input, centers, h, w) breakpoint() return output # Example usage if __name__ == "__main__": B, C, H, W = 1, 1, 10, 10 N = 2 h, w = 3, 3 input = torch.arange( B * C * H * W, device="cuda", dtype=torch.float32, requires_grad=True ).view(B, C, H, W) centers = torch.tensor([[[4, 4], [6, 6]]], device="cuda", dtype=torch.int32) patches = extract_patches(input, centers, (h, w)) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/__init__.py ================================================ from .functions.knn import knn_gather, knn_points __all__ = [ "knn_points", "knn_gather", ] ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/compile.sh ================================================ #!/usr/bin/env bash export TORCH_CUDA_ARCH_LIST="6.1 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" # export FORCE_CUDA=1 #if you do not actually have cuda, workaround python setup.py build install ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/functions/__init__.py ================================================ from .knn import knn_gather, knn_points __all__ = [ "knn_points", "knn_gather", ] ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/functions/knn.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe from collections import namedtuple from typing import Union import torch from KNN import knn_points_backward, knn_points_idx from torch.autograd import Function from torch.autograd.function import once_differentiable _KNN = namedtuple("KNN", "dists idx knn") class _knn_points(Function): """ Torch autograd Function wrapper for KNN C++/CUDA implementations. """ @staticmethod # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. def forward( ctx, p1, p2, lengths1, lengths2, K, version, norm: int = 2, return_sorted: bool = True, ): """ K-Nearest neighbors on point clouds. Args: p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each containing up to P1 points of dimension D. p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each containing up to P2 points of dimension D. lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the length of each pointcloud in p1. Or None to indicate that every cloud has length P1. lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the length of each pointcloud in p2. Or None to indicate that every cloud has length P2. K: Integer giving the number of nearest neighbors to return. version: Which KNN implementation to use in the backend. If version=-1, the correct implementation is selected based on the shapes of the inputs. norm: (int) indicating the norm. Only supports 1 (for L1) and 2 (for L2). return_sorted: (bool) whether to return the nearest neighbors sorted in ascending order of distance. Returns: p1_dists: Tensor of shape (N, P1, K) giving the squared distances to the nearest neighbors. This is padded with zeros both where a cloud in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points. p1_idx: LongTensor of shape (N, P1, K) giving the indices of the K nearest neighbors from points in p1 to points in p2. Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points. """ if not ((norm == 1) or (norm == 2)): raise ValueError("Support for 1 or 2 norm.") idx, dists = knn_points_idx(p1, p2, lengths1, lengths2, norm, K, version) # sort KNN in ascending order if K > 1 if K > 1 and return_sorted: if lengths2.min() < K: P1 = p1.shape[1] mask = lengths2[:, None] <= torch.arange(K, device=dists.device)[None] # mask has shape [N, K], true where dists irrelevant mask = mask[:, None].expand(-1, P1, -1) # mask has shape [N, P1, K], true where dists irrelevant dists[mask] = float("inf") dists, sort_idx = dists.sort(dim=2) dists[mask] = 0 else: dists, sort_idx = dists.sort(dim=2) idx = idx.gather(2, sort_idx) ctx.save_for_backward(p1, p2, lengths1, lengths2, idx) ctx.mark_non_differentiable(idx) ctx.norm = norm return dists, idx @staticmethod @once_differentiable def backward(ctx, grad_dists, grad_idx): p1, p2, lengths1, lengths2, idx = ctx.saved_tensors norm = ctx.norm # TODO(gkioxari) Change cast to floats once we add support for doubles. if not (grad_dists.dtype == torch.float32): grad_dists = grad_dists.float() if not (p1.dtype == torch.float32): p1 = p1.float() if not (p2.dtype == torch.float32): p2 = p2.float() grad_p1, grad_p2 = knn_points_backward( p1, p2, lengths1, lengths2, idx, norm, grad_dists ) return grad_p1, grad_p2, None, None, None, None, None, None def knn_points( p1: torch.Tensor, p2: torch.Tensor, lengths1: Union[torch.Tensor, None] = None, lengths2: Union[torch.Tensor, None] = None, norm: int = 2, K: int = 1, version: int = -1, return_nn: bool = False, return_sorted: bool = True, ) -> _KNN: """ K-Nearest neighbors on point clouds. Args: p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each containing up to P1 points of dimension D. p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each containing up to P2 points of dimension D. lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the length of each pointcloud in p1. Or None to indicate that every cloud has length P1. lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the length of each pointcloud in p2. Or None to indicate that every cloud has length P2. norm: Integer indicating the norm of the distance. Supports only 1 for L1, 2 for L2. K: Integer giving the number of nearest neighbors to return. version: Which KNN implementation to use in the backend. If version=-1, the correct implementation is selected based on the shapes of the inputs. return_nn: If set to True returns the K nearest neighbors in p2 for each point in p1. return_sorted: (bool) whether to return the nearest neighbors sorted in ascending order of distance. Returns: dists: Tensor of shape (N, P1, K) giving the squared distances to the nearest neighbors. This is padded with zeros both where a cloud in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points. idx: LongTensor of shape (N, P1, K) giving the indices of the K nearest neighbors from points in p1 to points in p2. Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points. nn: Tensor of shape (N, P1, K, D) giving the K nearest neighbors in p2 for each point in p1. Concretely, `p2_nn[n, i, k]` gives the k-th nearest neighbor for `p1[n, i]`. Returned if `return_nn` is True. The nearest neighbors are collected using `knn_gather` .. code-block:: p2_nn = knn_gather(p2, p1_idx, lengths2) which is a helper function that allows indexing any tensor of shape (N, P2, U) with the indices `p1_idx` returned by `knn_points`. The output is a tensor of shape (N, P1, K, U). """ if p1.shape[0] != p2.shape[0]: raise ValueError("pts1 and pts2 must have the same batch dimension.") if p1.shape[2] != p2.shape[2]: raise ValueError("pts1 and pts2 must have the same point dimension.") p1 = p1.contiguous() p2 = p2.contiguous() P1 = p1.shape[1] P2 = p2.shape[1] if lengths1 is None: lengths1 = torch.full((p1.shape[0],), P1, dtype=torch.int64, device=p1.device) if lengths2 is None: lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device) p1_dists, p1_idx = _knn_points.apply( p1, p2, lengths1, lengths2, K, version, norm, return_sorted ) p2_nn = None if return_nn: p2_nn = knn_gather(p2, p1_idx, lengths2) return _KNN(dists=p1_dists, idx=p1_idx, knn=p2_nn if return_nn else None) def knn_gather( x: torch.Tensor, idx: torch.Tensor, lengths: Union[torch.Tensor, None] = None ): """ A helper function for knn that allows indexing a tensor x with the indices `idx` returned by `knn_points`. For example, if `dists, idx = knn_points(p, x, lengths_p, lengths, K)` where p is a tensor of shape (N, L, D) and x a tensor of shape (N, M, D), then one can compute the K nearest neighbors of p with `p_nn = knn_gather(x, idx, lengths)`. It can also be applied for any tensor x of shape (N, M, U) where U != D. Args: x: Tensor of shape (N, M, U) containing U-dimensional features to be gathered. idx: LongTensor of shape (N, L, K) giving the indices returned by `knn_points`. lengths: LongTensor of shape (N,) of values in the range [0, M], giving the length of each example in the batch in x. Or None to indicate that every example has length M. Returns: x_out: Tensor of shape (N, L, K, U) resulting from gathering the elements of x with idx, s.t. `x_out[n, l, k] = x[n, idx[n, l, k]]`. If `k > lengths[n]` then `x_out[n, l, k]` is filled with 0.0. """ N, M, U = x.shape _N, L, K = idx.shape if N != _N: raise ValueError("x and idx must have same batch dimension.") if lengths is None: lengths = torch.full((x.shape[0],), M, dtype=torch.int64, device=x.device) idx_expanded = idx[:, :, :, None].expand(-1, -1, -1, U) # idx_expanded has shape [N, L, K, U] x_out = x[:, :, None].expand(-1, -1, K, -1).gather(1, idx_expanded) # p2_nn has shape [N, L, K, U] needs_mask = lengths.min() < K if needs_mask: # mask has shape [N, K], true where idx is irrelevant because # there is less number of points in p2 than K mask = lengths[:, None] <= torch.arange(K, device=x.device)[None] # expand mask to shape [N, L, K, U] mask = mask[:, None].expand(-1, L, -1) mask = mask[:, :, :, None].expand(-1, -1, -1, U) x_out[mask] = 0.0 return x_out ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/setup.py ================================================ import glob import os import torch from setuptools import find_packages, setup from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension requirements = ["torch", "torchvision"] def get_extensions(): this_dir = os.path.dirname(os.path.abspath(__file__)) extensions_dir = os.path.join(this_dir, "src") main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) source_cpu = glob.glob(os.path.join(extensions_dir, "*.cpp")) source_cuda = glob.glob(os.path.join(extensions_dir, "*.cu")) sources = main_file + source_cpu extension = CppExtension extra_compile_args = {"cxx": ["-O3"]} define_macros = [] if torch.cuda.is_available() and CUDA_HOME is not None: extension = CUDAExtension sources += source_cuda define_macros += [("WITH_CUDA", None)] extra_compile_args["nvcc"] = [ "-O3", ] else: raise NotImplementedError("Cuda is not available") sources = list(set([os.path.join(extensions_dir, s) for s in sources])) include_dirs = [extensions_dir] ext_modules = [ extension( "KNN", sources, include_dirs=include_dirs, define_macros=define_macros, extra_compile_args=extra_compile_args, ) ] return ext_modules setup( name="KNN", version="0.1", author="Luigi Piccinelli", ext_modules=get_extensions(), packages=find_packages( exclude=( "configs", "tests", ) ), cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, ) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/src/knn.cu ================================================ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include #include "utils/dispatch.cuh" #include "utils/mink.cuh" // A chunk of work is blocksize-many points of P1. // The number of potential chunks to do is N*(1+(P1-1)/blocksize) // call (1+(P1-1)/blocksize) chunks_per_cloud // These chunks are divided among the gridSize-many blocks. // In block b, we work on chunks b, b+gridSize, b+2*gridSize etc . // In chunk i, we work on cloud i/chunks_per_cloud on points starting from // blocksize*(i%chunks_per_cloud). template __global__ void KNearestNeighborKernelV0( const scalar_t* __restrict__ points1, const scalar_t* __restrict__ points2, const int64_t* __restrict__ lengths1, const int64_t* __restrict__ lengths2, scalar_t* __restrict__ dists, int64_t* __restrict__ idxs, const size_t N, const size_t P1, const size_t P2, const size_t D, const size_t K, const size_t norm) { // Store both dists and indices for knn in global memory. const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); const int64_t chunks_to_do = N * chunks_per_cloud; for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { const int64_t n = chunk / chunks_per_cloud; const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); int64_t p1 = start_point + threadIdx.x; if (p1 >= lengths1[n]) continue; int offset = n * P1 * K + p1 * K; int64_t length2 = lengths2[n]; MinK mink(dists + offset, idxs + offset, K); for (int p2 = 0; p2 < length2; ++p2) { // Find the distance between points1[n, p1] and points[n, p2] scalar_t dist = 0; for (int d = 0; d < D; ++d) { scalar_t coord1 = points1[n * P1 * D + p1 * D + d]; scalar_t coord2 = points2[n * P2 * D + p2 * D + d]; scalar_t diff = coord1 - coord2; scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); dist += norm_diff; } mink.add(dist, p2); } } } template __global__ void KNearestNeighborKernelV1( const scalar_t* __restrict__ points1, const scalar_t* __restrict__ points2, const int64_t* __restrict__ lengths1, const int64_t* __restrict__ lengths2, scalar_t* __restrict__ dists, int64_t* __restrict__ idxs, const size_t N, const size_t P1, const size_t P2, const size_t K, const size_t norm) { // Same idea as the previous version, but hoist D into a template argument // so we can cache the current point in a thread-local array. We still store // the current best K dists and indices in global memory, so this should work // for very large K and fairly large D. scalar_t cur_point[D]; const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); const int64_t chunks_to_do = N * chunks_per_cloud; for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { const int64_t n = chunk / chunks_per_cloud; const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); int64_t p1 = start_point + threadIdx.x; if (p1 >= lengths1[n]) continue; for (int d = 0; d < D; ++d) { cur_point[d] = points1[n * P1 * D + p1 * D + d]; } int offset = n * P1 * K + p1 * K; int64_t length2 = lengths2[n]; MinK mink(dists + offset, idxs + offset, K); for (int p2 = 0; p2 < length2; ++p2) { // Find the distance between cur_point and points[n, p2] scalar_t dist = 0; for (int d = 0; d < D; ++d) { scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d]; scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); dist += norm_diff; } mink.add(dist, p2); } } } // This is a shim functor to allow us to dispatch using DispatchKernel1D template struct KNearestNeighborV1Functor { static void run( size_t blocks, size_t threads, const scalar_t* __restrict__ points1, const scalar_t* __restrict__ points2, const int64_t* __restrict__ lengths1, const int64_t* __restrict__ lengths2, scalar_t* __restrict__ dists, int64_t* __restrict__ idxs, const size_t N, const size_t P1, const size_t P2, const size_t K, const size_t norm) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); KNearestNeighborKernelV1<<>>( points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K, norm); } }; template __global__ void KNearestNeighborKernelV2( const scalar_t* __restrict__ points1, const scalar_t* __restrict__ points2, const int64_t* __restrict__ lengths1, const int64_t* __restrict__ lengths2, scalar_t* __restrict__ dists, int64_t* __restrict__ idxs, const int64_t N, const int64_t P1, const int64_t P2, const size_t norm) { // Same general implementation as V2, but also hoist K into a template arg. scalar_t cur_point[D]; scalar_t min_dists[K]; int min_idxs[K]; const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); const int64_t chunks_to_do = N * chunks_per_cloud; for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { const int64_t n = chunk / chunks_per_cloud; const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); int64_t p1 = start_point + threadIdx.x; if (p1 >= lengths1[n]) continue; for (int d = 0; d < D; ++d) { cur_point[d] = points1[n * P1 * D + p1 * D + d]; } int64_t length2 = lengths2[n]; MinK mink(min_dists, min_idxs, K); for (int p2 = 0; p2 < length2; ++p2) { scalar_t dist = 0; for (int d = 0; d < D; ++d) { int offset = n * P2 * D + p2 * D + d; scalar_t diff = cur_point[d] - points2[offset]; scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); dist += norm_diff; } mink.add(dist, p2); } for (int k = 0; k < mink.size(); ++k) { idxs[n * P1 * K + p1 * K + k] = min_idxs[k]; dists[n * P1 * K + p1 * K + k] = min_dists[k]; } } } // This is a shim so we can dispatch using DispatchKernel2D template struct KNearestNeighborKernelV2Functor { static void run( size_t blocks, size_t threads, const scalar_t* __restrict__ points1, const scalar_t* __restrict__ points2, const int64_t* __restrict__ lengths1, const int64_t* __restrict__ lengths2, scalar_t* __restrict__ dists, int64_t* __restrict__ idxs, const int64_t N, const int64_t P1, const int64_t P2, const size_t norm) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); KNearestNeighborKernelV2<<>>( points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm); } }; template __global__ void KNearestNeighborKernelV3( const scalar_t* __restrict__ points1, const scalar_t* __restrict__ points2, const int64_t* __restrict__ lengths1, const int64_t* __restrict__ lengths2, scalar_t* __restrict__ dists, int64_t* __restrict__ idxs, const size_t N, const size_t P1, const size_t P2, const size_t norm) { // Same idea as V2, but use register indexing for thread-local arrays. // Enabling sorting for this version leads to huge slowdowns; I suspect // that it forces min_dists into local memory rather than registers. // As a result this version is always unsorted. scalar_t cur_point[D]; scalar_t min_dists[K]; int min_idxs[K]; const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); const int64_t chunks_to_do = N * chunks_per_cloud; for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { const int64_t n = chunk / chunks_per_cloud; const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); int64_t p1 = start_point + threadIdx.x; if (p1 >= lengths1[n]) continue; for (int d = 0; d < D; ++d) { cur_point[d] = points1[n * P1 * D + p1 * D + d]; } int64_t length2 = lengths2[n]; RegisterMinK mink(min_dists, min_idxs); for (int p2 = 0; p2 < length2; ++p2) { scalar_t dist = 0; for (int d = 0; d < D; ++d) { int offset = n * P2 * D + p2 * D + d; scalar_t diff = cur_point[d] - points2[offset]; scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); dist += norm_diff; } mink.add(dist, p2); } for (int k = 0; k < mink.size(); ++k) { idxs[n * P1 * K + p1 * K + k] = min_idxs[k]; dists[n * P1 * K + p1 * K + k] = min_dists[k]; } } } // This is a shim so we can dispatch using DispatchKernel2D template struct KNearestNeighborKernelV3Functor { static void run( size_t blocks, size_t threads, const scalar_t* __restrict__ points1, const scalar_t* __restrict__ points2, const int64_t* __restrict__ lengths1, const int64_t* __restrict__ lengths2, scalar_t* __restrict__ dists, int64_t* __restrict__ idxs, const size_t N, const size_t P1, const size_t P2, const size_t norm) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); KNearestNeighborKernelV3<<>>( points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm); } }; constexpr int V1_MIN_D = 1; constexpr int V1_MAX_D = 32; constexpr int V2_MIN_D = 1; constexpr int V2_MAX_D = 8; constexpr int V2_MIN_K = 1; constexpr int V2_MAX_K = 32; constexpr int V3_MIN_D = 1; constexpr int V3_MAX_D = 8; constexpr int V3_MIN_K = 1; constexpr int V3_MAX_K = 4; bool InBounds(const int64_t min, const int64_t x, const int64_t max) { return min <= x && x <= max; } bool KnnCheckVersion(int version, const int64_t D, const int64_t K) { if (version == 0) { return true; } else if (version == 1) { return InBounds(V1_MIN_D, D, V1_MAX_D); } else if (version == 2) { return InBounds(V2_MIN_D, D, V2_MAX_D) && InBounds(V2_MIN_K, K, V2_MAX_K); } else if (version == 3) { return InBounds(V3_MIN_D, D, V3_MAX_D) && InBounds(V3_MIN_K, K, V3_MAX_K); } return false; } int ChooseVersion(const int64_t D, const int64_t K) { for (int version = 3; version >= 1; version--) { if (KnnCheckVersion(version, D, K)) { return version; } } return 0; } std::tuple KNearestNeighborIdxCuda( const at::Tensor& p1, const at::Tensor& p2, const at::Tensor& lengths1, const at::Tensor& lengths2, const int norm, const int K, int version) { // Check inputs are on the same device at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2}, lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4}; at::CheckedFrom c = "KNearestNeighborIdxCuda"; at::checkAllSameGPU(c, {p1_t, p2_t, lengths1_t, lengths2_t}); at::checkAllSameType(c, {p1_t, p2_t}); // Set the device for the kernel launch based on the device of the input at::cuda::CUDAGuard device_guard(p1.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const auto N = p1.size(0); const auto P1 = p1.size(1); const auto P2 = p2.size(1); const auto D = p2.size(2); const int64_t K_64 = K; TORCH_CHECK((norm == 1) || (norm == 2), "Norm must be 1 or 2."); TORCH_CHECK(p1.size(2) == D, "Point sets must have the same last dimension"); auto long_dtype = lengths1.options().dtype(at::kLong); auto idxs = at::zeros({N, P1, K}, long_dtype); auto dists = at::zeros({N, P1, K}, p1.options()); if (idxs.numel() == 0) { AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(idxs, dists); } if (version < 0) { version = ChooseVersion(D, K); } else if (!KnnCheckVersion(version, D, K)) { int new_version = ChooseVersion(D, K); std::cout << "WARNING: Requested KNN version " << version << " is not compatible with D = " << D << "; K = " << K << ". Falling back to version = " << new_version << std::endl; version = new_version; } // At this point we should have a valid version no matter what data the user // gave us. But we can check once more to be sure; however this time // assert fail since failing at this point means we have a bug in our version // selection or checking code. AT_ASSERTM(KnnCheckVersion(version, D, K), "Invalid version"); const size_t threads = 256; const size_t blocks = 256; if (version == 0) { AT_DISPATCH_FLOATING_TYPES( p1.scalar_type(), "knn_kernel_cuda", ([&] { KNearestNeighborKernelV0<<>>( p1.contiguous().data_ptr(), p2.contiguous().data_ptr(), lengths1.contiguous().data_ptr(), lengths2.contiguous().data_ptr(), dists.data_ptr(), idxs.data_ptr(), N, P1, P2, D, K, norm); })); } else if (version == 1) { AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { DispatchKernel1D< KNearestNeighborV1Functor, scalar_t, V1_MIN_D, V1_MAX_D>( D, blocks, threads, p1.contiguous().data_ptr(), p2.contiguous().data_ptr(), lengths1.contiguous().data_ptr(), lengths2.contiguous().data_ptr(), dists.data_ptr(), idxs.data_ptr(), N, P1, P2, K, norm); })); } else if (version == 2) { AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { DispatchKernel2D< KNearestNeighborKernelV2Functor, scalar_t, V2_MIN_D, V2_MAX_D, V2_MIN_K, V2_MAX_K>( D, K_64, blocks, threads, p1.contiguous().data_ptr(), p2.contiguous().data_ptr(), lengths1.contiguous().data_ptr(), lengths2.contiguous().data_ptr(), dists.data_ptr(), idxs.data_ptr(), N, P1, P2, norm); })); } else if (version == 3) { AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { DispatchKernel2D< KNearestNeighborKernelV3Functor, scalar_t, V3_MIN_D, V3_MAX_D, V3_MIN_K, V3_MAX_K>( D, K_64, blocks, threads, p1.contiguous().data_ptr(), p2.contiguous().data_ptr(), lengths1.contiguous().data_ptr(), lengths2.contiguous().data_ptr(), dists.data_ptr(), idxs.data_ptr(), N, P1, P2, norm); })); } AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(idxs, dists); } // ------------------------------------------------------------- // // Backward Operators // // ------------------------------------------------------------- // // TODO(gkioxari) support all data types once AtomicAdd supports doubles. // Currently, support is for floats only. __global__ void KNearestNeighborBackwardKernel( const float* __restrict__ p1, // (N, P1, D) const float* __restrict__ p2, // (N, P2, D) const int64_t* __restrict__ lengths1, // (N,) const int64_t* __restrict__ lengths2, // (N,) const int64_t* __restrict__ idxs, // (N, P1, K) const float* __restrict__ grad_dists, // (N, P1, K) float* __restrict__ grad_p1, // (N, P1, D) float* __restrict__ grad_p2, // (N, P2, D) const size_t N, const size_t P1, const size_t P2, const size_t K, const size_t D, const size_t norm) { const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = gridDim.x * blockDim.x; for (size_t i = tid; i < N * P1 * K * D; i += stride) { const size_t n = i / (P1 * K * D); // batch index size_t rem = i % (P1 * K * D); const size_t p1_idx = rem / (K * D); // index of point in p1 rem = rem % (K * D); const size_t k = rem / D; // k-th nearest neighbor const size_t d = rem % D; // d-th dimension in the feature vector const size_t num1 = lengths1[n]; // number of valid points in p1 in batch const size_t num2 = lengths2[n]; // number of valid points in p2 in batch if ((p1_idx < num1) && (k < num2)) { const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k]; // index of point in p2 corresponding to the k-th nearest neighbor const int64_t p2_idx = idxs[n * P1 * K + p1_idx * K + k]; // If the index is the pad value of -1 then ignore it if (p2_idx == -1) { continue; } float diff = 0.0; if (norm == 1) { float sign = (p1[n * P1 * D + p1_idx * D + d] > p2[n * P2 * D + p2_idx * D + d]) ? 1.0 : -1.0; diff = grad_dist * sign; } else { // norm is 2 diff = 2.0 * grad_dist * (p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]); } atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff); atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff); } } } std::tuple KNearestNeighborBackwardCuda( const at::Tensor& p1, const at::Tensor& p2, const at::Tensor& lengths1, const at::Tensor& lengths2, const at::Tensor& idxs, int norm, const at::Tensor& grad_dists) { // Check inputs are on the same device at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2}, lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4}, idxs_t{idxs, "idxs", 5}, grad_dists_t{grad_dists, "grad_dists", 6}; at::CheckedFrom c = "KNearestNeighborBackwardCuda"; at::checkAllSameGPU( c, {p1_t, p2_t, lengths1_t, lengths2_t, idxs_t, grad_dists_t}); at::checkAllSameType(c, {p1_t, p2_t, grad_dists_t}); // This is nondeterministic because atomicAdd at::globalContext().alertNotDeterministic("KNearestNeighborBackwardCuda"); // Set the device for the kernel launch based on the device of the input at::cuda::CUDAGuard device_guard(p1.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const auto N = p1.size(0); const auto P1 = p1.size(1); const auto P2 = p2.size(1); const auto D = p2.size(2); const auto K = idxs.size(2); TORCH_CHECK(p1.size(2) == D, "Point sets must have the same last dimension"); TORCH_CHECK(idxs.size(0) == N, "KNN idxs must have the same batch dimension"); TORCH_CHECK( idxs.size(1) == P1, "KNN idxs must have the same point dimension as p1"); TORCH_CHECK(grad_dists.size(0) == N); TORCH_CHECK(grad_dists.size(1) == P1); TORCH_CHECK(grad_dists.size(2) == K); auto grad_p1 = at::zeros({N, P1, D}, p1.options()); auto grad_p2 = at::zeros({N, P2, D}, p2.options()); if (grad_p1.numel() == 0 || grad_p2.numel() == 0) { AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(grad_p1, grad_p2); } const int blocks = 64; const int threads = 512; KNearestNeighborBackwardKernel<<>>( p1.contiguous().data_ptr(), p2.contiguous().data_ptr(), lengths1.contiguous().data_ptr(), lengths2.contiguous().data_ptr(), idxs.contiguous().data_ptr(), grad_dists.contiguous().data_ptr(), grad_p1.data_ptr(), grad_p2.data_ptr(), N, P1, P2, K, D, norm); AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(grad_p1, grad_p2); } ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/src/knn.h ================================================ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include #include #include "utils/pytorch3d_cutils.h" // Compute indices of K nearest neighbors in pointcloud p2 to points // in pointcloud p1. // // Args: // p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each // containing P1 points of dimension D. // p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each // containing P2 points of dimension D. // lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud. // lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud. // norm: int specifying the norm for the distance (1 for L1, 2 for L2) // K: int giving the number of nearest points to return. // version: Integer telling which implementation to use. // // Returns: // p1_neighbor_idx: LongTensor of shape (N, P1, K), where // p1_neighbor_idx[n, i, k] = j means that the kth nearest // neighbor to p1[n, i] in the cloud p2[n] is p2[n, j]. // It is padded with zeros so that it can be used easily in a later // gather() operation. // // p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared // distance from each point p1[n, p, :] to its K neighbors // p2[n, p1_neighbor_idx[n, p, k], :]. // CPU implementation. std::tuple KNearestNeighborIdxCpu( const at::Tensor& p1, const at::Tensor& p2, const at::Tensor& lengths1, const at::Tensor& lengths2, const int norm, const int K); // CUDA implementation std::tuple KNearestNeighborIdxCuda( const at::Tensor& p1, const at::Tensor& p2, const at::Tensor& lengths1, const at::Tensor& lengths2, const int norm, const int K, const int version); // Implementation which is exposed. std::tuple KNearestNeighborIdx( const at::Tensor& p1, const at::Tensor& p2, const at::Tensor& lengths1, const at::Tensor& lengths2, const int norm, const int K, const int version) { if (p1.is_cuda() || p2.is_cuda()) { #ifdef WITH_CUDA CHECK_CUDA(p1); CHECK_CUDA(p2); return KNearestNeighborIdxCuda( p1, p2, lengths1, lengths2, norm, K, version); #else AT_ERROR("Not compiled with GPU support."); #endif } return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, norm, K); } // Compute gradients with respect to p1 and p2 // // Args: // p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each // containing P1 points of dimension D. // p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each // containing P2 points of dimension D. // lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud. // lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud. // p1_neighbor_idx: LongTensor of shape (N, P1, K), where // p1_neighbor_idx[n, i, k] = j means that the kth nearest // neighbor to p1[n, i] in the cloud p2[n] is p2[n, j]. // It is padded with zeros so that it can be used easily in a later // gather() operation. This is computed from the forward pass. // norm: int specifying the norm for the distance (1 for L1, 2 for L2) // grad_dists: FLoatTensor of shape (N, P1, K) which contains the input // gradients. // // Returns: // grad_p1: FloatTensor of shape (N, P1, D) containing the output gradients // wrt p1. // grad_p2: FloatTensor of shape (N, P2, D) containing the output gradients // wrt p2. // CPU implementation. std::tuple KNearestNeighborBackwardCpu( const at::Tensor& p1, const at::Tensor& p2, const at::Tensor& lengths1, const at::Tensor& lengths2, const at::Tensor& idxs, const int norm, const at::Tensor& grad_dists); // CUDA implementation std::tuple KNearestNeighborBackwardCuda( const at::Tensor& p1, const at::Tensor& p2, const at::Tensor& lengths1, const at::Tensor& lengths2, const at::Tensor& idxs, const int norm, const at::Tensor& grad_dists); // Implementation which is exposed. std::tuple KNearestNeighborBackward( const at::Tensor& p1, const at::Tensor& p2, const at::Tensor& lengths1, const at::Tensor& lengths2, const at::Tensor& idxs, const int norm, const at::Tensor& grad_dists) { if (p1.is_cuda() || p2.is_cuda()) { #ifdef WITH_CUDA CHECK_CUDA(p1); CHECK_CUDA(p2); return KNearestNeighborBackwardCuda( p1, p2, lengths1, lengths2, idxs, norm, grad_dists); #else AT_ERROR("Not compiled with GPU support."); #endif } return KNearestNeighborBackwardCpu( p1, p2, lengths1, lengths2, idxs, norm, grad_dists); } // Utility to check whether a KNN version can be used. // // Args: // version: Integer in the range 0 <= version <= 3 indicating one of our // KNN implementations. // D: Number of dimensions for the input and query point clouds // K: Number of neighbors to be found // // Returns: // Whether the indicated KNN version can be used. bool KnnCheckVersion(int version, const int64_t D, const int64_t K); ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/src/knn_cpu.cpp ================================================ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include std::tuple KNearestNeighborIdxCpu( const at::Tensor& p1, const at::Tensor& p2, const at::Tensor& lengths1, const at::Tensor& lengths2, const int norm, const int K) { const int N = p1.size(0); const int P1 = p1.size(1); const int D = p1.size(2); auto long_opts = lengths1.options().dtype(torch::kInt64); torch::Tensor idxs = torch::full({N, P1, K}, 0, long_opts); torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options()); auto p1_a = p1.accessor(); auto p2_a = p2.accessor(); auto lengths1_a = lengths1.accessor(); auto lengths2_a = lengths2.accessor(); auto idxs_a = idxs.accessor(); auto dists_a = dists.accessor(); for (int n = 0; n < N; ++n) { const int64_t length1 = lengths1_a[n]; const int64_t length2 = lengths2_a[n]; for (int64_t i1 = 0; i1 < length1; ++i1) { // Use a priority queue to store (distance, index) tuples. std::priority_queue> q; for (int64_t i2 = 0; i2 < length2; ++i2) { float dist = 0; for (int d = 0; d < D; ++d) { float diff = p1_a[n][i1][d] - p2_a[n][i2][d]; if (norm == 1) { dist += abs(diff); } else { // norm is 2 (default) dist += diff * diff; } } int size = static_cast(q.size()); if (size < K || dist < std::get<0>(q.top())) { q.emplace(dist, i2); if (size >= K) { q.pop(); } } } while (!q.empty()) { auto t = q.top(); q.pop(); const int k = q.size(); dists_a[n][i1][k] = std::get<0>(t); idxs_a[n][i1][k] = std::get<1>(t); } } } return std::make_tuple(idxs, dists); } // ------------------------------------------------------------- // // Backward Operators // // ------------------------------------------------------------- // std::tuple KNearestNeighborBackwardCpu( const at::Tensor& p1, const at::Tensor& p2, const at::Tensor& lengths1, const at::Tensor& lengths2, const at::Tensor& idxs, const int norm, const at::Tensor& grad_dists) { const int N = p1.size(0); const int P1 = p1.size(1); const int D = p1.size(2); const int P2 = p2.size(1); const int K = idxs.size(2); torch::Tensor grad_p1 = torch::full({N, P1, D}, 0, p1.options()); torch::Tensor grad_p2 = torch::full({N, P2, D}, 0, p2.options()); auto p1_a = p1.accessor(); auto p2_a = p2.accessor(); auto lengths1_a = lengths1.accessor(); auto lengths2_a = lengths2.accessor(); auto idxs_a = idxs.accessor(); auto grad_dists_a = grad_dists.accessor(); auto grad_p1_a = grad_p1.accessor(); auto grad_p2_a = grad_p2.accessor(); for (int n = 0; n < N; ++n) { const int64_t length1 = lengths1_a[n]; int64_t length2 = lengths2_a[n]; length2 = (length2 < K) ? length2 : K; for (int64_t i1 = 0; i1 < length1; ++i1) { for (int64_t k = 0; k < length2; ++k) { const int64_t i2 = idxs_a[n][i1][k]; // If the index is the pad value of -1 then ignore it if (i2 == -1) { continue; } for (int64_t d = 0; d < D; ++d) { float diff = 0.0; if (norm == 1) { float sign = (p1_a[n][i1][d] > p2_a[n][i2][d]) ? 1.0 : -1.0; diff = grad_dists_a[n][i1][k] * sign; } else { // norm is 2 (default) diff = 2.0f * grad_dists_a[n][i1][k] * (p1_a[n][i1][d] - p2_a[n][i2][d]); } grad_p1_a[n][i1][d] += diff; grad_p2_a[n][i2][d] += -1.0f * diff; } } } } return std::make_tuple(grad_p1, grad_p2); } ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/src/knn_ext.cpp ================================================ #include #include "knn.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { #ifdef WITH_CUDA m.def("knn_check_version", &KnnCheckVersion); #endif m.def("knn_points_idx", &KNearestNeighborIdx); m.def("knn_points_backward", &KNearestNeighborBackward); } ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/src/utils/dispatch.cuh ================================================ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ // This file provides utilities for dispatching to specialized versions of // functions. This is especially useful for CUDA kernels, since specializing // them to particular input sizes can often allow the compiler to unroll loops // and place arrays into registers, which can give huge performance speedups. // // As an example, suppose we have the following function which is specialized // based on a compile-time int64_t value: // // template // struct SquareOffset { // static void run(T y) { // T val = x * x + y; // std::cout << val << std::endl; // } // } // // This function takes one compile-time argument x, and one run-time argument y. // We might want to compile specialized versions of this for x=0, x=1, etc and // then dispatch to the correct one based on the runtime value of x. // One simple way to achieve this is with a lookup table: // // template // void DispatchSquareOffset(const int64_t x, T y) { // if (x == 0) { // SquareOffset::run(y); // } else if (x == 1) { // SquareOffset::run(y); // } else if (x == 2) { // SquareOffset::run(y); // } // } // // This function takes both x and y as run-time arguments, and dispatches to // different specialized versions of SquareOffset based on the run-time value // of x. This works, but it's tedious and error-prone. If we want to change the // set of x values for which we provide compile-time specializations, then we // will need to do a lot of tedius editing of the dispatch function. Also, if we // want to provide compile-time specializations for another function other than // SquareOffset, we will need to duplicate the entire lookup table. // // To solve these problems, we can use the DispatchKernel1D function provided by // this file instead: // // template // void DispatchSquareOffset(const int64_t x, T y) { // constexpr int64_t xmin = 0; // constexpr int64_t xmax = 2; // DispatchKernel1D(x, y); // } // // DispatchKernel1D uses template metaprogramming to compile specialized // versions of SquareOffset for all values of x with xmin <= x <= xmax, and // then dispatches to the correct one based on the run-time value of x. If we // want to change the range of x values for which SquareOffset is specialized // at compile-time, then all we have to do is change the values of the // compile-time constants xmin and xmax. // // This file also allows us to similarly dispatch functions that depend on two // compile-time int64_t values, using the DispatchKernel2D function like this: // // template // struct Sum { // static void run(T z, T w) { // T val = x + y + z + w; // std::cout << val << std::endl; // } // } // // template // void DispatchSum(const int64_t x, const int64_t y, int z, int w) { // constexpr int64_t xmin = 1; // constexpr int64_t xmax = 3; // constexpr int64_t ymin = 2; // constexpr int64_t ymax = 5; // DispatchKernel2D(x, y, z, w); // } // // Like its 1D counterpart, DispatchKernel2D uses template metaprogramming to // compile specialized versions of sum for all values of (x, y) with // xmin <= x <= xmax and ymin <= y <= ymax, then dispatches to the correct // specialized version based on the runtime values of x and y. // Define some helper structs in an anonymous namespace. namespace { // 1D dispatch: general case. // Kernel is the function we want to dispatch to; it should take a typename and // an int64_t as template args, and it should define a static void function // run which takes any number of arguments of any type. // In order to dispatch, we will take an additional template argument curN, // and increment it via template recursion until it is equal to the run-time // argument N. template < template class Kernel, typename T, int64_t minN, int64_t maxN, int64_t curN, typename... Args> struct DispatchKernelHelper1D { static void run(const int64_t N, Args... args) { if (curN == N) { // The compile-time value curN is equal to the run-time value N, so we // can dispatch to the run method of the Kernel. Kernel::run(args...); } else if (curN < N) { // Increment curN via template recursion DispatchKernelHelper1D::run( N, args...); } // We shouldn't get here -- throw an error? } }; // 1D dispatch: Specialization when curN == maxN // We need this base case to avoid infinite template recursion. template < template class Kernel, typename T, int64_t minN, int64_t maxN, typename... Args> struct DispatchKernelHelper1D { static void run(const int64_t N, Args... args) { if (N == maxN) { Kernel::run(args...); } // We shouldn't get here -- throw an error? } }; // 2D dispatch, general case. // This is similar to the 1D case: we take additional template args curN and // curM, and increment them via template recursion until they are equal to // the run-time values of N and M, at which point we dispatch to the run // method of the kernel. template < template class Kernel, typename T, int64_t minN, int64_t maxN, int64_t curN, int64_t minM, int64_t maxM, int64_t curM, typename... Args> struct DispatchKernelHelper2D { static void run(const int64_t N, const int64_t M, Args... args) { if (curN == N && curM == M) { Kernel::run(args...); } else if (curN < N && curM < M) { // Increment both curN and curM. This isn't strictly necessary; we could // just increment one or the other at each step. But this helps to cut // on the number of recursive calls we make. DispatchKernelHelper2D< Kernel, T, minN, maxN, curN + 1, minM, maxM, curM + 1, Args...>::run(N, M, args...); } else if (curN < N) { // Increment curN only DispatchKernelHelper2D< Kernel, T, minN, maxN, curN + 1, minM, maxM, curM, Args...>::run(N, M, args...); } else if (curM < M) { // Increment curM only DispatchKernelHelper2D< Kernel, T, minN, maxN, curN, minM, maxM, curM + 1, Args...>::run(N, M, args...); } } }; // 2D dispatch, specialization for curN == maxN template < template class Kernel, typename T, int64_t minN, int64_t maxN, int64_t minM, int64_t maxM, int64_t curM, typename... Args> struct DispatchKernelHelper2D< Kernel, T, minN, maxN, maxN, minM, maxM, curM, Args...> { static void run(const int64_t N, const int64_t M, Args... args) { if (maxN == N && curM == M) { Kernel::run(args...); } else if (curM < maxM) { DispatchKernelHelper2D< Kernel, T, minN, maxN, maxN, minM, maxM, curM + 1, Args...>::run(N, M, args...); } // We should not get here -- throw an error? } }; // 2D dispatch, specialization for curM == maxM template < template class Kernel, typename T, int64_t minN, int64_t maxN, int64_t curN, int64_t minM, int64_t maxM, typename... Args> struct DispatchKernelHelper2D< Kernel, T, minN, maxN, curN, minM, maxM, maxM, Args...> { static void run(const int64_t N, const int64_t M, Args... args) { if (curN == N && maxM == M) { Kernel::run(args...); } else if (curN < maxN) { DispatchKernelHelper2D< Kernel, T, minN, maxN, curN + 1, minM, maxM, maxM, Args...>::run(N, M, args...); } // We should not get here -- throw an error? } }; // 2D dispatch, specialization for curN == maxN, curM == maxM template < template class Kernel, typename T, int64_t minN, int64_t maxN, int64_t minM, int64_t maxM, typename... Args> struct DispatchKernelHelper2D< Kernel, T, minN, maxN, maxN, minM, maxM, maxM, Args...> { static void run(const int64_t N, const int64_t M, Args... args) { if (maxN == N && maxM == M) { Kernel::run(args...); } // We should not get here -- throw an error? } }; } // namespace // This is the function we expect users to call to dispatch to 1D functions template < template class Kernel, typename T, int64_t minN, int64_t maxN, typename... Args> void DispatchKernel1D(const int64_t N, Args... args) { if (minN <= N && N <= maxN) { // Kick off the template recursion by calling the Helper with curN = minN DispatchKernelHelper1D::run( N, args...); } // Maybe throw an error if we tried to dispatch outside the allowed range? } // This is the function we expect users to call to dispatch to 2D functions template < template class Kernel, typename T, int64_t minN, int64_t maxN, int64_t minM, int64_t maxM, typename... Args> void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) { if (minN <= N && N <= maxN && minM <= M && M <= maxM) { // Kick off the template recursion by calling the Helper with curN = minN // and curM = minM DispatchKernelHelper2D< Kernel, T, minN, maxN, minN, minM, maxM, minM, Args...>::run(N, M, args...); } // Maybe throw an error if we tried to dispatch outside the specified range? } ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/src/utils/index_utils.cuh ================================================ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ // This converts dynamic array lookups into static array lookups, for small // arrays up to size 32. // // Suppose we have a small thread-local array: // // float vals[10]; // // Ideally we should only index this array using static indices: // // for (int i = 0; i < 10; ++i) vals[i] = i * i; // // If we do so, then the CUDA compiler may be able to place the array into // registers, which can have a big performance improvement. However if we // access the array dynamically, the the compiler may force the array into // local memory, which has the same latency as global memory. // // These functions convert dynamic array access into static array access // using a brute-force lookup table. It can be used like this: // // float vals[10]; // int idx = 3; // float val = 3.14f; // RegisterIndexUtils::set(vals, idx, val); // float val2 = RegisterIndexUtils::get(vals, idx); // // The implementation is based on fbcuda/RegisterUtils.cuh: // https://github.com/facebook/fbcuda/blob/master/RegisterUtils.cuh // To avoid depending on the entire library, we just reimplement these two // functions. The fbcuda implementation is a bit more sophisticated, and uses // the preprocessor to generate switch statements that go up to N for each // value of N. We are lazy and just have a giant explicit switch statement. // // We might be able to use a template metaprogramming approach similar to // DispatchKernel1D for this. However DispatchKernel1D is intended to be used // for dispatching to the correct CUDA kernel on the host, while this is // is intended to run on the device. I was concerned that a metaprogramming // approach for this might lead to extra function calls at runtime if the // compiler fails to optimize them away, which could be very slow on device. // However I didn't actually benchmark or test this. template struct RegisterIndexUtils { __device__ __forceinline__ static T get(const T arr[N], int idx) { if (idx < 0 || idx >= N) return T(); switch (idx) { case 0: return arr[0]; case 1: return arr[1]; case 2: return arr[2]; case 3: return arr[3]; case 4: return arr[4]; case 5: return arr[5]; case 6: return arr[6]; case 7: return arr[7]; case 8: return arr[8]; case 9: return arr[9]; case 10: return arr[10]; case 11: return arr[11]; case 12: return arr[12]; case 13: return arr[13]; case 14: return arr[14]; case 15: return arr[15]; case 16: return arr[16]; case 17: return arr[17]; case 18: return arr[18]; case 19: return arr[19]; case 20: return arr[20]; case 21: return arr[21]; case 22: return arr[22]; case 23: return arr[23]; case 24: return arr[24]; case 25: return arr[25]; case 26: return arr[26]; case 27: return arr[27]; case 28: return arr[28]; case 29: return arr[29]; case 30: return arr[30]; case 31: return arr[31]; }; return T(); } __device__ __forceinline__ static void set(T arr[N], int idx, T val) { if (idx < 0 || idx >= N) return; switch (idx) { case 0: arr[0] = val; break; case 1: arr[1] = val; break; case 2: arr[2] = val; break; case 3: arr[3] = val; break; case 4: arr[4] = val; break; case 5: arr[5] = val; break; case 6: arr[6] = val; break; case 7: arr[7] = val; break; case 8: arr[8] = val; break; case 9: arr[9] = val; break; case 10: arr[10] = val; break; case 11: arr[11] = val; break; case 12: arr[12] = val; break; case 13: arr[13] = val; break; case 14: arr[14] = val; break; case 15: arr[15] = val; break; case 16: arr[16] = val; break; case 17: arr[17] = val; break; case 18: arr[18] = val; break; case 19: arr[19] = val; break; case 20: arr[20] = val; break; case 21: arr[21] = val; break; case 22: arr[22] = val; break; case 23: arr[23] = val; break; case 24: arr[24] = val; break; case 25: arr[25] = val; break; case 26: arr[26] = val; break; case 27: arr[27] = val; break; case 28: arr[28] = val; break; case 29: arr[29] = val; break; case 30: arr[30] = val; break; case 31: arr[31] = val; break; } } }; ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/src/utils/mink.cuh ================================================ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #define MINK_H #include "index_utils.cuh" // A data structure to keep track of the smallest K keys seen so far as well // as their associated values, intended to be used in device code. // This data structure doesn't allocate any memory; keys and values are stored // in arrays passed to the constructor. // // The implementation is generic; it can be used for any key type that supports // the < operator, and can be used with any value type. // // Example usage: // // float keys[K]; // int values[K]; // MinK mink(keys, values, K); // for (...) { // // Produce some key and value from somewhere // mink.add(key, value); // } // mink.sort(); // // Now keys and values store the smallest K keys seen so far and the values // associated to these keys: // // for (int k = 0; k < K; ++k) { // float key_k = keys[k]; // int value_k = values[k]; // } template class MinK { public: // Constructor. // // Arguments: // keys: Array in which to store keys // values: Array in which to store values // K: How many values to keep track of __device__ MinK(key_t* keys, value_t* vals, int K) : keys(keys), vals(vals), K(K), _size(0) {} // Try to add a new key and associated value to the data structure. If the key // is one of the smallest K seen so far then it will be kept; otherwise it // it will not be kept. // // This takes O(1) operations if the new key is not kept, or if the structure // currently contains fewer than K elements. Otherwise this takes O(K) time. // // Arguments: // key: The key to add // val: The value associated to the key __device__ __forceinline__ void add(const key_t& key, const value_t& val) { if (_size < K) { keys[_size] = key; vals[_size] = val; if (_size == 0 || key > max_key) { max_key = key; max_idx = _size; } _size++; } else if (key < max_key) { keys[max_idx] = key; vals[max_idx] = val; max_key = key; for (int k = 0; k < K; ++k) { key_t cur_key = keys[k]; if (cur_key > max_key) { max_key = cur_key; max_idx = k; } } } } // Get the number of items currently stored in the structure. // This takes O(1) time. __device__ __forceinline__ int size() { return _size; } // Sort the items stored in the structure using bubble sort. // This takes O(K^2) time. __device__ __forceinline__ void sort() { for (int i = 0; i < _size - 1; ++i) { for (int j = 0; j < _size - i - 1; ++j) { if (keys[j + 1] < keys[j]) { key_t key = keys[j]; value_t val = vals[j]; keys[j] = keys[j + 1]; vals[j] = vals[j + 1]; keys[j + 1] = key; vals[j + 1] = val; } } } } private: key_t* keys; value_t* vals; int K; int _size; key_t max_key; int max_idx; }; // This is a version of MinK that only touches the arrays using static indexing // via RegisterIndexUtils. If the keys and values are stored in thread-local // arrays, then this may allow the compiler to place them in registers for // fast access. // // This has the same API as RegisterMinK, but doesn't support sorting. // We found that sorting via RegisterIndexUtils gave very poor performance, // and suspect it may have prevented the compiler from placing the arrays // into registers. template class RegisterMinK { public: __device__ RegisterMinK(key_t* keys, value_t* vals) : keys(keys), vals(vals), _size(0) {} __device__ __forceinline__ void add(const key_t& key, const value_t& val) { if (_size < K) { RegisterIndexUtils::set(keys, _size, key); RegisterIndexUtils::set(vals, _size, val); if (_size == 0 || key > max_key) { max_key = key; max_idx = _size; } _size++; } else if (key < max_key) { RegisterIndexUtils::set(keys, max_idx, key); RegisterIndexUtils::set(vals, max_idx, val); max_key = key; for (int k = 0; k < K; ++k) { key_t cur_key = RegisterIndexUtils::get(keys, k); if (cur_key > max_key) { max_key = cur_key; max_idx = k; } } } } __device__ __forceinline__ int size() { return _size; } private: key_t* keys; value_t* vals; int _size; key_t max_key; int max_idx; }; ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/src/utils/pytorch3d_cutils.h ================================================ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor.") #define CHECK_CONTIGUOUS(x) \ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous.") #define CHECK_CONTIGUOUS_CUDA(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/losses/__init__.py ================================================ from .arel import ARel from .confidence import Confidence from .distill import SelfDistill, TeacherDistill from .dummy import Dummy from .local_ssi import EdgeGuidedLocalSSI, LocalSSI from .regression import Regression from .silog import SILog ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/losses/arel.py ================================================ import torch import torch.nn as nn from .utils import FNS, masked_mean class ARel(nn.Module): def __init__( self, weight: float, output_fn: str = "sqrt", input_fn: str = "linear", eps: float = 1e-5, ): super().__init__() self.name: str = self.__class__.__name__ self.weight: float = weight self.dims = [-2, -1] self.output_fn = FNS[output_fn] self.input_fn = FNS[input_fn] self.eps: float = eps @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def forward( self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, **kwargs ) -> torch.Tensor: mask = mask.bool().clone() input = self.input_fn(input.float()) target = self.input_fn(target.float()) error = (input - target).norm(dim=1) / target.norm(dim=1).clip(min=0.05) mask = mask.squeeze(1) error_image = masked_mean(data=error, mask=mask, dim=self.dims).squeeze(1, 2) error_image = self.output_fn(error_image) return error_image @classmethod def build(cls, config): obj = cls( weight=config["weight"], output_fn=config["output_fn"], input_fn=config["input_fn"], ) return obj ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/losses/confidence.py ================================================ import torch import torch.nn as nn from .utils import FNS, masked_mean class Confidence(nn.Module): def __init__( self, weight: float, output_fn: str = "sqrt", input_fn: str = "linear", rescale: bool = True, eps: float = 1e-5, ): super(Confidence, self).__init__() self.name: str = self.__class__.__name__ self.weight = weight self.rescale = rescale self.eps = eps self.output_fn = FNS[output_fn] self.input_fn = FNS[input_fn] @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def forward( self, input: torch.Tensor, target_pred: torch.Tensor, target_gt: torch.Tensor, mask: torch.Tensor, ): B, C = target_gt.shape[:2] mask = mask.bool() target_gt = target_gt.float().reshape(B, C, -1) target_pred = target_pred.float().reshape(B, C, -1) input = input.float().reshape(B, -1) mask = mask.reshape(B, -1) if self.rescale: target_pred = torch.stack( [ p * torch.median(gt[:, m]) / torch.median(p[:, m]) for p, gt, m in zip(target_pred, target_gt, mask) ] ) error = torch.abs( (self.input_fn(target_pred) - self.input_fn(target_gt)).norm(dim=1) - input ) losses = masked_mean(error, dim=[-1], mask=mask).squeeze(dim=-1) losses = self.output_fn(losses) return losses @classmethod def build(cls, config): obj = cls( weight=config["weight"], output_fn=config["output_fn"], input_fn=config["input_fn"], rescale=config.get("rescale", True), ) return obj ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/losses/distill.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from .utils import FNS, masked_mean class SelfDistill(nn.Module): def __init__(self, weight: float, output_fn: str = "sqrt", eps: float = 1e-5): super().__init__() self.name: str = self.__class__.__name__ self.weight: float = weight self.dims = (-2, -1) self.output_fn = FNS[output_fn] self.eps: float = eps @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def forward( self, input: torch.Tensor, intrinsics: torch.Tensor, mask: torch.Tensor, flips: torch.Tensor, downsample_ratio=14, ) -> torch.Tensor: chunks = input.shape[0] // 2 mask = F.interpolate(mask.float(), size=input.shape[-2:], mode="nearest") iters = zip( input.chunk(chunks), mask.chunk(chunks), intrinsics.chunk(chunks), flips.chunk(chunks), ) inputs0, inputs1, masks = [], [], [] for i, (pair_input, pair_mask, pair_cam, pair_flip) in enumerate(iters): mask0, mask1 = pair_mask input0, input1 = pair_input cam0, cam1 = pair_cam flip0, flip1 = pair_flip fx_0 = cam0[0, 0] / downsample_ratio fx_1 = cam1[0, 0] / downsample_ratio cx_0 = cam0[0, 2] / downsample_ratio cx_1 = cam1[0, 2] / downsample_ratio cy_0 = cam0[1, 2] / downsample_ratio cy_1 = cam1[1, 2] / downsample_ratio # flip image if flip0 ^ flip1: input0 = torch.flip(input0, dims=(2,)) mask0 = torch.flip(mask0, dims=(2,)) cx_0 = input0.shape[-1] - cx_0 # calc zoom zoom_x = float(fx_1 / fx_0) # apply zoom input0 = F.interpolate( input0.unsqueeze(0), scale_factor=zoom_x, mode="bilinear" ).squeeze(0) mask0 = F.interpolate( mask0.unsqueeze(0), scale_factor=zoom_x, mode="nearest" ).squeeze(0) # calc translation change_left = int(cx_1 - (cx_0 - 0.5) * zoom_x - 0.5) change_top = int(cy_1 - (cy_0 - 0.5) * zoom_x - 0.5) change_right = input1.shape[-1] - change_left - input0.shape[-1] change_bottom = input1.shape[-2] - change_top - input0.shape[-2] # apply translation pad_left = max(0, change_left) pad_right = max(0, change_right) pad_top = max(0, change_top) pad_bottom = max(0, change_bottom) crop_left = max(0, -change_left) crop_right = max(0, -change_right) crop_top = max(0, -change_top) crop_bottom = max(0, -change_bottom) input0 = F.pad( input0, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=0, ) mask0 = F.pad( mask0, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=0, ) input0 = input0[ :, crop_top : input0.shape[-2] - crop_bottom, crop_left : input0.shape[-1] - crop_right, ] mask0 = mask0[ :, crop_top : mask0.shape[-2] - crop_bottom, crop_left : mask0.shape[-1] - crop_right, ] mask = torch.logical_and(mask0, mask1) inputs0.append(input0) inputs1.append(input1) masks.append(mask) inputs0 = torch.stack(inputs0, dim=0) inputs1 = torch.stack(inputs1, dim=0) masks = torch.stack(masks, dim=0) loss1 = self.loss(inputs0, inputs1.detach(), masks) loss2 = self.loss(inputs1, inputs0.detach(), masks) return torch.cat([loss1, loss2], dim=0) def loss( self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: loss = masked_mean( (input - target).square().mean(dim=1), mask=mask, dim=[-2, -1] ) return self.output_fn(loss + self.eps) @classmethod def build(cls, config): obj = cls( weight=config["weight"], output_fn=config["output_fn"], ) return obj class TeacherDistill(nn.Module): def __init__( self, weight: float, output_fn: str = "sqrt", cross: bool = False, eps: float = 1e-5, ): super().__init__() assert output_fn in FNS self.name: str = self.__class__.__name__ self.weight: float = weight self.dims = (-2, -1) self.output_fn = FNS[output_fn] self.eps: float = eps self.cross = cross self.threshold = 0.05 self.head_dim = 64 # hardcoded for vit @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def forward( self, student_features: torch.Tensor, teacher_features: torch.Tensor, student_tokens: torch.Tensor, teacher_tokens: torch.Tensor, mask: torch.Tensor, # metas: List[Dict[str, torch.Tensor]], ) -> torch.Tensor: B = student_features.shape[0] device = student_features.device chunks = student_features.shape[0] // 2 mask = ( F.interpolate( mask.float() + 1e-3, size=student_features.shape[-2:], mode="nearest" ) > 0.5 ) # chunk features as self.head_dim student_features = rearrange( student_features, "b (n c) h w -> b c h w n", c=self.head_dim ) teacher_features = rearrange( teacher_features, "b (n c) h w -> b c h w n", c=self.head_dim ) student_tokens = rearrange( student_tokens, "b t (n c) -> b t c n", c=self.head_dim ) teacher_tokens = rearrange( teacher_tokens, "b t (n c) -> b t c n", c=self.head_dim ) distance = ( (student_features - teacher_features) .square() .sum(dim=1, keepdim=True) .sqrt() .mean(dim=-1) ) loss_features = masked_mean(distance, mask=mask, dim=[-2, -1]) loss_features = self.output_fn(loss_features.clamp(min=self.eps)).squeeze( 1, 2, 3 ) distance = ( (student_tokens - teacher_tokens).square().sum(dim=-2).sqrt().mean(dim=-1) ) loss_tokens = self.output_fn(distance.clamp(min=self.eps)).squeeze(1) return loss_features + 0.01 * loss_tokens @classmethod def build(cls, config): obj = cls( weight=config["weight"], output_fn=config["output_fn"], cross=config["cross"], ) return obj ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/losses/dummy.py ================================================ import torch import torch.nn as nn class Dummy(nn.Module): def __init__(self, *args, **kwargs): super().__init__() self.name: str = self.__class__.__name__ self.weight = 1.0 def forward(self, dummy: torch.Tensor, *args, **kwargs) -> torch.Tensor: return torch.tensor([0.0] * dummy.shape[0], device=dummy.device) @classmethod def build(cls, config): obj = cls() return obj ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/losses/local_ssi.py ================================================ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from unidepth.utils.geometric import erode from .utils import FNS, ind2sub, masked_mean, masked_quantile, ssi def sample_strong_edges(edges_img, quantile=0.95, reshape=8): # flat edges_img = F.interpolate( edges_img, scale_factor=1 / reshape, mode="bilinear", align_corners=False ) edges_img_flat = edges_img.flatten(1) # Find strong edges edges_mask = edges_img_flat > torch.quantile( edges_img_flat, quantile, dim=-1, keepdim=True ) num_samples = edges_mask.sum(dim=-1) if (num_samples < 10).any(): # sample random edges where num_samples < 2 random = torch.rand_like(edges_img_flat[num_samples < 10, :]) > quantile edges_mask[num_samples < 10, :] = torch.logical_or( edges_mask[num_samples < 10, :], random ) num_samples = edges_mask.sum(dim=-1) min_samples = num_samples.min() # Compute the coordinates of the strong edges as B, N, 2 edges_coords = torch.stack( [torch.nonzero(x, as_tuple=False)[:min_samples].squeeze() for x in edges_mask] ) edges_coords = ( torch.stack(ind2sub(edges_coords, edges_img.shape[-1]), dim=-1) * reshape ) return edges_coords @torch.jit.script def extract_patches(tensor, sample_coords, patch_size: tuple[int, int] = (32, 32)): N, _, H, W = tensor.shape device = tensor.device dtype = tensor.dtype patch_width, patch_height = patch_size pad_width = patch_width // 2 pad_height = patch_height // 2 # Pad the RGB images for both sheep tensor_padded = F.pad( tensor, (pad_width, pad_width, pad_height, pad_height), mode="constant", value=0.0, ) # Adjust edge coordinates to account for padding sample_coords_padded = sample_coords + torch.tensor( [pad_height, pad_width], dtype=dtype, device=device ).reshape(1, 1, 2) # Calculate the indices for gather operation x_centers = sample_coords_padded[:, :, 1].int() y_centers = sample_coords_padded[:, :, 0].int() all_patches = [] for tensor_i, x_centers_i, y_centers_i in zip(tensor_padded, x_centers, y_centers): patches = [] for x_center, y_center in zip(x_centers_i, y_centers_i): y_start, y_end = y_center - pad_height, y_center + pad_height + 1 x_start, x_end = x_center - pad_width, x_center + pad_width + 1 patches.append(tensor_i[..., y_start:y_end, x_start:x_end]) all_patches.append(torch.stack(patches, dim=0)) return torch.stack(all_patches, dim=0).reshape(N, -1, patch_height * patch_width) class LocalSSI(nn.Module): def __init__( self, weight: float, output_fn: str = "sqrt", patch_size: tuple[int, int] = (32, 32), min_samples: int = 4, num_levels: int = 4, input_fn: str = "linear", eps: float = 1e-5, ): super(LocalSSI, self).__init__() self.name: str = self.__class__.__name__ self.weight = weight self.output_fn = FNS[output_fn] self.input_fn = FNS[input_fn] self.min_samples = min_samples self.eps = eps patch_logrange = np.linspace( start=np.log2(min(patch_size)), stop=np.log2(max(patch_size)), endpoint=True, num=num_levels + 1, ) self.patch_logrange = [ (x, y) for x, y in zip(patch_logrange[:-1], patch_logrange[1:]) ] self.rescale_fn = ssi @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def forward( self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, *args, **kwargs, ) -> torch.Tensor: mask = mask.bool() input = self.input_fn(input.float()) target = self.input_fn(target.float()) B, C, H, W = input.shape total_errors = [] for ii, patch_logrange in enumerate(self.patch_logrange): log_kernel = ( np.random.uniform(*patch_logrange) if self.training else np.mean(patch_logrange) ) kernel_size = int( (2**log_kernel) * min(input.shape[-2:]) ) # always smaller than min_shape kernel_size = (kernel_size, kernel_size) stride = (int(kernel_size[0] * 0.9), int(kernel_size[1] * 0.9)) # unfold is always exceeding right/bottom, roll image only negative # to have them back in the unfolding window max_roll = ( (W - kernel_size[1]) % stride[1], (H - kernel_size[0]) % stride[0], ) roll_x, roll_y = np.random.randint(-max_roll[0], 1), np.random.randint( -max_roll[1], 1 ) input_fold = torch.roll(input, shifts=(roll_y, roll_x), dims=(2, 3)) target_fold = torch.roll(target, shifts=(roll_y, roll_x), dims=(2, 3)) mask_fold = torch.roll(mask.float(), shifts=(roll_y, roll_x), dims=(2, 3)) # unfold in patches input_fold = F.unfold( input_fold, kernel_size=kernel_size, stride=stride ).permute( 0, 2, 1 ) # B N C*H_p*W_p target_fold = F.unfold( target_fold, kernel_size=kernel_size, stride=stride ).permute(0, 2, 1) mask_fold = ( F.unfold(mask_fold, kernel_size=kernel_size, stride=stride) .bool() .permute(0, 2, 1) ) # calculate error patchwise, then mean over patch, then over image based if sample size is significant input_fold, target_fold, _ = self.rescale_fn( input_fold, target_fold, mask_fold, dim=[-1] ) error = (input_fold - target_fold).abs() # calculate elements more then 95 percentile and lower than 5percentile of error valid_patches = mask_fold.sum(dim=-1) >= self.min_samples error_mean_patch = masked_mean(error, mask_fold, dim=[-1]).squeeze(-1) error_mean_image = self.output_fn(error_mean_patch.clamp(min=self.eps)) error_mean_image = masked_mean( error_mean_image, mask=valid_patches, dim=[-1] ) total_errors.append(error_mean_image.squeeze(-1)) # global input_rescale = input.reshape(B, C, -1) target_rescale = target.reshape(B, C, -1) mask = mask.reshape(B, 1, -1).clone() input, target, mask = self.rescale_fn( input_rescale, target_rescale, mask, dim=[-1] ) error = (input - target).abs().squeeze(1) mask = mask.squeeze(1) error_mean_image = masked_mean(error, mask, dim=[-1]).squeeze(-1) error_mean_image = self.output_fn(error_mean_image.clamp(min=self.eps)) total_errors.append(error_mean_image) errors = torch.stack(total_errors).mean(dim=0) return errors @classmethod def build(cls, config): obj = cls( weight=config["weight"], patch_size=config["patch_size"], output_fn=config["output_fn"], min_samples=config["min_samples"], num_levels=config["num_levels"], input_fn=config["input_fn"], ) return obj class EdgeGuidedLocalSSI(nn.Module): def __init__( self, weight: float, output_fn: str = "sqrt", min_samples: int = 4, input_fn: str = "linear", use_global: bool = True, eps: float = 1e-5, ): super(EdgeGuidedLocalSSI, self).__init__() self.name: str = self.__class__.__name__ self.weight = weight self.output_fn = FNS[output_fn] self.input_fn = FNS[input_fn] self.min_samples = min_samples self.eps = eps self.use_global = use_global self.rescale_fn = ssi delta_x = torch.tensor( [[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]], requires_grad=False ) delta_y = torch.tensor( [[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]], requires_grad=False ) self.delta_x = delta_x.reshape(1, 1, 3, 3) self.delta_y = delta_y.reshape(1, 1, 3, 3) try: from unidepth.ops.extract_patches import RandomPatchExtractor self.random_patch_extractor = RandomPatchExtractor() except Exception as e: self.random_patch_extractor = extract_patches print( "EdgeGuidedLocalSSI reverts to a non cuda-optimized operation, " "you will experince large slowdown, " "please install it: ", "`cd ./unidepth/ops/extract_patches && bash compile.sh`", ) def get_edge(self, image, mask): channels = image.shape[1] device = image.device delta_x = self.delta_x.to(device).repeat(channels, 1, 1, 1) delta_y = self.delta_y.to(device).repeat(channels, 1, 1, 1) image_Gx = F.conv2d(image, delta_x, groups=channels, padding="same") / 8 image_Gy = F.conv2d(image, delta_y, groups=channels, padding="same") / 8 image_Gx = ( image_Gx.square().mean(dim=1, keepdim=True).sqrt() ) # RMSE over color dim image_Gy = image_Gy.square().mean(dim=1, keepdim=True).sqrt() edges = torch.sqrt(image_Gx**2 + image_Gy**2) edges[:, :, :3, :] = 0 edges[:, :, -3:, :] = 0 edges[:, :, :, :3] = 0 edges[:, :, :, -3:] = 0 edges[~mask.bool()] = 0 return edges def compute_sample_patch_error( self, input, target, mask, sampling_coords, kernel_size, image_size ): B, C, H, W = input.shape patch_size = kernel_size[0] * kernel_size[1] input = self.random_patch_extractor( input, sampling_coords, kernel_size ).reshape(B, -1, patch_size) target = self.random_patch_extractor( target, sampling_coords, kernel_size ).reshape(B, -1, patch_size) mask = ( self.random_patch_extractor(mask.float(), sampling_coords, kernel_size) .bool() .reshape(B, -1, patch_size) ) input, target, mask = self.rescale_fn(input, target, mask, dim=[-1]) error = (input - target).abs().clamp(min=self.eps) valid_patches = mask.sum(dim=-1) >= self.min_samples error_mean_patch = masked_mean(error, mask, dim=[-1]).squeeze(-1) error_mean_image = self.output_fn(error_mean_patch.clamp(min=self.eps)) error_mean_image = masked_mean(error_mean_image, mask=valid_patches, dim=[-1]) return error_mean_image def compute_image_error(self, input, target, mask, image_size): H, W = image_size input = input.reshape(-1, 1, H * W) target = target.reshape(-1, 1, H * W) mask = mask.reshape(-1, 1, H * W) input, target, mask = self.rescale_fn(input, target, mask, dim=[-1]) error = (input - target).abs().clamp(min=self.eps) error_mean_image = masked_mean(error, mask, dim=[-1]).squeeze(-1) error_mean_image = self.output_fn(error_mean_image.clamp(min=self.eps)) return error_mean_image @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def forward( self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, image: torch.Tensor | None = None, validity_mask: torch.Tensor | None = None, *args, **kwargs, ) -> torch.Tensor: mask = mask.bool() input = self.input_fn(input.float()) target = self.input_fn(target.float()) B, _, H, W = input.shape total_errors = [] # remove border and black border if validity_mask is not None: validity_mask = erode(validity_mask.float(), kernel_size=3) edges = self.get_edge(image, validity_mask) # quantile was 0.95? edges_coords = sample_strong_edges(edges, quantile=0.9, reshape=14) log_kernel = np.random.uniform(0.04, 0.08) if self.training else 0.05 kernel_size = int( log_kernel * min(input.shape[-2:]) ) # always smaller than min_shape kernel_size = kernel_size + int(kernel_size % 2 == 0) # odd num kernel_size = (kernel_size, kernel_size) error_mean_image = self.compute_sample_patch_error( input, target, mask, edges_coords, kernel_size, (H, W) ) total_errors.append(error_mean_image.squeeze(-1)) if self.use_global: error_mean_image = self.compute_image_error(input, target, mask, (H, W)) total_errors.append(error_mean_image.squeeze(-1)) errors = torch.stack(total_errors).mean(dim=0) return errors @classmethod def build(cls, config): obj = cls( weight=config["weight"], output_fn=config["output_fn"], input_fn=config["input_fn"], use_global=config["use_global"], min_samples=config.get("min_samples", 6), ) return obj ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/losses/regression.py ================================================ import torch import torch.nn as nn from .utils import FNS, REGRESSION_DICT, masked_mean, masked_quantile class Regression(nn.Module): def __init__( self, weight: float, input_fn: str, output_fn: str, alpha: float, gamma: float, fn: str, dims: list[int] = [-1], quantile: float = 0.0, **kwargs, ): super().__init__() self.name = self.__class__.__name__ self.output_fn = FNS[output_fn] self.input_fn = FNS[input_fn] self.weight = weight self.dims = dims self.quantile = quantile self.alpha = alpha self.gamma = gamma self.fn = REGRESSION_DICT[fn] @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def forward( self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: if mask is not None: # usually it is just repeated mask = mask[:, 0] input = self.input_fn(input.float()) target = self.input_fn(target.float()) error = self.fn(input - target, gamma=self.gamma, alpha=self.alpha).mean(dim=1) mean_error = masked_mean(data=error, mask=mask, dim=self.dims).squeeze( self.dims ) mean_error = self.output_fn(mean_error) return mean_error @classmethod def build(cls, config): obj = cls( weight=config["weight"], output_fn=config["output_fn"], input_fn=config["input_fn"], dims=config.get("dims", (-1,)), alpha=config["alpha"], gamma=config["gamma"], fn=config["fn"], ) return obj ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/losses/silog.py ================================================ import torch import torch.nn as nn from .utils import (FNS, REGRESSION_DICT, masked_mean, masked_mean_var, masked_quantile) class SILog(nn.Module): def __init__( self, weight: float, input_fn: str = "linear", output_fn: str = "sqrt", integrated: float = 0.15, dims: list[int] = [-3, -2, -1], eps: float = 1e-5, ): super().__init__() self.name: str = self.__class__.__name__ self.weight: float = weight self.dims = dims self.input_fn = FNS[input_fn] self.output_fn = FNS[output_fn] self.eps: float = eps self.integrated = integrated @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def forward( self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, si: torch.Tensor, **kwargs, ) -> torch.Tensor: mask = mask.bool() error = self.input_fn(input.float()) - self.input_fn(target.float()) mean_error, var_error = masked_mean_var( data=error, mask=mask, dim=self.dims, keepdim=False ) if var_error.ndim > 1: var_error = var_error.mean(dim=-1) if self.integrated > 0.0: scale_error = mean_error**2 var_error = var_error + self.integrated * scale_error * (1 - si.int()) out_loss = self.output_fn(var_error) return out_loss @classmethod def build(cls, config): obj = cls( weight=config["weight"], dims=config["dims"], output_fn=config["output_fn"], input_fn=config["input_fn"], integrated=config.get("integrated", 0.15), ) return obj ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/losses/utils.py ================================================ from math import prod from typing import Any, Dict, List, Optional, Tuple import torch FNS = { "sqrt": lambda x: torch.sqrt(x + 1e-4), "log": lambda x: torch.log(x + 1e-4), "log1": lambda x: torch.log(x + 1), # if x -> 0 : log(1/x) # if x -> inf : log(1+1/x) -> 1/x + hot "log1i": lambda x: torch.log(1 + 50 / (1e-4 + x)), "linear": lambda x: x, "square": torch.square, "disp": lambda x: 1 / (x + 1e-4), "disp1": lambda x: 1 / (1 + x), } FNS_INV = { "sqrt": torch.square, "log": torch.exp, "log1": lambda x: torch.exp(x) - 1, "linear": lambda x: x, "square": torch.sqrt, "disp": lambda x: 1 / x, } def masked_mean_var( data: torch.Tensor, mask: torch.Tensor, dim: List[int], keepdim: bool = True ): if mask is None: return data.mean(dim=dim, keepdim=keepdim), data.var(dim=dim, keepdim=keepdim) mask = mask.float() mask_sum = torch.sum(mask, dim=dim, keepdim=True) # data = torch.nan_to_num(data, nan=0.0) mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( mask_sum, min=1.0 ) mask_var = torch.sum( mask * (data - mask_mean) ** 2, dim=dim, keepdim=True ) / torch.clamp(mask_sum, min=1.0) if not keepdim: mask_mean, mask_var = mask_mean.squeeze(dim), mask_var.squeeze(dim) return mask_mean, mask_var def masked_mean(data: torch.Tensor, mask: torch.Tensor | None, dim: List[int]): if mask is None: return data.mean(dim=dim, keepdim=True) mask = mask.float() mask_sum = torch.sum(mask, dim=dim, keepdim=True) mask_mean = torch.sum( torch.nan_to_num(data, nan=0.0) * mask, dim=dim, keepdim=True ) / mask_sum.clamp(min=1.0) return mask_mean def masked_quantile( data: torch.Tensor, mask: torch.Tensor | None, dims: List[int], q: float ): """ Compute the quantile of the data only where the mask is 1 along specified dimensions. Args: data (torch.Tensor): The input data tensor. mask (torch.Tensor): The mask tensor with the same shape as data, containing 1s where data should be considered. dims (list of int): The dimensions to compute the quantile over. q (float): The quantile to compute, must be between 0 and 1. Returns: torch.Tensor: The quantile computed over the specified dimensions, ignoring masked values. """ masked_data = data * mask if mask is not None else data # Get a list of all dimensions all_dims = list(range(masked_data.dim())) # Revert negative dimensions dims = [d % masked_data.dim() for d in dims] # Find the dimensions to keep (not included in the `dims` list) keep_dims = [d for d in all_dims if d not in dims] # Permute dimensions to bring `dims` to the front permute_order = dims + keep_dims permuted_data = masked_data.permute(permute_order) # Reshape into 2D: (-1, remaining_dims) collapsed_shape = ( -1, prod([permuted_data.size(d) for d in range(len(dims), permuted_data.dim())]), ) reshaped_data = permuted_data.reshape(collapsed_shape) if mask is None: return torch.quantile(reshaped_data, q, dim=0) permuted_mask = mask.permute(permute_order) reshaped_mask = permuted_mask.reshape(collapsed_shape) # Calculate quantile along the first dimension where mask is true quantiles = [] for i in range(reshaped_data.shape[1]): valid_data = reshaped_data[:, i][reshaped_mask[:, i]] if valid_data.numel() == 0: # print("Warning: No valid data found for quantile calculation.") quantiles.append(reshaped_data[:, i].min() * 0.99) else: quantiles.append(torch.quantile(valid_data, q, dim=0)) # Stack back into a tensor with reduced dimensions quantiles = torch.stack(quantiles) quantiles = quantiles.reshape( [permuted_data.size(d) for d in range(len(dims), permuted_data.dim())] ) return quantiles def masked_median(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): ndim = data.ndim data = data.flatten(ndim - len(dim)) mask = mask.flatten(ndim - len(dim)) mask_median = torch.median(data[..., mask], dim=-1).values return mask_median def masked_median_mad(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): ndim = data.ndim data = data.flatten(ndim - len(dim)) mask = mask.flatten(ndim - len(dim)) mask_median = torch.median(data[mask], dim=-1, keepdim=True).values mask_mad = masked_mean((data - mask_median).abs(), mask, dim=[-1]) return mask_median, mask_mad def masked_weighted_mean_var( data: torch.Tensor, mask: torch.Tensor, weights: torch.Tensor, dim: Tuple[int, ...] ): if mask is None: return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True) mask = mask.float() mask_mean = torch.sum(data * mask * weights, dim=dim, keepdim=True) / torch.sum( mask * weights, dim=dim, keepdim=True ).clamp(min=1.0) # V1**2 - V2, V1: sum w_i, V2: sum w_i**2 denom = torch.sum(weights * mask, dim=dim, keepdim=True).square() - torch.sum( (mask * weights).square(), dim=dim, keepdim=True ) # correction is V1 / (V1**2 - V2), if w_i=1 => N/(N**2 - N) => 1/(N-1) (unbiased estimator of variance, cvd) correction_factor = torch.sum(mask * weights, dim=dim, keepdim=True) / denom.clamp( min=1.0 ) mask_var = correction_factor * torch.sum( weights * mask * (data - mask_mean) ** 2, dim=dim, keepdim=True ) return mask_mean, mask_var def ssi( input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, dim: list[int] ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # recalculate mask with points in 95% confidence interval # the statistics are calculated on the stable points and # are similar ot median/MAD, but median/MAD gradients # are really weird, so this is a workaround input_detach = input.detach() input_mean, input_var = masked_mean_var(input_detach, mask=mask, dim=dim) target_mean, target_var = masked_mean_var(target, mask=mask, dim=dim) input_std = (input_var).clip(min=1e-6).sqrt() target_std = (target_var).clip(min=1e-6).sqrt() stable_points_input = torch.logical_and( input_detach > input_mean - 1.96 * input_std, input_detach < input_mean + 1.96 * input_std, ) stable_points_target = torch.logical_and( target > target_mean - 1.96 * target_std, target < target_mean + 1.96 * target_std, ) stable_mask = stable_points_target & stable_points_input & mask input_mean, input_var = masked_mean_var(input, mask=stable_mask, dim=dim) target_mean, target_var = masked_mean_var(target, mask=stable_mask, dim=dim) target_normalized = (target - target_mean) / FNS["sqrt"](target_var) input_normalized = (input - input_mean) / FNS["sqrt"](input_var) return input_normalized, target_normalized, stable_mask def ind2sub(idx, cols): r = idx // cols c = idx % cols return r, c def sub2ind(r, c, cols): idx = r * cols + c return idx def l2(input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs) -> torch.Tensor: return gamma * (input_tensor / gamma) ** 2 def l1(input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs) -> torch.Tensor: return torch.abs(input_tensor) def charbonnier( input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs ) -> torch.Tensor: return torch.sqrt(torch.square(input_tensor) + gamma**2) - gamma def cauchy( input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs ) -> torch.Tensor: return gamma * torch.log(torch.square(input_tensor) / gamma + 1) def geman_mcclure( input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs ) -> torch.Tensor: return gamma * torch.square(input_tensor) / (torch.square(input_tensor) + gamma) def robust_loss( input_tensor: torch.Tensor, alpha: float, gamma: float = 1.0, *args, **kwargs ) -> torch.Tensor: coeff = abs(alpha - 2) / alpha power = torch.square(input_tensor) / abs(alpha - 2) / (gamma**2) + 1 return ( gamma * coeff * (torch.pow(power, alpha / 2) - 1) ) # mult gamma to keep grad magnitude invariant wrt gamma REGRESSION_DICT = { "l2": l2, "l1": l1, "cauchy": cauchy, "charbonnier": charbonnier, "geman_mcclure": geman_mcclure, "robust_loss": robust_loss, } ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/scheduler.py ================================================ import weakref import numpy as np class PlainCosineScheduler(object): def __init__( self, klass, key, warmup_iters, total_iters, overwrite=False, init_value=None, base_value=None, final_value=None, step_init=-1, ): super().__init__() self.iter = step_init self.overwrite = overwrite self.base_value = base_value self.init_value = init_value if init_value is not None else base_value self.final_value = final_value self.total_iters = total_iters self.warmup_iters = warmup_iters self.key = key self.klass = klass self.schedulers = [self.get_scheduler()] def get_scheduler(self): init_value = self.init_value base_value = self.base_value final_value = self.final_value warmup_iters = self.warmup_iters total_iters = self.total_iters # normalize in 0,1, then apply function (power) and denormalize normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True) normalized_schedule = np.power(normalized_schedule, 1) warmup_schedule = (base_value - init_value) * normalized_schedule + init_value # main scheduling iters = np.arange(total_iters - warmup_iters + 1) schedule = final_value + 0.5 * (base_value - final_value) * ( 1 + np.cos(np.pi * iters / (len(iters) - 1)) ) return np.concatenate((warmup_schedule, schedule)) def step(self): self.iter = self.iter + 1 vals = self[self.iter] for i, val in enumerate(vals): setattr(self.klass, self.key, val) def __getitem__(self, it): it = min(it, self.total_iters) return [scheduler[it] for scheduler in self.schedulers] class CosineScheduler(object): def __init__( self, optimizer, warmup_iters, total_iters, key, overwrite=False, init_value=None, base_value=None, final_value=None, step_init=-1, ): super().__init__() self.iter = step_init self.overwrite = overwrite self.optimizer = optimizer self.base_value = base_value self.init_value = init_value self.final_value = final_value self.total_iters = total_iters self.warmup_iters = warmup_iters self.key = key self.schedulers = [ self.get_schedulers(group) for group in optimizer.param_groups ] def get_schedulers(self, group): init_value = group.get(self.key + "_init", self.init_value) base_value = group.get(self.key + "_base", self.base_value) final_value = group.get(self.key + "_final", self.final_value) warmup_iters = self.warmup_iters total_iters = self.total_iters if self.overwrite: final_value = self.final_value # normalize in 0,1, then apply function (power) and denormalize normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True) normalized_schedule = np.power(normalized_schedule, 1) warmup_schedule = (base_value - init_value) * normalized_schedule + init_value # main scheduling iters = np.arange(total_iters - warmup_iters + 1) schedule = final_value + 0.5 * (base_value - final_value) * ( 1 + np.cos(np.pi * iters / (len(iters) - 1)) ) return np.concatenate((warmup_schedule, schedule)) def step(self): self.iter = self.iter + 1 vals = self[self.iter] for group, val in zip(self.optimizer.param_groups, vals): if isinstance(group[self.key], (tuple, list)): val = (val, *group[self.key][1:]) group[self.key] = val def __getitem__(self, it): it = min(it, self.total_iters) return [scheduler[it] for scheduler in self.schedulers] def get(self): return [group[self.key] for group in self.optimizer.param_groups] ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/__init__.py ================================================ from .camera import invert_pinhole # from .validation import validate from .coordinate import coords_grid, normalize_coords from .distributed import (barrier, get_dist_info, get_rank, is_main_process, setup_multi_processes, setup_slurm, sync_tensor_across_gpus) from .evaluation_depth import (DICT_METRICS, DICT_METRICS_3D, eval_3d, eval_depth) from .geometric import spherical_zbuffer_to_euclidean, unproject_points from .misc import (format_seconds, get_params, identity, recursive_index, remove_padding, to_cpu) from .visualization import colorize, image_grid, log_train_artifacts ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/camera.py ================================================ """ Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ from copy import deepcopy import numpy as np import torch import torch.nn.functional as F from .coordinate import coords_grid from .misc import recursive_to, squeeze_list def invert_pinhole(K): fx = K[..., 0, 0] fy = K[..., 1, 1] cx = K[..., 0, 2] cy = K[..., 1, 2] K_inv = torch.zeros_like(K) K_inv[..., 0, 0] = 1.0 / fx K_inv[..., 1, 1] = 1.0 / fy K_inv[..., 0, 2] = -cx / fx K_inv[..., 1, 2] = -cy / fy K_inv[..., 2, 2] = 1.0 return K_inv class Camera: """ This is meant to be an abstract parent class, please use the others as actual cameras. Pinhole, FIsheye624, MEI, OPENCV, EUCM, Spherical (Equirectangular). """ def __init__(self, params=None, K=None): if params.ndim == 1: params = params.unsqueeze(0) if K is None: K = ( torch.eye(3, device=params.device, dtype=params.dtype) .unsqueeze(0) .repeat(params.shape[0], 1, 1) ) K[..., 0, 0] = params[..., 0] K[..., 1, 1] = params[..., 1] K[..., 0, 2] = params[..., 2] K[..., 1, 2] = params[..., 3] self.params = params self.K = K self.overlap_mask = None self.projection_mask = None def project(self, xyz): raise NotImplementedError def unproject(self, uv): raise NotImplementedError def get_projection_mask(self): return self.projection_mask def get_overlap_mask(self): return self.overlap_mask def reconstruct(self, depth): id_coords = coords_grid( 1, depth.shape[-2], depth.shape[-1], device=depth.device ) rays = self.unproject(id_coords) return ( rays / rays[:, -1:].clamp(min=1e-4) * depth.clamp(min=1e-4) ) # assumption z>0!!! def resize(self, factor): self.K[..., :2, :] *= factor self.params[..., :4] *= factor return self def to(self, device, non_blocking=False): self.params = self.params.to(device, non_blocking=non_blocking) self.K = self.K.to(device, non_blocking=non_blocking) return self def get_rays(self, shapes, noisy=False): b, h, w = shapes uv = coords_grid(1, h, w, device=self.K.device, noisy=noisy) rays = self.unproject(uv) return rays / torch.norm(rays, dim=1, keepdim=True).clamp(min=1e-4) def get_pinhole_rays(self, shapes, noisy=False): b, h, w = shapes uv = coords_grid(b, h, w, device=self.K.device, homogeneous=True, noisy=noisy) rays = (invert_pinhole(self.K) @ uv.reshape(b, 3, -1)).reshape(b, 3, h, w) return rays / torch.norm(rays, dim=1, keepdim=True).clamp(min=1e-4) def flip(self, H, W, direction="horizontal"): new_cx = ( W - self.params[:, 2] if direction == "horizontal" else self.params[:, 2] ) new_cy = H - self.params[:, 3] if direction == "vertical" else self.params[:, 3] self.params = torch.stack( [self.params[:, 0], self.params[:, 1], new_cx, new_cy], dim=1 ) self.K[..., 0, 2] = new_cx self.K[..., 1, 2] = new_cy return self def clone(self): return deepcopy(self) def crop(self, left, top, right=None, bottom=None): self.K[..., 0, 2] -= left self.K[..., 1, 2] -= top self.params[..., 2] -= left self.params[..., 3] -= top return self # helper function to get how fov changes based on new original size and new size def get_new_fov(self, new_shape, original_shape): new_hfov = 2 * torch.atan( self.params[..., 2] / self.params[..., 0] * new_shape[1] / original_shape[1] ) new_vfov = 2 * torch.atan( self.params[..., 3] / self.params[..., 1] * new_shape[0] / original_shape[0] ) return new_hfov, new_vfov def mask_overlap_projection(self, projected): B, _, H, W = projected.shape id_coords = coords_grid(B, H, W, device=projected.device) # check for mask where flow would overlap with other part of the image # eleemtns coming from the border are then masked out flow = projected - id_coords gamma = 0.1 sample_grid = gamma * flow + id_coords # sample along the flow sample_grid[:, 0] = sample_grid[:, 0] / (W - 1) * 2 - 1 sample_grid[:, 1] = sample_grid[:, 1] / (H - 1) * 2 - 1 sampled_flow = F.grid_sample( flow, sample_grid.permute(0, 2, 3, 1), mode="bilinear", align_corners=False, padding_mode="border", ) mask = ( (1 - gamma) * torch.norm(flow, dim=1, keepdim=True) < torch.norm(sampled_flow, dim=1, keepdim=True) ) | (torch.norm(flow, dim=1, keepdim=True) < 1) return mask def _pad_params(self): # Ensure params are padded to length 16 if self.params.shape[1] < 16: padding = torch.zeros( 16 - self.params.shape[1], device=self.params.device, dtype=self.params.dtype, ) padding = padding.unsqueeze(0).repeat(self.params.shape[0], 1) return torch.cat([self.params, padding], dim=1) return self.params @staticmethod def flatten_cameras(cameras): # -> list[Camera]: # Recursively flatten BatchCamera into primitive cameras flattened_cameras = [] for camera in cameras: if isinstance(camera, BatchCamera): flattened_cameras.extend(BatchCamera.flatten_cameras(camera.cameras)) elif isinstance(camera, list): flattened_cameras.extend(camera) else: flattened_cameras.append(camera) return flattened_cameras @staticmethod def _stack_or_cat_cameras(cameras, func, **kwargs): # Generalized method to handle stacking or concatenation flat_cameras = BatchCamera.flatten_cameras(cameras) K_matrices = [camera.K for camera in flat_cameras] padded_params = [camera._pad_params() for camera in flat_cameras] stacked_K = func(K_matrices, **kwargs) stacked_params = func(padded_params, **kwargs) # Keep track of the original classes original_class = [x.__class__.__name__ for x in flat_cameras] return BatchCamera(stacked_params, stacked_K, original_class, flat_cameras) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} if func is torch.cat: return Camera._stack_or_cat_cameras(args[0], func, **kwargs) if func is torch.stack: return Camera._stack_or_cat_cameras(args[0], func, **kwargs) if func is torch.flatten: return Camera._stack_or_cat_cameras(args[0], torch.cat, **kwargs) return super().__torch_function__(func, types, args, kwargs) @property def device(self): return self.K.device # here we assume that cx,cy are more or less H/2 and W/2 @property def hfov(self): return 2 * torch.atan(self.params[..., 2] / self.params[..., 0]) @property def vfov(self): return 2 * torch.atan(self.params[..., 3] / self.params[..., 1]) @property def max_fov(self): return 150.0 / 180.0 * np.pi, 150.0 / 180.0 * np.pi class Pinhole(Camera): def __init__(self, params=None, K=None): assert params is not None or K is not None if params is None: params = torch.stack( [K[..., 0, 0], K[..., 1, 1], K[..., 0, 2], K[..., 1, 2]], dim=-1 ) super().__init__(params=params, K=K) @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def project(self, pcd): b, _, h, w = pcd.shape pcd_flat = pcd.reshape(b, 3, -1) # [B, 3, H*W] cam_coords = self.K @ pcd_flat pcd_proj = cam_coords[:, :2] / cam_coords[:, -1:].clamp(min=0.01) pcd_proj = pcd_proj.reshape(b, 2, h, w) invalid = ( (pcd_proj[:, 0] >= 0) & (pcd_proj[:, 0] < w) & (pcd_proj[:, 1] >= 0) & (pcd_proj[:, 1] < h) ) self.projection_mask = (~invalid).unsqueeze(1) return pcd_proj @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def unproject(self, uv): b, _, h, w = uv.shape uv_flat = uv.reshape(b, 2, -1) # [B, 2, H*W] uv_homogeneous = torch.cat( [uv_flat, torch.ones(b, 1, h * w, device=uv.device)], dim=1 ) # [B, 3, H*W] K_inv = torch.inverse(self.K.float()) xyz = K_inv @ uv_homogeneous xyz = xyz / xyz[:, -1:].clip(min=1e-4) xyz = xyz.reshape(b, 3, h, w) self.unprojection_mask = xyz[:, -1:] > 1e-4 return xyz @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def reconstruct(self, depth): b, _, h, w = depth.shape uv = coords_grid(b, h, w, device=depth.device) xyz = self.unproject(uv) * depth.clip(min=0.0) return xyz class EUCM(Camera): def __init__(self, params): super().__init__(params=params, K=None) @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def project(self, xyz): H, W = xyz.shape[-2:] fx, fy, cx, cy, alpha, beta = self.params[:6].unbind(dim=1) x, y, z = xyz.unbind(dim=1) d = torch.sqrt(beta * (x**2 + y**2) + z**2) x = x / (alpha * d + (1 - alpha) * z).clip(min=1e-3) y = y / (alpha * d + (1 - alpha) * z).clip(min=1e-3) Xnorm = fx * x + cx Ynorm = fy * y + cy coords = torch.stack([Xnorm, Ynorm], dim=1) invalid = ( (coords[:, 0] < 0) | (coords[:, 0] > W) | (coords[:, 1] < 0) | (coords[:, 1] > H) | (z < 0) ) self.projection_mask = (~invalid).unsqueeze(1) return coords @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def unproject(self, uv): u, v = uv.unbind(dim=1) fx, fy, cx, cy, alpha, beta = self.params.unbind(dim=1) mx = (u - cx) / fx my = (v - cy) / fy r_square = mx**2 + my**2 valid_mask = r_square < torch.where( alpha < 0.5, 1e6, 1 / (beta * (2 * alpha - 1)) ) sqrt_val = 1 - (2 * alpha - 1) * beta * r_square mz = (1 - beta * (alpha**2) * r_square) / ( alpha * torch.sqrt(sqrt_val.clip(min=1e-5)) + (1 - alpha) ) coeff = 1 / torch.sqrt(mx**2 + my**2 + mz**2 + 1e-5) x = coeff * mx y = coeff * my z = coeff * mz self.unprojection_mask = valid_mask & (z > 1e-3) xnorm = torch.stack((x, y, z.clamp(1e-3)), dim=1) return xnorm class Spherical(Camera): def __init__(self, params): # Hfov and Vofv are in radians and halved! super().__init__(params=params, K=None) def resize(self, factor): self.K[..., :2, :] *= factor self.params[..., :6] *= factor return self def crop(self, left, top, right, bottom): self.K[..., 0, 2] -= left self.K[..., 1, 2] -= top self.params[..., 2] -= left self.params[..., 3] -= top W, H = self.params[..., 4], self.params[..., 5] angle_ratio_W = (W - left - right) / W angle_ratio_H = (H - top - bottom) / H self.params[..., 4] -= left + right self.params[..., 5] -= top + bottom # rescale hfov and vfov self.params[..., 6] *= angle_ratio_W self.params[..., 7] *= angle_ratio_H return self @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def project(self, xyz): width, height = self.params[..., 4], self.params[..., 5] hfov, vfov = 2 * self.params[..., 6], 2 * self.params[..., 7] longitude = torch.atan2(xyz[:, 0], xyz[:, 2]) latitude = torch.asin(xyz[:, 1] / torch.norm(xyz, dim=1).clamp(min=1e-5)) u = longitude / hfov * (width - 1) + (width - 1) / 2 v = latitude / vfov * (height - 1) + (height - 1) / 2 return torch.stack([u, v], dim=1) @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def unproject(self, uv): u, v = uv.unbind(dim=1) width, height = self.params[..., 4], self.params[..., 5] hfov, vfov = 2 * self.params[..., 6], 2 * self.params[..., 7] longitude = (u - (width - 1) / 2) / (width - 1) * hfov latitude = (v - (height - 1) / 2) / (height - 1) * vfov x = torch.cos(latitude) * torch.sin(longitude) z = torch.cos(latitude) * torch.cos(longitude) y = torch.sin(latitude) unit_sphere = torch.stack([x, y, z], dim=1) unit_sphere = unit_sphere / torch.norm(unit_sphere, dim=1, keepdim=True).clip( min=1e-5 ) return unit_sphere def reconstruct(self, depth): id_coords = coords_grid( 1, depth.shape[-2], depth.shape[-1], device=depth.device ) return self.unproject(id_coords) * depth def get_new_fov(self, new_shape, original_shape): new_hfov = 2 * self.params[..., 6] * new_shape[1] / original_shape[1] new_vfov = 2 * self.params[..., 7] * new_shape[0] / original_shape[0] return new_hfov, new_vfov @property def hfov(self): return 2 * self.params[..., 6] @property def vfov(self): return 2 * self.params[..., 7] @property def max_fov(self): return 2 * np.pi, 0.9 * np.pi # avoid strong distortion on tops class OPENCV(Camera): def __init__(self, params): super().__init__(params=params, K=None) self.use_radial = self.params[..., 4:10].abs().sum() > 1e-6 assert ( self.params[..., 7:10].abs().sum() == 0.0 ), "Do not support poly division model" self.use_tangential = self.params[..., 10:12].abs().sum() > 1e-6 self.use_thin_prism = self.params[..., 12:].abs().sum() > 1e-6 @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def project(self, xyz): eps = 1e-9 B, _, H, W = xyz.shape N = H * W xyz = xyz.permute(0, 2, 3, 1).reshape(B, N, 3) # Radial correction. z = xyz[:, :, 2].reshape(B, N, 1) z = torch.where(torch.abs(z) < eps, eps * torch.sign(z), z) ab = xyz[:, :, :2] / z r = torch.norm(ab, dim=-1, p=2, keepdim=True) th = r # Create powers of th (th^3, th^5, ...) th_pow = torch.cat([torch.pow(th, 2 + i * 2) for i in range(3)], dim=-1) distortion_coeffs_num = self.params[:, 4:7].reshape(B, 1, 3) distortion_coeffs_den = self.params[:, 7:10].reshape(B, 1, 3) th_num = 1 + torch.sum(th_pow * distortion_coeffs_num, dim=-1, keepdim=True) th_den = 1 + torch.sum(th_pow * distortion_coeffs_den, dim=-1, keepdim=True) xr_yr = ab * th_num / th_den uv_dist = xr_yr # Tangential correction. p0 = self.params[..., -6].reshape(B, 1) p1 = self.params[..., -5].reshape(B, 1) xr = xr_yr[:, :, 0].reshape(B, N) yr = xr_yr[:, :, 1].reshape(B, N) xr_yr_sq = torch.square(xr_yr) xr_sq = xr_yr_sq[:, :, 0].reshape(B, N) yr_sq = xr_yr_sq[:, :, 1].reshape(B, N) rd_sq = xr_sq + yr_sq uv_dist_tu = uv_dist[:, :, 0] + ( (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 ) uv_dist_tv = uv_dist[:, :, 1] + ( (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 ) uv_dist = torch.stack( [uv_dist_tu, uv_dist_tv], dim=-1 ) # Avoids in-place complaint. # Thin Prism correction. s0 = self.params[..., -4].reshape(B, 1) s1 = self.params[..., -3].reshape(B, 1) s2 = self.params[..., -2].reshape(B, 1) s3 = self.params[..., -1].reshape(B, 1) rd_4 = torch.square(rd_sq) uv_dist[:, :, 0] = uv_dist[:, :, 0] + (s0 * rd_sq + s1 * rd_4) uv_dist[:, :, 1] = uv_dist[:, :, 1] + (s2 * rd_sq + s3 * rd_4) # Finally, apply standard terms: focal length and camera centers. if self.params.shape[-1] == 15: fx_fy = self.params[..., 0].reshape(B, 1, 1) cx_cy = self.params[..., 1:3].reshape(B, 1, 2) else: fx_fy = self.params[..., 0:2].reshape(B, 1, 2) cx_cy = self.params[..., 2:4].reshape(B, 1, 2) result = uv_dist * fx_fy + cx_cy result = result.reshape(B, H, W, 2).permute(0, 3, 1, 2) invalid = ( (result[:, 0] < 0) | (result[:, 0] > W) | (result[:, 1] < 0) | (result[:, 1] > H) ) self.projection_mask = (~invalid).unsqueeze(1) self.overlap_mask = self.mask_overlap_projection(result) return result @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def unproject(self, uv, max_iters: int = 10): eps = 1e-3 B, _, H, W = uv.shape N = H * W uv = uv.permute(0, 2, 3, 1).reshape(B, N, 2) if self.params.shape[-1] == 15: fx_fy = self.params[..., 0].reshape(B, 1, 1) cx_cy = self.params[..., 1:3].reshape(B, 1, 2) else: fx_fy = self.params[..., 0:2].reshape(B, 1, 2) cx_cy = self.params[..., 2:4].reshape(B, 1, 2) uv_dist = (uv - cx_cy) / fx_fy # Compute xr_yr using Newton's method. xr_yr = uv_dist.clone() # Initial guess. max_iters_tanprism = ( max_iters if self.use_thin_prism or self.use_tangential else 0 ) for _ in range(max_iters_tanprism): uv_dist_est = xr_yr.clone() xr = xr_yr[..., 0].reshape(B, N) yr = xr_yr[..., 1].reshape(B, N) xr_yr_sq = torch.square(xr_yr) xr_sq = xr_yr_sq[..., 0].reshape(B, N) yr_sq = xr_yr_sq[..., 1].reshape(B, N) rd_sq = xr_sq + yr_sq if self.use_tangential: # Tangential terms. p0 = self.params[..., -6].reshape(B, 1) p1 = self.params[..., -5].reshape(B, 1) uv_dist_est[..., 0] = uv_dist_est[..., 0] + ( (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 ) uv_dist_est[..., 1] = uv_dist_est[..., 1] + ( (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 ) if self.use_thin_prism: # Thin Prism terms. s0 = self.params[..., -4].reshape(B, 1) s1 = self.params[..., -3].reshape(B, 1) s2 = self.params[..., -2].reshape(B, 1) s3 = self.params[..., -1].reshape(B, 1) rd_4 = torch.square(rd_sq) uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (s0 * rd_sq + s1 * rd_4) uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (s2 * rd_sq + s3 * rd_4) # Compute the derivative of uv_dist w.r.t. xr_yr. duv_dist_dxr_yr = uv.new_ones(B, N, 2, 2) if self.use_tangential: duv_dist_dxr_yr[..., 0, 0] = 1.0 + 6.0 * xr * p0 + 2.0 * yr * p1 offdiag = 2.0 * (xr * p1 + yr * p0) duv_dist_dxr_yr[..., 0, 1] = offdiag duv_dist_dxr_yr[..., 1, 0] = offdiag duv_dist_dxr_yr[..., 1, 1] = 1.0 + 6.0 * yr * p1 + 2.0 * xr * p0 if self.use_thin_prism: xr_yr_sq_norm = xr_sq + yr_sq temp1 = 2.0 * (s0 + 2.0 * s1 * xr_yr_sq_norm) duv_dist_dxr_yr[..., 0, 0] = duv_dist_dxr_yr[..., 0, 0] + (xr * temp1) duv_dist_dxr_yr[..., 0, 1] = duv_dist_dxr_yr[..., 0, 1] + (yr * temp1) temp2 = 2.0 * (s2 + 2.0 * s3 * xr_yr_sq_norm) duv_dist_dxr_yr[..., 1, 0] = duv_dist_dxr_yr[..., 1, 0] + (xr * temp2) duv_dist_dxr_yr[..., 1, 1] = duv_dist_dxr_yr[..., 1, 1] + (yr * temp2) mat = duv_dist_dxr_yr.reshape(-1, 2, 2) a = mat[:, 0, 0].reshape(-1, 1, 1) b = mat[:, 0, 1].reshape(-1, 1, 1) c = mat[:, 1, 0].reshape(-1, 1, 1) d = mat[:, 1, 1].reshape(-1, 1, 1) det = 1.0 / ((a * d) - (b * c)) top = torch.cat([d, -b], dim=-1) bot = torch.cat([-c, a], dim=-1) inv = det * torch.cat([top, bot], dim=-2) inv = inv.reshape(B, N, 2, 2) diff = uv_dist - uv_dist_est a = inv[..., 0, 0] b = inv[..., 0, 1] c = inv[..., 1, 0] d = inv[..., 1, 1] e = diff[..., 0] f = diff[..., 1] step = torch.stack([a * e + b * f, c * e + d * f], dim=-1) # Newton step. xr_yr = xr_yr + step # Compute theta using Newton's method. xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1) th = xr_yr_norm.clone() max_iters_radial = max_iters if self.use_radial else 0 c = ( torch.tensor([2.0 * i + 3 for i in range(3)], device=self.device) .reshape(1, 1, 3) .repeat(B, 1, 1) ) radial_params_num = self.params[..., 4:7].reshape(B, 1, 3) # Trust region parameters delta = torch.full((B, N, 1), 0.1, device=self.device) # Initial trust radius delta_max = torch.tensor(1.0, device=self.device) # Maximum trust radius eta = 0.1 # Acceptable reduction threshold for i in range(max_iters_radial): th_sq = th * th # th^2 # Compute powers of th^2 up to th^(12) theta_powers = torch.cat( [th_sq ** (i + 1) for i in range(3)], dim=-1 ) # Shape: (B, N, 6) # Compute th_radial: radial distortion model applied to th th_radial = 1.0 + torch.sum( theta_powers * radial_params_num, dim=-1, keepdim=True ) th_radial = th_radial * th # Multiply by th at the end # Compute derivative dthd_th dthd_th = 1.0 + torch.sum( c * radial_params_num * theta_powers, dim=-1, keepdim=True ) dthd_th = dthd_th # Already includes derivative terms # Compute residual residual = th_radial - xr_yr_norm # Shape: (B, N, 1) residual_norm = torch.norm(residual, dim=2, keepdim=True) # For each pixel # Check for convergence if torch.max(torch.abs(residual)) < eps: break # Avoid division by zero by adding a small epsilon safe_dthd_th = dthd_th.clone() zero_derivative_mask = dthd_th.abs() < eps safe_dthd_th[zero_derivative_mask] = eps # Compute Newton's step step = -residual / safe_dthd_th # Compute predicted reduction predicted_reduction = -(residual * step).sum(dim=2, keepdim=True) # Adjust step based on trust region step_norm = torch.norm(step, dim=2, keepdim=True) over_trust_mask = step_norm > delta # Scale step if it exceeds trust radius step_scaled = step.clone() step_scaled[over_trust_mask] = step[over_trust_mask] * ( delta[over_trust_mask] / step_norm[over_trust_mask] ) # Update theta th_new = th + step_scaled # Compute new residual th_sq_new = th_new * th_new theta_powers_new = torch.cat( [th_sq_new ** (j + 1) for j in range(3)], dim=-1 ) th_radial_new = 1.0 + torch.sum( theta_powers_new * radial_params_num, dim=-1, keepdim=True ) th_radial_new = th_radial_new * th_new residual_new = th_radial_new - xr_yr_norm residual_new_norm = torch.norm(residual_new, dim=2, keepdim=True) # Compute actual reduction actual_reduction = residual_norm - residual_new_norm # Compute ratio of actual to predicted reduction # predicted_reduction[predicted_reduction.abs() < eps] = eps #* torch.sign(predicted_reduction[predicted_reduction.abs() < eps]) rho = actual_reduction / predicted_reduction rho[(actual_reduction == 0) & (predicted_reduction == 0)] = 1.0 # Update trust radius delta delta_update_mask = rho > 0.5 delta[delta_update_mask] = torch.min( 2.0 * delta[delta_update_mask], delta_max ) delta_decrease_mask = rho < 0.2 delta[delta_decrease_mask] = 0.25 * delta[delta_decrease_mask] # Accept or reject the step accept_step_mask = rho > eta th = torch.where(accept_step_mask, th_new, th) # Compute the ray direction using theta and xr_yr. close_to_zero = torch.logical_and(th.abs() < eps, xr_yr_norm.abs() < eps) ray_dir = torch.where(close_to_zero, xr_yr, th / xr_yr_norm * xr_yr) ray = torch.cat([ray_dir, uv.new_ones(B, N, 1)], dim=2) ray = ray.reshape(B, H, W, 3).permute(0, 3, 1, 2) return ray class Fisheye624(Camera): def __init__(self, params): super().__init__(params=params, K=None) self.use_radial = self.params[..., 4:10].abs().sum() > 1e-6 self.use_tangential = self.params[..., 10:12].abs().sum() > 1e-6 self.use_thin_prism = self.params[..., 12:].abs().sum() > 1e-6 @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def project(self, xyz): eps = 1e-9 B, _, H, W = xyz.shape N = H * W xyz = xyz.permute(0, 2, 3, 1).reshape(B, N, 3) # Radial correction. z = xyz[:, :, 2].reshape(B, N, 1) z = torch.where(torch.abs(z) < eps, eps * torch.sign(z), z) ab = xyz[:, :, :2] / z r = torch.norm(ab, dim=-1, p=2, keepdim=True) th = torch.atan(r) th_divr = torch.where(r < eps, torch.ones_like(ab), ab / r) th_pow = torch.cat( [torch.pow(th, 3 + i * 2) for i in range(6)], dim=-1 ) # Create powers of th (th^3, th^5, ...) distortion_coeffs = self.params[:, 4:10].reshape(B, 1, 6) th_k = th + torch.sum(th_pow * distortion_coeffs, dim=-1, keepdim=True) xr_yr = th_k * th_divr uv_dist = xr_yr # Tangential correction. p0 = self.params[..., -6].reshape(B, 1) p1 = self.params[..., -5].reshape(B, 1) xr = xr_yr[:, :, 0].reshape(B, N) yr = xr_yr[:, :, 1].reshape(B, N) xr_yr_sq = torch.square(xr_yr) xr_sq = xr_yr_sq[:, :, 0].reshape(B, N) yr_sq = xr_yr_sq[:, :, 1].reshape(B, N) rd_sq = xr_sq + yr_sq uv_dist_tu = uv_dist[:, :, 0] + ( (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 ) uv_dist_tv = uv_dist[:, :, 1] + ( (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 ) uv_dist = torch.stack( [uv_dist_tu, uv_dist_tv], dim=-1 ) # Avoids in-place complaint. # Thin Prism correction. s0 = self.params[..., -4].reshape(B, 1) s1 = self.params[..., -3].reshape(B, 1) s2 = self.params[..., -2].reshape(B, 1) s3 = self.params[..., -1].reshape(B, 1) rd_4 = torch.square(rd_sq) uv_dist[:, :, 0] = uv_dist[:, :, 0] + (s0 * rd_sq + s1 * rd_4) uv_dist[:, :, 1] = uv_dist[:, :, 1] + (s2 * rd_sq + s3 * rd_4) # Finally, apply standard terms: focal length and camera centers. if self.params.shape[-1] == 15: fx_fy = self.params[..., 0].reshape(B, 1, 1) cx_cy = self.params[..., 1:3].reshape(B, 1, 2) else: fx_fy = self.params[..., 0:2].reshape(B, 1, 2) cx_cy = self.params[..., 2:4].reshape(B, 1, 2) result = uv_dist * fx_fy + cx_cy result = result.reshape(B, H, W, 2).permute(0, 3, 1, 2) invalid = ( (result[:, 0] < 0) | (result[:, 0] > W) | (result[:, 1] < 0) | (result[:, 1] > H) ) self.projection_mask = (~invalid).unsqueeze(1) self.overlap_mask = self.mask_overlap_projection(result) return result @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def unproject(self, uv, max_iters: int = 10): eps = 1e-3 B, _, H, W = uv.shape N = H * W uv = uv.permute(0, 2, 3, 1).reshape(B, N, 2) if self.params.shape[-1] == 15: fx_fy = self.params[..., 0].reshape(B, 1, 1) cx_cy = self.params[..., 1:3].reshape(B, 1, 2) else: fx_fy = self.params[..., 0:2].reshape(B, 1, 2) cx_cy = self.params[..., 2:4].reshape(B, 1, 2) uv_dist = (uv - cx_cy) / fx_fy # Compute xr_yr using Trust-region method. xr_yr = uv_dist.clone() max_iters_tanprism = ( max_iters if self.use_thin_prism or self.use_tangential else 0 ) for _ in range(max_iters_tanprism): uv_dist_est = xr_yr.clone() xr = xr_yr[..., 0].reshape(B, N) yr = xr_yr[..., 1].reshape(B, N) xr_yr_sq = torch.square(xr_yr) xr_sq = xr_yr_sq[..., 0].reshape(B, N) yr_sq = xr_yr_sq[..., 1].reshape(B, N) rd_sq = xr_sq + yr_sq if self.use_tangential: # Tangential terms. p0 = self.params[..., -6].reshape(B, 1) p1 = self.params[..., -5].reshape(B, 1) uv_dist_est[..., 0] = uv_dist_est[..., 0] + ( (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 ) uv_dist_est[..., 1] = uv_dist_est[..., 1] + ( (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 ) if self.use_thin_prism: # Thin Prism terms. s0 = self.params[..., -4].reshape(B, 1) s1 = self.params[..., -3].reshape(B, 1) s2 = self.params[..., -2].reshape(B, 1) s3 = self.params[..., -1].reshape(B, 1) rd_4 = torch.square(rd_sq) uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (s0 * rd_sq + s1 * rd_4) uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (s2 * rd_sq + s3 * rd_4) # Compute the derivative of uv_dist w.r.t. xr_yr. duv_dist_dxr_yr = uv.new_ones(B, N, 2, 2) if self.use_tangential: duv_dist_dxr_yr[..., 0, 0] = 1.0 + 6.0 * xr * p0 + 2.0 * yr * p1 offdiag = 2.0 * (xr * p1 + yr * p0) duv_dist_dxr_yr[..., 0, 1] = offdiag duv_dist_dxr_yr[..., 1, 0] = offdiag duv_dist_dxr_yr[..., 1, 1] = 1.0 + 6.0 * yr * p1 + 2.0 * xr * p0 if self.use_thin_prism: xr_yr_sq_norm = xr_sq + yr_sq temp1 = 2.0 * (s0 + 2.0 * s1 * xr_yr_sq_norm) duv_dist_dxr_yr[..., 0, 0] = duv_dist_dxr_yr[..., 0, 0] + (xr * temp1) duv_dist_dxr_yr[..., 0, 1] = duv_dist_dxr_yr[..., 0, 1] + (yr * temp1) temp2 = 2.0 * (s2 + 2.0 * s3 * xr_yr_sq_norm) duv_dist_dxr_yr[..., 1, 0] = duv_dist_dxr_yr[..., 1, 0] + (xr * temp2) duv_dist_dxr_yr[..., 1, 1] = duv_dist_dxr_yr[..., 1, 1] + (yr * temp2) mat = duv_dist_dxr_yr.reshape(-1, 2, 2) a = mat[:, 0, 0].reshape(-1, 1, 1) b = mat[:, 0, 1].reshape(-1, 1, 1) c = mat[:, 1, 0].reshape(-1, 1, 1) d = mat[:, 1, 1].reshape(-1, 1, 1) det = 1.0 / ((a * d) - (b * c)) top = torch.cat([d, -b], dim=-1) bot = torch.cat([-c, a], dim=-1) inv = det * torch.cat([top, bot], dim=-2) inv = inv.reshape(B, N, 2, 2) diff = uv_dist - uv_dist_est a = inv[..., 0, 0] b = inv[..., 0, 1] c = inv[..., 1, 0] d = inv[..., 1, 1] e = diff[..., 0] f = diff[..., 1] step = torch.stack([a * e + b * f, c * e + d * f], dim=-1) # Newton step. xr_yr = xr_yr + step # Compute theta using Newton's method. xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1) th = xr_yr_norm.clone() max_iters_radial = max_iters if self.use_radial else 0 c = ( torch.tensor([2.0 * i + 3 for i in range(6)], device=self.device) .reshape(1, 1, 6) .repeat(B, 1, 1) ) radial_params = self.params[..., 4:10].reshape(B, 1, 6) # Trust region parameters delta = torch.full((B, N, 1), 0.1, device=self.device) # Initial trust radius delta_max = torch.tensor(1.0, device=self.device) # Maximum trust radius eta = 0.1 # Acceptable reduction threshold for i in range(max_iters_radial): th_sq = th * th # Compute powers of th^2 up to th^(12) theta_powers = torch.cat( [th_sq ** (i + 1) for i in range(6)], dim=-1 ) # Shape: (B, N, 6) # Compute th_radial: radial distortion model applied to th th_radial = 1.0 + torch.sum( theta_powers * radial_params, dim=-1, keepdim=True ) th_radial = th_radial * th # Compute derivative dthd_th dthd_th = 1.0 + torch.sum( c * radial_params * theta_powers, dim=-1, keepdim=True ) # Compute residual residual = th_radial - xr_yr_norm # Shape: (B, N, 1) residual_norm = torch.norm(residual, dim=2, keepdim=True) # Check for convergence if torch.max(torch.abs(residual)) < eps: break # Avoid division by zero by adding a small epsilon safe_dthd_th = dthd_th.clone() zero_derivative_mask = dthd_th.abs() < eps safe_dthd_th[zero_derivative_mask] = eps # Compute Newton's step step = -residual / safe_dthd_th # Compute predicted reduction predicted_reduction = -(residual * step).sum(dim=2, keepdim=True) # Adjust step based on trust region step_norm = torch.norm(step, dim=2, keepdim=True) over_trust_mask = step_norm > delta # Scale step if it exceeds trust radius step_scaled = step.clone() step_scaled[over_trust_mask] = step[over_trust_mask] * ( delta[over_trust_mask] / step_norm[over_trust_mask] ) # Update theta th_new = th + step_scaled # Compute new residual th_sq_new = th_new * th_new theta_powers_new = torch.cat( [th_sq_new ** (j + 1) for j in range(6)], dim=-1 ) th_radial_new = 1.0 + torch.sum( theta_powers_new * radial_params, dim=-1, keepdim=True ) th_radial_new = th_radial_new * th_new residual_new = th_radial_new - xr_yr_norm residual_new_norm = torch.norm(residual_new, dim=2, keepdim=True) # Compute actual reduction actual_reduction = residual_norm - residual_new_norm # Compute ratio of actual to predicted reduction rho = actual_reduction / predicted_reduction rho[(actual_reduction == 0) & (predicted_reduction == 0)] = 1.0 # Update trust radius delta delta_update_mask = rho > 0.5 delta[delta_update_mask] = torch.min( 2.0 * delta[delta_update_mask], delta_max ) delta_decrease_mask = rho < 0.2 delta[delta_decrease_mask] = 0.25 * delta[delta_decrease_mask] # Accept or reject the step accept_step_mask = rho > eta th = torch.where(accept_step_mask, th_new, th) # Compute the ray direction using theta and xr_yr. close_to_zero = torch.logical_and(th.abs() < eps, xr_yr_norm.abs() < eps) ray_dir = torch.where(close_to_zero, xr_yr, torch.tan(th) / xr_yr_norm * xr_yr) ray = torch.cat([ray_dir, uv.new_ones(B, N, 1)], dim=2) ray = ray.reshape(B, H, W, 3).permute(0, 3, 1, 2) return ray class MEI(Camera): def __init__(self, params): super().__init__(params=params, K=None) # fx fy cx cy k1 k2 p1 p2 xi self.use_radial = self.params[..., 4:6].abs().sum() > 1e-6 self.use_tangential = self.params[..., 6:8].abs().sum() > 1e-6 @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def unproject(self, uv, max_iters: int = 20): eps = 1e-6 B, _, H, W = uv.shape N = H * W uv = uv.permute(0, 2, 3, 1).reshape(B, N, 2) k1, k2, p0, p1, xi = self.params[..., 4:9].unbind(dim=1) fx_fy = self.params[..., 0:2].reshape(B, 1, 2) cx_cy = self.params[..., 2:4].reshape(B, 1, 2) uv_dist = (uv - cx_cy) / fx_fy # Compute xr_yr using Newton's method. xr_yr = uv_dist.clone() # Initial guess. max_iters_tangential = max_iters if self.use_tangential else 0 for _ in range(max_iters_tangential): uv_dist_est = xr_yr.clone() # Tangential terms. xr = xr_yr[..., 0] yr = xr_yr[..., 1] xr_yr_sq = xr_yr**2 xr_sq = xr_yr_sq[..., 0] yr_sq = xr_yr_sq[..., 1] rd_sq = xr_sq + yr_sq uv_dist_est[..., 0] = uv_dist_est[..., 0] + ( (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 ) uv_dist_est[..., 1] = uv_dist_est[..., 1] + ( (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 ) # Compute the derivative of uv_dist w.r.t. xr_yr. duv_dist_dxr_yr = torch.ones((B, N, 2, 2), device=uv.device) duv_dist_dxr_yr[..., 0, 0] = 1.0 + 6.0 * xr * p0 + 2.0 * yr * p1 offdiag = 2.0 * (xr * p1 + yr * p0) duv_dist_dxr_yr[..., 0, 1] = offdiag duv_dist_dxr_yr[..., 1, 0] = offdiag duv_dist_dxr_yr[..., 1, 1] = 1.0 + 6.0 * yr * p1 + 2.0 * xr * p0 mat = duv_dist_dxr_yr.reshape(-1, 2, 2) a = mat[:, 0, 0].reshape(-1, 1, 1) b = mat[:, 0, 1].reshape(-1, 1, 1) c = mat[:, 1, 0].reshape(-1, 1, 1) d = mat[:, 1, 1].reshape(-1, 1, 1) det = 1.0 / ((a * d) - (b * c)) top = torch.cat([d, -b], dim=-1) bot = torch.cat([-c, a], dim=-1) inv = det * torch.cat([top, bot], dim=-2) inv = inv.reshape(B, N, 2, 2) diff = uv_dist - uv_dist_est a = inv[..., 0, 0] b = inv[..., 0, 1] c = inv[..., 1, 0] d = inv[..., 1, 1] e = diff[..., 0] f = diff[..., 1] step = torch.stack([a * e + b * f, c * e + d * f], dim=-1) # Newton step. xr_yr = xr_yr + step # Compute theta using Newton's method. xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1) th = xr_yr_norm.clone() max_iters_radial = max_iters if self.use_radial else 0 for _ in range(max_iters_radial): th_radial = 1.0 + k1 * torch.pow(th, 2) + k2 * torch.pow(th, 4) dthd_th = 1.0 + 3.0 * k1 * torch.pow(th, 2) + 5.0 * k2 * torch.pow(th, 4) th_radial = th_radial * th step = (xr_yr_norm - th_radial) / dthd_th # handle dthd_th close to 0. step = torch.where( torch.abs(dthd_th) > eps, step, torch.sign(step) * eps * 10.0 ) th = th + step # Compute the ray direction using theta and xr_yr. close_to_zero = (torch.abs(th) < eps) & (torch.abs(xr_yr_norm) < eps) ray_dir = torch.where(close_to_zero, xr_yr, th * xr_yr / xr_yr_norm) # Compute the 3D projective ray rho2_u = ( ray_dir.norm(p=2, dim=2, keepdim=True) ** 2 ) # B N 1 # x_c * x_c + y_c * y_c xi = xi.reshape(B, 1, 1) sqrt_term = torch.sqrt(1.0 + (1.0 - xi * xi) * rho2_u) P_z = 1.0 - xi * (rho2_u + 1.0) / (xi + sqrt_term) # Special case when xi is 1.0 (unit sphere projection ??) P_z = torch.where(xi == 1.0, (1.0 - rho2_u) / 2.0, P_z) ray = torch.cat([ray_dir, P_z], dim=-1) ray = ray.reshape(B, H, W, 3).permute(0, 3, 1, 2) return ray @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def project(self, xyz): is_flat = xyz.ndim == 3 B, N = xyz.shape[:2] if not is_flat: B, _, H, W = xyz.shape N = H * W xyz = xyz.permute(0, 2, 3, 1).reshape(B, N, 3) k1, k2, p0, p1, xi = self.params[..., 4:].unbind(dim=1) fx_fy = self.params[..., 0:2].reshape(B, 1, 2) cx_cy = self.params[..., 2:4].reshape(B, 1, 2) norm = xyz.norm(p=2, dim=-1, keepdim=True) ab = xyz[..., :-1] / (xyz[..., -1:] + xi.reshape(B, 1, 1) * norm) # radial correction r = ab.norm(dim=-1, p=2, keepdim=True) k1 = self.params[..., 4].reshape(B, 1, 1) k2 = self.params[..., 5].reshape(B, 1, 1) # ab / r * th * (1 + k1 * (th ** 2) + k2 * (th**4)) # but here r = th, no spherical distortion xr_yr = ab * (1 + k1 * (r**2) + k2 * (r**4)) # Tangential correction. uv_dist = xr_yr p0 = self.params[:, -3].reshape(B, 1) p1 = self.params[:, -2].reshape(B, 1) xr = xr_yr[..., 0].reshape(B, N) yr = xr_yr[..., 1].reshape(B, N) xr_yr_sq = torch.square(xr_yr) xr_sq = xr_yr_sq[:, :, 0].reshape(B, N) yr_sq = xr_yr_sq[:, :, 1].reshape(B, N) rd_sq = xr_sq + yr_sq uv_dist_tu = uv_dist[:, :, 0] + ( (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 ) uv_dist_tv = uv_dist[:, :, 1] + ( (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 ) uv_dist = torch.stack( [uv_dist_tu, uv_dist_tv], dim=-1 ) # Avoids in-place complaint. result = uv_dist * fx_fy + cx_cy if not is_flat: result = result.reshape(B, H, W, 2).permute(0, 3, 1, 2) invalid = ( (result[:, 0] < 0) | (result[:, 0] > W) | (result[:, 1] < 0) | (result[:, 1] > H) ) self.projection_mask = (~invalid).unsqueeze(1) # creates hole in the middle... ?? # self.overlap_mask = self.mask_overlap_projection(result) return result class BatchCamera(Camera): """ This is not to be used directly, but to be used as a wrapper around multiple cameras. It should expose only the `from_camera` method as it the only way to create a BatchCamera. """ def __init__(self, params, K, original_class, cameras): super().__init__(params, K) self.original_class = original_class self.cameras = cameras # Delegate these methods to original camera @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def project(self, points_3d): return torch.cat( [ camera.project(points_3d[i : i + 1]) for i, camera in enumerate(self.cameras) ] ) @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def unproject(self, points_2d): val = torch.cat( [camera.unproject(points_2d) for i, camera in enumerate(self.cameras)] ) return val def crop(self, left, top, right=None, bottom=None): val = torch.cat( [ camera.crop(left, top, right, bottom) for i, camera in enumerate(self.cameras) ] ) return val def resize(self, ratio): val = torch.cat([camera.resize(ratio) for i, camera in enumerate(self.cameras)]) return val def reconstruct(self, depth): val = torch.cat( [ camera.reconstruct(depth[i : i + 1]) for i, camera in enumerate(self.cameras) ] ) return val def get_projection_mask(self): return torch.cat( [camera.projection_mask for i, camera in enumerate(self.cameras)] ) def to(self, device, non_blocking=False): self = super().to(device, non_blocking=non_blocking) self.cameras = recursive_to( self.cameras, device, non_blocking=non_blocking, cls=Camera ) return self def reshape(self, *shape): # Reshape the intrinsic matrix (K) and params # we know that the shape of K is (..., 3, 3) and params is (..., 16) reshaped_K = self.K.reshape(*shape, 3, 3) reshaped_params = self.params.reshape(*shape, self.params.shape[-1]) self.cameras = np.array(self.cameras, dtype=object).reshape(shape).tolist() self.original_class = ( np.array(self.original_class, dtype=object).reshape(shape).tolist() ) # Create a new BatchCamera with reshaped K and params return BatchCamera( reshaped_params, reshaped_K, self.original_class, self.cameras ) def get_new_fov(self, new_shape, original_shape): return [ camera.get_new_fov(new_shape, original_shape) for i, camera in enumerate(self.cameras) ] def squeeze(self, dim): return BatchCamera( self.params.squeeze(dim), self.K.squeeze(dim), squeeze_list(self.original_class, dim=dim), squeeze_list(self.cameras, dim=dim), ) def __getitem__(self, idx): if isinstance(idx, int): return self.cameras[idx] elif isinstance(idx, slice): return BatchCamera( self.params[idx], self.K[idx], self.original_class[idx], self.cameras[idx], ) raise TypeError(f"Invalid index type: {type(idx)}") def __setitem__(self, idx, value): # If it's an integer index, return a single camera if isinstance(idx, int): self.cameras[idx] = value self.params[idx, :] = 0.0 self.params[idx, : value.params.shape[1]] = value.params[0] self.K[idx] = value.K[0] self.original_class[idx] = getattr( value, "original_class", value.__class__.__name__ ) # If it's a slice, return a new BatchCamera with sliced cameras elif isinstance(idx, slice): # Update each internal attribute using the slice self.params[idx] = value.params self.K[idx] = value.K self.original_class[idx] = value.original_class self.cameras[idx] = value.cameras def __len__(self): return len(self.cameras) @classmethod def from_camera(cls, camera): return cls(camera.params, camera.K, [camera.__class__.__name__], [camera]) @property def is_perspective(self): return [isinstance(camera, Pinhole) for camera in self.cameras] @property def is_spherical(self): return [isinstance(camera, Spherical) for camera in self.cameras] @property def is_eucm(self): return [isinstance(camera, EUCM) for camera in self.cameras] @property def is_fisheye(self): return [isinstance(camera, Fisheye624) for camera in self.cameras] @property def is_pinhole(self): return [isinstance(camera, Pinhole) for camera in self.cameras] @property def hfov(self): return [camera.hfov for camera in self.cameras] @property def vfov(self): return [camera.vfov for camera in self.cameras] @property def max_fov(self): return [camera.max_fov for camera in self.cameras] ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/chamfer_distance.py ================================================ import warnings from typing import Union import torch try: from unidepth.ops.knn import knn_points except ImportError as e: warnings.warn( "!! To run evaluation you need KNN. Please compile KNN: " "`cd unidepth/ops/knn with && bash compile.sh`." ) knn_points = lambda x : x def _validate_chamfer_reduction_inputs( batch_reduction: Union[str, None], point_reduction: str ): """Check the requested reductions are valid. Args: batch_reduction: Reduction operation to apply for the loss across the batch, can be one of ["mean", "sum"] or None. point_reduction: Reduction operation to apply for the loss across the points, can be one of ["mean", "sum"]. """ if batch_reduction is not None and batch_reduction not in ["mean", "sum"]: raise ValueError('batch_reduction must be one of ["mean", "sum"] or None') if point_reduction not in ["mean", "sum"]: raise ValueError('point_reduction must be one of ["mean", "sum"]') def _handle_pointcloud_input( points: torch.Tensor, lengths: Union[torch.Tensor, None], normals: Union[torch.Tensor, None], ): """ If points is an instance of Pointclouds, retrieve the padded points tensor along with the number of points per batch and the padded normals. Otherwise, return the input points (and normals) with the number of points per cloud set to the size of the second dimension of `points`. """ if points.ndim != 3: raise ValueError("Expected points to be of shape (N, P, D)") X = points if lengths is not None and (lengths.ndim != 1 or lengths.shape[0] != X.shape[0]): raise ValueError("Expected lengths to be of shape (N,)") if lengths is None: lengths = torch.full( (X.shape[0],), X.shape[1], dtype=torch.int64, device=points.device ) if normals is not None and normals.ndim != 3: raise ValueError("Expected normals to be of shape (N, P, 3") return X, lengths, normals class ChamferDistance(torch.nn.Module): def forward( self, x, y, x_lengths=None, y_lengths=None, x_normals=None, y_normals=None, weights=None, batch_reduction: Union[str, None] = "mean", point_reduction: str = "mean", ): """ Chamfer distance between two pointclouds x and y. Args: x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing a batch of point clouds with at most P1 points in each batch element, batch size N and feature dimension D. y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing a batch of point clouds with at most P2 points in each batch element, batch size N and feature dimension D. x_lengths: Optional LongTensor of shape (N,) giving the number of points in each cloud in x. y_lengths: Optional LongTensor of shape (N,) giving the number of points in each cloud in x. x_normals: Optional FloatTensor of shape (N, P1, D). y_normals: Optional FloatTensor of shape (N, P2, D). weights: Optional FloatTensor of shape (N,) giving weights for batch elements for reduction operation. batch_reduction: Reduction operation to apply for the loss across the batch, can be one of ["mean", "sum"] or None. point_reduction: Reduction operation to apply for the loss across the points, can be one of ["mean", "sum"]. Returns: 2-element tuple containing - **loss**: Tensor giving the reduced distance between the pointclouds in x and the pointclouds in y. - **loss_normals**: Tensor giving the reduced cosine distance of normals between pointclouds in x and pointclouds in y. Returns None if x_normals and y_normals are None. """ _validate_chamfer_reduction_inputs(batch_reduction, point_reduction) x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals) y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals) return_normals = x_normals is not None and y_normals is not None N, P1, D = x.shape P2 = y.shape[1] # Check if inputs are heterogeneous and create a lengths mask. is_x_heterogeneous = (x_lengths != P1).any() is_y_heterogeneous = (y_lengths != P2).any() x_mask = ( torch.arange(P1, device=x.device)[None] >= x_lengths[:, None] ) # shape [N, P1] y_mask = ( torch.arange(P2, device=y.device)[None] >= y_lengths[:, None] ) # shape [N, P2] if y.shape[0] != N or y.shape[2] != D: raise ValueError("y does not have the correct shape.") if weights is not None: if weights.size(0) != N: raise ValueError("weights must be of shape (N,).") if not (weights >= 0).all(): raise ValueError("weights cannot be negative.") if weights.sum() == 0.0: weights = weights.view(N, 1) if batch_reduction in ["mean", "sum"]: return ( (x.sum((1, 2)) * weights).sum() * 0.0, (x.sum((1, 2)) * weights).sum() * 0.0, ) return ( (x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0, ) x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1) y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1) cham_x = x_nn.dists[..., 0] # (N, P1) cham_y = y_nn.dists[..., 0] # (N, P2) if is_x_heterogeneous: cham_x[x_mask] = 0.0 if is_y_heterogeneous: cham_y[y_mask] = 0.0 if weights is not None: cham_x *= weights.view(N, 1) cham_y *= weights.view(N, 1) return cham_x, cham_y, x_nn.idx[..., -1], y_nn.idx[..., -1] ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/constants.py ================================================ """ Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ import math import torch OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) IMAGENET_DATASET_MEAN = (0.485, 0.456, 0.406) IMAGENET_DATASET_STD = (0.229, 0.224, 0.225) DEPTH_BINS = torch.cat( ( torch.logspace(math.log10(0.1), math.log10(180.0), steps=512), torch.tensor([260.0]), ), dim=0, ) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/coordinate.py ================================================ import torch def coords_grid(b, h, w, homogeneous=False, device=None, noisy=False): pixel_coords_x = torch.linspace(0.5, w - 0.5, w, device=device) pixel_coords_y = torch.linspace(0.5, h - 0.5, h, device=device) if noisy: # \pm 0.5px noise pixel_coords_x += torch.rand_like(pixel_coords_x) - 0.5 pixel_coords_y += torch.rand_like(pixel_coords_y) - 0.5 stacks = [pixel_coords_x.repeat(h, 1), pixel_coords_y.repeat(w, 1).t()] if homogeneous: ones = torch.ones_like(stacks[0]) # [H, W] stacks.append(ones) grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] if device is not None: grid = grid.to(device) return grid def normalize_coords(coords, h, w): c = torch.tensor([(w - 1) / 2.0, (h - 1) / 2.0], device=coords.device).view( 1, 2, 1, 1 ) return (coords - c) / c ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/distributed.py ================================================ import os import pickle import platform import subprocess import warnings import cv2 import torch import torch.utils.data.distributed from torch import distributed as dist from torch import multiprocessing as mp _LOCAL_PROCESS_GROUP = None def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def get_local_rank() -> int: """ Returns: The rank of the current process within the local (per-machine) process group. """ if not is_dist_avail_and_initialized(): return 0 assert _LOCAL_PROCESS_GROUP is not None return dist.get_rank(group=_LOCAL_PROCESS_GROUP) def get_local_size() -> int: """ Returns: The size of the per-machine process group, i.e. the number of processes per machine. """ if not is_dist_avail_and_initialized(): return 1 assert _LOCAL_PROCESS_GROUP is not None return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def barrier(): if not is_dist_avail_and_initialized(): return dist.barrier() def is_main_process(): return get_rank() == 0 def is_rank_zero(args): return args.rank == 0 def get_dist_info(): if dist.is_available() and dist.is_initialized(): rank = dist.get_rank() world_size = dist.get_world_size() else: rank = 0 world_size = 1 return rank, world_size def setup_multi_processes(cfg): """Setup multi-processing environment variables.""" # set multi-process start method as `fork` to speed up the training if platform.system() != "Windows": mp_start_method = cfg.get("mp_start_method", "fork") current_method = mp.get_start_method(allow_none=True) if current_method is not None and current_method != mp_start_method: warnings.warn( f"Multi-processing start method `{mp_start_method}` is " f"different from the previous setting `{current_method}`." f"It will be force set to `{mp_start_method}`. You can change " f"this behavior by changing `mp_start_method` in your config." ) mp.set_start_method(mp_start_method, force=True) # disable opencv multithreading to avoid system being overloaded # opencv_num_threads = cfg.get('opencv_num_threads', 0) # cv2.setNumThreads(opencv_num_threads) # setup OMP threads # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa # workers_per_gpu = cfg.get('workers_per_gpu', 4) # if 'OMP_NUM_THREADS' not in os.environ and workers_per_gpu > 1: # omp_num_threads = 1 # warnings.warn( # f'Setting OMP_NUM_THREADS environment variable for each process ' # f'to be {omp_num_threads} in default, to avoid your system being ' # f'overloaded, please further tune the variable for optimal ' # f'performance in your application as needed.') # os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) # setup MKL threads # if 'MKL_NUM_THREADS' not in os.environ and workers_per_gpu > 1: # mkl_num_threads = os.environ.get('OMP_NUM_THREADS', 1) # warnings.warn( # f'Setting MKL_NUM_THREADS environment variable for each process ' # f'to be {mkl_num_threads} in default, to avoid your system being ' # f'overloaded, please further tune the variable for optimal ' # f'performance in your application as needed.') # os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) def setup_slurm(backend: str, port: str) -> None: proc_id = int(os.environ["SLURM_PROCID"]) ntasks = int(os.environ["SLURM_NTASKS"]) node_list = os.environ["SLURM_NODELIST"] num_gpus = torch.cuda.device_count() torch.cuda.set_device(proc_id % num_gpus) addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") os.environ["MASTER_PORT"] = str(port) os.environ["MASTER_ADDR"] = addr os.environ["WORLD_SIZE"] = str(ntasks) os.environ["LOCAL_RANK"] = str(proc_id % num_gpus) os.environ["RANK"] = str(proc_id) print( proc_id, ntasks, num_gpus, proc_id % num_gpus, node_list, addr, os.environ["MASTER_PORT"], os.system("nvidia-smi -L"), ) dist.init_process_group(backend, rank=proc_id, world_size=ntasks) def sync_tensor_across_gpus(t, dim=0, cat=True): if t is None or not (dist.is_available() and dist.is_initialized()): return t t = torch.atleast_1d(t) group = dist.group.WORLD group_size = torch.distributed.get_world_size(group) local_size = torch.tensor(t.size(dim), device=t.device) all_sizes = [torch.zeros_like(local_size) for _ in range(group_size)] dist.all_gather(all_sizes, local_size) max_size = max(all_sizes) size_diff = max_size.item() - local_size.item() if size_diff: padding = torch.zeros(size_diff, device=t.device, dtype=t.dtype) t = torch.cat((t, padding)) gather_t_tensor = [torch.zeros_like(t) for _ in range(group_size)] dist.all_gather(gather_t_tensor, t) all_ts = [] for t, size in zip(gather_t_tensor, all_sizes): all_ts.append(t[:size]) if cat: return torch.cat(all_ts, dim=0) return all_ts def sync_string_across_gpus(keys: list[str], device, dim=0): keys_serialized = pickle.dumps(keys, protocol=pickle.HIGHEST_PROTOCOL) keys_serialized_tensor = ( torch.frombuffer(keys_serialized, dtype=torch.uint8).clone().to(device) ) keys_serialized_tensor = sync_tensor_across_gpus( keys_serialized_tensor, dim=0, cat=False ) keys = [ key for keys in keys_serialized_tensor for key in pickle.loads(bytes(keys.cpu().tolist())) ] return keys def create_local_process_group() -> None: num_workers_per_machine = torch.cuda.device_count() global _LOCAL_PROCESS_GROUP assert _LOCAL_PROCESS_GROUP is None assert get_world_size() % num_workers_per_machine == 0 num_machines = get_world_size() // num_workers_per_machine machine_rank = get_rank() // num_workers_per_machine for i in range(num_machines): ranks_on_i = list( range(i * num_workers_per_machine, (i + 1) * num_workers_per_machine) ) pg = dist.new_group(ranks_on_i) if i == machine_rank: _LOCAL_PROCESS_GROUP = pg def _get_global_gloo_group(): if dist.get_backend() == "nccl": return dist.new_group(backend="gloo") else: return dist.group.WORLD def all_gather(data, group=None): if get_world_size() == 1: return [data] if group is None: group = ( _get_global_gloo_group() ) # use CPU group by default, to reduce GPU RAM usage. world_size = dist.get_world_size(group) if world_size == 1: return [data] output = [None for _ in range(world_size)] dist.all_gather_object(output, data, group=group) return output def local_broadcast_process_authkey(): if get_local_size() == 1: return local_rank = get_local_rank() authkey = bytes(mp.current_process().authkey) all_keys = all_gather(authkey) local_leader_key = all_keys[get_rank() - local_rank] if authkey != local_leader_key: # print("Process authkey is different from the key of local leader! workers are launched independently ??") # print("Overwriting local authkey ...") mp.current_process().authkey = local_leader_key ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/ema_torch.py ================================================ """ Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ from __future__ import division, unicode_literals import contextlib import copy import weakref from math import tanh from typing import Iterable, Optional import torch class DummyExponentialMovingAverage: def __init__(self, *args, **kwargs): pass def _get_parameters(self, *args, **kwargs): pass def get_current_decay(self, *args, **kwargs): pass def update(self, *args, **kwargs): pass def copy_to(self, *args, **kwargs): pass def store(self, *args, **kwargs): return def restore(self, *args, **kwargs): return @contextlib.contextmanager def average_parameters(self, *args, **kwargs): try: yield finally: pass def to(self, *args, **kwargs): pass def state_dict(self, *args, **kwargs): pass def load_state_dict(self, *args, **kwargs): pass class ExponentialMovingAverage: """ Maintains (exponential) moving average of a set of parameters. Args: parameters: Iterable of `torch.nn.Parameter` (typically from `model.parameters()`). Note that EMA is computed on *all* provided parameters, regardless of whether or not they have `requires_grad = True`; this allows a single EMA object to be consistantly used even if which parameters are trainable changes step to step. If you want to some parameters in the EMA, do not pass them to the object in the first place. For example: ExponentialMovingAverage( parameters=[p for p in model.parameters() if p.requires_grad], decay=0.9 ) will ignore parameters that do not require grad. decay: The exponential decay. use_num_updates: Whether to use number of updates when computing averages. """ def __init__( self, parameters: Iterable[torch.nn.Parameter], decay: float, use_num_updates: bool = True, update_after_step: int = 10000, tau: int = 20000, switch: bool = False, ): if decay < 0.0 or decay > 1.0: raise ValueError("Decay must be between 0 and 1") self.decay = decay self.switch = switch # fi keeping EMA params in model after epochs self.num_updates = 0 if use_num_updates else None parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters] self.collected_params = None # By maintaining only a weakref to each parameter, # we maintain the old GC behaviour of ExponentialMovingAverage: # if the model goes out of scope but the ExponentialMovingAverage # is kept, no references to the model or its parameters will be # maintained, and the model will be cleaned up. self._params_refs = [weakref.ref(p) for p in parameters] self.update_after_step = update_after_step self.tau = tau def _get_parameters( self, parameters: Optional[Iterable[torch.nn.Parameter]] ) -> Iterable[torch.nn.Parameter]: if parameters is None: parameters = [p() for p in self._params_refs] if any(p is None for p in parameters): raise ValueError( "(One of) the parameters with which this ExponentialMovingAverage was initialized no longer exists (was garbage collected);" " please either provide `parameters` explicitly or keep the model to which they belong from being garbage collected." ) return parameters else: parameters = list(parameters) if len(parameters) != len(self.shadow_params): raise ValueError( "Number of parameters passed as argument is different " "from number of shadow parameters maintained by this " "ExponentialMovingAverage" ) return parameters def get_current_decay(self): epoch = max(self.num_updates - self.update_after_step - 1, 0.0) if epoch <= 0: return 0.0 value = tanh(epoch / self.tau) * self.decay return value def update(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None: """ Update currently maintained parameters. Call this every time the parameters are updated, such as the result of the `optimizer.step()` call. Args: parameters: Iterable of `torch.nn.Parameter`; usually the same set of parameters used to initialize this object. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = self._get_parameters(parameters) decay = self.get_current_decay() if self.num_updates is not None: self.num_updates += 1 one_minus_decay = 1.0 - decay with torch.no_grad(): for s_param, param in zip(self.shadow_params, parameters): tmp = s_param - param # tmp will be a new tensor so we can do in-place tmp.mul_(one_minus_decay) s_param.sub_(tmp) def copy_to( self, parameters: Optional[Iterable[torch.nn.Parameter]] = None ) -> None: """ Copy current averaged parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = self._get_parameters(parameters) for s_param, param in zip(self.shadow_params, parameters): param.data.copy_(s_param.data) def store(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None: """ Save the current parameters for restoring later. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. If `None`, the parameters of with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = self._get_parameters(parameters) self.collected_params = [param.detach().clone() for param in parameters] def restore( self, parameters: Optional[Iterable[torch.nn.Parameter]] = None ) -> None: """ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the `copy_to` method. After validation (or model saving), use this to restore the former parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ if self.collected_params is None: raise RuntimeError( "This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`" ) parameters = self._get_parameters(parameters) for c_param, param in zip(self.collected_params, parameters): param.data.copy_(c_param.data) @contextlib.contextmanager def average_parameters( self, parameters: Optional[Iterable[torch.nn.Parameter]] = None ): r""" Context manager for validation/inference with averaged parameters. Equivalent to: ema.store() ema.copy_to() try: ... finally: ema.restore() Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = self._get_parameters(parameters) self.store(parameters) self.copy_to(parameters) try: yield finally: if not self.switch: self.restore(parameters) def to(self, device=None, dtype=None) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. Args: device: like `device` argument to `torch.Tensor.to` """ # .to() on the tensors handles None correctly self.shadow_params = [ ( p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) ) for p in self.shadow_params ] if self.collected_params is not None: self.collected_params = [ ( p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) ) for p in self.collected_params ] return def state_dict(self) -> dict: r"""Returns the state of the ExponentialMovingAverage as a dict.""" # Following PyTorch conventions, references to tensors are returned: # "returns a reference to the state and not its copy!" - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict return { "decay": self.decay, "num_updates": self.num_updates, "shadow_params": self.shadow_params, "collected_params": self.collected_params, } def load_state_dict(self, state_dict: dict) -> None: r"""Loads the ExponentialMovingAverage state. Args: state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`. """ # deepcopy, to be consistent with module API state_dict = copy.deepcopy(state_dict) self.decay = state_dict["decay"] if self.decay < 0.0 or self.decay > 1.0: raise ValueError("Decay must be between 0 and 1") self.num_updates = state_dict["num_updates"] assert self.num_updates is None or isinstance( self.num_updates, int ), "Invalid num_updates" self.shadow_params = state_dict["shadow_params"] assert isinstance(self.shadow_params, list), "shadow_params must be a list" assert all( isinstance(p, torch.Tensor) for p in self.shadow_params ), "shadow_params must all be Tensors" self.collected_params = state_dict["collected_params"] if self.collected_params is not None: assert isinstance( self.collected_params, list ), "collected_params must be a list" assert all( isinstance(p, torch.Tensor) for p in self.collected_params ), "collected_params must all be Tensors" assert len(self.collected_params) == len( self.shadow_params ), "collected_params and shadow_params had different lengths" if len(self.shadow_params) == len(self._params_refs): # Consistant with torch.optim.Optimizer, cast things to consistant # device and dtype with the parameters params = [p() for p in self._params_refs] # If parameters have been garbage collected, just load the state # we were given without change. if not any(p is None for p in params): # ^ parameter references are still good for i, p in enumerate(params): self.shadow_params[i] = self.shadow_params[i].to( device=p.device, dtype=p.dtype ) if self.collected_params is not None: self.collected_params[i] = self.collected_params[i].to( device=p.device, dtype=p.dtype ) else: raise ValueError( "Tried to `load_state_dict()` with the wrong number of " "parameters in the saved state." ) ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/evaluation_depth.py ================================================ from collections import defaultdict from functools import partial import torch import torch.nn.functional as F from unidepth.utils.chamfer_distance import ChamferDistance chamfer_cls = ChamferDistance() def chamfer_dist(tensor1, tensor2): x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) dist1, dist2, idx1, idx2 = chamfer_cls( tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths ) return (torch.sqrt(dist1) + torch.sqrt(dist2)) / 2 def auc(tensor1, tensor2, thresholds): x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) dist1, dist2, idx1, idx2 = chamfer_cls( tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths ) # compute precision recall precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds] recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds] auc_value = torch.trapz( torch.tensor(precisions, device=tensor1.device), torch.tensor(recalls, device=tensor1.device), ) return auc_value def delta(tensor1, tensor2, exponent): inlier = torch.maximum((tensor1 / tensor2), (tensor2 / tensor1)) return (inlier < 1.25**exponent).to(torch.float32).mean() def tau(tensor1, tensor2, perc): inlier = torch.maximum((tensor1 / tensor2), (tensor2 / tensor1)) return (inlier < (1.0 + perc)).to(torch.float32).mean() def ssi(tensor1, tensor2): stability_mat = 1e-9 * torch.eye(2, device=tensor1.device) tensor2_one = torch.stack( [tensor2.detach(), torch.ones_like(tensor2).detach()], dim=1 ) scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ ( tensor2_one.T @ tensor1.unsqueeze(1) ) scale, shift = scale_shift.squeeze().chunk(2, dim=0) return tensor2 * scale + shift def si(tensor1, tensor2): return tensor2 * torch.median(tensor1) / torch.median(tensor2) def arel(tensor1, tensor2): tensor2 = tensor2 * torch.median(tensor1) / torch.median(tensor2) return (torch.abs(tensor1 - tensor2) / tensor1).mean() def d_auc(tensor1, tensor2): exponents = torch.linspace(0.01, 5.0, steps=100, device=tensor1.device) deltas = [delta(tensor1, tensor2, exponent) for exponent in exponents] return torch.trapz(torch.tensor(deltas, device=tensor1.device), exponents) / 5.0 def f1_score(tensor1, tensor2, thresholds): x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) dist1, dist2, idx1, idx2 = chamfer_cls( tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths ) # compute precision recall precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds] recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds] precisions = torch.tensor(precisions, device=tensor1.device) recalls = torch.tensor(recalls, device=tensor1.device) f1_thresholds = 2 * precisions * recalls / (precisions + recalls) f1_thresholds = torch.where( torch.isnan(f1_thresholds), torch.zeros_like(f1_thresholds), f1_thresholds ) f1_value = torch.trapz(f1_thresholds) / len(thresholds) return f1_value DICT_METRICS = { "d1": partial(delta, exponent=1.0), "d2": partial(delta, exponent=2.0), "d3": partial(delta, exponent=3.0), "rmse": lambda gt, pred: torch.sqrt(((gt - pred) ** 2).mean()), "rmselog": lambda gt, pred: torch.sqrt( ((torch.log(gt) - torch.log(pred)) ** 2).mean() ), "arel": lambda gt, pred: (torch.abs(gt - pred) / gt).mean(), "sqrel": lambda gt, pred: (((gt - pred) ** 2) / gt).mean(), "log10": lambda gt, pred: torch.abs(torch.log10(pred) - torch.log10(gt)).mean(), "silog": lambda gt, pred: 100 * torch.std(torch.log(pred) - torch.log(gt)).mean(), "medianlog": lambda gt, pred: 100 * (torch.log(pred) - torch.log(gt)).median().abs(), "d_auc": d_auc, "tau": partial(tau, perc=0.03), } DICT_METRICS_3D = { "MSE_3d": lambda gt, pred, thresholds: torch.norm(gt - pred, dim=0, p=2), "chamfer": lambda gt, pred, thresholds: chamfer_dist( gt.unsqueeze(0).permute(0, 2, 1), pred.unsqueeze(0).permute(0, 2, 1) ), "F1": lambda gt, pred, thresholds: f1_score( gt.unsqueeze(0).permute(0, 2, 1), pred.unsqueeze(0).permute(0, 2, 1), thresholds=thresholds, ), } DICT_METRICS_D = { "a1": lambda gt, pred: (torch.maximum((gt / pred), (pred / gt)) > 1.25**1.0).to( torch.float32 ), "abs_rel": lambda gt, pred: (torch.abs(gt - pred) / gt), } def eval_depth( gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, max_depth=None ): summary_metrics = defaultdict(list) preds = F.interpolate(preds, gts.shape[-2:], mode="bilinear") for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)): if max_depth is not None: mask = mask & (gt <= max_depth) for name, fn in DICT_METRICS.items(): if name in ["tau", "d1", "arel"]: for rescale_fn in ["ssi", "si"]: summary_metrics[f"{name}_{rescale_fn}"].append( fn(gt[mask], eval(rescale_fn)(gt[mask], pred[mask])) ) summary_metrics[name].append(fn(gt[mask], pred[mask]).mean()) return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()} def eval_3d( gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, thresholds=None ): summary_metrics = defaultdict(list) ratio = min( 1.0, (240 * 320 / masks.sum()) ** 0.5 ) # rescale to avoid OOM during eval, FIXME h_max, w_max = int(gts.shape[-2] * ratio), int(gts.shape[-1] * ratio) gts = F.interpolate(gts, size=(h_max, w_max), mode="nearest-exact") preds = F.interpolate(preds, size=(h_max, w_max), mode="nearest-exact") masks = F.interpolate( masks.float(), size=(h_max, w_max), mode="nearest-exact" ).bool() for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)): if not torch.any(mask): continue for name, fn in DICT_METRICS_3D.items(): summary_metrics[name].append( fn(gt[:, mask.squeeze()], pred[:, mask.squeeze()], thresholds).mean() ) return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()} ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/geometric.py ================================================ """ Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ from typing import Tuple import torch from torch.nn import functional as F @torch.jit.script def generate_rays( camera_intrinsics: torch.Tensor, image_shape: Tuple[int, int], noisy: bool = False ): batch_size, device, dtype = ( camera_intrinsics.shape[0], camera_intrinsics.device, camera_intrinsics.dtype, ) height, width = image_shape # Generate grid of pixel coordinates pixel_coords_x = torch.linspace(0, width - 1, width, device=device, dtype=dtype) pixel_coords_y = torch.linspace(0, height - 1, height, device=device, dtype=dtype) if noisy: pixel_coords_x += torch.rand_like(pixel_coords_x) - 0.5 pixel_coords_y += torch.rand_like(pixel_coords_y) - 0.5 pixel_coords = torch.stack( [pixel_coords_x.repeat(height, 1), pixel_coords_y.repeat(width, 1).t()], dim=2 ) # (H, W, 2) pixel_coords = pixel_coords + 0.5 # Calculate ray directions intrinsics_inv = torch.eye(3, device=device).unsqueeze(0).repeat(batch_size, 1, 1) intrinsics_inv[:, 0, 0] = 1.0 / camera_intrinsics[:, 0, 0] intrinsics_inv[:, 1, 1] = 1.0 / camera_intrinsics[:, 1, 1] intrinsics_inv[:, 0, 2] = -camera_intrinsics[:, 0, 2] / camera_intrinsics[:, 0, 0] intrinsics_inv[:, 1, 2] = -camera_intrinsics[:, 1, 2] / camera_intrinsics[:, 1, 1] homogeneous_coords = torch.cat( [pixel_coords, torch.ones_like(pixel_coords[:, :, :1])], dim=2 ) # (H, W, 3) ray_directions = torch.matmul( intrinsics_inv, homogeneous_coords.permute(2, 0, 1).flatten(1) ) # (3, H*W) ray_directions = F.normalize(ray_directions, dim=1) # (B, 3, H*W) ray_directions = ray_directions.permute(0, 2, 1) # (B, H*W, 3) theta = torch.atan2(ray_directions[..., 0], ray_directions[..., -1]) phi = torch.acos(ray_directions[..., 1]) # pitch = torch.asin(ray_directions[..., 1]) # roll = torch.atan2(ray_directions[..., 0], - ray_directions[..., 1]) angles = torch.stack([theta, phi], dim=-1) return ray_directions, angles @torch.jit.script def spherical_zbuffer_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor: theta = spherical_tensor[..., 0] # Extract polar angle phi = spherical_tensor[..., 1] # Extract azimuthal angle z = spherical_tensor[..., 2] # Extract zbuffer depth # y = r * cos(phi) # x = r * sin(phi) * sin(theta) # z = r * sin(phi) * cos(theta) # => # r = z / sin(phi) / cos(theta) # y = z / (sin(phi) / cos(phi)) / cos(theta) # x = z * sin(theta) / cos(theta) x = z * torch.tan(theta) y = z / torch.tan(phi) / torch.cos(theta) euclidean_tensor = torch.stack((x, y, z), dim=-1) return euclidean_tensor @torch.jit.script def spherical_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor: theta = spherical_tensor[..., 0] # Extract polar angle phi = spherical_tensor[..., 1] # Extract azimuthal angle r = spherical_tensor[..., 2] # Extract radius # y = r * cos(phi) # x = r * sin(phi) * sin(theta) # z = r * sin(phi) * cos(theta) x = r * torch.sin(phi) * torch.sin(theta) y = r * torch.cos(phi) z = r * torch.cos(theta) * torch.sin(phi) euclidean_tensor = torch.stack((x, y, z), dim=-1) return euclidean_tensor @torch.jit.script def euclidean_to_spherical(spherical_tensor: torch.Tensor) -> torch.Tensor: x = spherical_tensor[..., 0] # Extract polar angle y = spherical_tensor[..., 1] # Extract azimuthal angle z = spherical_tensor[..., 2] # Extract radius # y = r * cos(phi) # x = r * sin(phi) * sin(theta) # z = r * sin(phi) * cos(theta) r = torch.sqrt(x**2 + y**2 + z**2) theta = torch.atan2(x / r, z / r) phi = torch.acos(y / r) euclidean_tensor = torch.stack((theta, phi, r), dim=-1) return euclidean_tensor @torch.jit.script def euclidean_to_spherical_zbuffer(euclidean_tensor: torch.Tensor) -> torch.Tensor: pitch = torch.asin(euclidean_tensor[..., 1]) yaw = torch.atan2(euclidean_tensor[..., 0], euclidean_tensor[..., -1]) z = euclidean_tensor[..., 2] # Extract zbuffer depth euclidean_tensor = torch.stack((pitch, yaw, z), dim=-1) return euclidean_tensor @torch.jit.script def unproject_points( depth: torch.Tensor, camera_intrinsics: torch.Tensor ) -> torch.Tensor: """ Unprojects a batch of depth maps to 3D point clouds using camera intrinsics. Args: depth (torch.Tensor): Batch of depth maps of shape (B, 1, H, W). camera_intrinsics (torch.Tensor): Camera intrinsic matrix of shape (B, 3, 3). Returns: torch.Tensor: Batch of 3D point clouds of shape (B, 3, H, W). """ batch_size, _, height, width = depth.shape device = depth.device # Create pixel grid y_coords, x_coords = torch.meshgrid( torch.arange(height, device=device), torch.arange(width, device=device), indexing="ij", ) pixel_coords = torch.stack((x_coords, y_coords), dim=-1) # (H, W, 2) # Get homogeneous coords (u v 1) pixel_coords_homogeneous = torch.cat( (pixel_coords, torch.ones((height, width, 1), device=device)), dim=-1 ) pixel_coords_homogeneous = pixel_coords_homogeneous.permute(2, 0, 1).flatten( 1 ) # (3, H*W) # Apply K^-1 @ (u v 1): [B, 3, 3] @ [3, H*W] -> [B, 3, H*W] unprojected_points = torch.matmul( torch.inverse(camera_intrinsics), pixel_coords_homogeneous ) # (B, 3, H*W) unprojected_points = unprojected_points.view( batch_size, 3, height, width ) # (B, 3, H, W) unprojected_points = unprojected_points * depth # (B, 3, H, W) return unprojected_points @torch.jit.script def project_points( points_3d: torch.Tensor, intrinsic_matrix: torch.Tensor, image_shape: Tuple[int, int], ) -> torch.Tensor: # Project 3D points onto the image plane via intrinsics (u v w) = (x y z) @ K^T points_2d = torch.matmul(points_3d, intrinsic_matrix.transpose(1, 2)) # Normalize projected points: (u v w) -> (u / w, v / w, 1) points_2d = points_2d[..., :2] / points_2d[..., 2:] points_2d = points_2d.int() # points need to be inside the image (can it diverge onto all points out???) valid_mask = ( (points_2d[..., 0] >= 0) & (points_2d[..., 0] < image_shape[1]) & (points_2d[..., 1] >= 0) & (points_2d[..., 1] < image_shape[0]) ) # Calculate the flat indices of the valid pixels flat_points_2d = points_2d[..., 0] + points_2d[..., 1] * image_shape[1] flat_indices = flat_points_2d.long() # Create depth maps and counts using scatter_add, (B, H, W) depth_maps = torch.zeros( [points_3d.shape[0], *image_shape], device=points_3d.device ) counts = torch.zeros([points_3d.shape[0], *image_shape], device=points_3d.device) # Loop over batches to apply masks and accumulate depth/count values for i in range(points_3d.shape[0]): valid_indices = flat_indices[i, valid_mask[i]] depth_maps[i].view(-1).scatter_add_( 0, valid_indices, points_3d[i, valid_mask[i], 2] ) counts[i].view(-1).scatter_add_( 0, valid_indices, torch.ones_like(points_3d[i, valid_mask[i], 2]) ) # Calculate mean depth for each pixel in each batch mean_depth_maps = depth_maps / counts.clamp(min=1.0) return mean_depth_maps.reshape(-1, 1, *image_shape) # (B, 1, H, W) @torch.jit.script def downsample(data: torch.Tensor, downsample_factor: int = 2): N, _, H, W = data.shape data = data.view( N, H // downsample_factor, downsample_factor, W // downsample_factor, downsample_factor, 1, ) data = data.permute(0, 1, 3, 5, 2, 4).contiguous() data = data.view(-1, downsample_factor * downsample_factor) data_tmp = torch.where(data == 0.0, 1e5 * torch.ones_like(data), data) data = torch.min(data_tmp, dim=-1).values data = data.view(N, 1, H // downsample_factor, W // downsample_factor) data = torch.where(data > 1000, torch.zeros_like(data), data) return data @torch.jit.script def flat_interpolate( flat_tensor: torch.Tensor, old: Tuple[int, int], new: Tuple[int, int], antialias: bool = True, mode: str = "bilinear", ) -> torch.Tensor: if old[0] == new[0] and old[1] == new[1]: return flat_tensor tensor = flat_tensor.view(flat_tensor.shape[0], old[0], old[1], -1).permute( 0, 3, 1, 2 ) # b c h w tensor_interp = F.interpolate( tensor, size=(new[0], new[1]), mode=mode, align_corners=False, antialias=antialias, ) flat_tensor_interp = tensor_interp.view( flat_tensor.shape[0], -1, new[0] * new[1] ).permute( 0, 2, 1 ) # b (h w) c return flat_tensor_interp.contiguous() @torch.jit.script def dilate(image, kernel_size: int | tuple[int, int]): if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) device, dtype = image.device, image.dtype padding = (kernel_size[0] // 2, kernel_size[1] // 2) kernel = torch.ones((1, 1, *kernel_size), dtype=torch.float32, device=image.device) dilated_image = F.conv2d(image.float(), kernel, padding=padding, stride=1) dilated_image = torch.where( dilated_image > 0, torch.tensor(1.0, device=device), torch.tensor(0.0, device=device), ) return dilated_image.to(dtype) @torch.jit.script def erode(image, kernel_size: int | tuple[int, int]): if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) device, dtype = image.device, image.dtype padding = (kernel_size[0] // 2, kernel_size[1] // 2) kernel = torch.ones((1, 1, *kernel_size), dtype=torch.float32, device=image.device) eroded_image = F.conv2d(image.float(), kernel, padding=padding, stride=1) eroded_image = torch.where( eroded_image == (kernel_size[0] * kernel_size[1]), torch.tensor(1.0, device=device), torch.tensor(0.0, device=device), ) return eroded_image.to(dtype) @torch.jit.script def iou(mask1: torch.Tensor, mask2: torch.Tensor) -> torch.Tensor: device = mask1.device # Ensure the masks are binary (0 or 1) mask1 = mask1.to(torch.bool) mask2 = mask2.to(torch.bool) # Compute intersection and union intersection = torch.sum(mask1 & mask2).to(torch.float32) union = torch.sum(mask1 | mask2).to(torch.float32) # Compute IoU iou = intersection / union.clip(min=1.0) return iou ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/misc.py ================================================ """ Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ from functools import wraps from time import time import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, reduce, repeat from scipy import interpolate @torch.jit.script def max_stack(tensors: list[torch.Tensor]) -> torch.Tensor: if len(tensors) == 1: return tensors[0] return torch.stack(tensors, dim=-1).max(dim=-1).values def last_stack(tensors: list[torch.Tensor]) -> torch.Tensor: return tensors[-1] def first_stack(tensors: list[torch.Tensor]) -> torch.Tensor: return tensors[0] @torch.jit.script def softmax_stack( tensors: list[torch.Tensor], temperature: float = 1.0 ) -> torch.Tensor: if len(tensors) == 1: return tensors[0] return F.softmax(torch.stack(tensors, dim=-1) / temperature, dim=-1).sum(dim=-1) @torch.jit.script def mean_stack(tensors: list[torch.Tensor]) -> torch.Tensor: if len(tensors) == 1: return tensors[0] return torch.stack(tensors, dim=-1).mean(dim=-1) @torch.jit.script def sum_stack(tensors: list[torch.Tensor]) -> torch.Tensor: if len(tensors) == 1: return tensors[0] return torch.stack(tensors, dim=-1).sum(dim=-1) def convert_module_to_f16(l): """ Convert primitive modules to float16. """ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): l.weight.data = l.weight.data.half() if l.bias is not None: l.bias.data = l.bias.data.half() def convert_module_to_f32(l): """ Convert primitive modules to float32, undoing convert_module_to_f16(). """ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): l.weight.data = l.weight.data.float() if l.bias is not None: l.bias.data = l.bias.data.float() def format_seconds(seconds): minutes, seconds = divmod(seconds, 60) hours, minutes = divmod(minutes, 60) return f"{hours:d}:{minutes:02d}:{seconds:02d}" def get_params(module, lr, wd): skip_list = {} skip_keywords = {} if hasattr(module, "no_weight_decay"): skip_list = module.no_weight_decay() if hasattr(module, "no_weight_decay_keywords"): skip_keywords = module.no_weight_decay_keywords() has_decay = [] no_decay = [] for name, param in module.named_parameters(): if not param.requires_grad: continue # frozen weights if ( (name in skip_list) or any((kw in name for kw in skip_keywords)) or len(param.shape) == 1 or name.endswith(".gamma") or name.endswith(".beta") or name.endswith(".bias") ): # if (name in skip_list) or any((kw in name for kw in skip_keywords)) or len(param.shape) == 1: no_decay.append(param) else: has_decay.append(param) group1 = { "params": has_decay, "weight_decay": wd, "lr": lr, "weight_decay_init": wd, "weight_decay_base": wd, # "lr_init": lr, "lr_base": lr, } group2 = { "params": no_decay, "weight_decay": 0.0, "lr": lr, "weight_decay_init": 0.0, "weight_decay_base": 0.0, "weight_decay_final": 0.0, # "lr_init": lr, "lr_base": lr, } return [group1, group2], [lr, lr] def get_num_layer_for_swin(var_name, num_max_layer, layers_per_stage): if var_name in ("cls_token", "mask_token", "pos_embed", "absolute_pos_embed"): return 0 elif var_name.startswith("patch_embed"): return 0 elif var_name.startswith("layers"): if var_name.split(".")[2] == "blocks": stage_id = int(var_name.split(".")[1]) layer_id = int(var_name.split(".")[3]) + sum(layers_per_stage[:stage_id]) return layer_id + 1 elif var_name.split(".")[2] == "downsample": stage_id = int(var_name.split(".")[1]) layer_id = sum(layers_per_stage[: stage_id + 1]) return layer_id else: return num_max_layer - 1 def get_params_layerdecayswin(module, lr, wd, ld): skip_list = {} skip_keywords = {} if hasattr(module, "no_weight_decay"): skip_list = module.no_weight_decay() if hasattr(module, "no_weight_decay_keywords"): skip_keywords = module.no_weight_decay_keywords() layers_per_stage = module.depths num_layers = sum(layers_per_stage) + 1 lrs = [] params = [] for name, param in module.named_parameters(): if not param.requires_grad: print(f"{name} frozen") continue # frozen weights layer_id = get_num_layer_for_swin(name, num_layers, layers_per_stage) lr_cur = lr * ld ** (num_layers - layer_id - 1) # if (name in skip_list) or any((kw in name for kw in skip_keywords)) or len(param.shape) == 1 or name.endswith(".bias"): if (name in skip_list) or any((kw in name for kw in skip_keywords)): wd_cur = 0.0 else: wd_cur = wd params.append({"params": param, "weight_decay": wd_cur, "lr": lr_cur}) lrs.append(lr_cur) return params, lrs def log(t, eps: float = 1e-5): return torch.log(t.clamp(min=eps)) def l2norm(t): return F.normalize(t, dim=-1) def exists(val): return val is not None def identity(t, *args, **kwargs): return t def divisible_by(numer, denom): return (numer % denom) == 0 def first(arr, d=None): if len(arr) == 0: return d return arr[0] def default(val, d): if exists(val): return val return d() if callable(d) else d def maybe(fn): @wraps(fn) def inner(x): if not exists(x): return x return fn(x) return inner def once(fn): called = False @wraps(fn) def inner(x): nonlocal called if called: return called = True return fn(x) return inner def _many(fn): @wraps(fn) def inner(tensors, pattern, **kwargs): return (fn(tensor, pattern, **kwargs) for tensor in tensors) return inner rearrange_many = _many(rearrange) repeat_many = _many(repeat) reduce_many = _many(reduce) def load_pretrained(state_dict, checkpoint): checkpoint_model = checkpoint["model"] if any([True if "encoder." in k else False for k in checkpoint_model.keys()]): checkpoint_model = { k.replace("encoder.", ""): v for k, v in checkpoint_model.items() if k.startswith("encoder.") } print("Detect pre-trained model, remove [encoder.] prefix.") else: print("Detect non-pre-trained model, pass without doing anything.") print(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........") checkpoint = load_checkpoint_swin(state_dict, checkpoint_model) def load_checkpoint_swin(model, checkpoint_model): state_dict = model.state_dict() # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size all_keys = list(checkpoint_model.keys()) for key in all_keys: if "relative_position_bias_table" in key: relative_position_bias_table_pretrained = checkpoint_model[key] relative_position_bias_table_current = state_dict[key] L1, nH1 = relative_position_bias_table_pretrained.size() L2, nH2 = relative_position_bias_table_current.size() if nH1 != nH2: print(f"Error in loading {key}, passing......") else: if L1 != L2: print(f"{key}: Interpolate relative_position_bias_table using geo.") src_size = int(L1**0.5) dst_size = int(L2**0.5) def geometric_progression(a, r, n): return a * (1.0 - r**n) / (1.0 - r) left, right = 1.01, 1.5 while right - left > 1e-6: q = (left + right) / 2.0 gp = geometric_progression(1, q, src_size // 2) if gp > dst_size // 2: right = q else: left = q # if q > 1.090307: # q = 1.090307 dis = [] cur = 1 for i in range(src_size // 2): dis.append(cur) cur += q ** (i + 1) r_ids = [-_ for _ in reversed(dis)] x = r_ids + [0] + dis y = r_ids + [0] + dis t = dst_size // 2.0 dx = np.arange(-t, t + 0.1, 1.0) dy = np.arange(-t, t + 0.1, 1.0) print("Original positions = %s" % str(x)) print("Target positions = %s" % str(dx)) all_rel_pos_bias = [] for i in range(nH1): z = ( relative_position_bias_table_pretrained[:, i] .view(src_size, src_size) .float() .numpy() ) f_cubic = interpolate.interp2d(x, y, z, kind="cubic") all_rel_pos_bias.append( torch.Tensor(f_cubic(dx, dy)) .contiguous() .view(-1, 1) .to(relative_position_bias_table_pretrained.device) ) new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) checkpoint_model[key] = new_rel_pos_bias # delete relative_position_index since we always re-init it relative_position_index_keys = [ k for k in checkpoint_model.keys() if "relative_position_index" in k ] for k in relative_position_index_keys: del checkpoint_model[k] # delete relative_coords_table since we always re-init it relative_coords_table_keys = [ k for k in checkpoint_model.keys() if "relative_coords_table" in k ] for k in relative_coords_table_keys: del checkpoint_model[k] # # re-map keys due to name change rpe_mlp_keys = [k for k in checkpoint_model.keys() if "cpb_mlp" in k] for k in rpe_mlp_keys: checkpoint_model[k.replace("cpb_mlp", "rpe_mlp")] = checkpoint_model.pop(k) # delete attn_mask since we always re-init it attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k] for k in attn_mask_keys: del checkpoint_model[k] encoder_keys = [k for k in checkpoint_model.keys() if k.startswith("encoder.")] for k in encoder_keys: checkpoint_model[k.replace("encoder.", "")] = checkpoint_model.pop(k) return checkpoint_model def add_padding_metas(out, image_metas): device = out.device # left, right, top, bottom paddings = [img_meta.get("paddings", [0] * 4) for img_meta in image_metas] paddings = torch.stack(paddings).to(device) outs = [F.pad(o, padding, value=0.0) for padding, o in zip(paddings, out)] return torch.stack(outs) # left, right, top, bottom def remove_padding(out, paddings): H, W = out.shape[-2:] outs = [ o[..., padding[2] : H - padding[3], padding[0] : W - padding[1]] for padding, o in zip(paddings, out) ] return torch.stack(outs) def remove_padding_metas(out, image_metas): B, C, H, W = out.shape device = out.device # left, right, top, bottom paddings = [ torch.tensor(img_meta.get("paddings", [0] * 4)) for img_meta in image_metas ] return remove_padding(out, paddings) def ssi_helper(tensor1, tensor2): stability_mat = 1e-4 * torch.eye(2, device=tensor1.device) tensor2_one = torch.stack([tensor2, torch.ones_like(tensor2)], dim=1) scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ ( tensor2_one.T @ tensor1.unsqueeze(1) ) scale, shift = scale_shift.squeeze().chunk(2, dim=0) return scale, shift def calculate_mean_values(names, values): # Create a defaultdict to store sum and count for each name name_values = {name: {} for name in names} # Iterate through the lists and accumulate values for each name for name, value in zip(names, values): name_values[name]["sum"] = name_values[name].get("sum", 0.0) + value name_values[name]["count"] = name_values[name].get("count", 0.0) + 1 # Calculate mean values and create the output dictionary output_dict = { name: name_values[name]["sum"] / name_values[name]["count"] for name in name_values } return output_dict def remove_leading_dim(infos): if isinstance(infos, dict): return {k: remove_leading_dim(v) for k, v in infos.items()} elif isinstance(infos, torch.Tensor): return infos.squeeze(0) else: return infos def recursive_index(infos, index): if isinstance(infos, dict): return {k: recursive_index(v, index) for k, v in infos.items()} elif isinstance(infos, torch.Tensor): return infos[index] else: return infos def to_cpu(infos): if isinstance(infos, dict): return {k: to_cpu(v) for k, v in infos.items()} elif isinstance(infos, torch.Tensor): return infos.detach() else: return infos def recursive_to(infos, device, non_blocking, cls): if isinstance(infos, dict): return {k: recursive_to(v, device, non_blocking, cls) for k, v in infos.items()} elif isinstance(infos, list): return [recursive_to(v, device, non_blocking, cls) for v in infos] elif isinstance(infos, cls): return infos.to(device, non_blocking=non_blocking) else: return infos def masked_mean( data: torch.Tensor, mask: torch.Tensor | None = None, dim: list[int] | None = None, keepdim: bool = False, ) -> torch.Tensor: dim = dim if dim is not None else list(range(data.dim())) if mask is None: return data.mean(dim=dim, keepdim=keepdim) mask = mask.float() mask_sum = torch.sum(mask, dim=dim, keepdim=True) mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( mask_sum, min=1.0 ) return mask_mean.squeeze(dim) if not keepdim else mask_mean class ProfileMethod: def __init__(self, model, func_name, track_statistics=True, verbose=False): self.model = model self.func_name = func_name self.verbose = verbose self.track_statistics = track_statistics self.timings = [] def __enter__(self): # Start timing if self.verbose: if torch.cuda.is_available(): torch.cuda.synchronize() self.start_time = time() return self def __exit__(self, exc_type, exc_val, exc_tb): if self.verbose: if torch.cuda.is_available(): torch.cuda.synchronize() self.end_time = time() elapsed_time = self.end_time - self.start_time self.timings.append(elapsed_time) if self.track_statistics and len(self.timings) > 25: # Compute statistics if tracking timings_array = np.array(self.timings) mean_time = np.mean(timings_array) std_time = np.std(timings_array) quantiles = np.percentile(timings_array, [0, 25, 50, 75, 100]) print( f"{self.model.__class__.__name__}.{self.func_name} took {elapsed_time:.4f} seconds" ) print(f"Mean Time: {mean_time:.4f} seconds") print(f"Std Time: {std_time:.4f} seconds") print( f"Quantiles: Min={quantiles[0]:.4f}, 25%={quantiles[1]:.4f}, Median={quantiles[2]:.4f}, 75%={quantiles[3]:.4f}, Max={quantiles[4]:.4f}" ) else: print( f"{self.model.__class__.__name__}.{self.func_name} took {elapsed_time:.4f} seconds" ) def profile_method(track_statistics=True, verbose=False): def decorator(func): @wraps(func) def wrapper(self, *args, **kwargs): with ProfileMethod(self, func.__name__, track_statistics, verbose): return func(self, *args, **kwargs) return wrapper return decorator class ProfileFunction: def __init__(self, func_name, track_statistics=True, verbose=False): self.func_name = func_name self.verbose = verbose self.track_statistics = track_statistics self.timings = [] def __enter__(self): # Start timing if self.verbose: if torch.cuda.is_available(): torch.cuda.synchronize() self.start_time = time() return self def __exit__(self, exc_type, exc_val, exc_tb): if self.verbose: if torch.cuda.is_available(): torch.cuda.synchronize() self.end_time = time() elapsed_time = self.end_time - self.start_time self.timings.append(elapsed_time) if self.track_statistics and len(self.timings) > 25: # Compute statistics if tracking timings_array = np.array(self.timings) mean_time = np.mean(timings_array) std_time = np.std(timings_array) quantiles = np.percentile(timings_array, [0, 25, 50, 75, 100]) print(f"{self.func_name} took {elapsed_time:.4f} seconds") print(f"Mean Time: {mean_time:.4f} seconds") print(f"Std Time: {std_time:.4f} seconds") print( f"Quantiles: Min={quantiles[0]:.4f}, 25%={quantiles[1]:.4f}, Median={quantiles[2]:.4f}, 75%={quantiles[3]:.4f}, Max={quantiles[4]:.4f}" ) else: print(f"{self.func_name} took {elapsed_time:.4f} seconds") def profile_function(track_statistics=True, verbose=False): def decorator(func): @wraps(func) def wrapper(self, *args, **kwargs): with ProfileFunction(func.__name__, track_statistics, verbose): return func(self, *args, **kwargs) return wrapper return decorator def squeeze_list(nested_list, dim, current_dim=0): # If the current dimension is in the list of indices to squeeze if isinstance(nested_list, list) and len(nested_list) == 1 and current_dim == dim: return squeeze_list(nested_list[0], dim, current_dim + 1) elif isinstance(nested_list, list): return [squeeze_list(item, dim, current_dim + 1) for item in nested_list] else: return nested_list def match_gt(tensor1, tensor2, padding1, padding2, mode: str = "bilinear"): """ Transform each item in tensor1 batch to match tensor2's dimensions and padding. Args: tensor1 (torch.Tensor): The input tensor to transform, with shape (batch_size, channels, height, width). tensor2 (torch.Tensor): The target tensor to match, with shape (batch_size, channels, height, width). padding1 (tuple): Padding applied to tensor1 (pad_left, pad_right, pad_top, pad_bottom). padding2 (tuple): Desired padding to be applied to match tensor2 (pad_left, pad_right, pad_top, pad_bottom). Returns: torch.Tensor: The batch of transformed tensors matching tensor2's size and padding. """ # Get batch size batch_size = len(tensor1) src_dtype = tensor1[0].dtype tgt_dtype = tensor2[0].dtype # List to store transformed tensors transformed_tensors = [] for i in range(batch_size): item1 = tensor1[i] item2 = tensor2[i] h1, w1 = item1.shape[1], item1.shape[2] pad1_l, pad1_r, pad1_t, pad1_b = ( padding1[i] if padding1 is not None else (0, 0, 0, 0) ) pad2_l, pad2_r, pad2_t, pad2_b = ( padding2[i] if padding2 is not None else (0, 0, 0, 0) ) item1_unpadded = item1[:, pad1_t : h1 - pad1_b, pad1_l : w1 - pad1_r] h2, w2 = ( item2.shape[1] - pad2_t - pad2_b, item2.shape[2] - pad2_l - pad2_r, ) item1_resized = F.interpolate( item1_unpadded.unsqueeze(0).to(tgt_dtype), size=(h2, w2), mode=mode ) item1_padded = F.pad(item1_resized, (pad2_l, pad2_r, pad2_t, pad2_b)) transformed_tensors.append(item1_padded) transformed_batch = torch.cat(transformed_tensors) return transformed_batch.to(src_dtype) def match_intrinsics(K1, tensor1, tensor2, padding1, padding2): """ Adjust camera intrinsics K1 to match the size and padding transformation applied to tensor1 so that it corresponds correctly to tensor2. Args: K1 (torch.Tensor): The camera intrinsics matrix for tensor1, shape (batch_size, 3, 3). tensor1 (torch.Tensor): The original image tensor, shape (batch_size, C, H1, W1). tensor2 (torch.Tensor): The target image tensor, shape (batch_size, C, H2, W2). padding1 (list of tuples): List of padding applied to tensor1 (pad_left, pad_right, pad_top, pad_bottom). padding2 (list of tuples): Desired padding to be applied to match tensor2 (pad_left, pad_right, pad_top, pad_bottom). Returns: torch.Tensor: The adjusted intrinsics matrix of shape (batch_size, 3, 3). """ batch_size = K1.shape[0] K1_new = K1.clone() for i in range(batch_size): h1, w1 = tensor1.shape[2], tensor1.shape[3] h2, w2 = tensor2.shape[2], tensor2.shape[3] # Remove original padding pad1_l, pad1_r, pad1_t, pad1_b = ( padding1[i] if padding1 is not None else (0, 0, 0, 0) ) w1_unpadded, h1_unpadded = w1 - (pad1_l + pad1_r), h1 - (pad1_t + pad1_b) # Compute new image size after removing original padding pad2_l, pad2_r, pad2_t, pad2_b = ( padding2[i] if padding2 is not None else (0, 0, 0, 0) ) w2_unpadded, h2_unpadded = w2 - (pad2_l + pad2_r), h2 - (pad2_t + pad2_b) # Compute scaling factors scale_x = w2_unpadded / w1_unpadded scale_y = h2_unpadded / h1_unpadded # Update focal length (fx, fy) and principal point (cx, cy) K1_new[i, 0, 0] *= scale_x # fx K1_new[i, 1, 1] *= scale_y # fy K1_new[i, 0, 2] = (K1[i, 0, 2] - pad1_l) * scale_x + pad2_l # cx K1_new[i, 1, 2] = (K1[i, 1, 2] - pad1_t) * scale_y + pad2_t # cy return K1_new ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/positional_embedding.py ================================================ """ Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ from math import pi from typing import Optional import torch import torch.nn as nn from einops import rearrange, repeat class PositionEmbeddingSine(nn.Module): def __init__( self, num_pos_feats=64, temperature=10000, normalize=False, scale=None ): super().__init__() self.num_pos_feats = num_pos_feats self.temperature = temperature self.normalize = normalize if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") if scale is None: scale = 2 * pi self.scale = scale def forward( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: if mask is None: mask = torch.zeros( (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool ) not_mask = ~mask y_embed = not_mask.cumsum(1, dtype=torch.float32) x_embed = not_mask.cumsum(2, dtype=torch.float32) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = self.temperature ** ( 2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats ) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack( (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos_y = torch.stack( (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) return pos def __repr__(self, _repr_indent=4): head = "Positional encoding " + self.__class__.__name__ body = [ "num_pos_feats: {}".format(self.num_pos_feats), "temperature: {}".format(self.temperature), "normalize: {}".format(self.normalize), "scale: {}".format(self.scale), ] # _repr_indent = 4 lines = [head] + [" " * _repr_indent + line for line in body] return "\n".join(lines) class LearnedSinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() assert (dim % 2) == 0 half_dim = dim // 2 self.weights = nn.Parameter(torch.randn(half_dim)) def forward(self, x): x = rearrange(x, "b -> b 1") freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) fouriered = torch.cat((x, fouriered), dim=-1) return fouriered def broadcat(tensors, dim=-1): num_tensors = len(tensors) shape_lens = set(list(map(lambda t: len(t.shape), tensors))) assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" shape_len = list(shape_lens)[0] dim = (dim + shape_len) if dim < 0 else dim dims = list(zip(*map(lambda t: list(t.shape), tensors))) expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] assert all( [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] ), "invalid dimensions for broadcastable concatentation" max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) expanded_dims.insert(dim, (dim, dims[dim])) expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) return torch.cat(tensors, dim=dim) def rotate_half(x): x = rearrange(x, "... (d r) -> ... d r", r=2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return rearrange(x, "... d r -> ... (d r)") class VisionRotaryEmbedding(nn.Module): def __init__( self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, ): super().__init__() if custom_freqs: freqs = custom_freqs elif freqs_for == "lang": freqs = 1.0 / ( theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) ) elif freqs_for == "pixel": freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi elif freqs_for == "constant": freqs = torch.ones(num_freqs).float() else: raise ValueError(f"unknown modality {freqs_for}") if ft_seq_len is None: ft_seq_len = pt_seq_len t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len freqs_h = torch.einsum("..., f -> ... f", t, freqs) freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) freqs_w = torch.einsum("..., f -> ... f", t, freqs) freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1) self.register_buffer("freqs_cos", freqs.cos()) self.register_buffer("freqs_sin", freqs.sin()) print("======== shape of rope freq", self.freqs_cos.shape, "========") def forward(self, t, start_index=0): rot_dim = self.freqs_cos.shape[-1] end_index = start_index + rot_dim assert ( rot_dim <= t.shape[-1] ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" t_left, t, t_right = ( t[..., :start_index], t[..., start_index:end_index], t[..., end_index:], ) t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) return torch.cat((t_left, t, t_right), dim=-1) class VisionRotaryEmbeddingFast(nn.Module): def __init__( self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, ): super().__init__() if custom_freqs: freqs = custom_freqs elif freqs_for == "lang": freqs = 1.0 / ( theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) ) elif freqs_for == "pixel": freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi elif freqs_for == "constant": freqs = torch.ones(num_freqs).float() else: raise ValueError(f"unknown modality {freqs_for}") if ft_seq_len is None: ft_seq_len = pt_seq_len t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len freqs = torch.einsum("..., f -> ... f", t, freqs) freqs = repeat(freqs, "... n -> ... (n r)", r=2) freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) self.register_buffer("freqs_cos", freqs_cos) self.register_buffer("freqs_sin", freqs_sin) def forward(self, t): return t * self.freqs_cos + rotate_half(t) * self.freqs_sin from math import log2 def generate_fourier_features( x: torch.Tensor, dim: int = 512, max_freq: int = 64, use_cos: bool = False, use_log: bool = False, cat_orig: bool = False, ): x_orig = x device, dtype, input_dim = x.device, x.dtype, x.shape[-1] num_bands = dim // (2 * input_dim) if use_cos else dim // input_dim if use_log: scales = 2.0 ** torch.linspace( 0.0, log2(max_freq), steps=num_bands, device=device, dtype=dtype ) else: scales = torch.linspace( 1.0, max_freq / 2, num_bands, device=device, dtype=dtype ) x = x.unsqueeze(-1) scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] x = x * scales * pi x = torch.cat( ( [x.sin(), x.cos()] if use_cos else [ x.sin(), ] ), dim=-1, ) x = x.flatten(-2) if cat_orig: return torch.cat((x, x_orig), dim=-1) return x # from PIL import Image # from unidepth.utils import image_grid, colorize # if __name__ == "__main__": # H, W = 512, 512 # resolution = 128 # mesh = torch.meshgrid(torch.linspace(-1, 1, H), torch.linspace(-1, 1, W)) # mesh = torch.stack(mesh, dim=0).unsqueeze(0) # mesh = mesh.view(1, 2, -1).permute(0, 2, 1) # features = generate_fourier_features(mesh, dim=32, max_freq=resolution, use_log=True) # channels = features.shape[-1] # print(features.shape) # features = features[0].view(H, W, channels).permute(2, 0, 1).numpy() # Image.fromarray(image_grid([colorize(1+x, 0.0, 2.0, "viridis") for x in features], rows=8, cols=4)).save(f"tmp_{resolution}.png") ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/sht.py ================================================ """Real spherical harmonics in Cartesian form for PyTorch. This is an autogenerated file. See https://github.com/cheind/torch-spherical-harmonics for more information. """ import torch def rsh_cart_0(xyz: torch.Tensor): """Computes all real spherical harmonics up to degree 0. This is an autogenerated method. See https://github.com/cheind/torch-spherical-harmonics for more information. Params: xyz: (N,...,3) tensor of points on the unit sphere Returns: rsh: (N,...,1) real spherical harmonics projections of input. Ynm is found at index `n*(n+1) + m`, with `0 <= n <= degree` and `-n <= m <= n`. """ return torch.stack( [ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), ], -1, ) def rsh_cart_1(xyz: torch.Tensor): """Computes all real spherical harmonics up to degree 1. This is an autogenerated method. See https://github.com/cheind/torch-spherical-harmonics for more information. Params: xyz: (N,...,3) tensor of points on the unit sphere Returns: rsh: (N,...,4) real spherical harmonics projections of input. Ynm is found at index `n*(n+1) + m`, with `0 <= n <= degree` and `-n <= m <= n`. """ x = xyz[..., 0] y = xyz[..., 1] z = xyz[..., 2] return torch.stack( [ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), -0.48860251190292 * y, 0.48860251190292 * z, -0.48860251190292 * x, ], -1, ) def rsh_cart_2(xyz: torch.Tensor): """Computes all real spherical harmonics up to degree 2. This is an autogenerated method. See https://github.com/cheind/torch-spherical-harmonics for more information. Params: xyz: (N,...,3) tensor of points on the unit sphere Returns: rsh: (N,...,9) real spherical harmonics projections of input. Ynm is found at index `n*(n+1) + m`, with `0 <= n <= degree` and `-n <= m <= n`. """ x = xyz[..., 0] y = xyz[..., 1] z = xyz[..., 2] x2 = x**2 y2 = y**2 z2 = z**2 xy = x * y xz = x * z yz = y * z return torch.stack( [ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), -0.48860251190292 * y, 0.48860251190292 * z, -0.48860251190292 * x, 1.09254843059208 * xy, -1.09254843059208 * yz, 0.94617469575756 * z2 - 0.31539156525252, -1.09254843059208 * xz, 0.54627421529604 * x2 - 0.54627421529604 * y2, ], -1, ) def rsh_cart_3(xyz: torch.Tensor): """Computes all real spherical harmonics up to degree 3. This is an autogenerated method. See https://github.com/cheind/torch-spherical-harmonics for more information. Params: xyz: (N,...,3) tensor of points on the unit sphere Returns: rsh: (N,...,16) real spherical harmonics projections of input. Ynm is found at index `n*(n+1) + m`, with `0 <= n <= degree` and `-n <= m <= n`. """ x = xyz[..., 0] y = xyz[..., 1] z = xyz[..., 2] x2 = x**2 y2 = y**2 z2 = z**2 xy = x * y xz = x * z yz = y * z return torch.stack( [ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), -0.48860251190292 * y, 0.48860251190292 * z, -0.48860251190292 * x, 1.09254843059208 * xy, -1.09254843059208 * yz, 0.94617469575756 * z2 - 0.31539156525252, -1.09254843059208 * xz, 0.54627421529604 * x2 - 0.54627421529604 * y2, -0.590043589926644 * y * (3.0 * x2 - y2), 2.89061144264055 * xy * z, 0.304697199642977 * y * (1.5 - 7.5 * z2), 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, 0.304697199642977 * x * (1.5 - 7.5 * z2), 1.44530572132028 * z * (x2 - y2), -0.590043589926644 * x * (x2 - 3.0 * y2), ], -1, ) def rsh_cart_4(xyz: torch.Tensor): """Computes all real spherical harmonics up to degree 4. This is an autogenerated method. See https://github.com/cheind/torch-spherical-harmonics for more information. Params: xyz: (N,...,3) tensor of points on the unit sphere Returns: rsh: (N,...,25) real spherical harmonics projections of input. Ynm is found at index `n*(n+1) + m`, with `0 <= n <= degree` and `-n <= m <= n`. """ x = xyz[..., 0] y = xyz[..., 1] z = xyz[..., 2] x2 = x**2 y2 = y**2 z2 = z**2 xy = x * y xz = x * z yz = y * z x4 = x2**2 y4 = y2**2 z4 = z2**2 return torch.stack( [ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), -0.48860251190292 * y, 0.48860251190292 * z, -0.48860251190292 * x, 1.09254843059208 * xy, -1.09254843059208 * yz, 0.94617469575756 * z2 - 0.31539156525252, -1.09254843059208 * xz, 0.54627421529604 * x2 - 0.54627421529604 * y2, -0.590043589926644 * y * (3.0 * x2 - y2), 2.89061144264055 * xy * z, 0.304697199642977 * y * (1.5 - 7.5 * z2), 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, 0.304697199642977 * x * (1.5 - 7.5 * z2), 1.44530572132028 * z * (x2 - y2), -0.590043589926644 * x * (x2 - 3.0 * y2), 2.5033429417967 * xy * (x2 - y2), -1.77013076977993 * yz * (3.0 * x2 - y2), 0.126156626101008 * xy * (52.5 * z2 - 7.5), 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), 1.48099765681286 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 0.952069922236839 * z2 + 0.317356640745613, 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), -1.77013076977993 * xz * (x2 - 3.0 * y2), -3.75501441269506 * x2 * y2 + 0.625835735449176 * x4 + 0.625835735449176 * y4, ], -1, ) def rsh_cart_5(xyz: torch.Tensor): """Computes all real spherical harmonics up to degree 5. This is an autogenerated method. See https://github.com/cheind/torch-spherical-harmonics for more information. Params: xyz: (N,...,3) tensor of points on the unit sphere Returns: rsh: (N,...,36) real spherical harmonics projections of input. Ynm is found at index `n*(n+1) + m`, with `0 <= n <= degree` and `-n <= m <= n`. """ x = xyz[..., 0] y = xyz[..., 1] z = xyz[..., 2] x2 = x**2 y2 = y**2 z2 = z**2 xy = x * y xz = x * z yz = y * z x4 = x2**2 y4 = y2**2 z4 = z2**2 return torch.stack( [ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), -0.48860251190292 * y, 0.48860251190292 * z, -0.48860251190292 * x, 1.09254843059208 * xy, -1.09254843059208 * yz, 0.94617469575756 * z2 - 0.31539156525252, -1.09254843059208 * xz, 0.54627421529604 * x2 - 0.54627421529604 * y2, -0.590043589926644 * y * (3.0 * x2 - y2), 2.89061144264055 * xy * z, 0.304697199642977 * y * (1.5 - 7.5 * z2), 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, 0.304697199642977 * x * (1.5 - 7.5 * z2), 1.44530572132028 * z * (x2 - y2), -0.590043589926644 * x * (x2 - 3.0 * y2), 2.5033429417967 * xy * (x2 - y2), -1.77013076977993 * yz * (3.0 * x2 - y2), 0.126156626101008 * xy * (52.5 * z2 - 7.5), 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), 1.48099765681286 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 0.952069922236839 * z2 + 0.317356640745613, 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), -1.77013076977993 * xz * (x2 - 3.0 * y2), -3.75501441269506 * x2 * y2 + 0.625835735449176 * x4 + 0.625835735449176 * y4, -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), 8.30264925952416 * xy * z * (x2 - y2), 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), 0.241571547304372 * y * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ), -1.24747010616985 * z * (1.5 * z2 - 0.5) + 1.6840846433293 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 0.498988042467941 * z, 0.241571547304372 * x * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ), 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), ], -1, ) def rsh_cart_6(xyz: torch.Tensor): """Computes all real spherical harmonics up to degree 6. This is an autogenerated method. See https://github.com/cheind/torch-spherical-harmonics for more information. Params: xyz: (N,...,3) tensor of points on the unit sphere Returns: rsh: (N,...,49) real spherical harmonics projections of input. Ynm is found at index `n*(n+1) + m`, with `0 <= n <= degree` and `-n <= m <= n`. """ x = xyz[..., 0] y = xyz[..., 1] z = xyz[..., 2] x2 = x**2 y2 = y**2 z2 = z**2 xy = x * y xz = x * z yz = y * z x4 = x2**2 y4 = y2**2 z4 = z2**2 return torch.stack( [ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), -0.48860251190292 * y, 0.48860251190292 * z, -0.48860251190292 * x, 1.09254843059208 * xy, -1.09254843059208 * yz, 0.94617469575756 * z2 - 0.31539156525252, -1.09254843059208 * xz, 0.54627421529604 * x2 - 0.54627421529604 * y2, -0.590043589926644 * y * (3.0 * x2 - y2), 2.89061144264055 * xy * z, 0.304697199642977 * y * (1.5 - 7.5 * z2), 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, 0.304697199642977 * x * (1.5 - 7.5 * z2), 1.44530572132028 * z * (x2 - y2), -0.590043589926644 * x * (x2 - 3.0 * y2), 2.5033429417967 * xy * (x2 - y2), -1.77013076977993 * yz * (3.0 * x2 - y2), 0.126156626101008 * xy * (52.5 * z2 - 7.5), 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), 1.48099765681286 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 0.952069922236839 * z2 + 0.317356640745613, 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), -1.77013076977993 * xz * (x2 - 3.0 * y2), -3.75501441269506 * x2 * y2 + 0.625835735449176 * x4 + 0.625835735449176 * y4, -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), 8.30264925952416 * xy * z * (x2 - y2), 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), 0.241571547304372 * y * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ), -1.24747010616985 * z * (1.5 * z2 - 0.5) + 1.6840846433293 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 0.498988042467941 * z, 0.241571547304372 * x * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ), 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), 4.09910463115149 * x**4 * xy - 13.6636821038383 * xy**3 + 4.09910463115149 * xy * y**4, -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4), 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5), 0.00584892228263444 * y * (3.0 * x2 - y2) * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), 0.0701870673916132 * xy * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ), 0.221950995245231 * y * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ), -1.48328138624466 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + 1.86469659985043 * z * ( -1.33333333333333 * z * (1.5 * z2 - 0.5) + 1.8 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 0.533333333333333 * z ) + 0.953538034014426 * z2 - 0.317846011338142, 0.221950995245231 * x * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ), 0.0350935336958066 * (x2 - y2) * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ), 0.00584892228263444 * x * (x2 - 3.0 * y2) * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4), -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4), 0.683184105191914 * x2**3 + 10.2477615778787 * x2 * y4 - 10.2477615778787 * x4 * y2 - 0.683184105191914 * y2**3, ], -1, ) def rsh_cart_7(xyz: torch.Tensor): """Computes all real spherical harmonics up to degree 7. This is an autogenerated method. See https://github.com/cheind/torch-spherical-harmonics for more information. Params: xyz: (N,...,3) tensor of points on the unit sphere Returns: rsh: (N,...,64) real spherical harmonics projections of input. Ynm is found at index `n*(n+1) + m`, with `0 <= n <= degree` and `-n <= m <= n`. """ x = xyz[..., 0] y = xyz[..., 1] z = xyz[..., 2] x2 = x**2 y2 = y**2 z2 = z**2 xy = x * y xz = x * z yz = y * z x4 = x2**2 y4 = y2**2 z4 = z2**2 return torch.stack( [ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), -0.48860251190292 * y, 0.48860251190292 * z, -0.48860251190292 * x, 1.09254843059208 * xy, -1.09254843059208 * yz, 0.94617469575756 * z2 - 0.31539156525252, -1.09254843059208 * xz, 0.54627421529604 * x2 - 0.54627421529604 * y2, -0.590043589926644 * y * (3.0 * x2 - y2), 2.89061144264055 * xy * z, 0.304697199642977 * y * (1.5 - 7.5 * z2), 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, 0.304697199642977 * x * (1.5 - 7.5 * z2), 1.44530572132028 * z * (x2 - y2), -0.590043589926644 * x * (x2 - 3.0 * y2), 2.5033429417967 * xy * (x2 - y2), -1.77013076977993 * yz * (3.0 * x2 - y2), 0.126156626101008 * xy * (52.5 * z2 - 7.5), 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), 1.48099765681286 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 0.952069922236839 * z2 + 0.317356640745613, 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), -1.77013076977993 * xz * (x2 - 3.0 * y2), -3.75501441269506 * x2 * y2 + 0.625835735449176 * x4 + 0.625835735449176 * y4, -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), 8.30264925952416 * xy * z * (x2 - y2), 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), 0.241571547304372 * y * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ), -1.24747010616985 * z * (1.5 * z2 - 0.5) + 1.6840846433293 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 0.498988042467941 * z, 0.241571547304372 * x * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ), 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), 4.09910463115149 * x**4 * xy - 13.6636821038383 * xy**3 + 4.09910463115149 * xy * y**4, -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4), 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5), 0.00584892228263444 * y * (3.0 * x2 - y2) * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), 0.0701870673916132 * xy * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ), 0.221950995245231 * y * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ), -1.48328138624466 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + 1.86469659985043 * z * ( -1.33333333333333 * z * (1.5 * z2 - 0.5) + 1.8 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 0.533333333333333 * z ) + 0.953538034014426 * z2 - 0.317846011338142, 0.221950995245231 * x * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ), 0.0350935336958066 * (x2 - y2) * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ), 0.00584892228263444 * x * (x2 - 3.0 * y2) * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4), -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4), 0.683184105191914 * x2**3 + 10.2477615778787 * x2 * y4 - 10.2477615778787 * x4 * y2 - 0.683184105191914 * y2**3, -0.707162732524596 * y * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3), 2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4), 9.98394571852353e-5 * y * (5197.5 - 67567.5 * z2) * (-10.0 * x2 * y2 + 5.0 * x4 + y4), 0.00239614697244565 * xy * (x2 - y2) * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z), 0.00397356022507413 * y * (3.0 * x2 - y2) * ( 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + 1063.125 * z2 - 118.125 ), 0.0561946276120613 * xy * ( -4.8 * z * (52.5 * z2 - 7.5) + 2.6 * z * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ) + 48.0 * z ), 0.206472245902897 * y * ( -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 2.16666666666667 * z * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ) - 10.9375 * z2 + 2.1875 ), 1.24862677781952 * z * (1.5 * z2 - 0.5) - 1.68564615005635 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 2.02901851395672 * z * ( -1.45833333333333 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + 1.83333333333333 * z * ( -1.33333333333333 * z * (1.5 * z2 - 0.5) + 1.8 * z * ( 1.75 * z * ( 1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z ) - 1.125 * z2 + 0.375 ) + 0.533333333333333 * z ) + 0.9375 * z2 - 0.3125 ) - 0.499450711127808 * z, 0.206472245902897 * x * ( -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 2.16666666666667 * z * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ) - 10.9375 * z2 + 2.1875 ), 0.0280973138060306 * (x2 - y2) * ( -4.8 * z * (52.5 * z2 - 7.5) + 2.6 * z * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ) + 48.0 * z ), 0.00397356022507413 * x * (x2 - 3.0 * y2) * ( 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + 1063.125 * z2 - 118.125 ), 0.000599036743111412 * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) * (-6.0 * x2 * y2 + x4 + y4), 9.98394571852353e-5 * x * (5197.5 - 67567.5 * z2) * (-10.0 * x2 * y2 + x4 + 5.0 * y4), 2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3), -0.707162732524596 * x * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3), ], -1, ) # @torch.jit.script def rsh_cart_8(xyz: torch.Tensor): """Computes all real spherical harmonics up to degree 8. This is an autogenerated method. See https://github.com/cheind/torch-spherical-harmonics for more information. Params: xyz: (N,...,3) tensor of points on the unit sphere Returns: rsh: (N,...,81) real spherical harmonics projections of input. Ynm is found at index `n*(n+1) + m`, with `0 <= n <= degree` and `-n <= m <= n`. """ x = xyz[..., 0] y = xyz[..., 1] z = xyz[..., 2] x2 = x**2 y2 = y**2 z2 = z**2 xy = x * y xz = x * z yz = y * z x4 = x2**2 y4 = y2**2 # z4 = z2**2 return torch.stack( [ 0.282094791773878 * torch.ones(1, device=xyz.device).expand(xyz.shape[:-1]), -0.48860251190292 * y, 0.48860251190292 * z, -0.48860251190292 * x, 1.09254843059208 * xy, -1.09254843059208 * yz, 0.94617469575756 * z2 - 0.31539156525252, -1.09254843059208 * xz, 0.54627421529604 * x2 - 0.54627421529604 * y2, -0.590043589926644 * y * (3.0 * x2 - y2), 2.89061144264055 * xy * z, 0.304697199642977 * y * (1.5 - 7.5 * z2), 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, 0.304697199642977 * x * (1.5 - 7.5 * z2), 1.44530572132028 * z * (x2 - y2), -0.590043589926644 * x * (x2 - 3.0 * y2), 2.5033429417967 * xy * (x2 - y2), -1.77013076977993 * yz * (3.0 * x2 - y2), 0.126156626101008 * xy * (52.5 * z2 - 7.5), 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), 1.48099765681286 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 0.952069922236839 * z2 + 0.317356640745613, 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), -1.77013076977993 * xz * (x2 - 3.0 * y2), -3.75501441269506 * x2 * y2 + 0.625835735449176 * x4 + 0.625835735449176 * y4, -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), 8.30264925952416 * xy * z * (x2 - y2), 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), 0.241571547304372 * y * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ), -1.24747010616985 * z * (1.5 * z2 - 0.5) + 1.6840846433293 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 0.498988042467941 * z, 0.241571547304372 * x * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ), 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), 4.09910463115149 * x**4 * xy - 13.6636821038383 * xy**3 + 4.09910463115149 * xy * y**4, -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4), 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5), 0.00584892228263444 * y * (3.0 * x2 - y2) * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), 0.0701870673916132 * xy * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ), 0.221950995245231 * y * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ), -1.48328138624466 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + 1.86469659985043 * z * ( -1.33333333333333 * z * (1.5 * z2 - 0.5) + 1.8 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 0.533333333333333 * z ) + 0.953538034014426 * z2 - 0.317846011338142, 0.221950995245231 * x * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ), 0.0350935336958066 * (x2 - y2) * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ), 0.00584892228263444 * x * (x2 - 3.0 * y2) * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4), -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4), 0.683184105191914 * x2**3 + 10.2477615778787 * x2 * y4 - 10.2477615778787 * x4 * y2 - 0.683184105191914 * y2**3, -0.707162732524596 * y * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3), 2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4), 9.98394571852353e-5 * y * (5197.5 - 67567.5 * z2) * (-10.0 * x2 * y2 + 5.0 * x4 + y4), 0.00239614697244565 * xy * (x2 - y2) * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z), 0.00397356022507413 * y * (3.0 * x2 - y2) * ( 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + 1063.125 * z2 - 118.125 ), 0.0561946276120613 * xy * ( -4.8 * z * (52.5 * z2 - 7.5) + 2.6 * z * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ) + 48.0 * z ), 0.206472245902897 * y * ( -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 2.16666666666667 * z * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ) - 10.9375 * z2 + 2.1875 ), 1.24862677781952 * z * (1.5 * z2 - 0.5) - 1.68564615005635 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 2.02901851395672 * z * ( -1.45833333333333 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + 1.83333333333333 * z * ( -1.33333333333333 * z * (1.5 * z2 - 0.5) + 1.8 * z * ( 1.75 * z * ( 1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z ) - 1.125 * z2 + 0.375 ) + 0.533333333333333 * z ) + 0.9375 * z2 - 0.3125 ) - 0.499450711127808 * z, 0.206472245902897 * x * ( -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 2.16666666666667 * z * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ) - 10.9375 * z2 + 2.1875 ), 0.0280973138060306 * (x2 - y2) * ( -4.8 * z * (52.5 * z2 - 7.5) + 2.6 * z * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ) + 48.0 * z ), 0.00397356022507413 * x * (x2 - 3.0 * y2) * ( 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + 1063.125 * z2 - 118.125 ), 0.000599036743111412 * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) * (-6.0 * x2 * y2 + x4 + y4), 9.98394571852353e-5 * x * (5197.5 - 67567.5 * z2) * (-10.0 * x2 * y2 + x4 + 5.0 * y4), 2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3), -0.707162732524596 * x * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3), 5.83141328139864 * xy * (x2**3 + 7.0 * x2 * y4 - 7.0 * x4 * y2 - y2**3), -2.91570664069932 * yz * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3), 7.87853281621404e-6 * (1013512.5 * z2 - 67567.5) * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4), 5.10587282657803e-5 * y * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z) * (-10.0 * x2 * y2 + 5.0 * x4 + y4), 0.00147275890257803 * xy * (x2 - y2) * ( 3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) - 14293.125 * z2 + 1299.375 ), 0.0028519853513317 * y * (3.0 * x2 - y2) * ( -7.33333333333333 * z * (52.5 - 472.5 * z2) + 3.0 * z * ( 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + 1063.125 * z2 - 118.125 ) - 560.0 * z ), 0.0463392770473559 * xy * ( -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + 2.5 * z * ( -4.8 * z * (52.5 * z2 - 7.5) + 2.6 * z * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ) + 48.0 * z ) + 137.8125 * z2 - 19.6875 ), 0.193851103820053 * y * ( 3.2 * z * (1.5 - 7.5 * z2) - 2.51428571428571 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) + 2.14285714285714 * z * ( -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 2.16666666666667 * z * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ) - 10.9375 * z2 + 2.1875 ) + 5.48571428571429 * z ), 1.48417251362228 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.86581687426801 * z * ( -1.33333333333333 * z * (1.5 * z2 - 0.5) + 1.8 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 0.533333333333333 * z ) + 2.1808249179756 * z * ( 1.14285714285714 * z * (1.5 * z2 - 0.5) - 1.54285714285714 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 1.85714285714286 * z * ( -1.45833333333333 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + 1.83333333333333 * z * ( -1.33333333333333 * z * (1.5 * z2 - 0.5) + 1.8 * z * ( 1.75 * z * ( 1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z ) - 1.125 * z2 + 0.375 ) + 0.533333333333333 * z ) + 0.9375 * z2 - 0.3125 ) - 0.457142857142857 * z ) - 0.954110901614325 * z2 + 0.318036967204775, 0.193851103820053 * x * ( 3.2 * z * (1.5 - 7.5 * z2) - 2.51428571428571 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) + 2.14285714285714 * z * ( -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 2.16666666666667 * z * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ) - 10.9375 * z2 + 2.1875 ) + 5.48571428571429 * z ), 0.0231696385236779 * (x2 - y2) * ( -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + 2.5 * z * ( -4.8 * z * (52.5 * z2 - 7.5) + 2.6 * z * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ) + 48.0 * z ) + 137.8125 * z2 - 19.6875 ), 0.0028519853513317 * x * (x2 - 3.0 * y2) * ( -7.33333333333333 * z * (52.5 - 472.5 * z2) + 3.0 * z * ( 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + 1063.125 * z2 - 118.125 ) - 560.0 * z ), 0.000368189725644507 * (-6.0 * x2 * y2 + x4 + y4) * ( 3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) - 14293.125 * z2 + 1299.375 ), 5.10587282657803e-5 * x * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z) * (-10.0 * x2 * y2 + x4 + 5.0 * y4), 7.87853281621404e-6 * (1013512.5 * z2 - 67567.5) * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3), -2.91570664069932 * xz * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3), -20.4099464848952 * x2**3 * y2 - 20.4099464848952 * x2 * y2**3 + 0.72892666017483 * x4**2 + 51.0248662122381 * x4 * y4 + 0.72892666017483 * y4**2, ], -1, ) __all__ = [ "rsh_cart_0", "rsh_cart_1", "rsh_cart_2", "rsh_cart_3", "rsh_cart_4", "rsh_cart_5", "rsh_cart_6", "rsh_cart_7", "rsh_cart_8", ] from typing import Optional import torch class SphHarm(torch.nn.Module): def __init__(self, m, n, dtype=torch.float32) -> None: super().__init__() self.dtype = dtype m = torch.tensor(list(range(-m + 1, m))) n = torch.tensor(list(range(n))) self.is_normalized = False vals = torch.cartesian_prod(m, n).T vals = vals[:, vals[0] <= vals[1]] m, n = vals.unbind(0) self.register_buffer("m", tensor=m) self.register_buffer("n", tensor=n) self.register_buffer("l_max", tensor=torch.max(self.n)) f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d = self._init_legendre() self.register_buffer("f_a", tensor=f_a) self.register_buffer("f_b", tensor=f_b) self.register_buffer("d0_mask_3d", tensor=d0_mask_3d) self.register_buffer("d1_mask_3d", tensor=d1_mask_3d) self.register_buffer("initial_value", tensor=initial_value) @property def device(self): return next(self.buffers()).device def forward(self, points: torch.Tensor) -> torch.Tensor: """Computes the spherical harmonics.""" # Y_l^m = (-1) ^ m c_l^m P_l^m(cos(theta)) exp(i m phi) B, N, D = points.shape dtype = points.dtype theta, phi = points.view(-1, D).to(self.dtype).unbind(-1) cos_colatitude = torch.cos(phi) legendre = self._gen_associated_legendre(cos_colatitude) vals = torch.stack([self.m.abs(), self.n], dim=0) vals = torch.cat( [ vals.repeat(1, theta.shape[0]), torch.arange(theta.shape[0], device=theta.device) .unsqueeze(0) .repeat_interleave(vals.shape[1], dim=1), ], dim=0, ) legendre_vals = legendre[vals[0], vals[1], vals[2]] legendre_vals = legendre_vals.reshape(-1, theta.shape[0]) angle = torch.outer(self.m.abs(), theta) vandermonde = torch.complex(torch.cos(angle), torch.sin(angle)) harmonics = torch.complex( legendre_vals * torch.real(vandermonde), legendre_vals * torch.imag(vandermonde), ) # Negative order. m = self.m.unsqueeze(-1) harmonics = torch.where( m < 0, (-1.0) ** m.abs() * torch.conj(harmonics), harmonics ) harmonics = harmonics.permute(1, 0).reshape(B, N, -1).to(dtype) return harmonics def _gen_recurrence_mask(self) -> tuple[torch.Tensor, torch.Tensor]: """Generates mask for recurrence relation on the remaining entries. The remaining entries are with respect to the diagonal and offdiagonal entries. Args: l_max: see `gen_normalized_legendre`. Returns: torch.Tensors representing the mask used by the recurrence relations. """ # Computes all coefficients. m_mat, l_mat = torch.meshgrid( torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype), torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype), indexing="ij", ) if self.is_normalized: c0 = l_mat * l_mat c1 = m_mat * m_mat c2 = 2.0 * l_mat c3 = (l_mat - 1.0) * (l_mat - 1.0) d0 = torch.sqrt((4.0 * c0 - 1.0) / (c0 - c1)) d1 = torch.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1))) else: d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat) d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat) d0_mask_indices = torch.triu_indices(self.l_max + 1, 1) d1_mask_indices = torch.triu_indices(self.l_max + 1, 2) d_zeros = torch.zeros( (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device ) d_zeros[d0_mask_indices] = d0[d0_mask_indices] d0_mask = d_zeros d_zeros = torch.zeros( (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device ) d_zeros[d1_mask_indices] = d1[d1_mask_indices] d1_mask = d_zeros # Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere. i = torch.arange(self.l_max + 1, device=self.device)[:, None, None] j = torch.arange(self.l_max + 1, device=self.device)[None, :, None] k = torch.arange(self.l_max + 1, device=self.device)[None, None, :] mask = (i + j - k == 0).to(self.dtype) d0_mask_3d = torch.einsum("jk,ijk->ijk", d0_mask, mask) d1_mask_3d = torch.einsum("jk,ijk->ijk", d1_mask, mask) return (d0_mask_3d, d1_mask_3d) def _recursive(self, i: int, p_val: torch.Tensor, x: torch.Tensor) -> torch.Tensor: coeff_0 = self.d0_mask_3d[i] coeff_1 = self.d1_mask_3d[i] h = torch.einsum( "ij,ijk->ijk", coeff_0, torch.einsum("ijk,k->ijk", torch.roll(p_val, shifts=1, dims=1), x), ) - torch.einsum("ij,ijk->ijk", coeff_1, torch.roll(p_val, shifts=2, dims=1)) p_val = p_val + h return p_val def _init_legendre(self): a_idx = torch.arange(1, self.l_max + 1, dtype=self.dtype, device=self.device) b_idx = torch.arange(self.l_max, dtype=self.dtype, device=self.device) if self.is_normalized: # The initial value p(0,0). initial_value: torch.Tensor = torch.tensor( 0.5 / (torch.pi**0.5), device=self.device ) f_a = torch.cumprod(-1 * torch.sqrt(1.0 + 0.5 / a_idx), dim=0) f_b = torch.sqrt(2.0 * b_idx + 3.0) else: # The initial value p(0,0). initial_value = torch.tensor(1.0, device=self.device) f_a = torch.cumprod(1.0 - 2.0 * a_idx, dim=0) f_b = 2.0 * b_idx + 1.0 d0_mask_3d, d1_mask_3d = self._gen_recurrence_mask() return f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d def _gen_associated_legendre(self, x: torch.Tensor) -> torch.Tensor: r"""Computes associated Legendre functions (ALFs) of the first kind. The ALFs of the first kind are used in spherical harmonics. The spherical harmonic of degree `l` and order `m` can be written as `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the normalization factor and θ and φ are the colatitude and longitude, repectively. `N_l^m` is chosen in the way that the spherical harmonics form a set of orthonormal basis function of L^2(S^2). For the computational efficiency of spherical harmonics transform, the normalization factor is used in the computation of the ALFs. In addition, normalizing `P_l^m` avoids overflow/underflow and achieves better numerical stability. Three recurrence relations are used in the computation. Args: l_max: The maximum degree of the associated Legendre function. Both the degrees and orders are `[0, 1, 2, ..., l_max]`. x: A vector of type `float32`, `float64` containing the sampled points in spherical coordinates, at which the ALFs are computed; `x` is essentially `cos(θ)`. For the numerical integration used by the spherical harmonics transforms, `x` contains the quadrature points in the interval of `[-1, 1]`. There are several approaches to provide the quadrature points: Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev method (`scipy.special.roots_chebyu`), and Driscoll & Healy method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier transforms and convolutions on the 2-sphere." Advances in applied mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature points are nearly equal-spaced along θ and provide exact discrete orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose operation, `W` is a diagonal matrix containing the quadrature weights, and `I` is the identity matrix. The Gauss-Chebyshev points are equally spaced, which only provide approximate discrete orthogonality. The Driscoll & Healy qudarture points are equally spaced and provide the exact discrete orthogonality. The number of sampling points is required to be twice as the number of frequency points (modes) in the Driscoll & Healy approach, which enables FFT and achieves a fast spherical harmonics transform. is_normalized: True if the associated Legendre functions are normalized. With normalization, `N_l^m` is applied such that the spherical harmonics form a set of orthonormal basis functions of L^2(S^2). Returns: The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values of the ALFs at `x`; the dimensions in the sequence of order, degree, and evalution points. """ p = torch.zeros( (self.l_max + 1, self.l_max + 1, x.shape[0]), dtype=x.dtype, device=x.device ) p[0, 0] = self.initial_value # Compute the diagonal entries p(l,l) with recurrence. y = torch.cumprod( torch.broadcast_to(torch.sqrt(1.0 - x * x), (self.l_max, x.shape[0])), dim=0 ) p_diag = self.initial_value * torch.einsum("i,ij->ij", self.f_a, y) # torch.diag_indices(l_max + 1) diag_indices = torch.stack( [torch.arange(0, self.l_max + 1, device=x.device)] * 2, dim=0 ) p[(diag_indices[0][1:], diag_indices[1][1:])] = p_diag diag_indices = torch.stack( [torch.arange(0, self.l_max, device=x.device)] * 2, dim=0 ) # Compute the off-diagonal entries with recurrence. p_offdiag = torch.einsum( "ij,ij->ij", torch.einsum("i,j->ij", self.f_b, x), p[(diag_indices[0], diag_indices[1])], ) # p[torch.diag_indices(l_max)]) p[(diag_indices[0][: self.l_max], diag_indices[1][: self.l_max] + 1)] = ( p_offdiag ) # Compute the remaining entries with recurrence. if self.l_max > 1: for i in range(2, self.l_max + 1): p = self._recursive(i, p, x) return p ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/validation.py ================================================ """ Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ import torch import torch.utils.data.distributed import wandb from torch.nn import functional as F from unidepth.utils import barrier, is_main_process from unidepth.utils.misc import remove_padding def original_image(batch, preds=None): paddings = [ torch.tensor(pads) for img_meta in batch["img_metas"] for pads in img_meta.get("paddings", [[0] * 4]) ] paddings = torch.stack(paddings).to(batch["data"]["image"].device)[ ..., [0, 2, 1, 3] ] # lrtb T, _, H, W = batch["data"]["depth"].shape batch["data"]["image"] = F.interpolate( batch["data"]["image"], (H + paddings[2] + paddings[3], W + paddings[1] + paddings[2]), mode="bilinear", align_corners=False, antialias=True, ) batch["data"]["image"] = remove_padding( batch["data"]["image"], paddings.repeat(T, 1) ) if preds is not None: for key in ["depth"]: if key in preds: preds[key] = F.interpolate( preds[key], (H + paddings[2] + paddings[3], W + paddings[1] + paddings[2]), mode="bilinear", align_corners=False, antialias=True, ) preds[key] = remove_padding(preds[key], paddings.repeat(T, 1)) return batch, preds def log_metrics(metrics_all, step): for name_ds, metrics in metrics_all.items(): for metrics_name, metrics_value in metrics.items(): try: print(f"Metrics/{name_ds}/{metrics_name} {round(metrics_value, 4)}") wandb.log( {f"Metrics/{name_ds}/{metrics_name}": metrics_value}, step=step ) except: pass def validate(model, test_loaders, step, context): metrics_all = {} for name_ds, test_loader in test_loaders.items(): for i, batch in enumerate(test_loader): with context: batch["data"] = { k: v.to(model.device) for k, v in batch["data"].items() } # remove temporal dimension of the dataloder, here is always 1! batch["data"] = {k: v.squeeze(1) for k, v in batch["data"].items()} batch["img_metas"] = [ {k: v[0] for k, v in meta.items() if isinstance(v, list)} for meta in batch["img_metas"] ] preds = model(batch["data"], batch["img_metas"]) batch, _ = original_image(batch, preds=None) test_loader.dataset.accumulate_metrics( inputs=batch["data"], preds=preds, keyframe_idx=batch["img_metas"][0].get("keyframe_idx"), ) barrier() metrics_all[name_ds] = test_loader.dataset.get_evaluation() barrier() if is_main_process(): log_metrics(metrics_all=metrics_all, step=step) return metrics_all ================================================ FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/visualization.py ================================================ """ Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ import os import matplotlib.pyplot as plt import numpy as np import torch import wandb from PIL import Image from unidepth.utils.misc import ssi_helper def colorize( value: np.ndarray, vmin: float = None, vmax: float = None, cmap: str = "magma_r" ): # if already RGB, do nothing if value.ndim > 2: if value.shape[-1] > 1: return value value = value[..., 0] invalid_mask = value < 0.0001 # normalize vmin = value.min() if vmin is None else vmin vmax = value.max() if vmax is None else vmax value = (value - vmin) / (vmax - vmin) # vmin..vmax # set color cmapper = plt.get_cmap(cmap) value = cmapper(value, bytes=True) # (nxmx4) value[invalid_mask] = 0 img = value[..., :3] return img def image_grid(imgs: list[np.ndarray], rows: int, cols: int) -> np.ndarray: if not len(imgs): return None assert len(imgs) == rows * cols h, w = imgs[0].shape[:2] grid = Image.new("RGB", size=(cols * w, rows * h)) for i, img in enumerate(imgs): grid.paste( Image.fromarray(img.astype(np.uint8)).resize( (w, h), resample=Image.BILINEAR ), box=(i % cols * w, i // cols * h), ) return np.array(grid) def get_pointcloud_from_rgbd( image: np.array, depth: np.array, mask: np.ndarray, intrinsic_matrix: np.array, extrinsic_matrix: np.array = None, ): depth = np.array(depth).squeeze() mask = np.array(mask).squeeze() # Mask the depth array masked_depth = np.ma.masked_where(mask == False, depth) # masked_depth = np.ma.masked_greater(masked_depth, 8000) # Create idx array idxs = np.indices(masked_depth.shape) u_idxs = idxs[1] v_idxs = idxs[0] # Get only non-masked depth and idxs z = masked_depth[~masked_depth.mask] compressed_u_idxs = u_idxs[~masked_depth.mask] compressed_v_idxs = v_idxs[~masked_depth.mask] image = np.stack( [image[..., i][~masked_depth.mask] for i in range(image.shape[-1])], axis=-1 ) # Calculate local position of each point # Apply vectorized math to depth using compressed arrays cx = intrinsic_matrix[0, 2] fx = intrinsic_matrix[0, 0] x = (compressed_u_idxs - cx) * z / fx cy = intrinsic_matrix[1, 2] fy = intrinsic_matrix[1, 1] # Flip y as we want +y pointing up not down y = -((compressed_v_idxs - cy) * z / fy) # # Apply camera_matrix to pointcloud as to get the pointcloud in world coords # if extrinsic_matrix is not None: # # Calculate camera pose from extrinsic matrix # camera_matrix = np.linalg.inv(extrinsic_matrix) # # Create homogenous array of vectors by adding 4th entry of 1 # # At the same time flip z as for eye space the camera is looking down the -z axis # w = np.ones(z.shape) # x_y_z_eye_hom = np.vstack((x, y, -z, w)) # # Transform the points from eye space to world space # x_y_z_world = np.dot(camera_matrix, x_y_z_eye_hom)[:3] # return x_y_z_world.T # else: x_y_z_local = np.stack((x, y, z), axis=-1) return np.concatenate([x_y_z_local, image], axis=-1) def save_file_ply(xyz, rgb, pc_file): if rgb.max() < 1.001: rgb = rgb * 255.0 rgb = rgb.astype(np.uint8) # print(rgb) with open(pc_file, "w") as f: # headers f.writelines( [ "ply\n" "format ascii 1.0\n", "element vertex {}\n".format(xyz.shape[0]), "property float x\n", "property float y\n", "property float z\n", "property uchar red\n", "property uchar green\n", "property uchar blue\n", "end_header\n", ] ) for i in range(xyz.shape[0]): str_v = "{:10.6f} {:10.6f} {:10.6f} {:d} {:d} {:d}\n".format( xyz[i, 0], xyz[i, 1], xyz[i, 2], rgb[i, 0], rgb[i, 1], rgb[i, 2] ) f.write(str_v) # really awful fct... FIXME def log_train_artifacts(rgbs, gts, preds, ds_name, step, infos={}): rgbs = [ (127.5 * (rgb + 1)) .clip(0, 255) .to(torch.uint8) .cpu() .detach() .permute(1, 2, 0) .numpy() for rgb in rgbs ] new_gts, new_preds = [], [] if len(gts) > 0: for i, gt in enumerate(gts): scale, shift = ssi_helper( gts[i][gts[i] > 0].cpu().detach(), preds[i][gts[i] > 0].cpu().detach() ) gt = gts[i].cpu().detach().squeeze().numpy() pred = (preds[i].cpu().detach() * scale + shift).squeeze().numpy() vmin = gt[gt > 0].min() if (gt > 0).any() else 0.0 vmax = gt.max() if (gt > 0).any() else 0.1 new_gts.append(colorize(gt, vmin=vmin, vmax=vmax)) new_preds.append(colorize(pred, vmin=vmin, vmax=vmax)) gts, preds = new_gts, new_preds else: preds = [ colorize(pred.cpu().detach().squeeze().numpy(), 0.0, 80.0) for i, pred in enumerate(preds) ] num_additional, additionals = 0, [] for name, info in infos.items(): num_additional += 1 if info.shape[1] == 3: additionals.extend( [ (127.5 * (x + 1)) .clip(0, 255) .to(torch.uint8) .cpu() .detach() .permute(1, 2, 0) .numpy() for x in info[:4] ] ) else: additionals.extend( [ colorize(x.cpu().detach().squeeze().numpy()) for i, x in enumerate(info[:4]) ] ) num_rows = 2 + int(len(gts) > 0) + num_additional artifacts_grid = image_grid( [*rgbs, *gts, *preds, *additionals], num_rows, len(rgbs) ) try: wandb.log({f"{ds_name}_training": [wandb.Image(artifacts_grid)]}, step=step) except: Image.fromarray(artifacts_grid).save( os.path.join(os.environ["HOME"], "Workspace", f"art_grid{step}.png") ) print("Logging training images failed") ================================================ FILE: camera_pose_annotation/depth_estimation/__init__.py ================================================ ================================================ FILE: camera_pose_annotation/dynamic_mask/__init__.py ================================================ ================================================ FILE: camera_pose_annotation/dynamic_mask/inference_batch.py ================================================ """ Batch inference script for dynamic mask generation using SAM2. Processes video frames to generate dynamic object masks based on motion probabilities. """ import os import numpy as np import torch import torch.nn.functional as F from glob import glob import cv2 from scipy import ndimage from scipy.sparse import csr_matrix import argparse import pandas as pd import subprocess from multiprocessing import Manager import queue import concurrent.futures from tqdm import tqdm from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor def compress(dyn_masks, save_path=None): """Compress dynamic masks using sparse matrix representation.""" assert save_path.endswith(".npz") # Transform to sparse matrices sparse_matrices_list = [csr_matrix(dyn_mask) for dyn_mask in dyn_masks] sparse_matrices = {} for i, dyn_mask in enumerate(sparse_matrices_list): sparse_matrices[f"f_{i}_data"] = dyn_mask.data sparse_matrices[f"f_{i}_indices"] = dyn_mask.indices sparse_matrices[f"f_{i}_indptr"] = dyn_mask.indptr if i == 0: sparse_matrices["shape"] = dyn_mask.shape np.savez_compressed(save_path, **sparse_matrices) def segment_sky(image): """Segment sky regions from image using HSV color space and morphological operations.""" # Convert RGB to HSV hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) # Define range for blue color and create mask lower_blue = np.array([0, 0, 100]) upper_blue = np.array([30, 255, 255]) mask = cv2.inRange(hsv, lower_blue, upper_blue).view(bool) # Add luminous gray regions (likely sky) mask |= (hsv[:, :, 1] < 10) & (hsv[:, :, 2] > 150) mask |= (hsv[:, :, 1] < 30) & (hsv[:, :, 2] > 180) mask |= (hsv[:, :, 1] < 50) & (hsv[:, :, 2] > 220) # Morphological operations to clean up mask kernel = np.ones((5, 5), np.uint8) mask2 = ndimage.binary_opening(mask, structure=kernel) # Keep only largest connected components _, labels, stats, _ = cv2.connectedComponentsWithStats( mask2.view(np.uint8), connectivity=8 ) cc_sizes = stats[1:, cv2.CC_STAT_AREA] order = cc_sizes.argsort()[::-1] # Bigger first i = 0 selection = [] while i < len(order) and cc_sizes[order[i]] > cc_sizes[order[0]] / 2: selection.append(1 + order[i]) i += 1 mask3 = np.isin(labels, selection).reshape(labels.shape) # Return as tensor return torch.from_numpy(mask3) def predict_mask(predictor, row, args, device): """Generate dynamic masks for a video using SAM2 and motion probabilities.""" dir_path = os.path.join(args.dir_path, str(row["id"])) if not os.path.exists(dir_path): print(f"Directory '{dir_path}' not found. Exit.") return img_dir = os.path.join(args.dir_path, str(row["id"]), "img") if not os.path.exists(img_dir): print(f"Image directory '{img_dir}' not found. Exit.") return rec_dir = os.path.join(dir_path, "reconstructions") if not os.path.exists(rec_dir): print(f"Reconstructions directory '{rec_dir}' not found. Exit.") return prob_file = os.path.join(rec_dir, "motion_prob.npy") if not os.path.exists(prob_file): print(f"Motion probability file '{prob_file}' not found. Exit.") return compress_file = os.path.join(rec_dir, "dyn_masks.npz") if os.path.exists(compress_file): return # Load motion probabilities motion_probs = torch.from_numpy(np.load(prob_file)).to(device) # Load images images_list = list(sorted(glob(os.path.join(img_dir, "*.jpg")))) images = [cv2.imread(img_path) for img_path in images_list] if len(images) == 0 or len(images) != len(motion_probs): print( f"{row['video_path']},Number of frames ({len(images)}) does not match number of motion probabilities ({len(motion_probs)}). Exit." ) return width, height = images[0].shape[1], images[0].shape[0] area = width * height masks = [] # Process each frame for i in range(len(images)): motion_prob = motion_probs[i].to(device) # Segment sky to avoid false detections sky_mask = segment_sky(images[i]) predictor.set_image(images[i]) # Adaptive thresholding based on motion probability distribution # We use an adaptive thresholding based on motion probability distribution to create initial masks. Then prob_min, prob_max = motion_prob.min(), motion_prob.max() threshold = (prob_max - prob_min) * 0.4 + prob_min if threshold > prob_max - 0.1: masks.append(np.zeros((height, width), dtype=np.uint8)) continue # Create initial mask from motion probabilities mask = (motion_prob < threshold).float() mask = F.interpolate( mask.unsqueeze(0).unsqueeze(0), size=(height, width), mode="bilinear", align_corners=False, ).squeeze() # Find contours and use them as SAM2 prompts mask_np = mask.cpu().numpy().astype(np.uint8) contours, _ = cv2.findContours( mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE ) merged_mask = np.zeros_like(mask_np) for c in contours: points = [] for point in c: points.append(point[0]) points = np.array(points) # Sample points from contour as prompts interval = max(1, len(points) // 3) input_points = points[::interval].astype(np.float32) # Skip if points are in sky region if sky_mask[input_points[:, 1], input_points[:, 0]].any(): continue input_labels = np.ones(input_points.shape[0], dtype=np.int64) # Use SAM2 to refine mask mask, score, _ = predictor.predict( point_coords=input_points, point_labels=input_labels, multimask_output=False, ) # Skip if mask area is too large (likely background) if mask[0].sum() > area * 0.3: continue merged_mask = np.logical_or(merged_mask, mask[0]) masks.append(merged_mask) # Save compressed masks masks = np.stack(masks, axis=0) compress(masks, compress_file) def worker(task_queue, progress_queue, args, id): """Worker function for parallel dynamic mask generation.""" gpu_id = id % args.gpu_num os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) # Bind to specific GPU device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu") sam2_model = None predictor = None while True: try: index, row = task_queue.get_nowait() except queue.Empty: break # Initialize SAM2 model and predictor lazily if sam2_model is None: sam2_model = build_sam2( args.model_cfg, args.checkpoints_path, device=device ) if predictor is None: predictor = SAM2ImagePredictor(sam2_model) predictor.reset_predictor() predict_mask(predictor, row, args, device) progress_queue.put(index) def parse_args(): """Parse command line arguments for dynamic mask generation.""" parser = argparse.ArgumentParser(description="SAM2 Image Predictor") parser.add_argument("--csv_path", type=str, help="Path to the csv file") parser.add_argument( "--dir_path", type=str, required=True, help="Path to the directory containing images and masks", ) parser.add_argument( "--num_workers", type=int, default=16, help="#workers for concurrent.futures" ) parser.add_argument("--gpu_num", type=int, default=1, help="gpu number") parser.add_argument( "--checkpoints_path", type=str, default="checkpoints", help="Path to the model checkpoint", ) parser.add_argument( "--model_cfg", type=str, default="configs/sam2.1/sam2.1_hiera_l.yaml", help="Path to the model configuration file", ) return parser.parse_args() def main(): args = parse_args() if not os.path.exists(args.csv_path): print(f"Meta file '{args.csv_path}' not found. Exit.") return # Set SAM2 checkpoint path args.checkpoints_path = os.path.join( args.checkpoints_path, "SAM2/sam2.1_hiera_large.pt" ) df = pd.read_csv(args.csv_path) # Setup multiprocessing manager = Manager() task_queue = manager.Queue() progress_queue = manager.Queue() for index, row in df.iterrows(): task_queue.put((index, row)) # Process tasks with multiple workers with concurrent.futures.ProcessPoolExecutor( max_workers=args.num_workers ) as executor: futures = [] for id in range(args.num_workers): futures.append(executor.submit(worker, task_queue, progress_queue, args, id)) processed = 0 total_tasks = len(df) with tqdm(total=total_tasks, desc="Processing rows") as pbar: while processed < total_tasks: try: progress_queue.get(timeout=1) processed += 1 pbar.update(1) except queue.Empty: if all(f.done() for f in futures) and progress_queue.empty(): break for future in futures: future.result() if __name__ == "__main__": main() ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from hydra import initialize_config_module from hydra.core.global_hydra import GlobalHydra if not GlobalHydra.instance().is_initialized(): initialize_config_module("sam2", version_base="1.2") ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/automatic_mask_generator.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py from typing import Any, Dict, List, Optional, Tuple import numpy as np import torch from torchvision.ops.boxes import batched_nms, box_area # type: ignore from sam2.modeling.sam2_base import SAM2Base from sam2.sam2_image_predictor import SAM2ImagePredictor from sam2.utils.amg import ( area_from_rle, batch_iterator, batched_mask_to_box, box_xyxy_to_xywh, build_all_layer_point_grids, calculate_stability_score, coco_encode_rle, generate_crop_boxes, is_box_near_crop_edge, mask_to_rle_pytorch, MaskData, remove_small_regions, rle_to_mask, uncrop_boxes_xyxy, uncrop_masks, uncrop_points, ) class SAM2AutomaticMaskGenerator: def __init__( self, model: SAM2Base, points_per_side: Optional[int] = 32, points_per_batch: int = 64, pred_iou_thresh: float = 0.8, stability_score_thresh: float = 0.95, stability_score_offset: float = 1.0, mask_threshold: float = 0.0, box_nms_thresh: float = 0.7, crop_n_layers: int = 0, crop_nms_thresh: float = 0.7, crop_overlap_ratio: float = 512 / 1500, crop_n_points_downscale_factor: int = 1, point_grids: Optional[List[np.ndarray]] = None, min_mask_region_area: int = 0, output_mode: str = "binary_mask", use_m2m: bool = False, multimask_output: bool = True, **kwargs, ) -> None: """ Using a SAM 2 model, generates masks for the entire image. Generates a grid of point prompts over the image, then filters low quality and duplicate masks. The default settings are chosen for SAM 2 with a HieraL backbone. Arguments: model (Sam): The SAM 2 model to use for mask prediction. points_per_side (int or None): The number of points to be sampled along one side of the image. The total number of points is points_per_side**2. If None, 'point_grids' must provide explicit point sampling. points_per_batch (int): Sets the number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory. pred_iou_thresh (float): A filtering threshold in [0,1], using the model's predicted mask quality. stability_score_thresh (float): A filtering threshold in [0,1], using the stability of the mask under changes to the cutoff used to binarize the model's mask predictions. stability_score_offset (float): The amount to shift the cutoff when calculated the stability score. mask_threshold (float): Threshold for binarizing the mask logits box_nms_thresh (float): The box IoU cutoff used by non-maximal suppression to filter duplicate masks. crop_n_layers (int): If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where each layer has 2**i_layer number of image crops. crop_nms_thresh (float): The box IoU cutoff used by non-maximal suppression to filter duplicate masks between different crops. crop_overlap_ratio (float): Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the image length. Later layers with more crops scale down this overlap. crop_n_points_downscale_factor (int): The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. point_grids (list(np.ndarray) or None): A list over explicit grids of points used for sampling, normalized to [0,1]. The nth grid in the list is used in the nth crop layer. Exclusive with points_per_side. min_mask_region_area (int): If >0, postprocessing will be applied to remove disconnected regions and holes in masks with area smaller than min_mask_region_area. Requires opencv. output_mode (str): The form masks are returned in. Can be 'binary_mask', 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. For large resolutions, 'binary_mask' may consume large amounts of memory. use_m2m (bool): Whether to add a one step refinement using previous mask predictions. multimask_output (bool): Whether to output multimask at each point of the grid. """ assert (points_per_side is None) != ( point_grids is None ), "Exactly one of points_per_side or point_grid must be provided." if points_per_side is not None: self.point_grids = build_all_layer_point_grids( points_per_side, crop_n_layers, crop_n_points_downscale_factor, ) elif point_grids is not None: self.point_grids = point_grids else: raise ValueError("Can't have both points_per_side and point_grid be None.") assert output_mode in [ "binary_mask", "uncompressed_rle", "coco_rle", ], f"Unknown output_mode {output_mode}." if output_mode == "coco_rle": try: from pycocotools import mask as mask_utils # type: ignore # noqa: F401 except ImportError as e: print("Please install pycocotools") raise e self.predictor = SAM2ImagePredictor( model, max_hole_area=min_mask_region_area, max_sprinkle_area=min_mask_region_area, ) self.points_per_batch = points_per_batch self.pred_iou_thresh = pred_iou_thresh self.stability_score_thresh = stability_score_thresh self.stability_score_offset = stability_score_offset self.mask_threshold = mask_threshold self.box_nms_thresh = box_nms_thresh self.crop_n_layers = crop_n_layers self.crop_nms_thresh = crop_nms_thresh self.crop_overlap_ratio = crop_overlap_ratio self.crop_n_points_downscale_factor = crop_n_points_downscale_factor self.min_mask_region_area = min_mask_region_area self.output_mode = output_mode self.use_m2m = use_m2m self.multimask_output = multimask_output @classmethod def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator": """ Load a pretrained model from the Hugging Face hub. Arguments: model_id (str): The Hugging Face repository ID. **kwargs: Additional arguments to pass to the model constructor. Returns: (SAM2AutomaticMaskGenerator): The loaded model. """ from sam2.build_sam import build_sam2_hf sam_model = build_sam2_hf(model_id, **kwargs) return cls(sam_model, **kwargs) @torch.no_grad() def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: """ Generates masks for the given image. Arguments: image (np.ndarray): The image to generate masks for, in HWC uint8 format. Returns: list(dict(str, any)): A list over records for masks. Each record is a dict containing the following keys: segmentation (dict(str, any) or np.ndarray): The mask. If output_mode='binary_mask', is an array of shape HW. Otherwise, is a dictionary containing the RLE. bbox (list(float)): The box around the mask, in XYWH format. area (int): The area in pixels of the mask. predicted_iou (float): The model's own prediction of the mask's quality. This is filtered by the pred_iou_thresh parameter. point_coords (list(list(float))): The point coordinates input to the model to generate this mask. stability_score (float): A measure of the mask's quality. This is filtered on using the stability_score_thresh parameter. crop_box (list(float)): The crop of the image used to generate the mask, given in XYWH format. """ # Generate masks mask_data = self._generate_masks(image) # Encode masks if self.output_mode == "coco_rle": mask_data["segmentations"] = [ coco_encode_rle(rle) for rle in mask_data["rles"] ] elif self.output_mode == "binary_mask": mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] else: mask_data["segmentations"] = mask_data["rles"] # Write mask records curr_anns = [] for idx in range(len(mask_data["segmentations"])): ann = { "segmentation": mask_data["segmentations"][idx], "area": area_from_rle(mask_data["rles"][idx]), "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), "predicted_iou": mask_data["iou_preds"][idx].item(), "point_coords": [mask_data["points"][idx].tolist()], "stability_score": mask_data["stability_score"][idx].item(), "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), } curr_anns.append(ann) return curr_anns def _generate_masks(self, image: np.ndarray) -> MaskData: orig_size = image.shape[:2] crop_boxes, layer_idxs = generate_crop_boxes( orig_size, self.crop_n_layers, self.crop_overlap_ratio ) # Iterate over image crops data = MaskData() for crop_box, layer_idx in zip(crop_boxes, layer_idxs): crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) data.cat(crop_data) # Remove duplicate masks between crops if len(crop_boxes) > 1: # Prefer masks from smaller crops scores = 1 / box_area(data["crop_boxes"]) scores = scores.to(data["boxes"].device) keep_by_nms = batched_nms( data["boxes"].float(), scores, torch.zeros_like(data["boxes"][:, 0]), # categories iou_threshold=self.crop_nms_thresh, ) data.filter(keep_by_nms) data.to_numpy() return data def _process_crop( self, image: np.ndarray, crop_box: List[int], crop_layer_idx: int, orig_size: Tuple[int, ...], ) -> MaskData: # Crop the image and calculate embeddings x0, y0, x1, y1 = crop_box cropped_im = image[y0:y1, x0:x1, :] cropped_im_size = cropped_im.shape[:2] self.predictor.set_image(cropped_im) # Get points for this crop points_scale = np.array(cropped_im_size)[None, ::-1] points_for_image = self.point_grids[crop_layer_idx] * points_scale # Generate masks for this crop in batches data = MaskData() for (points,) in batch_iterator(self.points_per_batch, points_for_image): batch_data = self._process_batch( points, cropped_im_size, crop_box, orig_size, normalize=True ) data.cat(batch_data) del batch_data self.predictor.reset_predictor() # Remove duplicates within this crop. keep_by_nms = batched_nms( data["boxes"].float(), data["iou_preds"], torch.zeros_like(data["boxes"][:, 0]), # categories iou_threshold=self.box_nms_thresh, ) data.filter(keep_by_nms) # Return to the original image frame data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) data["points"] = uncrop_points(data["points"], crop_box) data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) return data def _process_batch( self, points: np.ndarray, im_size: Tuple[int, ...], crop_box: List[int], orig_size: Tuple[int, ...], normalize=False, ) -> MaskData: orig_h, orig_w = orig_size # Run model on this batch points = torch.as_tensor( points, dtype=torch.float32, device=self.predictor.device ) in_points = self.predictor._transforms.transform_coords( points, normalize=normalize, orig_hw=im_size ) in_labels = torch.ones( in_points.shape[0], dtype=torch.int, device=in_points.device ) masks, iou_preds, low_res_masks = self.predictor._predict( in_points[:, None, :], in_labels[:, None], multimask_output=self.multimask_output, return_logits=True, ) # Serialize predictions and store in MaskData data = MaskData( masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1), points=points.repeat_interleave(masks.shape[1], dim=0), low_res_masks=low_res_masks.flatten(0, 1), ) del masks if not self.use_m2m: # Filter by predicted IoU if self.pred_iou_thresh > 0.0: keep_mask = data["iou_preds"] > self.pred_iou_thresh data.filter(keep_mask) # Calculate and filter by stability score data["stability_score"] = calculate_stability_score( data["masks"], self.mask_threshold, self.stability_score_offset ) if self.stability_score_thresh > 0.0: keep_mask = data["stability_score"] >= self.stability_score_thresh data.filter(keep_mask) else: # One step refinement using previous mask predictions in_points = self.predictor._transforms.transform_coords( data["points"], normalize=normalize, orig_hw=im_size ) labels = torch.ones( in_points.shape[0], dtype=torch.int, device=in_points.device ) masks, ious = self.refine_with_m2m( in_points, labels, data["low_res_masks"], self.points_per_batch ) data["masks"] = masks.squeeze(1) data["iou_preds"] = ious.squeeze(1) if self.pred_iou_thresh > 0.0: keep_mask = data["iou_preds"] > self.pred_iou_thresh data.filter(keep_mask) data["stability_score"] = calculate_stability_score( data["masks"], self.mask_threshold, self.stability_score_offset ) if self.stability_score_thresh > 0.0: keep_mask = data["stability_score"] >= self.stability_score_thresh data.filter(keep_mask) # Threshold masks and calculate boxes data["masks"] = data["masks"] > self.mask_threshold data["boxes"] = batched_mask_to_box(data["masks"]) # Filter boxes that touch crop boundaries keep_mask = ~is_box_near_crop_edge( data["boxes"], crop_box, [0, 0, orig_w, orig_h] ) if not torch.all(keep_mask): data.filter(keep_mask) # Compress to RLE data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) data["rles"] = mask_to_rle_pytorch(data["masks"]) del data["masks"] return data @staticmethod def postprocess_small_regions( mask_data: MaskData, min_area: int, nms_thresh: float ) -> MaskData: """ Removes small disconnected regions and holes in masks, then reruns box NMS to remove any new duplicates. Edits mask_data in place. Requires open-cv as a dependency. """ if len(mask_data["rles"]) == 0: return mask_data # Filter small disconnected regions and holes new_masks = [] scores = [] for rle in mask_data["rles"]: mask = rle_to_mask(rle) mask, changed = remove_small_regions(mask, min_area, mode="holes") unchanged = not changed mask, changed = remove_small_regions(mask, min_area, mode="islands") unchanged = unchanged and not changed new_masks.append(torch.as_tensor(mask).unsqueeze(0)) # Give score=0 to changed masks and score=1 to unchanged masks # so NMS will prefer ones that didn't need postprocessing scores.append(float(unchanged)) # Recalculate boxes and remove any new duplicates masks = torch.cat(new_masks, dim=0) boxes = batched_mask_to_box(masks) keep_by_nms = batched_nms( boxes.float(), torch.as_tensor(scores), torch.zeros_like(boxes[:, 0]), # categories iou_threshold=nms_thresh, ) # Only recalculate RLEs for masks that have changed for i_mask in keep_by_nms: if scores[i_mask] == 0.0: mask_torch = masks[i_mask].unsqueeze(0) mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly mask_data.filter(keep_by_nms) return mask_data def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch): new_masks = [] new_iou_preds = [] for cur_points, cur_point_labels, low_res_mask in batch_iterator( points_per_batch, points, point_labels, low_res_masks ): best_masks, best_iou_preds, _ = self.predictor._predict( cur_points[:, None, :], cur_point_labels[:, None], mask_input=low_res_mask[:, None, :], multimask_output=False, return_logits=True, ) new_masks.append(best_masks) new_iou_preds.append(best_iou_preds) masks = torch.cat(new_masks, dim=0) return masks, torch.cat(new_iou_preds, dim=0) ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/benchmark.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import time import numpy as np import torch from tqdm import tqdm from sam2.build_sam import build_sam2_video_predictor # Only cuda supported assert torch.cuda.is_available() device = torch.device("cuda") torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() if torch.cuda.get_device_properties(0).major >= 8: # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # Config and checkpoint sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt" model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml" # Build video predictor with vos_optimized=True setting predictor = build_sam2_video_predictor( model_cfg, sam2_checkpoint, device=device, vos_optimized=True ) # Initialize with video video_dir = "notebooks/videos/bedroom" # scan all the JPEG frame names in this directory frame_names = [ p for p in os.listdir(video_dir) if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] ] frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) inference_state = predictor.init_state(video_path=video_dir) # Number of runs, warmup etc warm_up, runs = 5, 25 verbose = True num_frames = len(frame_names) total, count = 0, 0 torch.cuda.empty_cache() # We will select an object with a click. # See video_predictor_example.ipynb for more detailed explanation ann_frame_idx, ann_obj_id = 0, 1 # Add a positive click at (x, y) = (210, 350) # For labels, `1` means positive click points = np.array([[210, 350]], dtype=np.float32) labels = np.array([1], np.int32) _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( inference_state=inference_state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, points=points, labels=labels, ) # Warmup and then average FPS over several runs with torch.autocast("cuda", torch.bfloat16): with torch.inference_mode(): for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"): start = time.time() # Start tracking for ( out_frame_idx, out_obj_ids, out_mask_logits, ) in predictor.propagate_in_video(inference_state): pass end = time.time() total += end - start count += 1 if i == warm_up - 1: print("Warmup FPS: ", count * num_frames / total) total = 0 count = 0 print("FPS: ", count * num_frames / total) ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/build_sam.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging import os import torch from hydra import compose from hydra.utils import instantiate from omegaconf import OmegaConf import sam2 # Check if the user is running Python from the parent directory of the sam2 repo # (i.e. the directory where this repo is cloned into) -- this is not supported since # it could shadow the sam2 package and cause issues. if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")): # If the user has "sam2/sam2" in their path, they are likey importing the repo itself # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory). # This typically happens because the user is running Python from the parent directory # that contains the sam2 repo they cloned. raise RuntimeError( "You're likely running Python from the parent directory of the sam2 repository " "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). " "This is not supported since the `sam2` Python package could be shadowed by the " "repository name (the repository is also named `sam2` and contains the Python package " "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir " "rather than its parent dir, or from your home directory) after installing SAM 2." ) HF_MODEL_ID_TO_FILENAMES = { "facebook/sam2-hiera-tiny": ( "configs/sam2/sam2_hiera_t.yaml", "sam2_hiera_tiny.pt", ), "facebook/sam2-hiera-small": ( "configs/sam2/sam2_hiera_s.yaml", "sam2_hiera_small.pt", ), "facebook/sam2-hiera-base-plus": ( "configs/sam2/sam2_hiera_b+.yaml", "sam2_hiera_base_plus.pt", ), "facebook/sam2-hiera-large": ( "configs/sam2/sam2_hiera_l.yaml", "sam2_hiera_large.pt", ), "facebook/sam2.1-hiera-tiny": ( "configs/sam2.1/sam2.1_hiera_t.yaml", "sam2.1_hiera_tiny.pt", ), "facebook/sam2.1-hiera-small": ( "configs/sam2.1/sam2.1_hiera_s.yaml", "sam2.1_hiera_small.pt", ), "facebook/sam2.1-hiera-base-plus": ( "configs/sam2.1/sam2.1_hiera_b+.yaml", "sam2.1_hiera_base_plus.pt", ), "facebook/sam2.1-hiera-large": ( "configs/sam2.1/sam2.1_hiera_l.yaml", "sam2.1_hiera_large.pt", ), } def build_sam2( config_file, ckpt_path=None, device="cuda", mode="eval", hydra_overrides_extra=[], apply_postprocessing=True, **kwargs, ): if apply_postprocessing: hydra_overrides_extra = hydra_overrides_extra.copy() hydra_overrides_extra += [ # dynamically fall back to multi-mask if the single mask is not stable "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", ] # Read config and init model cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) OmegaConf.resolve(cfg) model = instantiate(cfg.model, _recursive_=True) _load_checkpoint(model, ckpt_path) model = model.to(device) if mode == "eval": model.eval() return model def build_sam2_video_predictor( config_file, ckpt_path=None, device="cuda", mode="eval", hydra_overrides_extra=[], apply_postprocessing=True, vos_optimized=False, **kwargs, ): hydra_overrides = [ "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", ] if vos_optimized: hydra_overrides = [ "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS", "++model.compile_image_encoder=True", # Let sam2_base handle this ] if apply_postprocessing: hydra_overrides_extra = hydra_overrides_extra.copy() hydra_overrides_extra += [ # dynamically fall back to multi-mask if the single mask is not stable "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking "++model.binarize_mask_from_pts_for_mem_enc=true", # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) "++model.fill_hole_area=8", ] hydra_overrides.extend(hydra_overrides_extra) # Read config and init model cfg = compose(config_name=config_file, overrides=hydra_overrides) OmegaConf.resolve(cfg) model = instantiate(cfg.model, _recursive_=True) _load_checkpoint(model, ckpt_path) model = model.to(device) if mode == "eval": model.eval() return model def _hf_download(model_id): from huggingface_hub import hf_hub_download config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id] ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) return config_name, ckpt_path def build_sam2_hf(model_id, **kwargs): config_name, ckpt_path = _hf_download(model_id) return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs) def build_sam2_video_predictor_hf(model_id, **kwargs): config_name, ckpt_path = _hf_download(model_id) return build_sam2_video_predictor( config_file=config_name, ckpt_path=ckpt_path, **kwargs ) def _load_checkpoint(model, ckpt_path): if ckpt_path is not None: sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] missing_keys, unexpected_keys = model.load_state_dict(sd) if missing_keys: logging.error(missing_keys) raise RuntimeError() if unexpected_keys: logging.error(unexpected_keys) raise RuntimeError() logging.info("Loaded checkpoint sucessfully") ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/configs/sam2/sam2_hiera_b+.yaml ================================================ # @package _global_ # Model model: _target_: sam2.modeling.sam2_base.SAM2Base image_encoder: _target_: sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: _target_: sam2.modeling.backbones.hieradet.Hiera embed_dim: 112 num_heads: 2 neck: _target_: sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null temperature: 10000 d_model: 256 backbone_channel_list: [896, 448, 224, 112] fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features fpn_interp_model: nearest memory_attention: _target_: sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: _target_: sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 d_model: 256 pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 kv_in_dim: 64 num_layers: 4 memory_encoder: _target_: sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: _target_: sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: _target_: sam2.modeling.memory_encoder.Fuser layer: _target_: sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 layer_scale_init_value: 1e-6 use_dwconv: True # depth-wise convs num_layers: 2 num_maskmem: 7 image_size: 1024 # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask sigmoid_scale_for_mem_enc: 20.0 sigmoid_bias_for_mem_enc: -10.0 use_mask_input_as_output_without_sam: true # Memory directly_add_no_mem_embed: true # use high-resolution feature map in the SAM mask decoder use_high_res_features_in_sam: true # output 3 masks on the first click on initial conditioning frames multimask_output_in_sam: true # SAM heads iou_prediction_use_sigmoid: True # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder use_obj_ptrs_in_encoder: true add_tpos_enc_to_obj_ptrs: false only_obj_ptrs_in_the_past_for_eval: true # object occlusion prediction pred_obj_scores: true pred_obj_scores_mlp: true fixed_no_obj_ptr: true # multimask tracking settings multimask_output_for_tracking: true use_multimask_token_for_obj_ptr: true multimask_min_pt_num: 0 multimask_max_pt_num: 1 use_mlp_for_obj_ptr_proj: true # Compilation flag compile_image_encoder: False ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/configs/sam2/sam2_hiera_l.yaml ================================================ # @package _global_ # Model model: _target_: sam2.modeling.sam2_base.SAM2Base image_encoder: _target_: sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: _target_: sam2.modeling.backbones.hieradet.Hiera embed_dim: 144 num_heads: 2 stages: [2, 6, 36, 4] global_att_blocks: [23, 33, 43] window_pos_embed_bkg_spatial_size: [7, 7] window_spec: [8, 4, 16, 8] neck: _target_: sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null temperature: 10000 d_model: 256 backbone_channel_list: [1152, 576, 288, 144] fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features fpn_interp_model: nearest memory_attention: _target_: sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: _target_: sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 d_model: 256 pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 kv_in_dim: 64 num_layers: 4 memory_encoder: _target_: sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: _target_: sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: _target_: sam2.modeling.memory_encoder.Fuser layer: _target_: sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 layer_scale_init_value: 1e-6 use_dwconv: True # depth-wise convs num_layers: 2 num_maskmem: 7 image_size: 1024 # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask sigmoid_scale_for_mem_enc: 20.0 sigmoid_bias_for_mem_enc: -10.0 use_mask_input_as_output_without_sam: true # Memory directly_add_no_mem_embed: true # use high-resolution feature map in the SAM mask decoder use_high_res_features_in_sam: true # output 3 masks on the first click on initial conditioning frames multimask_output_in_sam: true # SAM heads iou_prediction_use_sigmoid: True # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder use_obj_ptrs_in_encoder: true add_tpos_enc_to_obj_ptrs: false only_obj_ptrs_in_the_past_for_eval: true # object occlusion prediction pred_obj_scores: true pred_obj_scores_mlp: true fixed_no_obj_ptr: true # multimask tracking settings multimask_output_for_tracking: true use_multimask_token_for_obj_ptr: true multimask_min_pt_num: 0 multimask_max_pt_num: 1 use_mlp_for_obj_ptr_proj: true # Compilation flag compile_image_encoder: False ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/configs/sam2/sam2_hiera_s.yaml ================================================ # @package _global_ # Model model: _target_: sam2.modeling.sam2_base.SAM2Base image_encoder: _target_: sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: _target_: sam2.modeling.backbones.hieradet.Hiera embed_dim: 96 num_heads: 1 stages: [1, 2, 11, 2] global_att_blocks: [7, 10, 13] window_pos_embed_bkg_spatial_size: [7, 7] neck: _target_: sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null temperature: 10000 d_model: 256 backbone_channel_list: [768, 384, 192, 96] fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features fpn_interp_model: nearest memory_attention: _target_: sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: _target_: sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 d_model: 256 pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 kv_in_dim: 64 num_layers: 4 memory_encoder: _target_: sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: _target_: sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: _target_: sam2.modeling.memory_encoder.Fuser layer: _target_: sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 layer_scale_init_value: 1e-6 use_dwconv: True # depth-wise convs num_layers: 2 num_maskmem: 7 image_size: 1024 # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask sigmoid_scale_for_mem_enc: 20.0 sigmoid_bias_for_mem_enc: -10.0 use_mask_input_as_output_without_sam: true # Memory directly_add_no_mem_embed: true # use high-resolution feature map in the SAM mask decoder use_high_res_features_in_sam: true # output 3 masks on the first click on initial conditioning frames multimask_output_in_sam: true # SAM heads iou_prediction_use_sigmoid: True # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder use_obj_ptrs_in_encoder: true add_tpos_enc_to_obj_ptrs: false only_obj_ptrs_in_the_past_for_eval: true # object occlusion prediction pred_obj_scores: true pred_obj_scores_mlp: true fixed_no_obj_ptr: true # multimask tracking settings multimask_output_for_tracking: true use_multimask_token_for_obj_ptr: true multimask_min_pt_num: 0 multimask_max_pt_num: 1 use_mlp_for_obj_ptr_proj: true # Compilation flag compile_image_encoder: False ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/configs/sam2/sam2_hiera_t.yaml ================================================ # @package _global_ # Model model: _target_: sam2.modeling.sam2_base.SAM2Base image_encoder: _target_: sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: _target_: sam2.modeling.backbones.hieradet.Hiera embed_dim: 96 num_heads: 1 stages: [1, 2, 7, 2] global_att_blocks: [5, 7, 9] window_pos_embed_bkg_spatial_size: [7, 7] neck: _target_: sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null temperature: 10000 d_model: 256 backbone_channel_list: [768, 384, 192, 96] fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features fpn_interp_model: nearest memory_attention: _target_: sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: _target_: sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 d_model: 256 pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 kv_in_dim: 64 num_layers: 4 memory_encoder: _target_: sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: _target_: sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: _target_: sam2.modeling.memory_encoder.Fuser layer: _target_: sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 layer_scale_init_value: 1e-6 use_dwconv: True # depth-wise convs num_layers: 2 num_maskmem: 7 image_size: 1024 # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask # SAM decoder sigmoid_scale_for_mem_enc: 20.0 sigmoid_bias_for_mem_enc: -10.0 use_mask_input_as_output_without_sam: true # Memory directly_add_no_mem_embed: true # use high-resolution feature map in the SAM mask decoder use_high_res_features_in_sam: true # output 3 masks on the first click on initial conditioning frames multimask_output_in_sam: true # SAM heads iou_prediction_use_sigmoid: True # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder use_obj_ptrs_in_encoder: true add_tpos_enc_to_obj_ptrs: false only_obj_ptrs_in_the_past_for_eval: true # object occlusion prediction pred_obj_scores: true pred_obj_scores_mlp: true fixed_no_obj_ptr: true # multimask tracking settings multimask_output_for_tracking: true use_multimask_token_for_obj_ptr: true multimask_min_pt_num: 0 multimask_max_pt_num: 1 use_mlp_for_obj_ptr_proj: true # Compilation flag # HieraT does not currently support compilation, should always be set to False compile_image_encoder: False ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml ================================================ # @package _global_ # Model model: _target_: sam2.modeling.sam2_base.SAM2Base image_encoder: _target_: sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: _target_: sam2.modeling.backbones.hieradet.Hiera embed_dim: 112 num_heads: 2 neck: _target_: sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null temperature: 10000 d_model: 256 backbone_channel_list: [896, 448, 224, 112] fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features fpn_interp_model: nearest memory_attention: _target_: sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: _target_: sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 d_model: 256 pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 kv_in_dim: 64 num_layers: 4 memory_encoder: _target_: sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: _target_: sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: _target_: sam2.modeling.memory_encoder.Fuser layer: _target_: sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 layer_scale_init_value: 1e-6 use_dwconv: True # depth-wise convs num_layers: 2 num_maskmem: 7 image_size: 1024 # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask sigmoid_scale_for_mem_enc: 20.0 sigmoid_bias_for_mem_enc: -10.0 use_mask_input_as_output_without_sam: true # Memory directly_add_no_mem_embed: true no_obj_embed_spatial: true # use high-resolution feature map in the SAM mask decoder use_high_res_features_in_sam: true # output 3 masks on the first click on initial conditioning frames multimask_output_in_sam: true # SAM heads iou_prediction_use_sigmoid: True # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder use_obj_ptrs_in_encoder: true add_tpos_enc_to_obj_ptrs: true proj_tpos_enc_in_obj_ptrs: true use_signed_tpos_enc_to_obj_ptrs: true only_obj_ptrs_in_the_past_for_eval: true # object occlusion prediction pred_obj_scores: true pred_obj_scores_mlp: true fixed_no_obj_ptr: true # multimask tracking settings multimask_output_for_tracking: true use_multimask_token_for_obj_ptr: true multimask_min_pt_num: 0 multimask_max_pt_num: 1 use_mlp_for_obj_ptr_proj: true # Compilation flag compile_image_encoder: False ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/configs/sam2.1/sam2.1_hiera_l.yaml ================================================ # @package _global_ # Model model: _target_: sam2.modeling.sam2_base.SAM2Base image_encoder: _target_: sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: _target_: sam2.modeling.backbones.hieradet.Hiera embed_dim: 144 num_heads: 2 stages: [2, 6, 36, 4] global_att_blocks: [23, 33, 43] window_pos_embed_bkg_spatial_size: [7, 7] window_spec: [8, 4, 16, 8] neck: _target_: sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null temperature: 10000 d_model: 256 backbone_channel_list: [1152, 576, 288, 144] fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features fpn_interp_model: nearest memory_attention: _target_: sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: _target_: sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 d_model: 256 pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 kv_in_dim: 64 num_layers: 4 memory_encoder: _target_: sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: _target_: sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: _target_: sam2.modeling.memory_encoder.Fuser layer: _target_: sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 layer_scale_init_value: 1e-6 use_dwconv: True # depth-wise convs num_layers: 2 num_maskmem: 7 image_size: 1024 # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask sigmoid_scale_for_mem_enc: 20.0 sigmoid_bias_for_mem_enc: -10.0 use_mask_input_as_output_without_sam: true # Memory directly_add_no_mem_embed: true no_obj_embed_spatial: true # use high-resolution feature map in the SAM mask decoder use_high_res_features_in_sam: true # output 3 masks on the first click on initial conditioning frames multimask_output_in_sam: true # SAM heads iou_prediction_use_sigmoid: True # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder use_obj_ptrs_in_encoder: true add_tpos_enc_to_obj_ptrs: true proj_tpos_enc_in_obj_ptrs: true use_signed_tpos_enc_to_obj_ptrs: true only_obj_ptrs_in_the_past_for_eval: true # object occlusion prediction pred_obj_scores: true pred_obj_scores_mlp: true fixed_no_obj_ptr: true # multimask tracking settings multimask_output_for_tracking: true use_multimask_token_for_obj_ptr: true multimask_min_pt_num: 0 multimask_max_pt_num: 1 use_mlp_for_obj_ptr_proj: true # Compilation flag compile_image_encoder: False ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/configs/sam2.1/sam2.1_hiera_s.yaml ================================================ # @package _global_ # Model model: _target_: sam2.modeling.sam2_base.SAM2Base image_encoder: _target_: sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: _target_: sam2.modeling.backbones.hieradet.Hiera embed_dim: 96 num_heads: 1 stages: [1, 2, 11, 2] global_att_blocks: [7, 10, 13] window_pos_embed_bkg_spatial_size: [7, 7] neck: _target_: sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null temperature: 10000 d_model: 256 backbone_channel_list: [768, 384, 192, 96] fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features fpn_interp_model: nearest memory_attention: _target_: sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: _target_: sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 d_model: 256 pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 kv_in_dim: 64 num_layers: 4 memory_encoder: _target_: sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: _target_: sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: _target_: sam2.modeling.memory_encoder.Fuser layer: _target_: sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 layer_scale_init_value: 1e-6 use_dwconv: True # depth-wise convs num_layers: 2 num_maskmem: 7 image_size: 1024 # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask sigmoid_scale_for_mem_enc: 20.0 sigmoid_bias_for_mem_enc: -10.0 use_mask_input_as_output_without_sam: true # Memory directly_add_no_mem_embed: true no_obj_embed_spatial: true # use high-resolution feature map in the SAM mask decoder use_high_res_features_in_sam: true # output 3 masks on the first click on initial conditioning frames multimask_output_in_sam: true # SAM heads iou_prediction_use_sigmoid: True # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder use_obj_ptrs_in_encoder: true add_tpos_enc_to_obj_ptrs: true proj_tpos_enc_in_obj_ptrs: true use_signed_tpos_enc_to_obj_ptrs: true only_obj_ptrs_in_the_past_for_eval: true # object occlusion prediction pred_obj_scores: true pred_obj_scores_mlp: true fixed_no_obj_ptr: true # multimask tracking settings multimask_output_for_tracking: true use_multimask_token_for_obj_ptr: true multimask_min_pt_num: 0 multimask_max_pt_num: 1 use_mlp_for_obj_ptr_proj: true # Compilation flag compile_image_encoder: False ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/configs/sam2.1/sam2.1_hiera_t.yaml ================================================ # @package _global_ # Model model: _target_: sam2.modeling.sam2_base.SAM2Base image_encoder: _target_: sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: _target_: sam2.modeling.backbones.hieradet.Hiera embed_dim: 96 num_heads: 1 stages: [1, 2, 7, 2] global_att_blocks: [5, 7, 9] window_pos_embed_bkg_spatial_size: [7, 7] neck: _target_: sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null temperature: 10000 d_model: 256 backbone_channel_list: [768, 384, 192, 96] fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features fpn_interp_model: nearest memory_attention: _target_: sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: _target_: sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 d_model: 256 pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 kv_in_dim: 64 num_layers: 4 memory_encoder: _target_: sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: _target_: sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: _target_: sam2.modeling.memory_encoder.Fuser layer: _target_: sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 layer_scale_init_value: 1e-6 use_dwconv: True # depth-wise convs num_layers: 2 num_maskmem: 7 image_size: 1024 # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask # SAM decoder sigmoid_scale_for_mem_enc: 20.0 sigmoid_bias_for_mem_enc: -10.0 use_mask_input_as_output_without_sam: true # Memory directly_add_no_mem_embed: true no_obj_embed_spatial: true # use high-resolution feature map in the SAM mask decoder use_high_res_features_in_sam: true # output 3 masks on the first click on initial conditioning frames multimask_output_in_sam: true # SAM heads iou_prediction_use_sigmoid: True # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder use_obj_ptrs_in_encoder: true add_tpos_enc_to_obj_ptrs: true proj_tpos_enc_in_obj_ptrs: true use_signed_tpos_enc_to_obj_ptrs: true only_obj_ptrs_in_the_past_for_eval: true # object occlusion prediction pred_obj_scores: true pred_obj_scores_mlp: true fixed_no_obj_ptr: true # multimask tracking settings multimask_output_for_tracking: true use_multimask_token_for_obj_ptr: true multimask_min_pt_num: 0 multimask_max_pt_num: 1 use_mlp_for_obj_ptr_proj: true # Compilation flag # HieraT does not currently support compilation, should always be set to False compile_image_encoder: False ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml ================================================ # @package _global_ scratch: resolution: 1024 train_batch_size: 1 num_train_workers: 10 num_frames: 8 max_num_objects: 3 base_lr: 5.0e-6 vision_lr: 3.0e-06 phases_per_epoch: 1 num_epochs: 40 dataset: # PATHS to Dataset img_folder: null # PATH to MOSE JPEGImages folder gt_folder: null # PATH to MOSE Annotations folder file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training multiplier: 2 # Video transforms vos: train_transforms: - _target_: training.dataset.transforms.ComposeAPI transforms: - _target_: training.dataset.transforms.RandomHorizontalFlip consistent_transform: True - _target_: training.dataset.transforms.RandomAffine degrees: 25 shear: 20 image_interpolation: bilinear consistent_transform: True - _target_: training.dataset.transforms.RandomResizeAPI sizes: ${scratch.resolution} square: true consistent_transform: True - _target_: training.dataset.transforms.ColorJitter consistent_transform: True brightness: 0.1 contrast: 0.03 saturation: 0.03 hue: null - _target_: training.dataset.transforms.RandomGrayscale p: 0.05 consistent_transform: True - _target_: training.dataset.transforms.ColorJitter consistent_transform: False brightness: 0.1 contrast: 0.05 saturation: 0.05 hue: null - _target_: training.dataset.transforms.ToTensorAPI - _target_: training.dataset.transforms.NormalizeAPI mean: [0.485, 0.456, 0.406] std: [0.229, 0.224, 0.225] trainer: _target_: training.trainer.Trainer mode: train_only max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}} accelerator: cuda seed_value: 123 model: _target_: training.model.sam2.SAM2Train image_encoder: _target_: sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: _target_: sam2.modeling.backbones.hieradet.Hiera embed_dim: 112 num_heads: 2 drop_path_rate: 0.1 neck: _target_: sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null temperature: 10000 d_model: 256 backbone_channel_list: [896, 448, 224, 112] fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features fpn_interp_model: nearest memory_attention: _target_: sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: _target_: sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 d_model: 256 pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 kv_in_dim: 64 num_layers: 4 memory_encoder: _target_: sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: _target_: sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: _target_: sam2.modeling.memory_encoder.Fuser layer: _target_: sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 layer_scale_init_value: 1e-6 use_dwconv: True # depth-wise convs num_layers: 2 num_maskmem: 7 image_size: ${scratch.resolution} # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask sigmoid_scale_for_mem_enc: 20.0 sigmoid_bias_for_mem_enc: -10.0 use_mask_input_as_output_without_sam: true # Memory directly_add_no_mem_embed: true no_obj_embed_spatial: true # use high-resolution feature map in the SAM mask decoder use_high_res_features_in_sam: true # output 3 masks on the first click on initial conditioning frames multimask_output_in_sam: true # SAM heads iou_prediction_use_sigmoid: True # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder use_obj_ptrs_in_encoder: true add_tpos_enc_to_obj_ptrs: true proj_tpos_enc_in_obj_ptrs: true use_signed_tpos_enc_to_obj_ptrs: true only_obj_ptrs_in_the_past_for_eval: true # object occlusion prediction pred_obj_scores: true pred_obj_scores_mlp: true fixed_no_obj_ptr: true # multimask tracking settings multimask_output_for_tracking: true use_multimask_token_for_obj_ptr: true multimask_min_pt_num: 0 multimask_max_pt_num: 1 use_mlp_for_obj_ptr_proj: true # Compilation flag # compile_image_encoder: False ####### Training specific params ####### # box/point input and corrections prob_to_use_pt_input_for_train: 0.5 prob_to_use_pt_input_for_eval: 0.0 prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points prob_to_use_box_input_for_eval: 0.0 prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame) num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2 add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame) # maximum 2 initial conditioning frames num_init_cond_frames_for_train: 2 rand_init_cond_frames_for_train: True # random 1~2 num_correction_pt_per_frame: 7 use_act_ckpt_iterative_pt_sampling: false num_init_cond_frames_for_eval: 1 # only mask on the first frame forward_backbone_per_frame_for_eval: True data: train: _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset phases_per_epoch: ${scratch.phases_per_epoch} batch_sizes: - ${scratch.train_batch_size} datasets: - _target_: training.dataset.utils.RepeatFactorWrapper dataset: _target_: training.dataset.utils.ConcatDataset datasets: - _target_: training.dataset.vos_dataset.VOSDataset transforms: ${vos.train_transforms} training: true video_dataset: _target_: training.dataset.vos_raw_dataset.PNGRawDataset img_folder: ${dataset.img_folder} gt_folder: ${dataset.gt_folder} file_list_txt: ${dataset.file_list_txt} sampler: _target_: training.dataset.vos_sampler.RandomUniformSampler num_frames: ${scratch.num_frames} max_num_objects: ${scratch.max_num_objects} multiplier: ${dataset.multiplier} shuffle: True num_workers: ${scratch.num_train_workers} pin_memory: True drop_last: True collate_fn: _target_: training.utils.data_utils.collate_fn _partial_: true dict_key: all optim: amp: enabled: True amp_dtype: bfloat16 optimizer: _target_: torch.optim.AdamW gradient_clip: _target_: training.optimizer.GradientClipper max_norm: 0.1 norm_type: 2 param_group_modifiers: - _target_: training.optimizer.layer_decay_param_modifier _partial_: True layer_decay_value: 0.9 apply_to: 'image_encoder.trunk' overrides: - pattern: '*pos_embed*' value: 1.0 options: lr: - scheduler: _target_: fvcore.common.param_scheduler.CosineParamScheduler start_value: ${scratch.base_lr} end_value: ${divide:${scratch.base_lr},10} - scheduler: _target_: fvcore.common.param_scheduler.CosineParamScheduler start_value: ${scratch.vision_lr} end_value: ${divide:${scratch.vision_lr},10} param_names: - 'image_encoder.*' weight_decay: - scheduler: _target_: fvcore.common.param_scheduler.ConstantParamScheduler value: 0.1 - scheduler: _target_: fvcore.common.param_scheduler.ConstantParamScheduler value: 0.0 param_names: - '*bias*' module_cls_names: ['torch.nn.LayerNorm'] loss: all: _target_: training.loss_fns.MultiStepMultiMasksAndIous weight_dict: loss_mask: 20 loss_dice: 1 loss_iou: 1 loss_class: 1 supervise_all_iou: true iou_use_l1_loss: true pred_obj_scores: true focal_gamma_obj_score: 0.0 focal_alpha_obj_score: -1.0 distributed: backend: nccl find_unused_parameters: True logging: tensorboard_writer: _target_: training.utils.logger.make_tensorboard_logger log_dir: ${launcher.experiment_log_dir}/tensorboard flush_secs: 120 should_log: True log_dir: ${launcher.experiment_log_dir}/logs log_freq: 10 # initialize from a SAM 2 checkpoint checkpoint: save_dir: ${launcher.experiment_log_dir}/checkpoints save_freq: 0 # 0 only last checkpoint is saved. model_weight_initializer: _partial_: True _target_: training.utils.checkpoint_utils.load_state_dict_into_model strict: True ignore_unexpected_keys: null ignore_missing_keys: null state_dict: _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint ckpt_state_dict_keys: ['model'] launcher: num_nodes: 1 gpus_per_node: 8 experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name} # SLURM args if running on a cluster submitit: partition: null account: null qos: null cpus_per_task: 10 use_cluster: false timeout_hour: 24 name: null port_range: [10000, 65000] ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/csrc/connected_components.cu ================================================ // Copyright (c) Meta Platforms, Inc. and affiliates. // All rights reserved. // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. // adapted from https://github.com/zsef123/Connected_components_PyTorch // with license found in the LICENSE_cctorch file in the root directory. #include #include #include #include #include #include // 2d #define BLOCK_ROWS 16 #define BLOCK_COLS 16 namespace cc2d { template __device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) { return (bitmap >> pos) & 1; } __device__ int32_t find(const int32_t* s_buf, int32_t n) { while (s_buf[n] != n) n = s_buf[n]; return n; } __device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) { const int32_t id = n; while (s_buf[n] != n) { n = s_buf[n]; s_buf[id] = n; } return n; } __device__ void union_(int32_t* s_buf, int32_t a, int32_t b) { bool done; do { a = find(s_buf, a); b = find(s_buf, b); if (a < b) { int32_t old = atomicMin(s_buf + b, a); done = (old == b); b = old; } else if (b < a) { int32_t old = atomicMin(s_buf + a, b); done = (old == a); a = old; } else done = true; } while (!done); } __global__ void init_labeling(int32_t* label, const uint32_t W, const uint32_t H) { const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; const uint32_t idx = row * W + col; if (row < H && col < W) label[idx] = idx; } __global__ void merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) { const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; const uint32_t idx = row * W + col; if (row >= H || col >= W) return; uint32_t P = 0; if (img[idx]) P |= 0x777; if (row + 1 < H && img[idx + W]) P |= 0x777 << 4; if (col + 1 < W && img[idx + 1]) P |= 0x777 << 1; if (col == 0) P &= 0xEEEE; if (col + 1 >= W) P &= 0x3333; else if (col + 2 >= W) P &= 0x7777; if (row == 0) P &= 0xFFF0; if (row + 1 >= H) P &= 0xFF; if (P > 0) { // If need check about top-left pixel(if flag the first bit) and hit the // top-left pixel if (hasBit(P, 0) && img[idx - W - 1]) { union_(label, idx, idx - 2 * W - 2); // top left block } if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1])) union_(label, idx, idx - 2 * W); // top bottom block if (hasBit(P, 3) && img[idx + 2 - W]) union_(label, idx, idx - 2 * W + 2); // top right block if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1])) union_(label, idx, idx - 2); // just left block } } __global__ void compression(int32_t* label, const int32_t W, const int32_t H) { const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; const uint32_t idx = row * W + col; if (row < H && col < W) find_n_compress(label, idx); } __global__ void final_labeling( const uint8_t* img, int32_t* label, const int32_t W, const int32_t H) { const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; const uint32_t idx = row * W + col; if (row >= H || col >= W) return; int32_t y = label[idx] + 1; if (img[idx]) label[idx] = y; else label[idx] = 0; if (col + 1 < W) { if (img[idx + 1]) label[idx + 1] = y; else label[idx + 1] = 0; if (row + 1 < H) { if (img[idx + W + 1]) label[idx + W + 1] = y; else label[idx + W + 1] = 0; } } if (row + 1 < H) { if (img[idx + W]) label[idx + W] = y; else label[idx + W] = 0; } } __global__ void init_counting( const int32_t* label, int32_t* count_init, const int32_t W, const int32_t H) { const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); const uint32_t idx = row * W + col; if (row >= H || col >= W) return; int32_t y = label[idx]; if (y > 0) { int32_t count_idx = y - 1; atomicAdd(count_init + count_idx, 1); } } __global__ void final_counting( const int32_t* label, const int32_t* count_init, int32_t* count_final, const int32_t W, const int32_t H) { const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); const uint32_t idx = row * W + col; if (row >= H || col >= W) return; int32_t y = label[idx]; if (y > 0) { int32_t count_idx = y - 1; count_final[idx] = count_init[count_idx]; } else { count_final[idx] = 0; } } } // namespace cc2d std::vector get_connected_componnets( const torch::Tensor& inputs) { AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor"); AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape"); AT_ASSERTM( inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type"); const uint32_t N = inputs.size(0); const uint32_t C = inputs.size(1); const uint32_t H = inputs.size(2); const uint32_t W = inputs.size(3); AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape"); AT_ASSERTM((H % 2) == 0, "height must be an even number"); AT_ASSERTM((W % 2) == 0, "width must be an even number"); // label must be uint32_t auto label_options = torch::TensorOptions().dtype(torch::kInt32).device(inputs.device()); torch::Tensor labels = torch::zeros({N, C, H, W}, label_options); torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options); torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options); dim3 grid = dim3( ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS, ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS); dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS); dim3 grid_count = dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS); dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); for (int n = 0; n < N; n++) { uint32_t offset = n * H * W; cc2d::init_labeling<<>>( labels.data_ptr() + offset, W, H); cc2d::merge<<>>( inputs.data_ptr() + offset, labels.data_ptr() + offset, W, H); cc2d::compression<<>>( labels.data_ptr() + offset, W, H); cc2d::final_labeling<<>>( inputs.data_ptr() + offset, labels.data_ptr() + offset, W, H); // get the counting of each pixel cc2d::init_counting<<>>( labels.data_ptr() + offset, counts_init.data_ptr() + offset, W, H); cc2d::final_counting<<>>( labels.data_ptr() + offset, counts_init.data_ptr() + offset, counts_final.data_ptr() + offset, W, H); } // returned values are [labels, counts] std::vector outputs; outputs.push_back(labels); outputs.push_back(counts_final); return outputs; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( "get_connected_componnets", &get_connected_componnets, "get_connected_componnets"); } ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/backbones/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/backbones/hieradet.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging from functools import partial from typing import List, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from iopath.common.file_io import g_pathmgr from sam2.modeling.backbones.utils import ( PatchEmbed, window_partition, window_unpartition, ) from sam2.modeling.sam2_utils import DropPath, MLP def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: if pool is None: return x # (B, H, W, C) -> (B, C, H, W) x = x.permute(0, 3, 1, 2) x = pool(x) # (B, C, H', W') -> (B, H', W', C) x = x.permute(0, 2, 3, 1) if norm: x = norm(x) return x class MultiScaleAttention(nn.Module): def __init__( self, dim: int, dim_out: int, num_heads: int, q_pool: nn.Module = None, ): super().__init__() self.dim = dim self.dim_out = dim_out self.num_heads = num_heads self.q_pool = q_pool self.qkv = nn.Linear(dim, dim_out * 3) self.proj = nn.Linear(dim_out, dim_out) def forward(self, x: torch.Tensor) -> torch.Tensor: B, H, W, _ = x.shape # qkv with shape (B, H * W, 3, nHead, C) qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) # q, k, v with shape (B, H * W, nheads, C) q, k, v = torch.unbind(qkv, 2) # Q pooling (for downsample at stage changes) if self.q_pool: q = do_pool(q.reshape(B, H, W, -1), self.q_pool) H, W = q.shape[1:3] # downsampled shape q = q.reshape(B, H * W, self.num_heads, -1) # Torch's SDPA expects [B, nheads, H*W, C] so we transpose x = F.scaled_dot_product_attention( q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), ) # Transpose back x = x.transpose(1, 2) x = x.reshape(B, H, W, -1) x = self.proj(x) return x class MultiScaleBlock(nn.Module): def __init__( self, dim: int, dim_out: int, num_heads: int, mlp_ratio: float = 4.0, drop_path: float = 0.0, norm_layer: Union[nn.Module, str] = "LayerNorm", q_stride: Tuple[int, int] = None, act_layer: nn.Module = nn.GELU, window_size: int = 0, ): super().__init__() if isinstance(norm_layer, str): norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) self.dim = dim self.dim_out = dim_out self.norm1 = norm_layer(dim) self.window_size = window_size self.pool, self.q_stride = None, q_stride if self.q_stride: self.pool = nn.MaxPool2d( kernel_size=q_stride, stride=q_stride, ceil_mode=False ) self.attn = MultiScaleAttention( dim, dim_out, num_heads=num_heads, q_pool=self.pool, ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim_out) self.mlp = MLP( dim_out, int(dim_out * mlp_ratio), dim_out, num_layers=2, activation=act_layer, ) if dim != dim_out: self.proj = nn.Linear(dim, dim_out) def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x # B, H, W, C x = self.norm1(x) # Skip connection if self.dim != self.dim_out: shortcut = do_pool(self.proj(x), self.pool) # Window partition window_size = self.window_size if window_size > 0: H, W = x.shape[1], x.shape[2] x, pad_hw = window_partition(x, window_size) # Window Attention + Q Pooling (if stage change) x = self.attn(x) if self.q_stride: # Shapes have changed due to Q pooling window_size = self.window_size // self.q_stride[0] H, W = shortcut.shape[1:3] pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size pad_hw = (H + pad_h, W + pad_w) # Reverse window partition if self.window_size > 0: x = window_unpartition(x, window_size, pad_hw, (H, W)) x = shortcut + self.drop_path(x) # MLP x = x + self.drop_path(self.mlp(self.norm2(x))) return x class Hiera(nn.Module): """ Reference: https://arxiv.org/abs/2306.00989 """ def __init__( self, embed_dim: int = 96, # initial embed dim num_heads: int = 1, # initial number of heads drop_path_rate: float = 0.0, # stochastic depth q_pool: int = 3, # number of q_pool stages q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage dim_mul: float = 2.0, # dim_mul factor at stage shift head_mul: float = 2.0, # head_mul factor at stage shift window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), # window size per stage, when not using global att. window_spec: Tuple[int, ...] = ( 8, 4, 14, 7, ), # global attn in these blocks global_att_blocks: Tuple[int, ...] = ( 12, 16, 20, ), weights_path=None, return_interm_layers=True, # return feats from every stage ): super().__init__() assert len(stages) == len(window_spec) self.window_spec = window_spec depth = sum(stages) self.q_stride = q_stride self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] assert 0 <= q_pool <= len(self.stage_ends[:-1]) self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] self.return_interm_layers = return_interm_layers self.patch_embed = PatchEmbed( embed_dim=embed_dim, ) # Which blocks have global att? self.global_att_blocks = global_att_blocks # Windowed positional embedding (https://arxiv.org/abs/2311.05613) self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size self.pos_embed = nn.Parameter( torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size) ) self.pos_embed_window = nn.Parameter( torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]) ) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule cur_stage = 1 self.blocks = nn.ModuleList() for i in range(depth): dim_out = embed_dim # lags by a block, so first block of # next stage uses an initial window size # of previous stage and final window size of current stage window_size = self.window_spec[cur_stage - 1] if self.global_att_blocks is not None: window_size = 0 if i in self.global_att_blocks else window_size if i - 1 in self.stage_ends: dim_out = int(embed_dim * dim_mul) num_heads = int(num_heads * head_mul) cur_stage += 1 block = MultiScaleBlock( dim=embed_dim, dim_out=dim_out, num_heads=num_heads, drop_path=dpr[i], q_stride=self.q_stride if i in self.q_pool_blocks else None, window_size=window_size, ) embed_dim = dim_out self.blocks.append(block) self.channel_list = ( [self.blocks[i].dim_out for i in self.stage_ends[::-1]] if return_interm_layers else [self.blocks[-1].dim_out] ) if weights_path is not None: with g_pathmgr.open(weights_path, "rb") as f: chkpt = torch.load(f, map_location="cpu") logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False)) def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: h, w = hw window_embed = self.pos_embed_window pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") pos_embed = pos_embed + window_embed.tile( [x // y for x, y in zip(pos_embed.shape, window_embed.shape)] ) pos_embed = pos_embed.permute(0, 2, 3, 1) return pos_embed def forward(self, x: torch.Tensor) -> List[torch.Tensor]: x = self.patch_embed(x) # x: (B, H, W, C) # Add pos embed x = x + self._get_pos_embed(x.shape[1:3]) outputs = [] for i, blk in enumerate(self.blocks): x = blk(x) if (i == self.stage_ends[-1]) or ( i in self.stage_ends and self.return_interm_layers ): feats = x.permute(0, 3, 1, 2) outputs.append(feats) return outputs def get_layer_id(self, layer_name): # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 num_layers = self.get_num_layers() if layer_name.find("rel_pos") != -1: return num_layers + 1 elif layer_name.find("pos_embed") != -1: return 0 elif layer_name.find("patch_embed") != -1: return 0 elif layer_name.find("blocks") != -1: return int(layer_name.split("blocks")[1].split(".")[1]) + 1 else: return num_layers + 1 def get_num_layers(self) -> int: return len(self.blocks) ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/backbones/image_encoder.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import List, Optional import torch import torch.nn as nn import torch.nn.functional as F class ImageEncoder(nn.Module): def __init__( self, trunk: nn.Module, neck: nn.Module, scalp: int = 0, ): super().__init__() self.trunk = trunk self.neck = neck self.scalp = scalp assert ( self.trunk.channel_list == self.neck.backbone_channel_list ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" def forward(self, sample: torch.Tensor): # Forward through backbone features, pos = self.neck(self.trunk(sample)) if self.scalp > 0: # Discard the lowest resolution features features, pos = features[: -self.scalp], pos[: -self.scalp] src = features[-1] output = { "vision_features": src, "vision_pos_enc": pos, "backbone_fpn": features, } return output class FpnNeck(nn.Module): """ A modified variant of Feature Pyramid Network (FPN) neck (we remove output conv and also do bicubic interpolation similar to ViT pos embed interpolation) """ def __init__( self, position_encoding: nn.Module, d_model: int, backbone_channel_list: List[int], kernel_size: int = 1, stride: int = 1, padding: int = 0, fpn_interp_model: str = "bilinear", fuse_type: str = "sum", fpn_top_down_levels: Optional[List[int]] = None, ): """Initialize the neck :param trunk: the backbone :param position_encoding: the positional encoding to use :param d_model: the dimension of the model :param neck_norm: the normalization to use """ super().__init__() self.position_encoding = position_encoding self.convs = nn.ModuleList() self.backbone_channel_list = backbone_channel_list self.d_model = d_model for dim in backbone_channel_list: current = nn.Sequential() current.add_module( "conv", nn.Conv2d( in_channels=dim, out_channels=d_model, kernel_size=kernel_size, stride=stride, padding=padding, ), ) self.convs.append(current) self.fpn_interp_model = fpn_interp_model assert fuse_type in ["sum", "avg"] self.fuse_type = fuse_type # levels to have top-down features in its outputs # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 # have top-down propagation, while outputs of level 0 and level 1 have only # lateral features from the same backbone level. if fpn_top_down_levels is None: # default is to have top-down features on all levels fpn_top_down_levels = range(len(self.convs)) self.fpn_top_down_levels = list(fpn_top_down_levels) def forward(self, xs: List[torch.Tensor]): out = [None] * len(self.convs) pos = [None] * len(self.convs) assert len(xs) == len(self.convs) # fpn forward pass # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py prev_features = None # forward in top-down order (from low to high resolution) n = len(self.convs) - 1 for i in range(n, -1, -1): x = xs[i] lateral_features = self.convs[n - i](x) if i in self.fpn_top_down_levels and prev_features is not None: top_down_features = F.interpolate( prev_features.to(dtype=torch.float32), scale_factor=2.0, mode=self.fpn_interp_model, align_corners=( None if self.fpn_interp_model == "nearest" else False ), antialias=False, ) prev_features = lateral_features + top_down_features if self.fuse_type == "avg": prev_features /= 2 else: prev_features = lateral_features x_out = prev_features out[i] = x_out pos[i] = self.position_encoding(x_out).to(x_out.dtype) return out, pos ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/backbones/utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Some utilities for backbones, in particular for windowing""" from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F def window_partition(x, window_size): """ Partition into non-overlapping windows with padding if needed. Args: x (tensor): input tokens with [B, H, W, C]. window_size (int): window size. Returns: windows: windows after partition with [B * num_windows, window_size, window_size, C]. (Hp, Wp): padded height and width before partition """ B, H, W, C = x.shape pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size if pad_h > 0 or pad_w > 0: x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) Hp, Wp = H + pad_h, W + pad_w x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C) return windows, (Hp, Wp) def window_unpartition(windows, window_size, pad_hw, hw): """ Window unpartition into original sequences and removing padding. Args: x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. window_size (int): window size. pad_hw (Tuple): padded height and width (Hp, Wp). hw (Tuple): original height and width (H, W) before padding. Returns: x: unpartitioned sequences with [B, H, W, C]. """ Hp, Wp = pad_hw H, W = hw B = windows.shape[0] // (Hp * Wp // window_size // window_size) x = windows.reshape( B, Hp // window_size, Wp // window_size, window_size, window_size, -1 ) x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1) if Hp > H or Wp > W: x = x[:, :H, :W, :] return x class PatchEmbed(nn.Module): """ Image to Patch Embedding. """ def __init__( self, kernel_size: Tuple[int, ...] = (7, 7), stride: Tuple[int, ...] = (4, 4), padding: Tuple[int, ...] = (3, 3), in_chans: int = 3, embed_dim: int = 768, ): """ Args: kernel_size (Tuple): kernel size of the projection layer. stride (Tuple): stride of the projection layer. padding (Tuple): padding size of the projection layer. in_chans (int): Number of input image channels. embed_dim (int): embed_dim (int): Patch embedding dimension. """ super().__init__() self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) # B C H W -> B H W C x = x.permute(0, 2, 3, 1) return x ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/memory_attention.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Optional import torch from torch import nn, Tensor from sam2.modeling.sam.transformer import RoPEAttention from sam2.modeling.sam2_utils import get_activation_fn, get_clones class MemoryAttentionLayer(nn.Module): def __init__( self, activation: str, cross_attention: nn.Module, d_model: int, dim_feedforward: int, dropout: float, pos_enc_at_attn: bool, pos_enc_at_cross_attn_keys: bool, pos_enc_at_cross_attn_queries: bool, self_attention: nn.Module, ): super().__init__() self.d_model = d_model self.dim_feedforward = dim_feedforward self.dropout_value = dropout self.self_attn = self_attention self.cross_attn_image = cross_attention # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.activation_str = activation self.activation = get_activation_fn(activation) # Where to add pos enc self.pos_enc_at_attn = pos_enc_at_attn self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys def _forward_sa(self, tgt, query_pos): # Self-Attention tgt2 = self.norm1(tgt) q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 tgt2 = self.self_attn(q, k, v=tgt2) tgt = tgt + self.dropout1(tgt2) return tgt def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): kwds = {} if num_k_exclude_rope > 0: assert isinstance(self.cross_attn_image, RoPEAttention) kwds = {"num_k_exclude_rope": num_k_exclude_rope} # Cross-Attention tgt2 = self.norm2(tgt) tgt2 = self.cross_attn_image( q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, v=memory, **kwds, ) tgt = tgt + self.dropout2(tgt2) return tgt def forward( self, tgt, memory, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, num_k_exclude_rope: int = 0, ) -> torch.Tensor: # Self-Attn, Cross-Attn tgt = self._forward_sa(tgt, query_pos) tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) # MLP tgt2 = self.norm3(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout3(tgt2) return tgt class MemoryAttention(nn.Module): def __init__( self, d_model: int, pos_enc_at_input: bool, layer: nn.Module, num_layers: int, batch_first: bool = True, # Do layers expect batch first input? ): super().__init__() self.d_model = d_model self.layers = get_clones(layer, num_layers) self.num_layers = num_layers self.norm = nn.LayerNorm(d_model) self.pos_enc_at_input = pos_enc_at_input self.batch_first = batch_first def forward( self, curr: torch.Tensor, # self-attention inputs memory: torch.Tensor, # cross-attention inputs curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* ): if isinstance(curr, list): assert isinstance(curr_pos, list) assert len(curr) == len(curr_pos) == 1 curr, curr_pos = ( curr[0], curr_pos[0], ) assert ( curr.shape[1] == memory.shape[1] ), "Batch size must be the same for curr and memory" output = curr if self.pos_enc_at_input and curr_pos is not None: output = output + 0.1 * curr_pos if self.batch_first: # Convert to batch first output = output.transpose(0, 1) curr_pos = curr_pos.transpose(0, 1) memory = memory.transpose(0, 1) memory_pos = memory_pos.transpose(0, 1) for layer in self.layers: kwds = {} if isinstance(layer.cross_attn_image, RoPEAttention): kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} output = layer( tgt=output, memory=memory, pos=memory_pos, query_pos=curr_pos, **kwds, ) normed_output = self.norm(output) if self.batch_first: # Convert back to seq first normed_output = normed_output.transpose(0, 1) curr_pos = curr_pos.transpose(0, 1) return normed_output ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/memory_encoder.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d class MaskDownSampler(nn.Module): """ Progressively downsample a mask by total_stride, each time by stride. Note that LayerNorm is applied per *token*, like in ViT. With each downsample (by a factor stride**2), channel capacity increases by the same factor. In the end, we linearly project to embed_dim channels. """ def __init__( self, embed_dim=256, kernel_size=4, stride=4, padding=0, total_stride=16, activation=nn.GELU, ): super().__init__() num_layers = int(math.log2(total_stride) // math.log2(stride)) assert stride**num_layers == total_stride self.encoder = nn.Sequential() mask_in_chans, mask_out_chans = 1, 1 for _ in range(num_layers): mask_out_chans = mask_in_chans * (stride**2) self.encoder.append( nn.Conv2d( mask_in_chans, mask_out_chans, kernel_size=kernel_size, stride=stride, padding=padding, ) ) self.encoder.append(LayerNorm2d(mask_out_chans)) self.encoder.append(activation()) mask_in_chans = mask_out_chans self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) def forward(self, x): return self.encoder(x) # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) class CXBlock(nn.Module): r"""ConvNeXt Block. There are two equivalent implementations: (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back We use (2) as we find it slightly faster in PyTorch Args: dim (int): Number of input channels. drop_path (float): Stochastic depth rate. Default: 0.0 layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. """ def __init__( self, dim, kernel_size=7, padding=3, drop_path=0.0, layer_scale_init_value=1e-6, use_dwconv=True, ): super().__init__() self.dwconv = nn.Conv2d( dim, dim, kernel_size=kernel_size, padding=padding, groups=dim if use_dwconv else 1, ) # depthwise conv self.norm = LayerNorm2d(dim, eps=1e-6) self.pwconv1 = nn.Linear( dim, 4 * dim ) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.pwconv2 = nn.Linear(4 * dim, dim) self.gamma = ( nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) if layer_scale_init_value > 0 else None ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, x): input = x x = self.dwconv(x) x = self.norm(x) x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) x = self.pwconv1(x) x = self.act(x) x = self.pwconv2(x) if self.gamma is not None: x = self.gamma * x x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) x = input + self.drop_path(x) return x class Fuser(nn.Module): def __init__(self, layer, num_layers, dim=None, input_projection=False): super().__init__() self.proj = nn.Identity() self.layers = get_clones(layer, num_layers) if input_projection: assert dim is not None self.proj = nn.Conv2d(dim, dim, kernel_size=1) def forward(self, x): # normally x: (N, C, H, W) x = self.proj(x) for layer in self.layers: x = layer(x) return x class MemoryEncoder(nn.Module): def __init__( self, out_dim, mask_downsampler, fuser, position_encoding, in_dim=256, # in_dim of pix_feats ): super().__init__() self.mask_downsampler = mask_downsampler self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) self.fuser = fuser self.position_encoding = position_encoding self.out_proj = nn.Identity() if out_dim != in_dim: self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) def forward( self, pix_feat: torch.Tensor, masks: torch.Tensor, skip_mask_sigmoid: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: ## Process masks # sigmoid, so that less domain shift from gt masks which are bool if not skip_mask_sigmoid: masks = F.sigmoid(masks) masks = self.mask_downsampler(masks) ## Fuse pix_feats and downsampled masks # in case the visual features are on CPU, cast them to CUDA pix_feat = pix_feat.to(masks.device) x = self.pix_feat_proj(pix_feat) x = x + masks x = self.fuser(x) x = self.out_proj(x) pos = self.position_encoding(x).to(x.dtype) return {"vision_features": x, "vision_pos_enc": [pos]} ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/position_encoding.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math from typing import Any, Optional, Tuple import numpy as np import torch from torch import nn class PositionEmbeddingSine(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention Is All You Need paper, generalized to work on images. """ def __init__( self, num_pos_feats, temperature: int = 10000, normalize: bool = True, scale: Optional[float] = None, # Following settings only relevant # for warmping up cache for compilation warmup_cache: bool = True, image_size: int = 1024, strides: Tuple[int] = (4, 8, 16, 32), ): super().__init__() assert num_pos_feats % 2 == 0, "Expecting even model width" self.num_pos_feats = num_pos_feats // 2 self.temperature = temperature self.normalize = normalize if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") if scale is None: scale = 2 * math.pi self.scale = scale self.cache = {} if warmup_cache and torch.cuda.is_available(): # Warmup cache for cuda, to help with compilation device = torch.device("cuda") for stride in strides: cache_key = (image_size // stride, image_size // stride) self._pe(1, device, *cache_key) def _encode_xy(self, x, y): # The positions are expected to be normalized assert len(x) == len(y) and x.ndim == y.ndim == 1 x_embed = x * self.scale y_embed = y * self.scale dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, None] / dim_t pos_y = y_embed[:, None] / dim_t pos_x = torch.stack( (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 ).flatten(1) pos_y = torch.stack( (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 ).flatten(1) return pos_x, pos_y @torch.no_grad() def encode_boxes(self, x, y, w, h): pos_x, pos_y = self._encode_xy(x, y) pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) return pos encode = encode_boxes # Backwards compatibility @torch.no_grad() def encode_points(self, x, y, labels): (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape assert bx == by and nx == ny and bx == bl and nx == nl pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) return pos @torch.no_grad() def _pe(self, B, device, *cache_key): H, W = cache_key if cache_key in self.cache: return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1) y_embed = ( torch.arange(1, H + 1, dtype=torch.float32, device=device) .view(1, -1, 1) .repeat(B, 1, W) ) x_embed = ( torch.arange(1, W + 1, dtype=torch.float32, device=device) .view(1, 1, -1) .repeat(B, H, 1) ) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack( (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos_y = torch.stack( (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) self.cache[cache_key] = pos[0] return pos @torch.no_grad() def forward(self, x: torch.Tensor): B = x.shape[0] cache_key = (x.shape[-2], x.shape[-1]) return self._pe(B, x.device, *cache_key) class PositionEmbeddingRandom(nn.Module): """ Positional encoding using random spatial frequencies. """ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: super().__init__() if scale is None or scale <= 0.0: scale = 1.0 self.register_buffer( "positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)), ) def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: """Positionally encode points that are normalized to [0,1].""" # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape coords = 2 * coords - 1 coords = coords @ self.positional_encoding_gaussian_matrix coords = 2 * np.pi * coords # outputs d_1 x ... x d_n x C shape return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) def forward(self, size: Tuple[int, int]) -> torch.Tensor: """Generate positional encoding for a grid of the specified size.""" h, w = size device: Any = self.positional_encoding_gaussian_matrix.device grid = torch.ones((h, w), device=device, dtype=torch.float32) y_embed = grid.cumsum(dim=0) - 0.5 x_embed = grid.cumsum(dim=1) - 0.5 y_embed = y_embed / h x_embed = x_embed / w pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) return pe.permute(2, 0, 1) # C x H x W def forward_with_coords( self, coords_input: torch.Tensor, image_size: Tuple[int, int] ) -> torch.Tensor: """Positionally encode points that are not normalized to [0,1].""" coords = coords_input.clone() coords[:, :, 0] = coords[:, :, 0] / image_size[1] coords[:, :, 1] = coords[:, :, 1] / image_size[0] return self._pe_encoding(coords.to(torch.float)) # B x N x C # Rotary Positional Encoding, adapted from: # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py # 2. https://github.com/naver-ai/rope-vit # 3. https://github.com/lucidrains/rotary-embedding-torch def init_t_xy(end_x: int, end_y: int): t = torch.arange(end_x * end_y, dtype=torch.float32) t_x = (t % end_x).float() t_y = torch.div(t, end_x, rounding_mode="floor").float() return t_x, t_y def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) t_x, t_y = init_t_xy(end_x, end_y) freqs_x = torch.outer(t_x, freqs_x) freqs_y = torch.outer(t_y, freqs_y) freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): ndim = x.ndim assert 0 <= 1 < ndim assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) def apply_rotary_enc( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, repeat_freqs_k: bool = False, ): xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = ( torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None ) freqs_cis = reshape_for_broadcast(freqs_cis, xq_) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) if xk_ is None: # no keys to rotate, due to dropout return xq_out.type_as(xq).to(xq.device), xk # repeat freqs along seq_len dim to match k seq_len if repeat_freqs_k: r = xk_.shape[-2] // xq_.shape[-2] if freqs_cis.is_cuda: freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) else: # torch.repeat on complex numbers may not be supported on non-CUDA devices # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/sam/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/sam/mask_decoder.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import List, Optional, Tuple, Type import torch from torch import nn from sam2.modeling.sam2_utils import LayerNorm2d, MLP class MaskDecoder(nn.Module): def __init__( self, *, transformer_dim: int, transformer: nn.Module, num_multimask_outputs: int = 3, activation: Type[nn.Module] = nn.GELU, iou_head_depth: int = 3, iou_head_hidden_dim: int = 256, use_high_res_features: bool = False, iou_prediction_use_sigmoid=False, dynamic_multimask_via_stability=False, dynamic_multimask_stability_delta=0.05, dynamic_multimask_stability_thresh=0.98, pred_obj_scores: bool = False, pred_obj_scores_mlp: bool = False, use_multimask_token_for_obj_ptr: bool = False, ) -> None: """ Predicts masks given an image and prompt embeddings, using a transformer architecture. Arguments: transformer_dim (int): the channel dimension of the transformer transformer (nn.Module): the transformer used to predict masks num_multimask_outputs (int): the number of masks to predict when disambiguating masks activation (nn.Module): the type of activation to use when upscaling masks iou_head_depth (int): the depth of the MLP used to predict mask quality iou_head_hidden_dim (int): the hidden dimension of the MLP used to predict mask quality """ super().__init__() self.transformer_dim = transformer_dim self.transformer = transformer self.num_multimask_outputs = num_multimask_outputs self.iou_token = nn.Embedding(1, transformer_dim) self.num_mask_tokens = num_multimask_outputs + 1 self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) self.pred_obj_scores = pred_obj_scores if self.pred_obj_scores: self.obj_score_token = nn.Embedding(1, transformer_dim) self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr self.output_upscaling = nn.Sequential( nn.ConvTranspose2d( transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 ), LayerNorm2d(transformer_dim // 4), activation(), nn.ConvTranspose2d( transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 ), activation(), ) self.use_high_res_features = use_high_res_features if use_high_res_features: self.conv_s0 = nn.Conv2d( transformer_dim, transformer_dim // 8, kernel_size=1, stride=1 ) self.conv_s1 = nn.Conv2d( transformer_dim, transformer_dim // 4, kernel_size=1, stride=1 ) self.output_hypernetworks_mlps = nn.ModuleList( [ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for i in range(self.num_mask_tokens) ] ) self.iou_prediction_head = MLP( transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth, sigmoid_output=iou_prediction_use_sigmoid, ) if self.pred_obj_scores: self.pred_obj_score_head = nn.Linear(transformer_dim, 1) if pred_obj_scores_mlp: self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) # When outputting a single mask, optionally we can dynamically fall back to the best # multimask output token if the single mask output token gives low stability scores. self.dynamic_multimask_via_stability = dynamic_multimask_via_stability self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh def forward( self, image_embeddings: torch.Tensor, image_pe: torch.Tensor, sparse_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor, multimask_output: bool, repeat_image: bool, high_res_features: Optional[List[torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Predict masks given image and prompt embeddings. Arguments: image_embeddings (torch.Tensor): the embeddings from the image encoder image_pe (torch.Tensor): positional encoding with the shape of image_embeddings sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs multimask_output (bool): Whether to return multiple masks or a single mask. Returns: torch.Tensor: batched predicted masks torch.Tensor: batched predictions of mask quality torch.Tensor: batched SAM token for mask output """ masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( image_embeddings=image_embeddings, image_pe=image_pe, sparse_prompt_embeddings=sparse_prompt_embeddings, dense_prompt_embeddings=dense_prompt_embeddings, repeat_image=repeat_image, high_res_features=high_res_features, ) # Select the correct mask or masks for output if multimask_output: masks = masks[:, 1:, :, :] iou_pred = iou_pred[:, 1:] elif self.dynamic_multimask_via_stability and not self.training: masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) else: masks = masks[:, 0:1, :, :] iou_pred = iou_pred[:, 0:1] if multimask_output and self.use_multimask_token_for_obj_ptr: sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape else: # Take the mask output token. Here we *always* use the token for single mask output. # At test time, even if we track after 1-click (and using multimask_output=True), # we still take the single mask token here. The rationale is that we always track # after multiple clicks during training, so the past tokens seen during training # are always the single mask token (and we'll let it be the object-memory token). sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape # Prepare output return masks, iou_pred, sam_tokens_out, object_score_logits def predict_masks( self, image_embeddings: torch.Tensor, image_pe: torch.Tensor, sparse_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor, repeat_image: bool, high_res_features: Optional[List[torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Predicts masks. See 'forward' for more details.""" # Concatenate output tokens s = 0 if self.pred_obj_scores: output_tokens = torch.cat( [ self.obj_score_token.weight, self.iou_token.weight, self.mask_tokens.weight, ], dim=0, ) s = 1 else: output_tokens = torch.cat( [self.iou_token.weight, self.mask_tokens.weight], dim=0 ) output_tokens = output_tokens.unsqueeze(0).expand( sparse_prompt_embeddings.size(0), -1, -1 ) tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # Expand per-image data in batch direction to be per-mask if repeat_image: src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) else: assert image_embeddings.shape[0] == tokens.shape[0] src = image_embeddings src = src + dense_prompt_embeddings assert ( image_pe.size(0) == 1 ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) b, c, h, w = src.shape # Run the transformer hs, src = self.transformer(src, pos_src, tokens) iou_token_out = hs[:, s, :] mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] # Upscale mask embeddings and predict masks using the mask tokens src = src.transpose(1, 2).view(b, c, h, w) if not self.use_high_res_features: upscaled_embedding = self.output_upscaling(src) else: dc1, ln1, act1, dc2, act2 = self.output_upscaling feat_s0, feat_s1 = high_res_features upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) hyper_in_list: List[torch.Tensor] = [] for i in range(self.num_mask_tokens): hyper_in_list.append( self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) ) hyper_in = torch.stack(hyper_in_list, dim=1) b, c, h, w = upscaled_embedding.shape masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) # Generate mask quality predictions iou_pred = self.iou_prediction_head(iou_token_out) if self.pred_obj_scores: assert s == 1 object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) else: # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) return masks, iou_pred, mask_tokens_out, object_score_logits def _get_stability_scores(self, mask_logits): """ Compute stability scores of the mask logits based on the IoU between upper and lower thresholds. """ mask_logits = mask_logits.flatten(-2) stability_delta = self.dynamic_multimask_stability_delta area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) return stability_scores def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): """ When outputting a single mask, if the stability score from the current single-mask output (based on output token 0) falls below a threshold, we instead select from multi-mask outputs (based on output token 1~3) the mask with the highest predicted IoU score. This is intended to ensure a valid mask for both clicking and tracking. """ # The best mask from multimask output tokens (1~3) multimask_logits = all_mask_logits[:, 1:, :, :] multimask_iou_scores = all_iou_scores[:, 1:] best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) batch_inds = torch.arange( multimask_iou_scores.size(0), device=all_iou_scores.device ) best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] best_multimask_logits = best_multimask_logits.unsqueeze(1) best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) # The mask from singlemask output token 0 and its stability score singlemask_logits = all_mask_logits[:, 0:1, :, :] singlemask_iou_scores = all_iou_scores[:, 0:1] stability_scores = self._get_stability_scores(singlemask_logits) is_stable = stability_scores >= self.dynamic_multimask_stability_thresh # Dynamically fall back to best multimask output upon low stability scores. mask_logits_out = torch.where( is_stable[..., None, None].expand_as(singlemask_logits), singlemask_logits, best_multimask_logits, ) iou_scores_out = torch.where( is_stable.expand_as(singlemask_iou_scores), singlemask_iou_scores, best_multimask_iou_scores, ) return mask_logits_out, iou_scores_out ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/sam/prompt_encoder.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Optional, Tuple, Type import torch from torch import nn from sam2.modeling.position_encoding import PositionEmbeddingRandom from sam2.modeling.sam2_utils import LayerNorm2d class PromptEncoder(nn.Module): def __init__( self, embed_dim: int, image_embedding_size: Tuple[int, int], input_image_size: Tuple[int, int], mask_in_chans: int, activation: Type[nn.Module] = nn.GELU, ) -> None: """ Encodes prompts for input to SAM's mask decoder. Arguments: embed_dim (int): The prompts' embedding dimension image_embedding_size (tuple(int, int)): The spatial size of the image embedding, as (H, W). input_image_size (int): The padded size of the image as input to the image encoder, as (H, W). mask_in_chans (int): The number of hidden channels used for encoding input masks. activation (nn.Module): The activation to use when encoding input masks. """ super().__init__() self.embed_dim = embed_dim self.input_image_size = input_image_size self.image_embedding_size = image_embedding_size self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners point_embeddings = [ nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) ] self.point_embeddings = nn.ModuleList(point_embeddings) self.not_a_point_embed = nn.Embedding(1, embed_dim) self.mask_input_size = ( 4 * image_embedding_size[0], 4 * image_embedding_size[1], ) self.mask_downscaling = nn.Sequential( nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), LayerNorm2d(mask_in_chans // 4), activation(), nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), LayerNorm2d(mask_in_chans), activation(), nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), ) self.no_mask_embed = nn.Embedding(1, embed_dim) def get_dense_pe(self) -> torch.Tensor: """ Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the image encoding. Returns: torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w) """ return self.pe_layer(self.image_embedding_size).unsqueeze(0) def _embed_points( self, points: torch.Tensor, labels: torch.Tensor, pad: bool, ) -> torch.Tensor: """Embeds point prompts.""" points = points + 0.5 # Shift to center of pixel if pad: padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) points = torch.cat([points, padding_point], dim=1) labels = torch.cat([labels, padding_label], dim=1) point_embedding = self.pe_layer.forward_with_coords( points, self.input_image_size ) point_embedding = torch.where( (labels == -1).unsqueeze(-1), torch.zeros_like(point_embedding) + self.not_a_point_embed.weight, point_embedding, ) point_embedding = torch.where( (labels == 0).unsqueeze(-1), point_embedding + self.point_embeddings[0].weight, point_embedding, ) point_embedding = torch.where( (labels == 1).unsqueeze(-1), point_embedding + self.point_embeddings[1].weight, point_embedding, ) point_embedding = torch.where( (labels == 2).unsqueeze(-1), point_embedding + self.point_embeddings[2].weight, point_embedding, ) point_embedding = torch.where( (labels == 3).unsqueeze(-1), point_embedding + self.point_embeddings[3].weight, point_embedding, ) return point_embedding def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: """Embeds box prompts.""" boxes = boxes + 0.5 # Shift to center of pixel coords = boxes.reshape(-1, 2, 2) corner_embedding = self.pe_layer.forward_with_coords( coords, self.input_image_size ) corner_embedding[:, 0, :] += self.point_embeddings[2].weight corner_embedding[:, 1, :] += self.point_embeddings[3].weight return corner_embedding def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: """Embeds mask inputs.""" mask_embedding = self.mask_downscaling(masks) return mask_embedding def _get_batch_size( self, points: Optional[Tuple[torch.Tensor, torch.Tensor]], boxes: Optional[torch.Tensor], masks: Optional[torch.Tensor], ) -> int: """ Gets the batch size of the output given the batch size of the input prompts. """ if points is not None: return points[0].shape[0] elif boxes is not None: return boxes.shape[0] elif masks is not None: return masks.shape[0] else: return 1 def _get_device(self) -> torch.device: return self.point_embeddings[0].weight.device def forward( self, points: Optional[Tuple[torch.Tensor, torch.Tensor]], boxes: Optional[torch.Tensor], masks: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Embeds different types of prompts, returning both sparse and dense embeddings. Arguments: points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates and labels to embed. boxes (torch.Tensor or none): boxes to embed masks (torch.Tensor or none): masks to embed Returns: torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined by the number of input points and boxes. torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W) """ bs = self._get_batch_size(points, boxes, masks) sparse_embeddings = torch.empty( (bs, 0, self.embed_dim), device=self._get_device() ) if points is not None: coords, labels = points point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) if boxes is not None: box_embeddings = self._embed_boxes(boxes) sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) if masks is not None: dense_embeddings = self._embed_masks(masks) else: dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] ) return sparse_embeddings, dense_embeddings ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/sam/transformer.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math from functools import partial from typing import Tuple, Type import torch import torch.nn.functional as F from torch import nn, Tensor from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis from sam2.modeling.sam2_utils import MLP class TwoWayTransformer(nn.Module): def __init__( self, depth: int, embedding_dim: int, num_heads: int, mlp_dim: int, activation: Type[nn.Module] = nn.ReLU, attention_downsample_rate: int = 2, ) -> None: """ A transformer decoder that attends to an input image using queries whose positional embedding is supplied. Args: depth (int): number of layers in the transformer embedding_dim (int): the channel dimension for the input embeddings num_heads (int): the number of heads for multihead attention. Must divide embedding_dim mlp_dim (int): the channel dimension internal to the MLP block activation (nn.Module): the activation to use in the MLP block """ super().__init__() self.depth = depth self.embedding_dim = embedding_dim self.num_heads = num_heads self.mlp_dim = mlp_dim self.layers = nn.ModuleList() for i in range(depth): self.layers.append( TwoWayAttentionBlock( embedding_dim=embedding_dim, num_heads=num_heads, mlp_dim=mlp_dim, activation=activation, attention_downsample_rate=attention_downsample_rate, skip_first_layer_pe=(i == 0), ) ) self.final_attn_token_to_image = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.norm_final_attn = nn.LayerNorm(embedding_dim) def forward( self, image_embedding: Tensor, image_pe: Tensor, point_embedding: Tensor, ) -> Tuple[Tensor, Tensor]: """ Args: image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w. image_pe (torch.Tensor): the positional encoding to add to the image. Must have the same shape as image_embedding. point_embedding (torch.Tensor): the embedding to add to the query points. Must have shape B x N_points x embedding_dim for any N_points. Returns: torch.Tensor: the processed point_embedding torch.Tensor: the processed image_embedding """ # BxCxHxW -> BxHWxC == B x N_image_tokens x C bs, c, h, w = image_embedding.shape image_embedding = image_embedding.flatten(2).permute(0, 2, 1) image_pe = image_pe.flatten(2).permute(0, 2, 1) # Prepare queries queries = point_embedding keys = image_embedding # Apply transformer blocks and final layernorm for layer in self.layers: queries, keys = layer( queries=queries, keys=keys, query_pe=point_embedding, key_pe=image_pe, ) # Apply the final attention layer from the points to the image q = queries + point_embedding k = keys + image_pe attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) queries = queries + attn_out queries = self.norm_final_attn(queries) return queries, keys class TwoWayAttentionBlock(nn.Module): def __init__( self, embedding_dim: int, num_heads: int, mlp_dim: int = 2048, activation: Type[nn.Module] = nn.ReLU, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, ) -> None: """ A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse inputs. Arguments: embedding_dim (int): the channel dimension of the embeddings num_heads (int): the number of heads in the attention layers mlp_dim (int): the hidden dimension of the mlp block activation (nn.Module): the activation of the mlp block skip_first_layer_pe (bool): skip the PE on the first layer """ super().__init__() self.self_attn = Attention(embedding_dim, num_heads) self.norm1 = nn.LayerNorm(embedding_dim) self.cross_attn_token_to_image = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.norm2 = nn.LayerNorm(embedding_dim) self.mlp = MLP( embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation ) self.norm3 = nn.LayerNorm(embedding_dim) self.norm4 = nn.LayerNorm(embedding_dim) self.cross_attn_image_to_token = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.skip_first_layer_pe = skip_first_layer_pe def forward( self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor ) -> Tuple[Tensor, Tensor]: # Self attention block if self.skip_first_layer_pe: queries = self.self_attn(q=queries, k=queries, v=queries) else: q = queries + query_pe attn_out = self.self_attn(q=q, k=q, v=queries) queries = queries + attn_out queries = self.norm1(queries) # Cross attention block, tokens attending to image embedding q = queries + query_pe k = keys + key_pe attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) queries = queries + attn_out queries = self.norm2(queries) # MLP block mlp_out = self.mlp(queries) queries = queries + mlp_out queries = self.norm3(queries) # Cross attention block, image embedding attending to tokens q = queries + query_pe k = keys + key_pe attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) keys = keys + attn_out keys = self.norm4(keys) return queries, keys class Attention(nn.Module): """ An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and values. """ def __init__( self, embedding_dim: int, num_heads: int, downsample_rate: int = 1, dropout: float = 0.0, kv_in_dim: int = None, ) -> None: super().__init__() self.embedding_dim = embedding_dim self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim self.internal_dim = embedding_dim // downsample_rate self.num_heads = num_heads assert ( self.internal_dim % num_heads == 0 ), "num_heads must divide embedding_dim." self.q_proj = nn.Linear(embedding_dim, self.internal_dim) self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) self.out_proj = nn.Linear(self.internal_dim, embedding_dim) self.dropout_p = dropout def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: b, n, c = x.shape x = x.reshape(b, n, num_heads, c // num_heads) return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head def _recombine_heads(self, x: Tensor) -> Tensor: b, n_heads, n_tokens, c_per_head = x.shape x = x.transpose(1, 2) return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: # Input projections q = self.q_proj(q) k = self.k_proj(k) v = self.v_proj(v) # Separate into heads q = self._separate_heads(q, self.num_heads) k = self._separate_heads(k, self.num_heads) v = self._separate_heads(v, self.num_heads) dropout_p = self.dropout_p if self.training else 0.0 # Attention out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) out = self._recombine_heads(out) out = self.out_proj(out) return out class RoPEAttention(Attention): """Attention with rotary position encoding.""" def __init__( self, *args, rope_theta=10000.0, # whether to repeat q rope to match k length # this is needed for cross-attention to memories rope_k_repeat=False, feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution **kwargs, ): super().__init__(*args, **kwargs) self.compute_cis = partial( compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta ) freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) self.freqs_cis = ( freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis ) self.rope_k_repeat = rope_k_repeat def forward( self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0 ) -> Tensor: # Input projections q = self.q_proj(q) k = self.k_proj(k) v = self.v_proj(v) # Separate into heads q = self._separate_heads(q, self.num_heads) k = self._separate_heads(k, self.num_heads) v = self._separate_heads(v, self.num_heads) # Apply rotary position encoding w = h = math.sqrt(q.shape[-2]) self.freqs_cis = self.freqs_cis.to(q.device) if self.freqs_cis.shape[0] != q.shape[-2]: self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) if q.shape[-2] != k.shape[-2]: assert self.rope_k_repeat num_k_rope = k.size(-2) - num_k_exclude_rope q, k[:, :, :num_k_rope] = apply_rotary_enc( q, k[:, :, :num_k_rope], freqs_cis=self.freqs_cis, repeat_freqs_k=self.rope_k_repeat, ) dropout_p = self.dropout_p if self.training else 0.0 # Attention out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) out = self._recombine_heads(out) out = self.out_proj(out) return out ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/sam2_base.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.distributed import torch.nn.functional as F from torch.nn.init import trunc_normal_ from sam2.modeling.sam.mask_decoder import MaskDecoder from sam2.modeling.sam.prompt_encoder import PromptEncoder from sam2.modeling.sam.transformer import TwoWayTransformer from sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames # a large negative value as a placeholder score for missing objects NO_OBJ_SCORE = -1024.0 class SAM2Base(torch.nn.Module): def __init__( self, image_encoder, memory_attention, memory_encoder, num_maskmem=7, # default 1 input frame + 6 previous frames image_size=512, backbone_stride=16, # stride of the image backbone output sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks binarize_mask_from_pts_for_mem_enc=False, use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit, # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM. max_cond_frames_in_attn=-1, # on the first frame, whether to directly add the no-memory embedding to the image feature # (instead of using the transformer encoder) directly_add_no_mem_embed=False, # whether to use high-resolution feature maps in the SAM mask decoder use_high_res_features_in_sam=False, # whether to output multiple (3) masks for the first click on initial conditioning frames multimask_output_in_sam=False, # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`; # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points) multimask_min_pt_num=1, multimask_max_pt_num=1, # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`) multimask_output_for_tracking=False, # Whether to use multimask tokens for obj ptr; Only relevant when both # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True use_multimask_token_for_obj_ptr: bool = False, # whether to use sigmoid to restrict ious prediction to [0-1] iou_prediction_use_sigmoid=False, # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5). # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame. memory_temporal_stride_for_eval=1, # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks) non_overlap_masks_for_mem_enc=False, # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder use_obj_ptrs_in_encoder=False, # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`) max_obj_ptrs_in_encoder=16, # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`) add_tpos_enc_to_obj_ptrs=True, # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) proj_tpos_enc_in_obj_ptrs=False, # whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers # (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) use_signed_tpos_enc_to_obj_ptrs=False, # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking) only_obj_ptrs_in_the_past_for_eval=False, # Whether to predict if there is an object in the frame pred_obj_scores: bool = False, # Whether to use an MLP to predict object scores pred_obj_scores_mlp: bool = False, # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True; # Whether to have a fixed no obj pointer when there is no object present # or to use it as an additive embedding with obj_ptr produced by decoder fixed_no_obj_ptr: bool = False, # Soft no object, i.e. mix in no_obj_ptr softly, # hope to make recovery easier if there is a mistake and mitigate accumulation of errors soft_no_obj_ptr: bool = False, use_mlp_for_obj_ptr_proj: bool = False, # add no obj embedding to spatial frames no_obj_embed_spatial: bool = False, # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class. sam_mask_decoder_extra_args=None, compile_image_encoder: bool = False, ): super().__init__() # Part 1: the image backbone self.image_encoder = image_encoder # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting self.use_high_res_features_in_sam = use_high_res_features_in_sam self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder if use_obj_ptrs_in_encoder: # A conv layer to downsample the mask prompt to stride 4 (the same stride as # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, # so that it can be fed into the SAM mask decoder to generate a pointer. self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs if proj_tpos_enc_in_obj_ptrs: assert add_tpos_enc_to_obj_ptrs # these options need to be used together self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval # Part 2: memory attention to condition current frame's visual features # with memories (and obj ptrs) from past frames self.memory_attention = memory_attention self.hidden_dim = image_encoder.neck.d_model # Part 3: memory encoder for the previous frame's outputs self.memory_encoder = memory_encoder self.mem_dim = self.hidden_dim if hasattr(self.memory_encoder, "out_proj") and hasattr( self.memory_encoder.out_proj, "weight" ): # if there is compression of memories along channel dim self.mem_dim = self.memory_encoder.out_proj.weight.shape[0] self.num_maskmem = num_maskmem # Number of memories accessible # Temporal encoding of the memories self.maskmem_tpos_enc = torch.nn.Parameter( torch.zeros(num_maskmem, 1, 1, self.mem_dim) ) trunc_normal_(self.maskmem_tpos_enc, std=0.02) # a single token to indicate no memory embedding from previous frames self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) trunc_normal_(self.no_mem_embed, std=0.02) trunc_normal_(self.no_mem_pos_enc, std=0.02) self.directly_add_no_mem_embed = directly_add_no_mem_embed # Apply sigmoid to the output raw mask logits (to turn them from # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval # On frames with mask input, whether to directly output the input mask without # using a SAM prompt encoder + mask decoder self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam self.multimask_output_in_sam = multimask_output_in_sam self.multimask_min_pt_num = multimask_min_pt_num self.multimask_max_pt_num = multimask_max_pt_num self.multimask_output_for_tracking = multimask_output_for_tracking self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid # Part 4: SAM-style prompt encoder (for both mask and point inputs) # and SAM-style mask decoder for the final mask output self.image_size = image_size self.backbone_stride = backbone_stride self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args self.pred_obj_scores = pred_obj_scores self.pred_obj_scores_mlp = pred_obj_scores_mlp self.fixed_no_obj_ptr = fixed_no_obj_ptr self.soft_no_obj_ptr = soft_no_obj_ptr if self.fixed_no_obj_ptr: assert self.pred_obj_scores assert self.use_obj_ptrs_in_encoder if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) trunc_normal_(self.no_obj_ptr, std=0.02) self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj self.no_obj_embed_spatial = None if no_obj_embed_spatial: self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim)) trunc_normal_(self.no_obj_embed_spatial, std=0.02) self._build_sam_heads() self.max_cond_frames_in_attn = max_cond_frames_in_attn # Model compilation if compile_image_encoder: # Compile the forward function (not the full module) to allow loading checkpoints. print( "Image encoder compilation is enabled. First forward pass will be slow." ) self.image_encoder.forward = torch.compile( self.image_encoder.forward, mode="max-autotune", fullgraph=True, dynamic=False, ) @property def device(self): return next(self.parameters()).device def forward(self, *args, **kwargs): raise NotImplementedError( "Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning" "See notebooks/video_predictor_example.ipynb for an inference example." ) def _build_sam_heads(self): """Build SAM-style prompt encoder and mask decoder.""" self.sam_prompt_embed_dim = self.hidden_dim self.sam_image_embedding_size = self.image_size // self.backbone_stride # build PromptEncoder and MaskDecoder from SAM # (their hyperparameters like `mask_in_chans=16` are from SAM code) self.sam_prompt_encoder = PromptEncoder( embed_dim=self.sam_prompt_embed_dim, image_embedding_size=( self.sam_image_embedding_size, self.sam_image_embedding_size, ), input_image_size=(self.image_size, self.image_size), mask_in_chans=16, ) self.sam_mask_decoder = MaskDecoder( num_multimask_outputs=3, transformer=TwoWayTransformer( depth=2, embedding_dim=self.sam_prompt_embed_dim, mlp_dim=2048, num_heads=8, ), transformer_dim=self.sam_prompt_embed_dim, iou_head_depth=3, iou_head_hidden_dim=256, use_high_res_features=self.use_high_res_features_in_sam, iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, pred_obj_scores=self.pred_obj_scores, pred_obj_scores_mlp=self.pred_obj_scores_mlp, use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, **(self.sam_mask_decoder_extra_args or {}), ) if self.use_obj_ptrs_in_encoder: # a linear projection on SAM output tokens to turn them into object pointers self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) if self.use_mlp_for_obj_ptr_proj: self.obj_ptr_proj = MLP( self.hidden_dim, self.hidden_dim, self.hidden_dim, 3 ) else: self.obj_ptr_proj = torch.nn.Identity() if self.proj_tpos_enc_in_obj_ptrs: # a linear projection on temporal positional encoding in object pointers to # avoid potential interference with spatial positional encoding self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) else: self.obj_ptr_tpos_proj = torch.nn.Identity() def _forward_sam_heads( self, backbone_features, point_inputs=None, mask_inputs=None, high_res_features=None, multimask_output=False, ): """ Forward SAM prompt encoders and mask heads. Inputs: - backbone_features: image features of [B, C, H, W] shape - point_inputs: a dictionary with "point_coords" and "point_labels", where 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the absolute pixel-unit coordinate in (x, y) format of the P input points 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means positive clicks, 0 means negative clicks, and -1 means padding - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the same spatial size as the image. - high_res_features: either 1) None or 2) or a list of length 2 containing two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, which will be used as high-resolution feature maps for SAM decoder. - multimask_output: if it's True, we output 3 candidate masks and their 3 corresponding IoU estimates, and if it's False, we output only 1 mask and its corresponding IoU estimate. Outputs: - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM output mask logits (before sigmoid) for the low-resolution masks, with 4x the resolution (1/4 stride) of the input backbone_features. - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 if `multimask_output=True` and M = 1 if `multimask_output=False`), upsampled from the low-resolution masks, with shape size as the image (stride is 1 pixel). - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1 if `multimask_output=False`), the estimated IoU of each output mask. - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. If `multimask_output=True`, it's the mask with the highest IoU estimate. If `multimask_output=False`, it's the same as `low_res_multimasks`. - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. If `multimask_output=True`, it's the mask with the highest IoU estimate. If `multimask_output=False`, it's the same as `high_res_multimasks`. - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted based on the output token from the SAM mask decoder. """ B = backbone_features.size(0) device = backbone_features.device assert backbone_features.size(1) == self.sam_prompt_embed_dim assert backbone_features.size(2) == self.sam_image_embedding_size assert backbone_features.size(3) == self.sam_image_embedding_size # a) Handle point prompts if point_inputs is not None: sam_point_coords = point_inputs["point_coords"] sam_point_labels = point_inputs["point_labels"] assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B else: # If no points are provide, pad with an empty point (with label -1) sam_point_coords = torch.zeros(B, 1, 2, device=device) sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) # b) Handle mask prompts if mask_inputs is not None: # If mask_inputs is provided, downsize it into low-res mask input if needed # and feed it as a dense mask prompt into the SAM mask encoder assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: sam_mask_prompt = F.interpolate( mask_inputs.float(), size=self.sam_prompt_encoder.mask_input_size, align_corners=False, mode="bilinear", antialias=True, # use antialias for downsampling ) else: sam_mask_prompt = mask_inputs else: # Otherwise, simply feed None (and SAM's prompt encoder will add # a learned `no_mask_embed` to indicate no mask input in this case). sam_mask_prompt = None sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( points=(sam_point_coords, sam_point_labels), boxes=None, masks=sam_mask_prompt, ) ( low_res_multimasks, ious, sam_output_tokens, object_score_logits, ) = self.sam_mask_decoder( image_embeddings=backbone_features, image_pe=self.sam_prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, repeat_image=False, # the image is already batched high_res_features=high_res_features, ) if self.pred_obj_scores: is_obj_appearing = object_score_logits > 0 # Mask used for spatial memories is always a *hard* choice between obj and no obj, # consistent with the actual mask prediction low_res_multimasks = torch.where( is_obj_appearing[:, None, None], low_res_multimasks, NO_OBJ_SCORE, ) # convert masks from possibly bfloat16 (or float16) to float32 # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) low_res_multimasks = low_res_multimasks.float() high_res_multimasks = F.interpolate( low_res_multimasks, size=(self.image_size, self.image_size), mode="bilinear", align_corners=False, ) sam_output_token = sam_output_tokens[:, 0] if multimask_output: # take the best mask prediction (with the highest IoU estimation) best_iou_inds = torch.argmax(ious, dim=-1) batch_inds = torch.arange(B, device=device) low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) if sam_output_tokens.size(1) > 1: sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] else: low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks # Extract object pointer from the SAM output token (with occlusion handling) obj_ptr = self.obj_ptr_proj(sam_output_token) if self.pred_obj_scores: # Allow *soft* no obj ptr, unlike for masks if self.soft_no_obj_ptr: lambda_is_obj_appearing = object_score_logits.sigmoid() else: lambda_is_obj_appearing = is_obj_appearing.float() if self.fixed_no_obj_ptr: obj_ptr = lambda_is_obj_appearing * obj_ptr obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr return ( low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits, ) def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): """ Directly turn binary `mask_inputs` into a output mask logits without using SAM. (same input and output shapes as in _forward_sam_heads above). """ # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 mask_inputs_float = mask_inputs.float() high_res_masks = mask_inputs_float * out_scale + out_bias low_res_masks = F.interpolate( high_res_masks, size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), align_corners=False, mode="bilinear", antialias=True, # use antialias for downsampling ) # a dummy IoU prediction of all 1's under mask input ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float() if not self.use_obj_ptrs_in_encoder: # all zeros as a dummy object pointer (of shape [B, C]) obj_ptr = torch.zeros( mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device ) else: # produce an object pointer using the SAM decoder from the mask input _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( backbone_features=backbone_features, mask_inputs=self.mask_downsample(mask_inputs_float), high_res_features=high_res_features, ) # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying # on the object_scores from the SAM decoder. is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) is_obj_appearing = is_obj_appearing[..., None] lambda_is_obj_appearing = is_obj_appearing.float() object_score_logits = out_scale * lambda_is_obj_appearing + out_bias if self.pred_obj_scores: if self.fixed_no_obj_ptr: obj_ptr = lambda_is_obj_appearing * obj_ptr obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr return ( low_res_masks, high_res_masks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits, ) def forward_image(self, img_batch: torch.Tensor): """Get the image feature on the input batch.""" backbone_out = self.image_encoder(img_batch) if self.use_high_res_features_in_sam: # precompute projected level 0 and level 1 features in SAM decoder # to avoid running it again on every SAM click backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0( backbone_out["backbone_fpn"][0] ) backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( backbone_out["backbone_fpn"][1] ) return backbone_out def _prepare_backbone_features(self, backbone_out): """Prepare and flatten visual features.""" backbone_out = backbone_out.copy() assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] # flatten NxCxHxW to HWxNxC vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] return backbone_out, vision_feats, vision_pos_embeds, feat_sizes def _prepare_memory_conditioned_features( self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, feat_sizes, output_dict, num_frames, track_in_reverse=False, # tracking in reverse time order (for demo usage) ): """Fuse the current frame's visual feature map with previous memory.""" B = current_vision_feats[-1].size(1) # batch size on this frame C = self.hidden_dim H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size device = current_vision_feats[-1].device # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. # In this case, we skip the fusion with any memory. if self.num_maskmem == 0: # Disable memory and skip fusion pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) return pix_feat num_obj_ptr_tokens = 0 tpos_sign_mul = -1 if track_in_reverse else 1 # Step 1: condition the visual features of the current frame on previous memories if not is_init_cond_frame: # Retrieve the memories encoded with the maskmem backbone to_cat_memory, to_cat_memory_pos_embed = [], [] # Add conditioning frames's output first (all cond frames have t_pos=0 for # when getting temporal positional embedding below) assert len(output_dict["cond_frame_outputs"]) > 0 # Select a maximum number of temporally closest cond frames for cross attention cond_outputs = output_dict["cond_frame_outputs"] selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( frame_idx, cond_outputs, self.max_cond_frames_in_attn ) t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 # We also allow taking the memory frame non-consecutively (with stride>1), in which case # we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame. stride = 1 if self.training else self.memory_temporal_stride_for_eval for t_pos in range(1, self.num_maskmem): t_rel = self.num_maskmem - t_pos # how many frames before current frame if t_rel == 1: # for t_rel == 1, we take the last frame (regardless of r) if not track_in_reverse: # the frame immediately before this frame (i.e. frame_idx - 1) prev_frame_idx = frame_idx - t_rel else: # the frame immediately after this frame (i.e. frame_idx + 1) prev_frame_idx = frame_idx + t_rel else: # for t_rel >= 2, we take the memory frame from every r-th frames if not track_in_reverse: # first find the nearest frame among every r-th frames before this frame # for r=1, this would be (frame_idx - 2) prev_frame_idx = ((frame_idx - 2) // stride) * stride # then seek further among every r-th frames prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride else: # first find the nearest frame among every r-th frames after this frame # for r=1, this would be (frame_idx + 2) prev_frame_idx = -(-(frame_idx + 2) // stride) * stride # then seek further among every r-th frames prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) if out is None: # If an unselected conditioning frame is among the last (self.num_maskmem - 1) # frames, we still attend to it as if it's a non-conditioning frame. out = unselected_cond_outputs.get(prev_frame_idx, None) t_pos_and_prevs.append((t_pos, out)) for t_pos, prev in t_pos_and_prevs: if prev is None: continue # skip padding frames # "maskmem_features" might have been offloaded to CPU in demo use cases, # so we load it back to GPU (it's a no-op if it's already on GPU). feats = prev["maskmem_features"].to(device, non_blocking=True) to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) # Spatial positional encoding (it might have been offloaded to CPU in eval) maskmem_enc = prev["maskmem_pos_enc"][-1].to(device) maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) # Temporal positional encoding maskmem_enc = ( maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] ) to_cat_memory_pos_embed.append(maskmem_enc) # Construct the list of past object pointers if self.use_obj_ptrs_in_encoder: max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) # First add those object pointers from selected conditioning frames # (optionally, only include object pointers in the past during evaluation) if not self.training and self.only_obj_ptrs_in_the_past_for_eval: ptr_cond_outputs = { t: out for t, out in selected_cond_outputs.items() if (t >= frame_idx if track_in_reverse else t <= frame_idx) } else: ptr_cond_outputs = selected_cond_outputs pos_and_ptrs = [ # Temporal pos encoding contains how far away each pointer is from current frame ( ( (frame_idx - t) * tpos_sign_mul if self.use_signed_tpos_enc_to_obj_ptrs else abs(frame_idx - t) ), out["obj_ptr"], ) for t, out in ptr_cond_outputs.items() ] # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame for t_diff in range(1, max_obj_ptrs_in_encoder): t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff if t < 0 or (num_frames is not None and t >= num_frames): break out = output_dict["non_cond_frame_outputs"].get( t, unselected_cond_outputs.get(t, None) ) if out is not None: pos_and_ptrs.append((t_diff, out["obj_ptr"])) # If we have at least one object pointer, add them to the across attention if len(pos_and_ptrs) > 0: pos_list, ptrs_list = zip(*pos_and_ptrs) # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape obj_ptrs = torch.stack(ptrs_list, dim=0) # a temporal positional embedding based on how far each object pointer is from # the current frame (sine embedding normalized by the max pointer num). if self.add_tpos_enc_to_obj_ptrs: t_diff_max = max_obj_ptrs_in_encoder - 1 tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim obj_pos = torch.tensor(pos_list).to( device=device, non_blocking=True ) obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) obj_pos = self.obj_ptr_tpos_proj(obj_pos) obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) else: obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) if self.mem_dim < C: # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C obj_ptrs = obj_ptrs.reshape( -1, B, C // self.mem_dim, self.mem_dim ) obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) to_cat_memory.append(obj_ptrs) to_cat_memory_pos_embed.append(obj_pos) num_obj_ptr_tokens = obj_ptrs.shape[0] else: num_obj_ptr_tokens = 0 else: # for initial conditioning frames, encode them without using any previous memory if self.directly_add_no_mem_embed: # directly add no-mem embedding (instead of using the transformer encoder) pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) return pix_feat_with_mem # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder) to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] # Step 2: Concatenate the memories and forward through the transformer encoder memory = torch.cat(to_cat_memory, dim=0) memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) pix_feat_with_mem = self.memory_attention( curr=current_vision_feats, curr_pos=current_vision_pos_embeds, memory=memory, memory_pos=memory_pos_embed, num_obj_ptr_tokens=num_obj_ptr_tokens, ) # reshape the output (HW)BC => BCHW pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) return pix_feat_with_mem def _encode_new_memory( self, current_vision_feats, feat_sizes, pred_masks_high_res, object_score_logits, is_mask_from_pts, ): """Encode the current image and its prediction into a memory feature.""" B = current_vision_feats[-1].size(1) # batch size on this frame C = self.hidden_dim H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size # top-level feature, (HW)BC => BCHW pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) if self.non_overlap_masks_for_mem_enc and not self.training: # optionally, apply non-overlapping constraints to the masks (it's applied # in the batch dimension and should only be used during eval, where all # the objects come from the same video under batch size 1). pred_masks_high_res = self._apply_non_overlapping_constraints( pred_masks_high_res ) # scale the raw mask logits with a temperature before applying sigmoid binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts if binarize and not self.training: mask_for_mem = (pred_masks_high_res > 0).float() else: # apply sigmoid on the raw mask logits to turn them into range (0, 1) mask_for_mem = torch.sigmoid(pred_masks_high_res) # apply scale and bias terms to the sigmoid probabilities if self.sigmoid_scale_for_mem_enc != 1.0: mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc if self.sigmoid_bias_for_mem_enc != 0.0: mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc maskmem_out = self.memory_encoder( pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied ) maskmem_features = maskmem_out["vision_features"] maskmem_pos_enc = maskmem_out["vision_pos_enc"] # add a no-object embedding to the spatial memory to indicate that the frame # is predicted to be occluded (i.e. no object is appearing in the frame) if self.no_obj_embed_spatial is not None: is_obj_appearing = (object_score_logits > 0).float() maskmem_features += ( 1 - is_obj_appearing[..., None, None] ) * self.no_obj_embed_spatial[..., None, None].expand( *maskmem_features.shape ) return maskmem_features, maskmem_pos_enc def _track_step( self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, feat_sizes, point_inputs, mask_inputs, output_dict, num_frames, track_in_reverse, prev_sam_mask_logits, ): current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW if len(current_vision_feats) > 1: high_res_features = [ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) ] else: high_res_features = None if mask_inputs is not None and self.use_mask_input_as_output_without_sam: # When use_mask_input_as_output_without_sam=True, we directly output the mask input # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. pix_feat = current_vision_feats[-1].permute(1, 2, 0) pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) sam_outputs = self._use_mask_as_output( pix_feat, high_res_features, mask_inputs ) else: # fused the visual feature with previous memory features in the memory bank pix_feat = self._prepare_memory_conditioned_features( frame_idx=frame_idx, is_init_cond_frame=is_init_cond_frame, current_vision_feats=current_vision_feats[-1:], current_vision_pos_embeds=current_vision_pos_embeds[-1:], feat_sizes=feat_sizes[-1:], output_dict=output_dict, num_frames=num_frames, track_in_reverse=track_in_reverse, ) # apply SAM-style segmentation head # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, # e.g. in demo where such logits come from earlier interaction instead of correction sampling # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) if prev_sam_mask_logits is not None: assert point_inputs is not None and mask_inputs is None mask_inputs = prev_sam_mask_logits multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) sam_outputs = self._forward_sam_heads( backbone_features=pix_feat, point_inputs=point_inputs, mask_inputs=mask_inputs, high_res_features=high_res_features, multimask_output=multimask_output, ) return current_out, sam_outputs, high_res_features, pix_feat def _encode_memory_in_output( self, current_vision_feats, feat_sizes, point_inputs, run_mem_encoder, high_res_masks, object_score_logits, current_out, ): if run_mem_encoder and self.num_maskmem > 0: high_res_masks_for_mem_enc = high_res_masks maskmem_features, maskmem_pos_enc = self._encode_new_memory( current_vision_feats=current_vision_feats, feat_sizes=feat_sizes, pred_masks_high_res=high_res_masks_for_mem_enc, object_score_logits=object_score_logits, is_mask_from_pts=(point_inputs is not None), ) current_out["maskmem_features"] = maskmem_features current_out["maskmem_pos_enc"] = maskmem_pos_enc else: current_out["maskmem_features"] = None current_out["maskmem_pos_enc"] = None def track_step( self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, feat_sizes, point_inputs, mask_inputs, output_dict, num_frames, track_in_reverse=False, # tracking in reverse time order (for demo usage) # Whether to run the memory encoder on the predicted masks. Sometimes we might want # to skip the memory encoder with `run_mem_encoder=False`. For example, # in demo we might call `track_step` multiple times for each user click, # and only encode the memory when the user finalizes their clicks. And in ablation # settings like SAM training on static images, we don't need the memory encoder. run_mem_encoder=True, # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). prev_sam_mask_logits=None, ): current_out, sam_outputs, _, _ = self._track_step( frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, feat_sizes, point_inputs, mask_inputs, output_dict, num_frames, track_in_reverse, prev_sam_mask_logits, ) ( _, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits, ) = sam_outputs current_out["pred_masks"] = low_res_masks current_out["pred_masks_high_res"] = high_res_masks current_out["obj_ptr"] = obj_ptr if not self.training: # Only add this in inference (to avoid unused param in activation checkpointing; # it's mainly used in the demo to encode spatial memories w/ consolidated masks) current_out["object_score_logits"] = object_score_logits # Finally run the memory encoder on the predicted mask to encode # it into a new memory feature (that can be used in future frames) self._encode_memory_in_output( current_vision_feats, feat_sizes, point_inputs, run_mem_encoder, high_res_masks, object_score_logits, current_out, ) return current_out def _use_multimask(self, is_init_cond_frame, point_inputs): """Whether to use multimask output in the SAM head.""" num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) multimask_output = ( self.multimask_output_in_sam and (is_init_cond_frame or self.multimask_output_for_tracking) and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) ) return multimask_output def _apply_non_overlapping_constraints(self, pred_masks): """ Apply non-overlapping constraints to the object scores in pred_masks. Here we keep only the highest scoring object at each spatial location in pred_masks. """ batch_size = pred_masks.size(0) if batch_size == 1: return pred_masks device = pred_masks.device # "max_obj_inds": object index of the object with the highest score at each location max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] keep = max_obj_inds == batch_obj_inds # suppress overlapping regions' scores below -10.0 so that the foreground regions # don't overlap (here sigmoid(-10.0)=4.5398e-05) pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) return pred_masks ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/sam2_utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import copy from typing import Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from sam2.utils.misc import mask_to_box def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): """ Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` that are temporally closest to the current frame at `frame_idx`. Here, we take - a) the closest conditioning frame before `frame_idx` (if any); - b) the closest conditioning frame after `frame_idx` (if any); - c) any other temporally closest conditioning frames until reaching a total of `max_cond_frame_num` conditioning frames. Outputs: - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. """ if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: selected_outputs = cond_frame_outputs unselected_outputs = {} else: assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" selected_outputs = {} # the closest conditioning frame before `frame_idx` (if any) idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) if idx_before is not None: selected_outputs[idx_before] = cond_frame_outputs[idx_before] # the closest conditioning frame after `frame_idx` (if any) idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) if idx_after is not None: selected_outputs[idx_after] = cond_frame_outputs[idx_after] # add other temporally closest conditioning frames until reaching a total # of `max_cond_frame_num` conditioning frames. num_remain = max_cond_frame_num - len(selected_outputs) inds_remain = sorted( (t for t in cond_frame_outputs if t not in selected_outputs), key=lambda x: abs(x - frame_idx), )[:num_remain] selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) unselected_outputs = { t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs } return selected_outputs, unselected_outputs def get_1d_sine_pe(pos_inds, dim, temperature=10000): """ Get 1D sine positional embedding as in the original Transformer paper. """ pe_dim = dim // 2 dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) pos_embed = pos_inds.unsqueeze(-1) / dim_t pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) return pos_embed def get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(f"activation should be relu/gelu, not {activation}.") def get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) class DropPath(nn.Module): # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py def __init__(self, drop_prob=0.0, scale_by_keep=True): super(DropPath, self).__init__() self.drop_prob = drop_prob self.scale_by_keep = scale_by_keep def forward(self, x): if self.drop_prob == 0.0 or not self.training: return x keep_prob = 1 - self.drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0 and self.scale_by_keep: random_tensor.div_(keep_prob) return x * random_tensor # Lightly adapted from # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa class MLP(nn.Module): def __init__( self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, activation: nn.Module = nn.ReLU, sigmoid_output: bool = False, ) -> None: super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList( nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) ) self.sigmoid_output = sigmoid_output self.act = activation() def forward(self, x): for i, layer in enumerate(self.layers): x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) if self.sigmoid_output: x = F.sigmoid(x) return x # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa class LayerNorm2d(nn.Module): def __init__(self, num_channels: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x def sample_box_points( masks: torch.Tensor, noise: float = 0.1, # SAM default noise_bound: int = 20, # SAM default top_left_label: int = 2, bottom_right_label: int = 3, ) -> Tuple[np.array, np.array]: """ Sample a noised version of the top left and bottom right corners of a given `bbox` Inputs: - masks: [B, 1, H,W] boxes, dtype=torch.Tensor - noise: noise as a fraction of box width and height, dtype=float - noise_bound: maximum amount of noise (in pure pixesl), dtype=int Returns: - box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float - box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32 """ device = masks.device box_coords = mask_to_box(masks) B, _, H, W = masks.shape box_labels = torch.tensor( [top_left_label, bottom_right_label], dtype=torch.int, device=device ).repeat(B) if noise > 0.0: if not isinstance(noise_bound, torch.Tensor): noise_bound = torch.tensor(noise_bound, device=device) bbox_w = box_coords[..., 2] - box_coords[..., 0] bbox_h = box_coords[..., 3] - box_coords[..., 1] max_dx = torch.min(bbox_w * noise, noise_bound) max_dy = torch.min(bbox_h * noise, noise_bound) box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1 box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1) box_coords = box_coords + box_noise img_bounds = ( torch.tensor([W, H, W, H], device=device) - 1 ) # uncentered pixel coords box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping box_coords = box_coords.reshape(-1, 2, 2) # always 2 points box_labels = box_labels.reshape(-1, 2) return box_coords, box_labels def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1): """ Sample `num_pt` random points (along with their labels) independently from the error regions. Inputs: - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None - num_pt: int, number of points to sample independently for each of the B error maps Outputs: - points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point - labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks """ if pred_masks is None: # if pred_masks is not provided, treat it as empty pred_masks = torch.zeros_like(gt_masks) assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape assert num_pt >= 0 B, _, H_im, W_im = gt_masks.shape device = gt_masks.device # false positive region, a new point sampled in this region should have # negative label to correct the FP error fp_masks = ~gt_masks & pred_masks # false negative region, a new point sampled in this region should have # positive label to correct the FN error fn_masks = gt_masks & ~pred_masks # whether the prediction completely match the ground-truth on each mask all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2) all_correct = all_correct[..., None, None] # channel 0 is FP map, while channel 1 is FN map pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device) # sample a negative new click from FP region or a positive new click # from FN region, depend on where the maximum falls, # and in case the predictions are all correct (no FP or FN), we just # sample a negative click from the background region pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks) pts_noise[..., 1] *= fn_masks pts_idx = pts_noise.flatten(2).argmax(dim=2) labels = (pts_idx % 2).to(torch.int32) pts_idx = pts_idx // 2 pts_x = pts_idx % W_im pts_y = pts_idx // W_im points = torch.stack([pts_x, pts_y], dim=2).to(torch.float) return points, labels def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True): """ Sample 1 random point (along with its label) from the center of each error region, that is, the point with the largest distance to the boundary of each error region. This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py Inputs: - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None - padding: if True, pad with boundary of 1 px for distance transform Outputs: - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks """ import cv2 if pred_masks is None: pred_masks = torch.zeros_like(gt_masks) assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape B, _, _, W_im = gt_masks.shape device = gt_masks.device # false positive region, a new point sampled in this region should have # negative label to correct the FP error fp_masks = ~gt_masks & pred_masks # false negative region, a new point sampled in this region should have # positive label to correct the FN error fn_masks = gt_masks & ~pred_masks fp_masks = fp_masks.cpu().numpy() fn_masks = fn_masks.cpu().numpy() points = torch.zeros(B, 1, 2, dtype=torch.float) labels = torch.ones(B, 1, dtype=torch.int32) for b in range(B): fn_mask = fn_masks[b, 0] fp_mask = fp_masks[b, 0] if padding: fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant") fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant") # compute the distance of each point in FN/FP region to its boundary fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0) fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0) if padding: fn_mask_dt = fn_mask_dt[1:-1, 1:-1] fp_mask_dt = fp_mask_dt[1:-1, 1:-1] # take the point in FN/FP region with the largest distance to its boundary fn_mask_dt_flat = fn_mask_dt.reshape(-1) fp_mask_dt_flat = fp_mask_dt.reshape(-1) fn_argmax = np.argmax(fn_mask_dt_flat) fp_argmax = np.argmax(fp_mask_dt_flat) is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax] pt_idx = fn_argmax if is_positive else fp_argmax points[b, 0, 0] = pt_idx % W_im # x points[b, 0, 1] = pt_idx // W_im # y labels[b, 0] = int(is_positive) points = points.to(device) labels = labels.to(device) return points, labels def get_next_point(gt_masks, pred_masks, method): if method == "uniform": return sample_random_points_from_errors(gt_masks, pred_masks) elif method == "center": return sample_one_point_from_error_center(gt_masks, pred_masks) else: raise ValueError(f"unknown sampling method {method}") ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/sam2_hiera_b+.yaml ================================================ # @package _global_ # Model model: _target_: sam2.modeling.sam2_base.SAM2Base image_encoder: _target_: sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: _target_: sam2.modeling.backbones.hieradet.Hiera embed_dim: 112 num_heads: 2 neck: _target_: sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null temperature: 10000 d_model: 256 backbone_channel_list: [896, 448, 224, 112] fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features fpn_interp_model: nearest memory_attention: _target_: sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: _target_: sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 d_model: 256 pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 kv_in_dim: 64 num_layers: 4 memory_encoder: _target_: sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: _target_: sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: _target_: sam2.modeling.memory_encoder.Fuser layer: _target_: sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 layer_scale_init_value: 1e-6 use_dwconv: True # depth-wise convs num_layers: 2 num_maskmem: 7 image_size: 1024 # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask sigmoid_scale_for_mem_enc: 20.0 sigmoid_bias_for_mem_enc: -10.0 use_mask_input_as_output_without_sam: true # Memory directly_add_no_mem_embed: true # use high-resolution feature map in the SAM mask decoder use_high_res_features_in_sam: true # output 3 masks on the first click on initial conditioning frames multimask_output_in_sam: true # SAM heads iou_prediction_use_sigmoid: True # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder use_obj_ptrs_in_encoder: true add_tpos_enc_to_obj_ptrs: false only_obj_ptrs_in_the_past_for_eval: true # object occlusion prediction pred_obj_scores: true pred_obj_scores_mlp: true fixed_no_obj_ptr: true # multimask tracking settings multimask_output_for_tracking: true use_multimask_token_for_obj_ptr: true multimask_min_pt_num: 0 multimask_max_pt_num: 1 use_mlp_for_obj_ptr_proj: true # Compilation flag compile_image_encoder: False ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/sam2_hiera_l.yaml ================================================ # @package _global_ # Model model: _target_: sam2.modeling.sam2_base.SAM2Base image_encoder: _target_: sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: _target_: sam2.modeling.backbones.hieradet.Hiera embed_dim: 144 num_heads: 2 stages: [2, 6, 36, 4] global_att_blocks: [23, 33, 43] window_pos_embed_bkg_spatial_size: [7, 7] window_spec: [8, 4, 16, 8] neck: _target_: sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null temperature: 10000 d_model: 256 backbone_channel_list: [1152, 576, 288, 144] fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features fpn_interp_model: nearest memory_attention: _target_: sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: _target_: sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 d_model: 256 pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 kv_in_dim: 64 num_layers: 4 memory_encoder: _target_: sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: _target_: sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: _target_: sam2.modeling.memory_encoder.Fuser layer: _target_: sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 layer_scale_init_value: 1e-6 use_dwconv: True # depth-wise convs num_layers: 2 num_maskmem: 7 image_size: 1024 # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask sigmoid_scale_for_mem_enc: 20.0 sigmoid_bias_for_mem_enc: -10.0 use_mask_input_as_output_without_sam: true # Memory directly_add_no_mem_embed: true # use high-resolution feature map in the SAM mask decoder use_high_res_features_in_sam: true # output 3 masks on the first click on initial conditioning frames multimask_output_in_sam: true # SAM heads iou_prediction_use_sigmoid: True # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder use_obj_ptrs_in_encoder: true add_tpos_enc_to_obj_ptrs: false only_obj_ptrs_in_the_past_for_eval: true # object occlusion prediction pred_obj_scores: true pred_obj_scores_mlp: true fixed_no_obj_ptr: true # multimask tracking settings multimask_output_for_tracking: true use_multimask_token_for_obj_ptr: true multimask_min_pt_num: 0 multimask_max_pt_num: 1 use_mlp_for_obj_ptr_proj: true # Compilation flag compile_image_encoder: False ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/sam2_hiera_s.yaml ================================================ # @package _global_ # Model model: _target_: sam2.modeling.sam2_base.SAM2Base image_encoder: _target_: sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: _target_: sam2.modeling.backbones.hieradet.Hiera embed_dim: 96 num_heads: 1 stages: [1, 2, 11, 2] global_att_blocks: [7, 10, 13] window_pos_embed_bkg_spatial_size: [7, 7] neck: _target_: sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null temperature: 10000 d_model: 256 backbone_channel_list: [768, 384, 192, 96] fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features fpn_interp_model: nearest memory_attention: _target_: sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: _target_: sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 d_model: 256 pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 kv_in_dim: 64 num_layers: 4 memory_encoder: _target_: sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: _target_: sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: _target_: sam2.modeling.memory_encoder.Fuser layer: _target_: sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 layer_scale_init_value: 1e-6 use_dwconv: True # depth-wise convs num_layers: 2 num_maskmem: 7 image_size: 1024 # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask sigmoid_scale_for_mem_enc: 20.0 sigmoid_bias_for_mem_enc: -10.0 use_mask_input_as_output_without_sam: true # Memory directly_add_no_mem_embed: true # use high-resolution feature map in the SAM mask decoder use_high_res_features_in_sam: true # output 3 masks on the first click on initial conditioning frames multimask_output_in_sam: true # SAM heads iou_prediction_use_sigmoid: True # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder use_obj_ptrs_in_encoder: true add_tpos_enc_to_obj_ptrs: false only_obj_ptrs_in_the_past_for_eval: true # object occlusion prediction pred_obj_scores: true pred_obj_scores_mlp: true fixed_no_obj_ptr: true # multimask tracking settings multimask_output_for_tracking: true use_multimask_token_for_obj_ptr: true multimask_min_pt_num: 0 multimask_max_pt_num: 1 use_mlp_for_obj_ptr_proj: true # Compilation flag compile_image_encoder: False ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/sam2_hiera_t.yaml ================================================ # @package _global_ # Model model: _target_: sam2.modeling.sam2_base.SAM2Base image_encoder: _target_: sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: _target_: sam2.modeling.backbones.hieradet.Hiera embed_dim: 96 num_heads: 1 stages: [1, 2, 7, 2] global_att_blocks: [5, 7, 9] window_pos_embed_bkg_spatial_size: [7, 7] neck: _target_: sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null temperature: 10000 d_model: 256 backbone_channel_list: [768, 384, 192, 96] fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features fpn_interp_model: nearest memory_attention: _target_: sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: _target_: sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 d_model: 256 pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 downsample_rate: 1 dropout: 0.1 kv_in_dim: 64 num_layers: 4 memory_encoder: _target_: sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: _target_: sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: _target_: sam2.modeling.memory_encoder.Fuser layer: _target_: sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 layer_scale_init_value: 1e-6 use_dwconv: True # depth-wise convs num_layers: 2 num_maskmem: 7 image_size: 1024 # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask # SAM decoder sigmoid_scale_for_mem_enc: 20.0 sigmoid_bias_for_mem_enc: -10.0 use_mask_input_as_output_without_sam: true # Memory directly_add_no_mem_embed: true # use high-resolution feature map in the SAM mask decoder use_high_res_features_in_sam: true # output 3 masks on the first click on initial conditioning frames multimask_output_in_sam: true # SAM heads iou_prediction_use_sigmoid: True # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder use_obj_ptrs_in_encoder: true add_tpos_enc_to_obj_ptrs: false only_obj_ptrs_in_the_past_for_eval: true # object occlusion prediction pred_obj_scores: true pred_obj_scores_mlp: true fixed_no_obj_ptr: true # multimask tracking settings multimask_output_for_tracking: true use_multimask_token_for_obj_ptr: true multimask_min_pt_num: 0 multimask_max_pt_num: 1 use_mlp_for_obj_ptr_proj: true # Compilation flag # HieraT does not currently support compilation, should always be set to False compile_image_encoder: False ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/sam2_image_predictor.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging from typing import List, Optional, Tuple, Union import numpy as np import torch from PIL.Image import Image from sam2.modeling.sam2_base import SAM2Base from sam2.utils.transforms import SAM2Transforms class SAM2ImagePredictor: def __init__( self, sam_model: SAM2Base, mask_threshold=0.0, max_hole_area=0.0, max_sprinkle_area=0.0, **kwargs, ) -> None: """ Uses SAM-2 to calculate the image embedding for an image, and then allow repeated, efficient mask prediction given prompts. Arguments: sam_model (Sam-2): The model to use for mask prediction. mask_threshold (float): The threshold to use when converting mask logits to binary masks. Masks are thresholded at 0 by default. max_hole_area (int): If max_hole_area > 0, we fill small holes in up to the maximum area of max_hole_area in low_res_masks. max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to the maximum area of max_sprinkle_area in low_res_masks. """ super().__init__() self.model = sam_model self._transforms = SAM2Transforms( resolution=self.model.image_size, mask_threshold=mask_threshold, max_hole_area=max_hole_area, max_sprinkle_area=max_sprinkle_area, ) # Predictor state self._is_image_set = False self._features = None self._orig_hw = None # Whether the predictor is set for single image or a batch of images self._is_batch = False # Predictor config self.mask_threshold = mask_threshold # Spatial dim for backbone feature maps self._bb_feat_sizes = [ (256, 256), (128, 128), (64, 64), ] @classmethod def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": """ Load a pretrained model from the Hugging Face hub. Arguments: model_id (str): The Hugging Face repository ID. **kwargs: Additional arguments to pass to the model constructor. Returns: (SAM2ImagePredictor): The loaded model. """ from sam2.build_sam import build_sam2_hf sam_model = build_sam2_hf(model_id, **kwargs) return cls(sam_model, **kwargs) @torch.no_grad() def set_image( self, image: Union[np.ndarray, Image], ) -> None: """ Calculates the image embeddings for the provided image, allowing masks to be predicted with the 'predict' method. Arguments: image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image with pixel values in [0, 255]. image_format (str): The color format of the image, in ['RGB', 'BGR']. """ self.reset_predictor() # Transform the image to the form expected by the model if isinstance(image, np.ndarray): logging.info("For numpy array image, we assume (HxWxC) format") self._orig_hw = [image.shape[:2]] elif isinstance(image, Image): w, h = image.size self._orig_hw = [(h, w)] else: raise NotImplementedError("Image format not supported") input_image = self._transforms(image) input_image = input_image[None, ...].to(self.device) assert ( len(input_image.shape) == 4 and input_image.shape[1] == 3 ), f"input_image must be of size 1x3xHxW, got {input_image.shape}" logging.info("Computing image embeddings for the provided image...") backbone_out = self.model.forward_image(input_image) _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos if self.model.directly_add_no_mem_embed: vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed feats = [ feat.permute(1, 2, 0).view(1, -1, *feat_size) for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) ][::-1] self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} self._is_image_set = True logging.info("Image embeddings computed.") @torch.no_grad() def set_image_batch( self, image_list: List[Union[np.ndarray]], ) -> None: """ Calculates the image embeddings for the provided image batch, allowing masks to be predicted with the 'predict_batch' method. Arguments: image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray with pixel values in [0, 255]. """ self.reset_predictor() assert isinstance(image_list, list) self._orig_hw = [] for image in image_list: assert isinstance( image, np.ndarray ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC" self._orig_hw.append(image.shape[:2]) # Transform the image to the form expected by the model img_batch = self._transforms.forward_batch(image_list) img_batch = img_batch.to(self.device) batch_size = img_batch.shape[0] assert ( len(img_batch.shape) == 4 and img_batch.shape[1] == 3 ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}" logging.info("Computing image embeddings for the provided images...") backbone_out = self.model.forward_image(img_batch) _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos if self.model.directly_add_no_mem_embed: vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed feats = [ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) ][::-1] self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} self._is_image_set = True self._is_batch = True logging.info("Image embeddings computed.") def predict_batch( self, point_coords_batch: List[np.ndarray] = None, point_labels_batch: List[np.ndarray] = None, box_batch: List[np.ndarray] = None, mask_input_batch: List[np.ndarray] = None, multimask_output: bool = True, return_logits: bool = False, normalize_coords=True, ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images. It returns a tuple of lists of masks, ious, and low_res_masks_logits. """ assert self._is_batch, "This function should only be used when in batched mode" if not self._is_image_set: raise RuntimeError( "An image must be set with .set_image_batch(...) before mask prediction." ) num_images = len(self._features["image_embed"]) all_masks = [] all_ious = [] all_low_res_masks = [] for img_idx in range(num_images): # Transform input prompts point_coords = ( point_coords_batch[img_idx] if point_coords_batch is not None else None ) point_labels = ( point_labels_batch[img_idx] if point_labels_batch is not None else None ) box = box_batch[img_idx] if box_batch is not None else None mask_input = ( mask_input_batch[img_idx] if mask_input_batch is not None else None ) mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( point_coords, point_labels, box, mask_input, normalize_coords, img_idx=img_idx, ) masks, iou_predictions, low_res_masks = self._predict( unnorm_coords, labels, unnorm_box, mask_input, multimask_output, return_logits=return_logits, img_idx=img_idx, ) masks_np = masks.squeeze(0).float().detach().cpu().numpy() iou_predictions_np = ( iou_predictions.squeeze(0).float().detach().cpu().numpy() ) low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() all_masks.append(masks_np) all_ious.append(iou_predictions_np) all_low_res_masks.append(low_res_masks_np) return all_masks, all_ious, all_low_res_masks def predict( self, point_coords: Optional[np.ndarray] = None, point_labels: Optional[np.ndarray] = None, box: Optional[np.ndarray] = None, mask_input: Optional[np.ndarray] = None, multimask_output: bool = True, return_logits: bool = False, normalize_coords=True, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Predict masks for the given input prompts, using the currently set image. Arguments: point_coords (np.ndarray or None): A Nx2 array of point prompts to the model. Each point is in (X,Y) in pixels. point_labels (np.ndarray or None): A length N array of labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point. box (np.ndarray or None): A length 4 array given a box prompt to the model, in XYXY format. mask_input (np.ndarray): A low resolution mask input to the model, typically coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256. multimask_output (bool): If true, the model will return three masks. For ambiguous input prompts (such as a single click), this will often produce better masks than a single prediction. If only a single mask is needed, the model's predicted quality score can be used to select the best mask. For non-ambiguous prompts, such as multiple input prompts, multimask_output=False can give better results. return_logits (bool): If true, returns un-thresholded masks logits instead of a binary mask. normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions. Returns: (np.ndarray): The output masks in CxHxW format, where C is the number of masks, and (H, W) is the original image size. (np.ndarray): An array of length C containing the model's predictions for the quality of each mask. (np.ndarray): An array of shape CxHxW, where C is the number of masks and H=W=256. These low resolution logits can be passed to a subsequent iteration as mask input. """ if not self._is_image_set: raise RuntimeError( "An image must be set with .set_image(...) before mask prediction." ) # Transform input prompts mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( point_coords, point_labels, box, mask_input, normalize_coords ) masks, iou_predictions, low_res_masks = self._predict( unnorm_coords, labels, unnorm_box, mask_input, multimask_output, return_logits=return_logits, ) masks_np = masks.squeeze(0).float().detach().cpu().numpy() iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy() low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() return masks_np, iou_predictions_np, low_res_masks_np def _prep_prompts( self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1 ): unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None if point_coords is not None: assert ( point_labels is not None ), "point_labels must be supplied if point_coords is supplied." point_coords = torch.as_tensor( point_coords, dtype=torch.float, device=self.device ) unnorm_coords = self._transforms.transform_coords( point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] ) labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) if len(unnorm_coords.shape) == 2: unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...] if box is not None: box = torch.as_tensor(box, dtype=torch.float, device=self.device) unnorm_box = self._transforms.transform_boxes( box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] ) # Bx2x2 if mask_logits is not None: mask_input = torch.as_tensor( mask_logits, dtype=torch.float, device=self.device ) if len(mask_input.shape) == 3: mask_input = mask_input[None, :, :, :] return mask_input, unnorm_coords, labels, unnorm_box @torch.no_grad() def _predict( self, point_coords: Optional[torch.Tensor], point_labels: Optional[torch.Tensor], boxes: Optional[torch.Tensor] = None, mask_input: Optional[torch.Tensor] = None, multimask_output: bool = True, return_logits: bool = False, img_idx: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Predict masks for the given input prompts, using the currently set image. Input prompts are batched torch tensors and are expected to already be transformed to the input frame using SAM2Transforms. Arguments: point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the model. Each point is in (X,Y) in pixels. point_labels (torch.Tensor or None): A BxN array of labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point. boxes (np.ndarray or None): A Bx4 array given a box prompt to the model, in XYXY format. mask_input (np.ndarray): A low resolution mask input to the model, typically coming from a previous prediction iteration. Has form Bx1xHxW, where for SAM, H=W=256. Masks returned by a previous iteration of the predict method do not need further transformation. multimask_output (bool): If true, the model will return three masks. For ambiguous input prompts (such as a single click), this will often produce better masks than a single prediction. If only a single mask is needed, the model's predicted quality score can be used to select the best mask. For non-ambiguous prompts, such as multiple input prompts, multimask_output=False can give better results. return_logits (bool): If true, returns un-thresholded masks logits instead of a binary mask. Returns: (torch.Tensor): The output masks in BxCxHxW format, where C is the number of masks, and (H, W) is the original image size. (torch.Tensor): An array of shape BxC containing the model's predictions for the quality of each mask. (torch.Tensor): An array of shape BxCxHxW, where C is the number of masks and H=W=256. These low res logits can be passed to a subsequent iteration as mask input. """ if not self._is_image_set: raise RuntimeError( "An image must be set with .set_image(...) before mask prediction." ) if point_coords is not None: concat_points = (point_coords, point_labels) else: concat_points = None # Embed prompts if boxes is not None: box_coords = boxes.reshape(-1, 2, 2) box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device) box_labels = box_labels.repeat(boxes.size(0), 1) # we merge "boxes" and "points" into a single "concat_points" input (where # boxes are added at the beginning) to sam_prompt_encoder if concat_points is not None: concat_coords = torch.cat([box_coords, concat_points[0]], dim=1) concat_labels = torch.cat([box_labels, concat_points[1]], dim=1) concat_points = (concat_coords, concat_labels) else: concat_points = (box_coords, box_labels) sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( points=concat_points, boxes=None, masks=mask_input, ) # Predict masks batched_mode = ( concat_points is not None and concat_points[0].shape[0] > 1 ) # multi object prediction high_res_features = [ feat_level[img_idx].unsqueeze(0) for feat_level in self._features["high_res_feats"] ] low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder( image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0), image_pe=self.model.sam_prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, repeat_image=batched_mode, high_res_features=high_res_features, ) # Upscale the masks to the original image resolution masks = self._transforms.postprocess_masks( low_res_masks, self._orig_hw[img_idx] ) low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0) if not return_logits: masks = masks > self.mask_threshold return masks, iou_predictions, low_res_masks def get_image_embedding(self) -> torch.Tensor: """ Returns the image embeddings for the currently set image, with shape 1xCxHxW, where C is the embedding dimension and (H,W) are the embedding spatial dimension of SAM (typically C=256, H=W=64). """ if not self._is_image_set: raise RuntimeError( "An image must be set with .set_image(...) to generate an embedding." ) assert ( self._features is not None ), "Features must exist if an image has been set." return self._features["image_embed"] @property def device(self) -> torch.device: return self.model.device def reset_predictor(self) -> None: """ Resets the image embeddings and other state variables. """ self._is_image_set = False self._features = None self._orig_hw = None self._is_batch = False ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/sam2_video_predictor.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import warnings from collections import OrderedDict import torch import torch.nn.functional as F from tqdm import tqdm from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames class SAM2VideoPredictor(SAM2Base): """The predictor class to handle user interactions and manage inference states.""" def __init__( self, fill_hole_area=0, # whether to apply non-overlapping constraints on the output object masks non_overlap_masks=False, # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks; # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True) clear_non_cond_mem_around_input=False, # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames add_all_frames_to_correct_as_cond=False, **kwargs, ): super().__init__(**kwargs) self.fill_hole_area = fill_hole_area self.non_overlap_masks = non_overlap_masks self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond @torch.inference_mode() def init_state( self, video_path, offload_video_to_cpu=False, offload_state_to_cpu=False, async_loading_frames=False, ): """Initialize an inference state.""" compute_device = self.device # device of the model images, video_height, video_width = load_video_frames( video_path=video_path, image_size=self.image_size, offload_video_to_cpu=offload_video_to_cpu, async_loading_frames=async_loading_frames, compute_device=compute_device, ) inference_state = {} inference_state["images"] = images inference_state["num_frames"] = len(images) # whether to offload the video frames to CPU memory # turning on this option saves the GPU memory with only a very small overhead inference_state["offload_video_to_cpu"] = offload_video_to_cpu # whether to offload the inference state to CPU memory # turning on this option saves the GPU memory at the cost of a lower tracking fps # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object # and from 24 to 21 when tracking two objects) inference_state["offload_state_to_cpu"] = offload_state_to_cpu # the original video height and width, used for resizing final output scores inference_state["video_height"] = video_height inference_state["video_width"] = video_width inference_state["device"] = compute_device if offload_state_to_cpu: inference_state["storage_device"] = torch.device("cpu") else: inference_state["storage_device"] = compute_device # inputs on each frame inference_state["point_inputs_per_obj"] = {} inference_state["mask_inputs_per_obj"] = {} # visual features on a small number of recently visited frames for quick interactions inference_state["cached_features"] = {} # values that don't change across frames (so we only need to hold one copy of them) inference_state["constants"] = {} # mapping between client-side object id and model-side object index inference_state["obj_id_to_idx"] = OrderedDict() inference_state["obj_idx_to_id"] = OrderedDict() inference_state["obj_ids"] = [] # Slice (view) of each object tracking results, sharing the same memory with "output_dict" inference_state["output_dict_per_obj"] = {} # A temporary storage to hold new outputs when user interact with a frame # to add clicks or mask (it's merged into "output_dict" before propagation starts) inference_state["temp_output_dict_per_obj"] = {} # Frames that already holds consolidated outputs from click or mask inputs # (we directly use their consolidated outputs during tracking) # metadata for each tracking frame (e.g. which direction it's tracked) inference_state["frames_tracked_per_obj"] = {} # Warm up the visual backbone and cache the image feature on frame 0 self._get_image_feature(inference_state, frame_idx=0, batch_size=1) return inference_state @classmethod def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor": """ Load a pretrained model from the Hugging Face hub. Arguments: model_id (str): The Hugging Face repository ID. **kwargs: Additional arguments to pass to the model constructor. Returns: (SAM2VideoPredictor): The loaded model. """ from sam2.build_sam import build_sam2_video_predictor_hf sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) return sam_model def _obj_id_to_idx(self, inference_state, obj_id): """Map client-side object id to model-side object index.""" obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) if obj_idx is not None: return obj_idx # We always allow adding new objects (including after tracking starts). allow_new_object = True if allow_new_object: # get the next object slot obj_idx = len(inference_state["obj_id_to_idx"]) inference_state["obj_id_to_idx"][obj_id] = obj_idx inference_state["obj_idx_to_id"][obj_idx] = obj_id inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) # set up input and output structures for this object inference_state["point_inputs_per_obj"][obj_idx] = {} inference_state["mask_inputs_per_obj"][obj_idx] = {} inference_state["output_dict_per_obj"][obj_idx] = { "cond_frame_outputs": {}, # dict containing {frame_idx: } "non_cond_frame_outputs": {}, # dict containing {frame_idx: } } inference_state["temp_output_dict_per_obj"][obj_idx] = { "cond_frame_outputs": {}, # dict containing {frame_idx: } "non_cond_frame_outputs": {}, # dict containing {frame_idx: } } inference_state["frames_tracked_per_obj"][obj_idx] = {} return obj_idx else: raise RuntimeError( f"Cannot add new object id {obj_id} after tracking starts. " f"All existing object ids: {inference_state['obj_ids']}. " f"Please call 'reset_state' to restart from scratch." ) def _obj_idx_to_id(self, inference_state, obj_idx): """Map model-side object index to client-side object id.""" return inference_state["obj_idx_to_id"][obj_idx] def _get_obj_num(self, inference_state): """Get the total number of unique object ids received so far in this session.""" return len(inference_state["obj_idx_to_id"]) @torch.inference_mode() def add_new_points_or_box( self, inference_state, frame_idx, obj_id, points=None, labels=None, clear_old_points=True, normalize_coords=True, box=None, ): """Add new points to a frame.""" obj_idx = self._obj_id_to_idx(inference_state, obj_id) point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] if (points is not None) != (labels is not None): raise ValueError("points and labels must be provided together") if points is None and box is None: raise ValueError("at least one of points or box must be provided as input") if points is None: points = torch.zeros(0, 2, dtype=torch.float32) elif not isinstance(points, torch.Tensor): points = torch.tensor(points, dtype=torch.float32) if labels is None: labels = torch.zeros(0, dtype=torch.int32) elif not isinstance(labels, torch.Tensor): labels = torch.tensor(labels, dtype=torch.int32) if points.dim() == 2: points = points.unsqueeze(0) # add batch dimension if labels.dim() == 1: labels = labels.unsqueeze(0) # add batch dimension # If `box` is provided, we add it as the first two points with labels 2 and 3 # along with the user-provided points (consistent with how SAM 2 is trained). if box is not None: if not clear_old_points: raise ValueError( "cannot add box without clearing old points, since " "box prompt must be provided before any point prompt " "(please use clear_old_points=True instead)" ) if not isinstance(box, torch.Tensor): box = torch.tensor(box, dtype=torch.float32, device=points.device) box_coords = box.reshape(1, 2, 2) box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device) box_labels = box_labels.reshape(1, 2) points = torch.cat([box_coords, points], dim=1) labels = torch.cat([box_labels, labels], dim=1) if normalize_coords: video_H = inference_state["video_height"] video_W = inference_state["video_width"] points = points / torch.tensor([video_W, video_H]).to(points.device) # scale the (normalized) coordinates by the model's internal image size points = points * self.image_size points = points.to(inference_state["device"]) labels = labels.to(inference_state["device"]) if not clear_old_points: point_inputs = point_inputs_per_frame.get(frame_idx, None) else: point_inputs = None point_inputs = concat_points(point_inputs, points, labels) point_inputs_per_frame[frame_idx] = point_inputs mask_inputs_per_frame.pop(frame_idx, None) # If this frame hasn't been tracked before, we treat it as an initial conditioning # frame, meaning that the inputs points are to generate segments on this frame without # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), # the input points will be used to correct the already tracked masks. obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx] is_init_cond_frame = frame_idx not in obj_frames_tracked # whether to track in reverse time order if is_init_cond_frame: reverse = False else: reverse = obj_frames_tracked[frame_idx]["reverse"] obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] # Add a frame to conditioning output if it's an initial conditioning frame or # if the model sees all frames receiving clicks/mask as conditioning frames. is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" # Get any previously predicted mask logits on this object and feed it along with # the new clicks into the SAM mask decoder. prev_sam_mask_logits = None # lookup temporary output dict first, which contains the most recent output # (if not found, then lookup conditioning and non-conditioning frame output) prev_out = obj_temp_output_dict[storage_key].get(frame_idx) if prev_out is None: prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) if prev_out is None: prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) if prev_out is not None and prev_out["pred_masks"] is not None: device = inference_state["device"] prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True) # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) current_out, _ = self._run_single_frame_inference( inference_state=inference_state, output_dict=obj_output_dict, # run on the slice of a single object frame_idx=frame_idx, batch_size=1, # run on the slice of a single object is_init_cond_frame=is_init_cond_frame, point_inputs=point_inputs, mask_inputs=None, reverse=reverse, # Skip the memory encoder when adding clicks or mask. We execute the memory encoder # at the beginning of `propagate_in_video` (after user finalize their clicks). This # allows us to enforce non-overlapping constraints on all objects before encoding # them into memory. run_mem_encoder=False, prev_sam_mask_logits=prev_sam_mask_logits, ) # Add the output to the output dict (to be used as future memory) obj_temp_output_dict[storage_key][frame_idx] = current_out # Resize the output mask to the original video resolution obj_ids = inference_state["obj_ids"] consolidated_out = self._consolidate_temp_output_across_obj( inference_state, frame_idx, is_cond=is_cond, consolidate_at_video_res=True, ) _, video_res_masks = self._get_orig_video_res_output( inference_state, consolidated_out["pred_masks_video_res"] ) return frame_idx, obj_ids, video_res_masks def add_new_points(self, *args, **kwargs): """Deprecated method. Please use `add_new_points_or_box` instead.""" return self.add_new_points_or_box(*args, **kwargs) @torch.inference_mode() def add_new_mask( self, inference_state, frame_idx, obj_id, mask, ): """Add new mask to a frame.""" obj_idx = self._obj_id_to_idx(inference_state, obj_id) point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] if not isinstance(mask, torch.Tensor): mask = torch.tensor(mask, dtype=torch.bool) assert mask.dim() == 2 mask_H, mask_W = mask.shape mask_inputs_orig = mask[None, None] # add batch and channel dimension mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"]) # resize the mask if it doesn't match the model's image size if mask_H != self.image_size or mask_W != self.image_size: mask_inputs = torch.nn.functional.interpolate( mask_inputs_orig, size=(self.image_size, self.image_size), align_corners=False, mode="bilinear", antialias=True, # use antialias for downsampling ) mask_inputs = (mask_inputs >= 0.5).float() else: mask_inputs = mask_inputs_orig mask_inputs_per_frame[frame_idx] = mask_inputs point_inputs_per_frame.pop(frame_idx, None) # If this frame hasn't been tracked before, we treat it as an initial conditioning # frame, meaning that the inputs points are to generate segments on this frame without # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), # the input points will be used to correct the already tracked masks. obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx] is_init_cond_frame = frame_idx not in obj_frames_tracked # whether to track in reverse time order if is_init_cond_frame: reverse = False else: reverse = obj_frames_tracked[frame_idx]["reverse"] obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] # Add a frame to conditioning output if it's an initial conditioning frame or # if the model sees all frames receiving clicks/mask as conditioning frames. is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" current_out, _ = self._run_single_frame_inference( inference_state=inference_state, output_dict=obj_output_dict, # run on the slice of a single object frame_idx=frame_idx, batch_size=1, # run on the slice of a single object is_init_cond_frame=is_init_cond_frame, point_inputs=None, mask_inputs=mask_inputs, reverse=reverse, # Skip the memory encoder when adding clicks or mask. We execute the memory encoder # at the beginning of `propagate_in_video` (after user finalize their clicks). This # allows us to enforce non-overlapping constraints on all objects before encoding # them into memory. run_mem_encoder=False, ) # Add the output to the output dict (to be used as future memory) obj_temp_output_dict[storage_key][frame_idx] = current_out # Resize the output mask to the original video resolution obj_ids = inference_state["obj_ids"] consolidated_out = self._consolidate_temp_output_across_obj( inference_state, frame_idx, is_cond=is_cond, consolidate_at_video_res=True, ) _, video_res_masks = self._get_orig_video_res_output( inference_state, consolidated_out["pred_masks_video_res"] ) return frame_idx, obj_ids, video_res_masks def _get_orig_video_res_output(self, inference_state, any_res_masks): """ Resize the object scores to the original video resolution (video_res_masks) and apply non-overlapping constraints for final output. """ device = inference_state["device"] video_H = inference_state["video_height"] video_W = inference_state["video_width"] any_res_masks = any_res_masks.to(device, non_blocking=True) if any_res_masks.shape[-2:] == (video_H, video_W): video_res_masks = any_res_masks else: video_res_masks = torch.nn.functional.interpolate( any_res_masks, size=(video_H, video_W), mode="bilinear", align_corners=False, ) if self.non_overlap_masks: video_res_masks = self._apply_non_overlapping_constraints(video_res_masks) return any_res_masks, video_res_masks def _consolidate_temp_output_across_obj( self, inference_state, frame_idx, is_cond, consolidate_at_video_res=False, ): """ Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on a frame into a single output for all objects, including 1) fill any missing objects either from `output_dict_per_obj` (if they exist in `output_dict_per_obj` for this frame) or leave them as placeholder values (if they don't exist in `output_dict_per_obj` for this frame); 2) if specified, rerun memory encoder after apply non-overlapping constraints on the object scores. """ batch_size = self._get_obj_num(inference_state) storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" # Optionally, we allow consolidating the temporary outputs at the original # video resolution (to provide a better editing experience for mask prompts). if consolidate_at_video_res: consolidated_H = inference_state["video_height"] consolidated_W = inference_state["video_width"] consolidated_mask_key = "pred_masks_video_res" else: consolidated_H = consolidated_W = self.image_size // 4 consolidated_mask_key = "pred_masks" # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" # will be added when rerunning the memory encoder after applying non-overlapping # constraints to object scores. Its "pred_masks" are prefilled with a large # negative value (NO_OBJ_SCORE) to represent missing objects. consolidated_out = { consolidated_mask_key: torch.full( size=(batch_size, 1, consolidated_H, consolidated_W), fill_value=NO_OBJ_SCORE, dtype=torch.float32, device=inference_state["storage_device"], ), } for obj_idx in range(batch_size): obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] out = obj_temp_output_dict[storage_key].get(frame_idx, None) # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, # we fall back and look up its previous output in "output_dict_per_obj". # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in # "output_dict_per_obj" to find a previous output for this object. if out is None: out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None) if out is None: out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None) # If the object doesn't appear in "output_dict_per_obj" either, we skip it # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE # placeholder above) and set its object pointer to be a dummy pointer. if out is None: continue # Add the temporary object output mask to consolidated output mask obj_mask = out["pred_masks"] consolidated_pred_masks = consolidated_out[consolidated_mask_key] if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]: consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask else: # Resize first if temporary object mask has a different resolution resized_obj_mask = torch.nn.functional.interpolate( obj_mask, size=consolidated_pred_masks.shape[-2:], mode="bilinear", align_corners=False, ) consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask return consolidated_out @torch.inference_mode() def propagate_in_video_preflight(self, inference_state): """Prepare inference_state and consolidate temporary outputs before tracking.""" # Check and make sure that every object has received input points or masks. batch_size = self._get_obj_num(inference_state) if batch_size == 0: raise RuntimeError( "No input points or masks are provided for any object; please add inputs first." ) # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and # add them into "output_dict". for obj_idx in range(batch_size): obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] for is_cond in [False, True]: # Separately consolidate conditioning and non-conditioning temp outputs storage_key = ( "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" ) # Find all the frames that contain temporary outputs for any objects # (these should be the frames that have just received clicks for mask inputs # via `add_new_points_or_box` or `add_new_mask`) for frame_idx, out in obj_temp_output_dict[storage_key].items(): # Run memory encoder on the temporary outputs (if the memory feature is missing) if out["maskmem_features"] is None: high_res_masks = torch.nn.functional.interpolate( out["pred_masks"].to(inference_state["device"]), size=(self.image_size, self.image_size), mode="bilinear", align_corners=False, ) maskmem_features, maskmem_pos_enc = self._run_memory_encoder( inference_state=inference_state, frame_idx=frame_idx, batch_size=1, # run on the slice of a single object high_res_masks=high_res_masks, object_score_logits=out["object_score_logits"], # these frames are what the user interacted with is_mask_from_pts=True, ) out["maskmem_features"] = maskmem_features out["maskmem_pos_enc"] = maskmem_pos_enc obj_output_dict[storage_key][frame_idx] = out if self.clear_non_cond_mem_around_input: # clear non-conditioning memory of the surrounding frames self._clear_obj_non_cond_mem_around_input( inference_state, frame_idx, obj_idx ) # clear temporary outputs in `temp_output_dict_per_obj` obj_temp_output_dict[storage_key].clear() # check and make sure that every object has received input points or masks obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] if len(obj_output_dict["cond_frame_outputs"]) == 0: obj_id = self._obj_idx_to_id(inference_state, obj_idx) raise RuntimeError( f"No input points or masks are provided for object id {obj_id}; please add inputs first." ) # edge case: if an output is added to "cond_frame_outputs", we remove any prior # output on the same frame in "non_cond_frame_outputs" for frame_idx in obj_output_dict["cond_frame_outputs"]: obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) @torch.inference_mode() def propagate_in_video( self, inference_state, start_frame_idx=None, max_frame_num_to_track=None, reverse=False, ): """Propagate the input points across frames to track in the entire video.""" self.propagate_in_video_preflight(inference_state) obj_ids = inference_state["obj_ids"] num_frames = inference_state["num_frames"] batch_size = self._get_obj_num(inference_state) # set start index, end index, and processing order if start_frame_idx is None: # default: start from the earliest frame with input points start_frame_idx = min( t for obj_output_dict in inference_state["output_dict_per_obj"].values() for t in obj_output_dict["cond_frame_outputs"] ) if max_frame_num_to_track is None: # default: track all the frames in the video max_frame_num_to_track = num_frames if reverse: end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) if start_frame_idx > 0: processing_order = range(start_frame_idx, end_frame_idx - 1, -1) else: processing_order = [] # skip reverse tracking if starting from frame 0 else: end_frame_idx = min( start_frame_idx + max_frame_num_to_track, num_frames - 1 ) processing_order = range(start_frame_idx, end_frame_idx + 1) for frame_idx in tqdm(processing_order, desc="propagate in video"): pred_masks_per_obj = [None] * batch_size for obj_idx in range(batch_size): obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] # We skip those frames already in consolidated outputs (these are frames # that received input clicks or mask). Note that we cannot directly run # batched forward on them via `_run_single_frame_inference` because the # number of clicks on each object might be different. if frame_idx in obj_output_dict["cond_frame_outputs"]: storage_key = "cond_frame_outputs" current_out = obj_output_dict[storage_key][frame_idx] device = inference_state["device"] pred_masks = current_out["pred_masks"].to(device, non_blocking=True) if self.clear_non_cond_mem_around_input: # clear non-conditioning memory of the surrounding frames self._clear_obj_non_cond_mem_around_input( inference_state, frame_idx, obj_idx ) else: storage_key = "non_cond_frame_outputs" current_out, pred_masks = self._run_single_frame_inference( inference_state=inference_state, output_dict=obj_output_dict, frame_idx=frame_idx, batch_size=1, # run on the slice of a single object is_init_cond_frame=False, point_inputs=None, mask_inputs=None, reverse=reverse, run_mem_encoder=True, ) obj_output_dict[storage_key][frame_idx] = current_out inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = { "reverse": reverse } pred_masks_per_obj[obj_idx] = pred_masks # Resize the output mask to the original video resolution (we directly use # the mask scores on GPU for output to avoid any CPU conversion in between) if len(pred_masks_per_obj) > 1: all_pred_masks = torch.cat(pred_masks_per_obj, dim=0) else: all_pred_masks = pred_masks_per_obj[0] _, video_res_masks = self._get_orig_video_res_output( inference_state, all_pred_masks ) yield frame_idx, obj_ids, video_res_masks @torch.inference_mode() def clear_all_prompts_in_frame( self, inference_state, frame_idx, obj_id, need_output=True ): """Remove all input points or mask in a specific frame for a given object.""" obj_idx = self._obj_id_to_idx(inference_state, obj_id) # Clear the conditioning information on the given frame inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None) inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None) temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None) temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None) # Remove the frame's conditioning output (possibly downgrading it to non-conditioning) obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None) if out is not None: # The frame is not a conditioning frame anymore since it's not receiving inputs, # so we "downgrade" its output (if exists) to a non-conditioning frame output. obj_output_dict["non_cond_frame_outputs"][frame_idx] = out inference_state["frames_tracked_per_obj"][obj_idx].pop(frame_idx, None) if not need_output: return # Finally, output updated masks per object (after removing the inputs above) obj_ids = inference_state["obj_ids"] is_cond = any( frame_idx in obj_temp_output_dict["cond_frame_outputs"] for obj_temp_output_dict in temp_output_dict_per_obj.values() ) consolidated_out = self._consolidate_temp_output_across_obj( inference_state, frame_idx, is_cond=is_cond, consolidate_at_video_res=True, ) _, video_res_masks = self._get_orig_video_res_output( inference_state, consolidated_out["pred_masks_video_res"] ) return frame_idx, obj_ids, video_res_masks @torch.inference_mode() def reset_state(self, inference_state): """Remove all input points or mask in all frames throughout the video.""" self._reset_tracking_results(inference_state) # Remove all object ids inference_state["obj_id_to_idx"].clear() inference_state["obj_idx_to_id"].clear() inference_state["obj_ids"].clear() inference_state["point_inputs_per_obj"].clear() inference_state["mask_inputs_per_obj"].clear() inference_state["output_dict_per_obj"].clear() inference_state["temp_output_dict_per_obj"].clear() inference_state["frames_tracked_per_obj"].clear() def _reset_tracking_results(self, inference_state): """Reset all tracking inputs and results across the videos.""" for v in inference_state["point_inputs_per_obj"].values(): v.clear() for v in inference_state["mask_inputs_per_obj"].values(): v.clear() for v in inference_state["output_dict_per_obj"].values(): v["cond_frame_outputs"].clear() v["non_cond_frame_outputs"].clear() for v in inference_state["temp_output_dict_per_obj"].values(): v["cond_frame_outputs"].clear() v["non_cond_frame_outputs"].clear() for v in inference_state["frames_tracked_per_obj"].values(): v.clear() def _get_image_feature(self, inference_state, frame_idx, batch_size): """Compute the image features on a given frame.""" # Look up in the cache first image, backbone_out = inference_state["cached_features"].get( frame_idx, (None, None) ) if backbone_out is None: # Cache miss -- we will run inference on a single image device = inference_state["device"] image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0) backbone_out = self.forward_image(image) # Cache the most recent frame's feature (for repeated interactions with # a frame; we can use an LRU cache for more frames in the future). inference_state["cached_features"] = {frame_idx: (image, backbone_out)} # expand the features to have the same dimension as the number of objects expanded_image = image.expand(batch_size, -1, -1, -1) expanded_backbone_out = { "backbone_fpn": backbone_out["backbone_fpn"].copy(), "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), } for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]): expanded_backbone_out["backbone_fpn"][i] = feat.expand( batch_size, -1, -1, -1 ) for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]): pos = pos.expand(batch_size, -1, -1, -1) expanded_backbone_out["vision_pos_enc"][i] = pos features = self._prepare_backbone_features(expanded_backbone_out) features = (expanded_image,) + features return features def _run_single_frame_inference( self, inference_state, output_dict, frame_idx, batch_size, is_init_cond_frame, point_inputs, mask_inputs, reverse, run_mem_encoder, prev_sam_mask_logits=None, ): """Run tracking on a single frame based on current inputs and previous memory.""" # Retrieve correct image features ( _, _, current_vision_feats, current_vision_pos_embeds, feat_sizes, ) = self._get_image_feature(inference_state, frame_idx, batch_size) # point and mask should not appear as input simultaneously on the same frame assert point_inputs is None or mask_inputs is None current_out = self.track_step( frame_idx=frame_idx, is_init_cond_frame=is_init_cond_frame, current_vision_feats=current_vision_feats, current_vision_pos_embeds=current_vision_pos_embeds, feat_sizes=feat_sizes, point_inputs=point_inputs, mask_inputs=mask_inputs, output_dict=output_dict, num_frames=inference_state["num_frames"], track_in_reverse=reverse, run_mem_encoder=run_mem_encoder, prev_sam_mask_logits=prev_sam_mask_logits, ) # optionally offload the output to CPU memory to save GPU space storage_device = inference_state["storage_device"] maskmem_features = current_out["maskmem_features"] if maskmem_features is not None: maskmem_features = maskmem_features.to(torch.bfloat16) maskmem_features = maskmem_features.to(storage_device, non_blocking=True) pred_masks_gpu = current_out["pred_masks"] # potentially fill holes in the predicted masks if self.fill_hole_area > 0: pred_masks_gpu = fill_holes_in_mask_scores( pred_masks_gpu, self.fill_hole_area ) pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) # object pointer is a small tensor, so we always keep it on GPU memory for fast access obj_ptr = current_out["obj_ptr"] object_score_logits = current_out["object_score_logits"] # make a compact version of this frame's output to reduce the state size compact_current_out = { "maskmem_features": maskmem_features, "maskmem_pos_enc": maskmem_pos_enc, "pred_masks": pred_masks, "obj_ptr": obj_ptr, "object_score_logits": object_score_logits, } return compact_current_out, pred_masks_gpu def _run_memory_encoder( self, inference_state, frame_idx, batch_size, high_res_masks, object_score_logits, is_mask_from_pts, ): """ Run the memory encoder on `high_res_masks`. This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their memory also need to be computed again with the memory encoder. """ # Retrieve correct image features _, _, current_vision_feats, _, feat_sizes = self._get_image_feature( inference_state, frame_idx, batch_size ) maskmem_features, maskmem_pos_enc = self._encode_new_memory( current_vision_feats=current_vision_feats, feat_sizes=feat_sizes, pred_masks_high_res=high_res_masks, object_score_logits=object_score_logits, is_mask_from_pts=is_mask_from_pts, ) # optionally offload the output to CPU memory to save GPU space storage_device = inference_state["storage_device"] maskmem_features = maskmem_features.to(torch.bfloat16) maskmem_features = maskmem_features.to(storage_device, non_blocking=True) # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it maskmem_pos_enc = self._get_maskmem_pos_enc( inference_state, {"maskmem_pos_enc": maskmem_pos_enc} ) return maskmem_features, maskmem_pos_enc def _get_maskmem_pos_enc(self, inference_state, current_out): """ `maskmem_pos_enc` is the same across frames and objects, so we cache it as a constant in the inference session to reduce session storage size. """ model_constants = inference_state["constants"] # "out_maskmem_pos_enc" should be either a list of tensors or None out_maskmem_pos_enc = current_out["maskmem_pos_enc"] if out_maskmem_pos_enc is not None: if "maskmem_pos_enc" not in model_constants: assert isinstance(out_maskmem_pos_enc, list) # only take the slice for one object, since it's same across objects maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] model_constants["maskmem_pos_enc"] = maskmem_pos_enc else: maskmem_pos_enc = model_constants["maskmem_pos_enc"] # expand the cached maskmem_pos_enc to the actual batch size batch_size = out_maskmem_pos_enc[0].size(0) expanded_maskmem_pos_enc = [ x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc ] else: expanded_maskmem_pos_enc = None return expanded_maskmem_pos_enc @torch.inference_mode() def remove_object(self, inference_state, obj_id, strict=False, need_output=True): """ Remove an object id from the tracking state. If strict is True, we check whether the object id actually exists and raise an error if it doesn't exist. """ old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None) updated_frames = [] # Check whether this object_id to remove actually exists and possibly raise an error. if old_obj_idx_to_rm is None: if not strict: return inference_state["obj_ids"], updated_frames raise RuntimeError( f"Cannot remove object id {obj_id} as it doesn't exist. " f"All existing object ids: {inference_state['obj_ids']}." ) # If this is the only remaining object id, we simply reset the state. if len(inference_state["obj_id_to_idx"]) == 1: self.reset_state(inference_state) return inference_state["obj_ids"], updated_frames # There are still remaining objects after removing this object id. In this case, # we need to delete the object storage from inference state tensors. # Step 0: clear the input on those frames where this object id has point or mask input # (note that this step is required as it might downgrade conditioning frames to # non-conditioning ones) obj_input_frames_inds = set() obj_input_frames_inds.update( inference_state["point_inputs_per_obj"][old_obj_idx_to_rm] ) obj_input_frames_inds.update( inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm] ) for frame_idx in obj_input_frames_inds: self.clear_all_prompts_in_frame( inference_state, frame_idx, obj_id, need_output=False ) # Step 1: Update the object id mapping (note that it must be done after Step 0, # since Step 0 still requires the old object id mappings in inference_state) old_obj_ids = inference_state["obj_ids"] old_obj_inds = list(range(len(old_obj_ids))) remain_old_obj_inds = old_obj_inds.copy() remain_old_obj_inds.remove(old_obj_idx_to_rm) new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds] new_obj_inds = list(range(len(new_obj_ids))) # build new mappings old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds)) inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds)) inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids)) inference_state["obj_ids"] = new_obj_ids # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys. def _map_keys(container): new_kvs = [] for k in old_obj_inds: v = container.pop(k) if k in old_idx_to_new_idx: new_kvs.append((old_idx_to_new_idx[k], v)) container.update(new_kvs) _map_keys(inference_state["point_inputs_per_obj"]) _map_keys(inference_state["mask_inputs_per_obj"]) _map_keys(inference_state["output_dict_per_obj"]) _map_keys(inference_state["temp_output_dict_per_obj"]) _map_keys(inference_state["frames_tracked_per_obj"]) # Step 3: Further collect the outputs on those frames in `obj_input_frames_inds`, which # could show an updated mask for objects previously occluded by the object being removed if need_output: temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] for frame_idx in obj_input_frames_inds: is_cond = any( frame_idx in obj_temp_output_dict["cond_frame_outputs"] for obj_temp_output_dict in temp_output_dict_per_obj.values() ) consolidated_out = self._consolidate_temp_output_across_obj( inference_state, frame_idx, is_cond=is_cond, consolidate_at_video_res=True, ) _, video_res_masks = self._get_orig_video_res_output( inference_state, consolidated_out["pred_masks_video_res"] ) updated_frames.append((frame_idx, video_res_masks)) return inference_state["obj_ids"], updated_frames def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): """ Remove the non-conditioning memory around the input frame. When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain outdated object appearance information and could confuse the model. This method clears those non-conditioning memories surrounding the interacted frame to avoid giving the model both old and new information about the object. """ r = self.memory_temporal_stride_for_eval frame_idx_begin = frame_idx - r * self.num_maskmem frame_idx_end = frame_idx + r * self.num_maskmem batch_size = self._get_obj_num(inference_state) for obj_idx in range(batch_size): obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] non_cond_frame_outputs = obj_output_dict["non_cond_frame_outputs"] for t in range(frame_idx_begin, frame_idx_end + 1): non_cond_frame_outputs.pop(t, None) class SAM2VideoPredictorVOS(SAM2VideoPredictor): """Optimized for the VOS setting""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._compile_all_components() def _compile_all_components(self): print("Compiling all components for VOS setting. First time may be very slow.") self.memory_encoder.forward = torch.compile( self.memory_encoder.forward, mode="max-autotune", fullgraph=True, dynamic=False, ) self.memory_attention.forward = torch.compile( self.memory_attention.forward, mode="max-autotune", fullgraph=True, dynamic=True, # Num. of memories varies ) self.sam_prompt_encoder.forward = torch.compile( self.sam_prompt_encoder.forward, mode="max-autotune", fullgraph=True, dynamic=False, # Accuracy regression on True ) self.sam_mask_decoder.forward = torch.compile( self.sam_mask_decoder.forward, mode="max-autotune", fullgraph=True, dynamic=False, # Accuracy regression on True ) def forward_image(self, img_batch: torch.Tensor): """ Identical to the corresponding method in the parent (SAM2VideoPredictor), but cloning the backbone features and pos encoding to enable compilation. """ backbone_out = self.image_encoder(img_batch) if self.use_high_res_features_in_sam: # precompute projected level 0 and level 1 features in SAM decoder # to avoid running it again on every SAM click backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0( backbone_out["backbone_fpn"][0] ) backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( backbone_out["backbone_fpn"][1] ) # Clone to help torch.compile for i in range(len(backbone_out["backbone_fpn"])): backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone() backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][ i ].clone() return backbone_out def _forward_sam_heads( self, backbone_features, point_inputs=None, mask_inputs=None, high_res_features=None, multimask_output=False, ): """ Identical to the corresponding method in the parent (SAM2VideoPredictor), but cloning the outputs of prompt_encoder and mask_decoder to enable compilation. """ B = backbone_features.size(0) device = backbone_features.device assert backbone_features.size(1) == self.sam_prompt_embed_dim assert backbone_features.size(2) == self.sam_image_embedding_size assert backbone_features.size(3) == self.sam_image_embedding_size # a) Handle point prompts if point_inputs is not None: sam_point_coords = point_inputs["point_coords"] sam_point_labels = point_inputs["point_labels"] assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B else: # If no points are provide, pad with an empty point (with label -1) sam_point_coords = torch.zeros(B, 1, 2, device=device) sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) # b) Handle mask prompts if mask_inputs is not None: # If mask_inputs is provided, downsize it into low-res mask input if needed # and feed it as a dense mask prompt into the SAM mask encoder assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: sam_mask_prompt = F.interpolate( mask_inputs.float(), size=self.sam_prompt_encoder.mask_input_size, align_corners=False, mode="bilinear", antialias=True, # use antialias for downsampling ) else: sam_mask_prompt = mask_inputs else: # Otherwise, simply feed None (and SAM's prompt encoder will add # a learned `no_mask_embed` to indicate no mask input in this case). sam_mask_prompt = None sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( points=(sam_point_coords, sam_point_labels), boxes=None, masks=sam_mask_prompt, ) # Clone image_pe and the outputs of sam_prompt_encoder # to enable compilation sparse_embeddings = sparse_embeddings.clone() dense_embeddings = dense_embeddings.clone() image_pe = self.sam_prompt_encoder.get_dense_pe().clone() ( low_res_multimasks, ious, sam_output_tokens, object_score_logits, ) = self.sam_mask_decoder( image_embeddings=backbone_features, image_pe=image_pe, sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, repeat_image=False, # the image is already batched high_res_features=high_res_features, ) # Clone the output of sam_mask_decoder # to enable compilation low_res_multimasks = low_res_multimasks.clone() ious = ious.clone() sam_output_tokens = sam_output_tokens.clone() object_score_logits = object_score_logits.clone() if self.pred_obj_scores: is_obj_appearing = object_score_logits > 0 # Mask used for spatial memories is always a *hard* choice between obj and no obj, # consistent with the actual mask prediction low_res_multimasks = torch.where( is_obj_appearing[:, None, None], low_res_multimasks, NO_OBJ_SCORE, ) # convert masks from possibly bfloat16 (or float16) to float32 # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) low_res_multimasks = low_res_multimasks.float() high_res_multimasks = F.interpolate( low_res_multimasks, size=(self.image_size, self.image_size), mode="bilinear", align_corners=False, ) sam_output_token = sam_output_tokens[:, 0] if multimask_output: # take the best mask prediction (with the highest IoU estimation) best_iou_inds = torch.argmax(ious, dim=-1) batch_inds = torch.arange(B, device=device) low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) if sam_output_tokens.size(1) > 1: sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] else: low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks # Extract object pointer from the SAM output token (with occlusion handling) obj_ptr = self.obj_ptr_proj(sam_output_token) if self.pred_obj_scores: # Allow *soft* no obj ptr, unlike for masks if self.soft_no_obj_ptr: lambda_is_obj_appearing = object_score_logits.sigmoid() else: lambda_is_obj_appearing = is_obj_appearing.float() if self.fixed_no_obj_ptr: obj_ptr = lambda_is_obj_appearing * obj_ptr obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr return ( low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits, ) def _encode_new_memory( self, current_vision_feats, feat_sizes, pred_masks_high_res, object_score_logits, is_mask_from_pts, ): """ Identical to the corresponding method in the parent (SAM2VideoPredictor), but cloning the memories and their pos enc to enable compilation. """ B = current_vision_feats[-1].size(1) # batch size on this frame C = self.hidden_dim H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size # top-level feature, (HW)BC => BCHW pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) if self.non_overlap_masks_for_mem_enc and not self.training: # optionally, apply non-overlapping constraints to the masks (it's applied # in the batch dimension and should only be used during eval, where all # the objects come from the same video under batch size 1). pred_masks_high_res = self._apply_non_overlapping_constraints( pred_masks_high_res ) # scale the raw mask logits with a temperature before applying sigmoid binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts if binarize and not self.training: mask_for_mem = (pred_masks_high_res > 0).float() else: # apply sigmoid on the raw mask logits to turn them into range (0, 1) mask_for_mem = torch.sigmoid(pred_masks_high_res) # apply scale and bias terms to the sigmoid probabilities if self.sigmoid_scale_for_mem_enc != 1.0: mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc if self.sigmoid_bias_for_mem_enc != 0.0: mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc maskmem_out = self.memory_encoder( pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied ) # Clone the feats and pos_enc to enable compilation maskmem_features = maskmem_out["vision_features"].clone() maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]] # add a no-object embedding to the spatial memory to indicate that the frame # is predicted to be occluded (i.e. no object is appearing in the frame) if self.no_obj_embed_spatial is not None: is_obj_appearing = (object_score_logits > 0).float() maskmem_features += ( 1 - is_obj_appearing[..., None, None] ) * self.no_obj_embed_spatial[..., None, None].expand( *maskmem_features.shape ) return maskmem_features, maskmem_pos_enc ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/sam2_video_predictor_legacy.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import warnings from collections import OrderedDict import torch from tqdm import tqdm from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames class SAM2VideoPredictor(SAM2Base): """The predictor class to handle user interactions and manage inference states.""" def __init__( self, fill_hole_area=0, # whether to apply non-overlapping constraints on the output object masks non_overlap_masks=False, # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks; # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True) clear_non_cond_mem_around_input=False, # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True). clear_non_cond_mem_for_multi_obj=False, # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames add_all_frames_to_correct_as_cond=False, **kwargs, ): super().__init__(**kwargs) self.fill_hole_area = fill_hole_area self.non_overlap_masks = non_overlap_masks self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond @torch.inference_mode() def init_state( self, video_path, offload_video_to_cpu=False, offload_state_to_cpu=False, async_loading_frames=False, ): """Initialize an inference state.""" compute_device = self.device # device of the model images, video_height, video_width = load_video_frames( video_path=video_path, image_size=self.image_size, offload_video_to_cpu=offload_video_to_cpu, async_loading_frames=async_loading_frames, compute_device=compute_device, ) inference_state = {} inference_state["images"] = images inference_state["num_frames"] = len(images) # whether to offload the video frames to CPU memory # turning on this option saves the GPU memory with only a very small overhead inference_state["offload_video_to_cpu"] = offload_video_to_cpu # whether to offload the inference state to CPU memory # turning on this option saves the GPU memory at the cost of a lower tracking fps # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object # and from 24 to 21 when tracking two objects) inference_state["offload_state_to_cpu"] = offload_state_to_cpu # the original video height and width, used for resizing final output scores inference_state["video_height"] = video_height inference_state["video_width"] = video_width inference_state["device"] = compute_device if offload_state_to_cpu: inference_state["storage_device"] = torch.device("cpu") else: inference_state["storage_device"] = compute_device # inputs on each frame inference_state["point_inputs_per_obj"] = {} inference_state["mask_inputs_per_obj"] = {} # visual features on a small number of recently visited frames for quick interactions inference_state["cached_features"] = {} # values that don't change across frames (so we only need to hold one copy of them) inference_state["constants"] = {} # mapping between client-side object id and model-side object index inference_state["obj_id_to_idx"] = OrderedDict() inference_state["obj_idx_to_id"] = OrderedDict() inference_state["obj_ids"] = [] # A storage to hold the model's tracking results and states on each frame inference_state["output_dict"] = { "cond_frame_outputs": {}, # dict containing {frame_idx: } "non_cond_frame_outputs": {}, # dict containing {frame_idx: } } # Slice (view) of each object tracking results, sharing the same memory with "output_dict" inference_state["output_dict_per_obj"] = {} # A temporary storage to hold new outputs when user interact with a frame # to add clicks or mask (it's merged into "output_dict" before propagation starts) inference_state["temp_output_dict_per_obj"] = {} # Frames that already holds consolidated outputs from click or mask inputs # (we directly use their consolidated outputs during tracking) inference_state["consolidated_frame_inds"] = { "cond_frame_outputs": set(), # set containing frame indices "non_cond_frame_outputs": set(), # set containing frame indices } # metadata for each tracking frame (e.g. which direction it's tracked) inference_state["tracking_has_started"] = False inference_state["frames_already_tracked"] = {} # Warm up the visual backbone and cache the image feature on frame 0 self._get_image_feature(inference_state, frame_idx=0, batch_size=1) return inference_state @classmethod def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor": """ Load a pretrained model from the Hugging Face hub. Arguments: model_id (str): The Hugging Face repository ID. **kwargs: Additional arguments to pass to the model constructor. Returns: (SAM2VideoPredictor): The loaded model. """ from sam2.build_sam import build_sam2_video_predictor_hf sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) return sam_model def _obj_id_to_idx(self, inference_state, obj_id): """Map client-side object id to model-side object index.""" obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) if obj_idx is not None: return obj_idx # This is a new object id not sent to the server before. We only allow adding # new objects *before* the tracking starts. allow_new_object = not inference_state["tracking_has_started"] if allow_new_object: # get the next object slot obj_idx = len(inference_state["obj_id_to_idx"]) inference_state["obj_id_to_idx"][obj_id] = obj_idx inference_state["obj_idx_to_id"][obj_idx] = obj_id inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) # set up input and output structures for this object inference_state["point_inputs_per_obj"][obj_idx] = {} inference_state["mask_inputs_per_obj"][obj_idx] = {} inference_state["output_dict_per_obj"][obj_idx] = { "cond_frame_outputs": {}, # dict containing {frame_idx: } "non_cond_frame_outputs": {}, # dict containing {frame_idx: } } inference_state["temp_output_dict_per_obj"][obj_idx] = { "cond_frame_outputs": {}, # dict containing {frame_idx: } "non_cond_frame_outputs": {}, # dict containing {frame_idx: } } return obj_idx else: raise RuntimeError( f"Cannot add new object id {obj_id} after tracking starts. " f"All existing object ids: {inference_state['obj_ids']}. " f"Please call 'reset_state' to restart from scratch." ) def _obj_idx_to_id(self, inference_state, obj_idx): """Map model-side object index to client-side object id.""" return inference_state["obj_idx_to_id"][obj_idx] def _get_obj_num(self, inference_state): """Get the total number of unique object ids received so far in this session.""" return len(inference_state["obj_idx_to_id"]) @torch.inference_mode() def add_new_points_or_box( self, inference_state, frame_idx, obj_id, points=None, labels=None, clear_old_points=True, normalize_coords=True, box=None, ): """Add new points to a frame.""" obj_idx = self._obj_id_to_idx(inference_state, obj_id) point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] if (points is not None) != (labels is not None): raise ValueError("points and labels must be provided together") if points is None and box is None: raise ValueError("at least one of points or box must be provided as input") if points is None: points = torch.zeros(0, 2, dtype=torch.float32) elif not isinstance(points, torch.Tensor): points = torch.tensor(points, dtype=torch.float32) if labels is None: labels = torch.zeros(0, dtype=torch.int32) elif not isinstance(labels, torch.Tensor): labels = torch.tensor(labels, dtype=torch.int32) if points.dim() == 2: points = points.unsqueeze(0) # add batch dimension if labels.dim() == 1: labels = labels.unsqueeze(0) # add batch dimension # If `box` is provided, we add it as the first two points with labels 2 and 3 # along with the user-provided points (consistent with how SAM 2 is trained). if box is not None: if not clear_old_points: raise ValueError( "cannot add box without clearing old points, since " "box prompt must be provided before any point prompt " "(please use clear_old_points=True instead)" ) if inference_state["tracking_has_started"]: warnings.warn( "You are adding a box after tracking starts. SAM 2 may not always be " "able to incorporate a box prompt for *refinement*. If you intend to " "use box prompt as an *initial* input before tracking, please call " "'reset_state' on the inference state to restart from scratch.", category=UserWarning, stacklevel=2, ) if not isinstance(box, torch.Tensor): box = torch.tensor(box, dtype=torch.float32, device=points.device) box_coords = box.reshape(1, 2, 2) box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device) box_labels = box_labels.reshape(1, 2) points = torch.cat([box_coords, points], dim=1) labels = torch.cat([box_labels, labels], dim=1) if normalize_coords: video_H = inference_state["video_height"] video_W = inference_state["video_width"] points = points / torch.tensor([video_W, video_H]).to(points.device) # scale the (normalized) coordinates by the model's internal image size points = points * self.image_size points = points.to(inference_state["device"]) labels = labels.to(inference_state["device"]) if not clear_old_points: point_inputs = point_inputs_per_frame.get(frame_idx, None) else: point_inputs = None point_inputs = concat_points(point_inputs, points, labels) point_inputs_per_frame[frame_idx] = point_inputs mask_inputs_per_frame.pop(frame_idx, None) # If this frame hasn't been tracked before, we treat it as an initial conditioning # frame, meaning that the inputs points are to generate segments on this frame without # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), # the input points will be used to correct the already tracked masks. is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] # whether to track in reverse time order if is_init_cond_frame: reverse = False else: reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] # Add a frame to conditioning output if it's an initial conditioning frame or # if the model sees all frames receiving clicks/mask as conditioning frames. is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" # Get any previously predicted mask logits on this object and feed it along with # the new clicks into the SAM mask decoder. prev_sam_mask_logits = None # lookup temporary output dict first, which contains the most recent output # (if not found, then lookup conditioning and non-conditioning frame output) prev_out = obj_temp_output_dict[storage_key].get(frame_idx) if prev_out is None: prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) if prev_out is None: prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) if prev_out is not None and prev_out["pred_masks"] is not None: device = inference_state["device"] prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True) # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) current_out, _ = self._run_single_frame_inference( inference_state=inference_state, output_dict=obj_output_dict, # run on the slice of a single object frame_idx=frame_idx, batch_size=1, # run on the slice of a single object is_init_cond_frame=is_init_cond_frame, point_inputs=point_inputs, mask_inputs=None, reverse=reverse, # Skip the memory encoder when adding clicks or mask. We execute the memory encoder # at the beginning of `propagate_in_video` (after user finalize their clicks). This # allows us to enforce non-overlapping constraints on all objects before encoding # them into memory. run_mem_encoder=False, prev_sam_mask_logits=prev_sam_mask_logits, ) # Add the output to the output dict (to be used as future memory) obj_temp_output_dict[storage_key][frame_idx] = current_out # Resize the output mask to the original video resolution obj_ids = inference_state["obj_ids"] consolidated_out = self._consolidate_temp_output_across_obj( inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=False, consolidate_at_video_res=True, ) _, video_res_masks = self._get_orig_video_res_output( inference_state, consolidated_out["pred_masks_video_res"] ) return frame_idx, obj_ids, video_res_masks def add_new_points(self, *args, **kwargs): """Deprecated method. Please use `add_new_points_or_box` instead.""" return self.add_new_points_or_box(*args, **kwargs) @torch.inference_mode() def add_new_mask( self, inference_state, frame_idx, obj_id, mask, ): """Add new mask to a frame.""" obj_idx = self._obj_id_to_idx(inference_state, obj_id) point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] if not isinstance(mask, torch.Tensor): mask = torch.tensor(mask, dtype=torch.bool) assert mask.dim() == 2 mask_H, mask_W = mask.shape mask_inputs_orig = mask[None, None] # add batch and channel dimension mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"]) # resize the mask if it doesn't match the model's image size if mask_H != self.image_size or mask_W != self.image_size: mask_inputs = torch.nn.functional.interpolate( mask_inputs_orig, size=(self.image_size, self.image_size), align_corners=False, mode="bilinear", antialias=True, # use antialias for downsampling ) mask_inputs = (mask_inputs >= 0.5).float() else: mask_inputs = mask_inputs_orig mask_inputs_per_frame[frame_idx] = mask_inputs point_inputs_per_frame.pop(frame_idx, None) # If this frame hasn't been tracked before, we treat it as an initial conditioning # frame, meaning that the inputs points are to generate segments on this frame without # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), # the input points will be used to correct the already tracked masks. is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] # whether to track in reverse time order if is_init_cond_frame: reverse = False else: reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] # Add a frame to conditioning output if it's an initial conditioning frame or # if the model sees all frames receiving clicks/mask as conditioning frames. is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" current_out, _ = self._run_single_frame_inference( inference_state=inference_state, output_dict=obj_output_dict, # run on the slice of a single object frame_idx=frame_idx, batch_size=1, # run on the slice of a single object is_init_cond_frame=is_init_cond_frame, point_inputs=None, mask_inputs=mask_inputs, reverse=reverse, # Skip the memory encoder when adding clicks or mask. We execute the memory encoder # at the beginning of `propagate_in_video` (after user finalize their clicks). This # allows us to enforce non-overlapping constraints on all objects before encoding # them into memory. run_mem_encoder=False, ) # Add the output to the output dict (to be used as future memory) obj_temp_output_dict[storage_key][frame_idx] = current_out # Resize the output mask to the original video resolution obj_ids = inference_state["obj_ids"] consolidated_out = self._consolidate_temp_output_across_obj( inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=False, consolidate_at_video_res=True, ) _, video_res_masks = self._get_orig_video_res_output( inference_state, consolidated_out["pred_masks_video_res"] ) return frame_idx, obj_ids, video_res_masks def _get_orig_video_res_output(self, inference_state, any_res_masks): """ Resize the object scores to the original video resolution (video_res_masks) and apply non-overlapping constraints for final output. """ device = inference_state["device"] video_H = inference_state["video_height"] video_W = inference_state["video_width"] any_res_masks = any_res_masks.to(device, non_blocking=True) if any_res_masks.shape[-2:] == (video_H, video_W): video_res_masks = any_res_masks else: video_res_masks = torch.nn.functional.interpolate( any_res_masks, size=(video_H, video_W), mode="bilinear", align_corners=False, ) if self.non_overlap_masks: video_res_masks = self._apply_non_overlapping_constraints(video_res_masks) return any_res_masks, video_res_masks def _consolidate_temp_output_across_obj( self, inference_state, frame_idx, is_cond, run_mem_encoder, consolidate_at_video_res=False, ): """ Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on a frame into a single output for all objects, including 1) fill any missing objects either from `output_dict_per_obj` (if they exist in `output_dict_per_obj` for this frame) or leave them as placeholder values (if they don't exist in `output_dict_per_obj` for this frame); 2) if specified, rerun memory encoder after apply non-overlapping constraints on the object scores. """ batch_size = self._get_obj_num(inference_state) storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" # Optionally, we allow consolidating the temporary outputs at the original # video resolution (to provide a better editing experience for mask prompts). if consolidate_at_video_res: assert not run_mem_encoder, "memory encoder cannot run at video resolution" consolidated_H = inference_state["video_height"] consolidated_W = inference_state["video_width"] consolidated_mask_key = "pred_masks_video_res" else: consolidated_H = consolidated_W = self.image_size // 4 consolidated_mask_key = "pred_masks" # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" # will be added when rerunning the memory encoder after applying non-overlapping # constraints to object scores. Its "pred_masks" are prefilled with a large # negative value (NO_OBJ_SCORE) to represent missing objects. consolidated_out = { "maskmem_features": None, "maskmem_pos_enc": None, consolidated_mask_key: torch.full( size=(batch_size, 1, consolidated_H, consolidated_W), fill_value=NO_OBJ_SCORE, dtype=torch.float32, device=inference_state["storage_device"], ), "obj_ptr": torch.full( size=(batch_size, self.hidden_dim), fill_value=NO_OBJ_SCORE, dtype=torch.float32, device=inference_state["device"], ), "object_score_logits": torch.full( size=(batch_size, 1), # default to 10.0 for object_score_logits, i.e. assuming the object is # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder` fill_value=10.0, dtype=torch.float32, device=inference_state["device"], ), } empty_mask_ptr = None for obj_idx in range(batch_size): obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] out = obj_temp_output_dict[storage_key].get(frame_idx, None) # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, # we fall back and look up its previous output in "output_dict_per_obj". # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in # "output_dict_per_obj" to find a previous output for this object. if out is None: out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None) if out is None: out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None) # If the object doesn't appear in "output_dict_per_obj" either, we skip it # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE # placeholder above) and set its object pointer to be a dummy pointer. if out is None: # Fill in dummy object pointers for those objects without any inputs or # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, # i.e. when we need to build the memory for tracking). if run_mem_encoder: if empty_mask_ptr is None: empty_mask_ptr = self._get_empty_mask_ptr( inference_state, frame_idx ) # fill object pointer with a dummy pointer (based on an empty mask) consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr continue # Add the temporary object output mask to consolidated output mask obj_mask = out["pred_masks"] consolidated_pred_masks = consolidated_out[consolidated_mask_key] if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]: consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask else: # Resize first if temporary object mask has a different resolution resized_obj_mask = torch.nn.functional.interpolate( obj_mask, size=consolidated_pred_masks.shape[-2:], mode="bilinear", align_corners=False, ) consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[ "object_score_logits" ] # Optionally, apply non-overlapping constraints on the consolidated scores # and rerun the memory encoder if run_mem_encoder: device = inference_state["device"] high_res_masks = torch.nn.functional.interpolate( consolidated_out["pred_masks"].to(device, non_blocking=True), size=(self.image_size, self.image_size), mode="bilinear", align_corners=False, ) if self.non_overlap_masks_for_mem_enc: high_res_masks = self._apply_non_overlapping_constraints(high_res_masks) maskmem_features, maskmem_pos_enc = self._run_memory_encoder( inference_state=inference_state, frame_idx=frame_idx, batch_size=batch_size, high_res_masks=high_res_masks, object_score_logits=consolidated_out["object_score_logits"], is_mask_from_pts=True, # these frames are what the user interacted with ) consolidated_out["maskmem_features"] = maskmem_features consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc return consolidated_out def _get_empty_mask_ptr(self, inference_state, frame_idx): """Get a dummy object pointer based on an empty mask on the current frame.""" # A dummy (empty) mask with a single object batch_size = 1 mask_inputs = torch.zeros( (batch_size, 1, self.image_size, self.image_size), dtype=torch.float32, device=inference_state["device"], ) # Retrieve correct image features ( _, _, current_vision_feats, current_vision_pos_embeds, feat_sizes, ) = self._get_image_feature(inference_state, frame_idx, batch_size) # Feed the empty mask and image feature above to get a dummy object pointer current_out = self.track_step( frame_idx=frame_idx, is_init_cond_frame=True, current_vision_feats=current_vision_feats, current_vision_pos_embeds=current_vision_pos_embeds, feat_sizes=feat_sizes, point_inputs=None, mask_inputs=mask_inputs, output_dict={}, num_frames=inference_state["num_frames"], track_in_reverse=False, run_mem_encoder=False, prev_sam_mask_logits=None, ) return current_out["obj_ptr"] @torch.inference_mode() def propagate_in_video_preflight(self, inference_state): """Prepare inference_state and consolidate temporary outputs before tracking.""" # Tracking has started and we don't allow adding new objects until session is reset. inference_state["tracking_has_started"] = True batch_size = self._get_obj_num(inference_state) # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and # add them into "output_dict". temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] output_dict = inference_state["output_dict"] # "consolidated_frame_inds" contains indices of those frames where consolidated # temporary outputs have been added (either in this call or any previous calls # to `propagate_in_video_preflight`). consolidated_frame_inds = inference_state["consolidated_frame_inds"] for is_cond in [False, True]: # Separately consolidate conditioning and non-conditioning temp outputs storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" # Find all the frames that contain temporary outputs for any objects # (these should be the frames that have just received clicks for mask inputs # via `add_new_points_or_box` or `add_new_mask`) temp_frame_inds = set() for obj_temp_output_dict in temp_output_dict_per_obj.values(): temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) consolidated_frame_inds[storage_key].update(temp_frame_inds) # consolidate the temporary output across all objects on this frame for frame_idx in temp_frame_inds: consolidated_out = self._consolidate_temp_output_across_obj( inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True ) # merge them into "output_dict" and also create per-object slices output_dict[storage_key][frame_idx] = consolidated_out self._add_output_per_object( inference_state, frame_idx, consolidated_out, storage_key ) clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 ) if clear_non_cond_mem: # clear non-conditioning memory of the surrounding frames self._clear_non_cond_mem_around_input(inference_state, frame_idx) # clear temporary outputs in `temp_output_dict_per_obj` for obj_temp_output_dict in temp_output_dict_per_obj.values(): obj_temp_output_dict[storage_key].clear() # edge case: if an output is added to "cond_frame_outputs", we remove any prior # output on the same frame in "non_cond_frame_outputs" for frame_idx in output_dict["cond_frame_outputs"]: output_dict["non_cond_frame_outputs"].pop(frame_idx, None) for obj_output_dict in inference_state["output_dict_per_obj"].values(): for frame_idx in obj_output_dict["cond_frame_outputs"]: obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: assert frame_idx in output_dict["cond_frame_outputs"] consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames # with either points or mask inputs (which should be true under a correct workflow). all_consolidated_frame_inds = ( consolidated_frame_inds["cond_frame_outputs"] | consolidated_frame_inds["non_cond_frame_outputs"] ) input_frames_inds = set() for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values(): input_frames_inds.update(point_inputs_per_frame.keys()) for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values(): input_frames_inds.update(mask_inputs_per_frame.keys()) assert all_consolidated_frame_inds == input_frames_inds @torch.inference_mode() def propagate_in_video( self, inference_state, start_frame_idx=None, max_frame_num_to_track=None, reverse=False, ): """Propagate the input points across frames to track in the entire video.""" self.propagate_in_video_preflight(inference_state) output_dict = inference_state["output_dict"] consolidated_frame_inds = inference_state["consolidated_frame_inds"] obj_ids = inference_state["obj_ids"] num_frames = inference_state["num_frames"] batch_size = self._get_obj_num(inference_state) if len(output_dict["cond_frame_outputs"]) == 0: raise RuntimeError("No points are provided; please add points first") clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 ) # set start index, end index, and processing order if start_frame_idx is None: # default: start from the earliest frame with input points start_frame_idx = min(output_dict["cond_frame_outputs"]) if max_frame_num_to_track is None: # default: track all the frames in the video max_frame_num_to_track = num_frames if reverse: end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) if start_frame_idx > 0: processing_order = range(start_frame_idx, end_frame_idx - 1, -1) else: processing_order = [] # skip reverse tracking if starting from frame 0 else: end_frame_idx = min( start_frame_idx + max_frame_num_to_track, num_frames - 1 ) processing_order = range(start_frame_idx, end_frame_idx + 1) for frame_idx in tqdm(processing_order, desc="propagate in video"): # We skip those frames already in consolidated outputs (these are frames # that received input clicks or mask). Note that we cannot directly run # batched forward on them via `_run_single_frame_inference` because the # number of clicks on each object might be different. if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: storage_key = "cond_frame_outputs" current_out = output_dict[storage_key][frame_idx] pred_masks = current_out["pred_masks"] if clear_non_cond_mem: # clear non-conditioning memory of the surrounding frames self._clear_non_cond_mem_around_input(inference_state, frame_idx) elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: storage_key = "non_cond_frame_outputs" current_out = output_dict[storage_key][frame_idx] pred_masks = current_out["pred_masks"] else: storage_key = "non_cond_frame_outputs" current_out, pred_masks = self._run_single_frame_inference( inference_state=inference_state, output_dict=output_dict, frame_idx=frame_idx, batch_size=batch_size, is_init_cond_frame=False, point_inputs=None, mask_inputs=None, reverse=reverse, run_mem_encoder=True, ) output_dict[storage_key][frame_idx] = current_out # Create slices of per-object outputs for subsequent interaction with each # individual object after tracking. self._add_output_per_object( inference_state, frame_idx, current_out, storage_key ) inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} # Resize the output mask to the original video resolution (we directly use # the mask scores on GPU for output to avoid any CPU conversion in between) _, video_res_masks = self._get_orig_video_res_output( inference_state, pred_masks ) yield frame_idx, obj_ids, video_res_masks def _add_output_per_object( self, inference_state, frame_idx, current_out, storage_key ): """ Split a multi-object output into per-object output slices and add them into `output_dict_per_obj`. The resulting slices share the same tensor storage. """ maskmem_features = current_out["maskmem_features"] assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) maskmem_pos_enc = current_out["maskmem_pos_enc"] assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) output_dict_per_obj = inference_state["output_dict_per_obj"] for obj_idx, obj_output_dict in output_dict_per_obj.items(): obj_slice = slice(obj_idx, obj_idx + 1) obj_out = { "maskmem_features": None, "maskmem_pos_enc": None, "pred_masks": current_out["pred_masks"][obj_slice], "obj_ptr": current_out["obj_ptr"][obj_slice], "object_score_logits": current_out["object_score_logits"][obj_slice], } if maskmem_features is not None: obj_out["maskmem_features"] = maskmem_features[obj_slice] if maskmem_pos_enc is not None: obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] obj_output_dict[storage_key][frame_idx] = obj_out @torch.inference_mode() def clear_all_prompts_in_frame( self, inference_state, frame_idx, obj_id, need_output=True ): """Remove all input points or mask in a specific frame for a given object.""" obj_idx = self._obj_id_to_idx(inference_state, obj_id) # Clear the conditioning information on the given frame inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None) inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None) temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None) temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None) # Check and see if there are still any inputs left on this frame batch_size = self._get_obj_num(inference_state) frame_has_input = False for obj_idx2 in range(batch_size): if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]: frame_has_input = True break if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]: frame_has_input = True break # If this frame has no remaining inputs for any objects, we further clear its # conditioning frame status if not frame_has_input: output_dict = inference_state["output_dict"] consolidated_frame_inds = inference_state["consolidated_frame_inds"] consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx) consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) # Remove the frame's conditioning output (possibly downgrading it to non-conditioning) out = output_dict["cond_frame_outputs"].pop(frame_idx, None) if out is not None: # The frame is not a conditioning frame anymore since it's not receiving inputs, # so we "downgrade" its output (if exists) to a non-conditioning frame output. output_dict["non_cond_frame_outputs"][frame_idx] = out inference_state["frames_already_tracked"].pop(frame_idx, None) # Similarly, do it for the sliced output on each object. for obj_idx2 in range(batch_size): obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2] obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None) if obj_out is not None: obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out # If all the conditioning frames have been removed, we also clear the tracking outputs if len(output_dict["cond_frame_outputs"]) == 0: self._reset_tracking_results(inference_state) if not need_output: return # Finally, output updated masks per object (after removing the inputs above) obj_ids = inference_state["obj_ids"] is_cond = any( frame_idx in obj_temp_output_dict["cond_frame_outputs"] for obj_temp_output_dict in temp_output_dict_per_obj.values() ) consolidated_out = self._consolidate_temp_output_across_obj( inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=False, consolidate_at_video_res=True, ) _, video_res_masks = self._get_orig_video_res_output( inference_state, consolidated_out["pred_masks_video_res"] ) return frame_idx, obj_ids, video_res_masks @torch.inference_mode() def reset_state(self, inference_state): """Remove all input points or mask in all frames throughout the video.""" self._reset_tracking_results(inference_state) # Remove all object ids inference_state["obj_id_to_idx"].clear() inference_state["obj_idx_to_id"].clear() inference_state["obj_ids"].clear() inference_state["point_inputs_per_obj"].clear() inference_state["mask_inputs_per_obj"].clear() inference_state["output_dict_per_obj"].clear() inference_state["temp_output_dict_per_obj"].clear() def _reset_tracking_results(self, inference_state): """Reset all tracking inputs and results across the videos.""" for v in inference_state["point_inputs_per_obj"].values(): v.clear() for v in inference_state["mask_inputs_per_obj"].values(): v.clear() for v in inference_state["output_dict_per_obj"].values(): v["cond_frame_outputs"].clear() v["non_cond_frame_outputs"].clear() for v in inference_state["temp_output_dict_per_obj"].values(): v["cond_frame_outputs"].clear() v["non_cond_frame_outputs"].clear() inference_state["output_dict"]["cond_frame_outputs"].clear() inference_state["output_dict"]["non_cond_frame_outputs"].clear() inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() inference_state["tracking_has_started"] = False inference_state["frames_already_tracked"].clear() def _get_image_feature(self, inference_state, frame_idx, batch_size): """Compute the image features on a given frame.""" # Look up in the cache first image, backbone_out = inference_state["cached_features"].get( frame_idx, (None, None) ) if backbone_out is None: # Cache miss -- we will run inference on a single image device = inference_state["device"] image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0) backbone_out = self.forward_image(image) # Cache the most recent frame's feature (for repeated interactions with # a frame; we can use an LRU cache for more frames in the future). inference_state["cached_features"] = {frame_idx: (image, backbone_out)} # expand the features to have the same dimension as the number of objects expanded_image = image.expand(batch_size, -1, -1, -1) expanded_backbone_out = { "backbone_fpn": backbone_out["backbone_fpn"].copy(), "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), } for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]): expanded_backbone_out["backbone_fpn"][i] = feat.expand( batch_size, -1, -1, -1 ) for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]): pos = pos.expand(batch_size, -1, -1, -1) expanded_backbone_out["vision_pos_enc"][i] = pos features = self._prepare_backbone_features(expanded_backbone_out) features = (expanded_image,) + features return features def _run_single_frame_inference( self, inference_state, output_dict, frame_idx, batch_size, is_init_cond_frame, point_inputs, mask_inputs, reverse, run_mem_encoder, prev_sam_mask_logits=None, ): """Run tracking on a single frame based on current inputs and previous memory.""" # Retrieve correct image features ( _, _, current_vision_feats, current_vision_pos_embeds, feat_sizes, ) = self._get_image_feature(inference_state, frame_idx, batch_size) # point and mask should not appear as input simultaneously on the same frame assert point_inputs is None or mask_inputs is None current_out = self.track_step( frame_idx=frame_idx, is_init_cond_frame=is_init_cond_frame, current_vision_feats=current_vision_feats, current_vision_pos_embeds=current_vision_pos_embeds, feat_sizes=feat_sizes, point_inputs=point_inputs, mask_inputs=mask_inputs, output_dict=output_dict, num_frames=inference_state["num_frames"], track_in_reverse=reverse, run_mem_encoder=run_mem_encoder, prev_sam_mask_logits=prev_sam_mask_logits, ) # optionally offload the output to CPU memory to save GPU space storage_device = inference_state["storage_device"] maskmem_features = current_out["maskmem_features"] if maskmem_features is not None: maskmem_features = maskmem_features.to(torch.bfloat16) maskmem_features = maskmem_features.to(storage_device, non_blocking=True) pred_masks_gpu = current_out["pred_masks"] # potentially fill holes in the predicted masks if self.fill_hole_area > 0: pred_masks_gpu = fill_holes_in_mask_scores( pred_masks_gpu, self.fill_hole_area ) pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) # object pointer is a small tensor, so we always keep it on GPU memory for fast access obj_ptr = current_out["obj_ptr"] object_score_logits = current_out["object_score_logits"] # make a compact version of this frame's output to reduce the state size compact_current_out = { "maskmem_features": maskmem_features, "maskmem_pos_enc": maskmem_pos_enc, "pred_masks": pred_masks, "obj_ptr": obj_ptr, "object_score_logits": object_score_logits, } return compact_current_out, pred_masks_gpu def _run_memory_encoder( self, inference_state, frame_idx, batch_size, high_res_masks, object_score_logits, is_mask_from_pts, ): """ Run the memory encoder on `high_res_masks`. This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their memory also need to be computed again with the memory encoder. """ # Retrieve correct image features _, _, current_vision_feats, _, feat_sizes = self._get_image_feature( inference_state, frame_idx, batch_size ) maskmem_features, maskmem_pos_enc = self._encode_new_memory( current_vision_feats=current_vision_feats, feat_sizes=feat_sizes, pred_masks_high_res=high_res_masks, object_score_logits=object_score_logits, is_mask_from_pts=is_mask_from_pts, ) # optionally offload the output to CPU memory to save GPU space storage_device = inference_state["storage_device"] maskmem_features = maskmem_features.to(torch.bfloat16) maskmem_features = maskmem_features.to(storage_device, non_blocking=True) # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it maskmem_pos_enc = self._get_maskmem_pos_enc( inference_state, {"maskmem_pos_enc": maskmem_pos_enc} ) return maskmem_features, maskmem_pos_enc def _get_maskmem_pos_enc(self, inference_state, current_out): """ `maskmem_pos_enc` is the same across frames and objects, so we cache it as a constant in the inference session to reduce session storage size. """ model_constants = inference_state["constants"] # "out_maskmem_pos_enc" should be either a list of tensors or None out_maskmem_pos_enc = current_out["maskmem_pos_enc"] if out_maskmem_pos_enc is not None: if "maskmem_pos_enc" not in model_constants: assert isinstance(out_maskmem_pos_enc, list) # only take the slice for one object, since it's same across objects maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] model_constants["maskmem_pos_enc"] = maskmem_pos_enc else: maskmem_pos_enc = model_constants["maskmem_pos_enc"] # expand the cached maskmem_pos_enc to the actual batch size batch_size = out_maskmem_pos_enc[0].size(0) expanded_maskmem_pos_enc = [ x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc ] else: expanded_maskmem_pos_enc = None return expanded_maskmem_pos_enc @torch.inference_mode() def remove_object(self, inference_state, obj_id, strict=False, need_output=True): """ Remove an object id from the tracking state. If strict is True, we check whether the object id actually exists and raise an error if it doesn't exist. """ old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None) updated_frames = [] # Check whether this object_id to remove actually exists and possibly raise an error. if old_obj_idx_to_rm is None: if not strict: return inference_state["obj_ids"], updated_frames raise RuntimeError( f"Cannot remove object id {obj_id} as it doesn't exist. " f"All existing object ids: {inference_state['obj_ids']}." ) # If this is the only remaining object id, we simply reset the state. if len(inference_state["obj_id_to_idx"]) == 1: self.reset_state(inference_state) return inference_state["obj_ids"], updated_frames # There are still remaining objects after removing this object id. In this case, # we need to delete the object storage from inference state tensors. # Step 0: clear the input on those frames where this object id has point or mask input # (note that this step is required as it might downgrade conditioning frames to # non-conditioning ones) obj_input_frames_inds = set() obj_input_frames_inds.update( inference_state["point_inputs_per_obj"][old_obj_idx_to_rm] ) obj_input_frames_inds.update( inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm] ) for frame_idx in obj_input_frames_inds: self.clear_all_prompts_in_frame( inference_state, frame_idx, obj_id, need_output=False ) # Step 1: Update the object id mapping (note that it must be done after Step 0, # since Step 0 still requires the old object id mappings in inference_state) old_obj_ids = inference_state["obj_ids"] old_obj_inds = list(range(len(old_obj_ids))) remain_old_obj_inds = old_obj_inds.copy() remain_old_obj_inds.remove(old_obj_idx_to_rm) new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds] new_obj_inds = list(range(len(new_obj_ids))) # build new mappings old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds)) inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds)) inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids)) inference_state["obj_ids"] = new_obj_ids # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys. # (note that "consolidated_frame_inds" doesn't need to be updated in this step as # it's already handled in Step 0) def _map_keys(container): new_kvs = [] for k in old_obj_inds: v = container.pop(k) if k in old_idx_to_new_idx: new_kvs.append((old_idx_to_new_idx[k], v)) container.update(new_kvs) _map_keys(inference_state["point_inputs_per_obj"]) _map_keys(inference_state["mask_inputs_per_obj"]) _map_keys(inference_state["output_dict_per_obj"]) _map_keys(inference_state["temp_output_dict_per_obj"]) # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices. def _slice_state(output_dict, storage_key): for frame_idx, out in output_dict[storage_key].items(): out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds] out["maskmem_pos_enc"] = [ x[remain_old_obj_inds] for x in out["maskmem_pos_enc"] ] # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out) out["pred_masks"] = out["pred_masks"][remain_old_obj_inds] out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds] out["object_score_logits"] = out["object_score_logits"][ remain_old_obj_inds ] # also update the per-object slices self._add_output_per_object( inference_state, frame_idx, out, storage_key ) _slice_state(inference_state["output_dict"], "cond_frame_outputs") _slice_state(inference_state["output_dict"], "non_cond_frame_outputs") # Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which # could show an updated mask for objects previously occluded by the object being removed if need_output: temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] for frame_idx in obj_input_frames_inds: is_cond = any( frame_idx in obj_temp_output_dict["cond_frame_outputs"] for obj_temp_output_dict in temp_output_dict_per_obj.values() ) consolidated_out = self._consolidate_temp_output_across_obj( inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=False, consolidate_at_video_res=True, ) _, video_res_masks = self._get_orig_video_res_output( inference_state, consolidated_out["pred_masks_video_res"] ) updated_frames.append((frame_idx, video_res_masks)) return inference_state["obj_ids"], updated_frames def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): """ Remove the non-conditioning memory around the input frame. When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain outdated object appearance information and could confuse the model. This method clears those non-conditioning memories surrounding the interacted frame to avoid giving the model both old and new information about the object. """ r = self.memory_temporal_stride_for_eval frame_idx_begin = frame_idx - r * self.num_maskmem frame_idx_end = frame_idx + r * self.num_maskmem output_dict = inference_state["output_dict"] non_cond_frame_outputs = output_dict["non_cond_frame_outputs"] for t in range(frame_idx_begin, frame_idx_end + 1): non_cond_frame_outputs.pop(t, None) for obj_output_dict in inference_state["output_dict_per_obj"].values(): obj_output_dict["non_cond_frame_outputs"].pop(t, None) ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/utils/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/utils/amg.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math from copy import deepcopy from itertools import product from typing import Any, Dict, Generator, ItemsView, List, Tuple import numpy as np import torch # Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py class MaskData: """ A structure for storing masks and their related data in batched format. Implements basic filtering and concatenation. """ def __init__(self, **kwargs) -> None: for v in kwargs.values(): assert isinstance( v, (list, np.ndarray, torch.Tensor) ), "MaskData only supports list, numpy arrays, and torch tensors." self._stats = dict(**kwargs) def __setitem__(self, key: str, item: Any) -> None: assert isinstance( item, (list, np.ndarray, torch.Tensor) ), "MaskData only supports list, numpy arrays, and torch tensors." self._stats[key] = item def __delitem__(self, key: str) -> None: del self._stats[key] def __getitem__(self, key: str) -> Any: return self._stats[key] def items(self) -> ItemsView[str, Any]: return self._stats.items() def filter(self, keep: torch.Tensor) -> None: for k, v in self._stats.items(): if v is None: self._stats[k] = None elif isinstance(v, torch.Tensor): self._stats[k] = v[torch.as_tensor(keep, device=v.device)] elif isinstance(v, np.ndarray): self._stats[k] = v[keep.detach().cpu().numpy()] elif isinstance(v, list) and keep.dtype == torch.bool: self._stats[k] = [a for i, a in enumerate(v) if keep[i]] elif isinstance(v, list): self._stats[k] = [v[i] for i in keep] else: raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") def cat(self, new_stats: "MaskData") -> None: for k, v in new_stats.items(): if k not in self._stats or self._stats[k] is None: self._stats[k] = deepcopy(v) elif isinstance(v, torch.Tensor): self._stats[k] = torch.cat([self._stats[k], v], dim=0) elif isinstance(v, np.ndarray): self._stats[k] = np.concatenate([self._stats[k], v], axis=0) elif isinstance(v, list): self._stats[k] = self._stats[k] + deepcopy(v) else: raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") def to_numpy(self) -> None: for k, v in self._stats.items(): if isinstance(v, torch.Tensor): self._stats[k] = v.float().detach().cpu().numpy() def is_box_near_crop_edge( boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 ) -> torch.Tensor: """Filter masks at the edge of a crop, but not at the edge of the original image.""" crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) boxes = uncrop_boxes_xyxy(boxes, crop_box).float() near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) return torch.any(near_crop_edge, dim=1) def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: box_xywh = deepcopy(box_xyxy) box_xywh[2] = box_xywh[2] - box_xywh[0] box_xywh[3] = box_xywh[3] - box_xywh[1] return box_xywh def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: assert len(args) > 0 and all( len(a) == len(args[0]) for a in args ), "Batched iteration must have inputs of all the same size." n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) for b in range(n_batches): yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: """ Encodes masks to an uncompressed RLE, in the format expected by pycoco tools. """ # Put in fortran order and flatten h,w b, h, w = tensor.shape tensor = tensor.permute(0, 2, 1).flatten(1) # Compute change indices diff = tensor[:, 1:] ^ tensor[:, :-1] change_indices = diff.nonzero() # Encode run length out = [] for i in range(b): cur_idxs = change_indices[change_indices[:, 0] == i, 1] cur_idxs = torch.cat( [ torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), cur_idxs + 1, torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), ] ) btw_idxs = cur_idxs[1:] - cur_idxs[:-1] counts = [] if tensor[i, 0] == 0 else [0] counts.extend(btw_idxs.detach().cpu().tolist()) out.append({"size": [h, w], "counts": counts}) return out def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: """Compute a binary mask from an uncompressed RLE.""" h, w = rle["size"] mask = np.empty(h * w, dtype=bool) idx = 0 parity = False for count in rle["counts"]: mask[idx : idx + count] = parity idx += count parity ^= True mask = mask.reshape(w, h) return mask.transpose() # Put in C order def area_from_rle(rle: Dict[str, Any]) -> int: return sum(rle["counts"][1::2]) def calculate_stability_score( masks: torch.Tensor, mask_threshold: float, threshold_offset: float ) -> torch.Tensor: """ Computes the stability score for a batch of masks. The stability score is the IoU between the binary masks obtained by thresholding the predicted mask logits at high and low values. """ # One mask is always contained inside the other. # Save memory by preventing unnecessary cast to torch.int64 intersections = ( (masks > (mask_threshold + threshold_offset)) .sum(-1, dtype=torch.int16) .sum(-1, dtype=torch.int32) ) unions = ( (masks > (mask_threshold - threshold_offset)) .sum(-1, dtype=torch.int16) .sum(-1, dtype=torch.int32) ) return intersections / unions def build_point_grid(n_per_side: int) -> np.ndarray: """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" offset = 1 / (2 * n_per_side) points_one_side = np.linspace(offset, 1 - offset, n_per_side) points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) points_y = np.tile(points_one_side[:, None], (1, n_per_side)) points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) return points def build_all_layer_point_grids( n_per_side: int, n_layers: int, scale_per_layer: int ) -> List[np.ndarray]: """Generates point grids for all crop layers.""" points_by_layer = [] for i in range(n_layers + 1): n_points = int(n_per_side / (scale_per_layer**i)) points_by_layer.append(build_point_grid(n_points)) return points_by_layer def generate_crop_boxes( im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float ) -> Tuple[List[List[int]], List[int]]: """ Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. """ crop_boxes, layer_idxs = [], [] im_h, im_w = im_size short_side = min(im_h, im_w) # Original image crop_boxes.append([0, 0, im_w, im_h]) layer_idxs.append(0) def crop_len(orig_len, n_crops, overlap): return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) for i_layer in range(n_layers): n_crops_per_side = 2 ** (i_layer + 1) overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) crop_w = crop_len(im_w, n_crops_per_side, overlap) crop_h = crop_len(im_h, n_crops_per_side, overlap) crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] # Crops in XYWH format for x0, y0 in product(crop_box_x0, crop_box_y0): box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] crop_boxes.append(box) layer_idxs.append(i_layer + 1) return crop_boxes, layer_idxs def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: x0, y0, _, _ = crop_box offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) # Check if boxes has a channel dimension if len(boxes.shape) == 3: offset = offset.unsqueeze(1) return boxes + offset def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: x0, y0, _, _ = crop_box offset = torch.tensor([[x0, y0]], device=points.device) # Check if points has a channel dimension if len(points.shape) == 3: offset = offset.unsqueeze(1) return points + offset def uncrop_masks( masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int ) -> torch.Tensor: x0, y0, x1, y1 = crop_box if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: return masks # Coordinate transform masks pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) pad = (x0, pad_x - x0, y0, pad_y - y0) return torch.nn.functional.pad(masks, pad, value=0) def remove_small_regions( mask: np.ndarray, area_thresh: float, mode: str ) -> Tuple[np.ndarray, bool]: """ Removes small disconnected regions and holes in a mask. Returns the mask and an indicator of if the mask has been modified. """ import cv2 # type: ignore assert mode in ["holes", "islands"] correct_holes = mode == "holes" working_mask = (correct_holes ^ mask).astype(np.uint8) n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) sizes = stats[:, -1][1:] # Row 0 is background label small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] if len(small_regions) == 0: return mask, False fill_labels = [0] + small_regions if not correct_holes: fill_labels = [i for i in range(n_labels) if i not in fill_labels] # If every region is below threshold, keep largest if len(fill_labels) == 0: fill_labels = [int(np.argmax(sizes)) + 1] mask = np.isin(regions, fill_labels) return mask, True def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: from pycocotools import mask as mask_utils # type: ignore h, w = uncompressed_rle["size"] rle = mask_utils.frPyObjects(uncompressed_rle, h, w) rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json return rle def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: """ Calculates boxes in XYXY format around masks. Return [0,0,0,0] for an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. """ # torch.max below raises an error on empty inputs, just skip in this case if torch.numel(masks) == 0: return torch.zeros(*masks.shape[:-2], 4, device=masks.device) # Normalize shape to CxHxW shape = masks.shape h, w = shape[-2:] if len(shape) > 2: masks = masks.flatten(0, -3) else: masks = masks.unsqueeze(0) # Get top and bottom edges in_height, _ = torch.max(masks, dim=-1) in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] bottom_edges, _ = torch.max(in_height_coords, dim=-1) in_height_coords = in_height_coords + h * (~in_height) top_edges, _ = torch.min(in_height_coords, dim=-1) # Get left and right edges in_width, _ = torch.max(masks, dim=-2) in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] right_edges, _ = torch.max(in_width_coords, dim=-1) in_width_coords = in_width_coords + w * (~in_width) left_edges, _ = torch.min(in_width_coords, dim=-1) # If the mask is empty the right edge will be to the left of the left edge. # Replace these boxes with [0, 0, 0, 0] empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) out = out * (~empty_filter).unsqueeze(-1) # Return to original shape if len(shape) > 2: out = out.reshape(*shape[:-2], 4) else: out = out[0] return out ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/utils/misc.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import warnings from threading import Thread import numpy as np import torch from PIL import Image from tqdm import tqdm def get_sdpa_settings(): if torch.cuda.is_available(): old_gpu = torch.cuda.get_device_properties(0).major < 7 # only use Flash Attention on Ampere (8.0) or newer GPUs use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 if not use_flash_attn: warnings.warn( "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", category=UserWarning, stacklevel=2, ) # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) if pytorch_version < (2, 2): warnings.warn( f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", category=UserWarning, stacklevel=2, ) math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn else: old_gpu = True use_flash_attn = False math_kernel_on = True return old_gpu, use_flash_attn, math_kernel_on def get_connected_components(mask): """ Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W). Inputs: - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is background. Outputs: - labels: A tensor of shape (N, 1, H, W) containing the connected component labels for foreground pixels and 0 for background pixels. - counts: A tensor of shape (N, 1, H, W) containing the area of the connected components for foreground pixels and 0 for background pixels. """ from sam2 import _C return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) def mask_to_box(masks: torch.Tensor): """ compute bounding box given an input mask Inputs: - masks: [B, 1, H, W] masks, dtype=torch.Tensor Returns: - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor """ B, _, h, w = masks.shape device = masks.device xs = torch.arange(w, device=device, dtype=torch.int32) ys = torch.arange(h, device=device, dtype=torch.int32) grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) return bbox_coords def _load_img_as_tensor(img_path, image_size): img_pil = Image.open(img_path) img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images img_np = img_np / 255.0 else: raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") img = torch.from_numpy(img_np).permute(2, 0, 1) video_width, video_height = img_pil.size # the original video size return img, video_height, video_width class AsyncVideoFrameLoader: """ A list of video frames to be load asynchronously without blocking session start. """ def __init__( self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std, compute_device, ): self.img_paths = img_paths self.image_size = image_size self.offload_video_to_cpu = offload_video_to_cpu self.img_mean = img_mean self.img_std = img_std # items in `self.images` will be loaded asynchronously self.images = [None] * len(img_paths) # catch and raise any exceptions in the async loading thread self.exception = None # video_height and video_width be filled when loading the first image self.video_height = None self.video_width = None self.compute_device = compute_device # load the first frame to fill video_height and video_width and also # to cache it (since it's most likely where the user will click) self.__getitem__(0) # load the rest of frames asynchronously without blocking the session start def _load_frames(): try: for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): self.__getitem__(n) except Exception as e: self.exception = e self.thread = Thread(target=_load_frames, daemon=True) self.thread.start() def __getitem__(self, index): if self.exception is not None: raise RuntimeError("Failure in frame loading thread") from self.exception img = self.images[index] if img is not None: return img img, video_height, video_width = _load_img_as_tensor( self.img_paths[index], self.image_size ) self.video_height = video_height self.video_width = video_width # normalize by mean and std img -= self.img_mean img /= self.img_std if not self.offload_video_to_cpu: img = img.to(self.compute_device, non_blocking=True) self.images[index] = img return img def __len__(self): return len(self.images) def load_video_frames( video_path, image_size, offload_video_to_cpu, img_mean=(0.485, 0.456, 0.406), img_std=(0.229, 0.224, 0.225), async_loading_frames=False, compute_device=torch.device("cuda"), ): """ Load the video frames from video_path. The frames are resized to image_size as in the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo. """ is_bytes = isinstance(video_path, bytes) is_str = isinstance(video_path, str) is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"] if is_bytes or is_mp4_path: return load_video_frames_from_video_file( video_path=video_path, image_size=image_size, offload_video_to_cpu=offload_video_to_cpu, img_mean=img_mean, img_std=img_std, compute_device=compute_device, ) elif is_str and os.path.isdir(video_path): return load_video_frames_from_jpg_images( video_path=video_path, image_size=image_size, offload_video_to_cpu=offload_video_to_cpu, img_mean=img_mean, img_std=img_std, async_loading_frames=async_loading_frames, compute_device=compute_device, ) else: raise NotImplementedError( "Only MP4 video and JPEG folder are supported at this moment" ) def load_video_frames_from_jpg_images( video_path, image_size, offload_video_to_cpu, img_mean=(0.485, 0.456, 0.406), img_std=(0.229, 0.224, 0.225), async_loading_frames=False, compute_device=torch.device("cuda"), ): """ Load the video frames from a directory of JPEG files (".jpg" format). The frames are resized to image_size x image_size and are loaded to GPU if `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. You can load a frame asynchronously by setting `async_loading_frames` to `True`. """ if isinstance(video_path, str) and os.path.isdir(video_path): jpg_folder = video_path else: raise NotImplementedError( "Only JPEG frames are supported at this moment. For video files, you may use " "ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n" "```\n" "ffmpeg -i .mp4 -q:v 2 -start_number 0 /'%05d.jpg'\n" "```\n" "where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks " "ffmpeg to start the JPEG file from 00000.jpg." ) frame_names = [ p for p in os.listdir(jpg_folder) if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] ] frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) num_frames = len(frame_names) if num_frames == 0: raise RuntimeError(f"no images found in {jpg_folder}") img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] if async_loading_frames: lazy_images = AsyncVideoFrameLoader( img_paths, image_size, offload_video_to_cpu, img_mean, img_std, compute_device, ) return lazy_images, lazy_images.video_height, lazy_images.video_width images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) if not offload_video_to_cpu: images = images.to(compute_device) img_mean = img_mean.to(compute_device) img_std = img_std.to(compute_device) # normalize by mean and std images -= img_mean images /= img_std return images, video_height, video_width def load_video_frames_from_video_file( video_path, image_size, offload_video_to_cpu, img_mean=(0.485, 0.456, 0.406), img_std=(0.229, 0.224, 0.225), compute_device=torch.device("cuda"), ): """Load the video frames from a video file.""" import decord img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] # Get the original video height and width decord.bridge.set_bridge("torch") video_height, video_width, _ = decord.VideoReader(video_path).next().shape # Iterate over all frames in the video images = [] for frame in decord.VideoReader(video_path, width=image_size, height=image_size): images.append(frame.permute(2, 0, 1)) images = torch.stack(images, dim=0).float() / 255.0 if not offload_video_to_cpu: images = images.to(compute_device) img_mean = img_mean.to(compute_device) img_std = img_std.to(compute_device) # normalize by mean and std images -= img_mean images /= img_std return images, video_height, video_width def fill_holes_in_mask_scores(mask, max_area): """ A post processor to fill small holes in mask scores with area under `max_area`. """ # Holes are those connected components in background with area <= self.max_area # (background regions are those with mask scores <= 0) assert max_area > 0, "max_area must be positive" input_mask = mask try: labels, areas = get_connected_components(mask <= 0) is_hole = (labels > 0) & (areas <= max_area) # We fill holes with a small positive mask score (0.1) to change them to foreground. mask = torch.where(is_hole, 0.1, mask) except Exception as e: # Skip the post-processing step on removing small holes if the CUDA kernel fails warnings.warn( f"{e}\n\nSkipping the post-processing step due to the error above. You can " "still use SAM 2 and it's OK to ignore the error above, although some post-processing " "functionality may be limited (which doesn't affect the results in most cases; see " "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", category=UserWarning, stacklevel=2, ) mask = input_mask return mask def concat_points(old_point_inputs, new_points, new_labels): """Add new points and labels to previous point inputs (add at the end).""" if old_point_inputs is None: points, labels = new_points, new_labels else: points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) return {"point_coords": points, "point_labels": labels} ================================================ FILE: camera_pose_annotation/dynamic_mask/sam2/utils/transforms.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import warnings import torch import torch.nn as nn import torch.nn.functional as F from torchvision.transforms import Normalize, Resize, ToTensor class SAM2Transforms(nn.Module): def __init__( self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 ): """ Transforms for SAM2. """ super().__init__() self.resolution = resolution self.mask_threshold = mask_threshold self.max_hole_area = max_hole_area self.max_sprinkle_area = max_sprinkle_area self.mean = [0.485, 0.456, 0.406] self.std = [0.229, 0.224, 0.225] self.to_tensor = ToTensor() self.transforms = torch.jit.script( nn.Sequential( Resize((self.resolution, self.resolution)), Normalize(self.mean, self.std), ) ) def __call__(self, x): x = self.to_tensor(x) return self.transforms(x) def forward_batch(self, img_list): img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] img_batch = torch.stack(img_batch, dim=0) return img_batch def transform_coords( self, coords: torch.Tensor, normalize=False, orig_hw=None ) -> torch.Tensor: """ Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. Returns Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. """ if normalize: assert orig_hw is not None h, w = orig_hw coords = coords.clone() coords[..., 0] = coords[..., 0] / w coords[..., 1] = coords[..., 1] / h coords = coords * self.resolution # unnormalize coords return coords def transform_boxes( self, boxes: torch.Tensor, normalize=False, orig_hw=None ) -> torch.Tensor: """ Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. """ boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) return boxes def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: """ Perform PostProcessing on output masks. """ from sam2.utils.misc import get_connected_components masks = masks.float() input_masks = masks mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image try: if self.max_hole_area > 0: # Holes are those connected components in background with area <= self.fill_hole_area # (background regions are those with mask scores <= self.mask_threshold) labels, areas = get_connected_components( mask_flat <= self.mask_threshold ) is_hole = (labels > 0) & (areas <= self.max_hole_area) is_hole = is_hole.reshape_as(masks) # We fill holes with a small positive mask score (10.0) to change them to foreground. masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) if self.max_sprinkle_area > 0: labels, areas = get_connected_components( mask_flat > self.mask_threshold ) is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) is_hole = is_hole.reshape_as(masks) # We fill holes with negative mask score (-10.0) to change them to background. masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) except Exception as e: # Skip the post-processing step if the CUDA kernel fails warnings.warn( f"{e}\n\nSkipping the post-processing step due to the error above. You can " "still use SAM 2 and it's OK to ignore the error above, although some post-processing " "functionality may be limited (which doesn't affect the results in most cases; see " "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", category=UserWarning, stacklevel=2, ) masks = input_masks masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) return masks ================================================ FILE: caption/LLM/__init__.py ================================================ ================================================ FILE: caption/LLM/inference.py ================================================ import os import time import queue from argparse import ArgumentParser from multiprocessing import Manager from concurrent.futures import ThreadPoolExecutor from tqdm import tqdm import concurrent import pandas as pd import numpy as np import sys sys.path.append(os.path.abspath(os.path.join(__file__, "../.."))) from utils.api_call import api_call def get_pose(pose_dir): """ Retrieve and process pose data from extrinsics.npy file """ # Base directory for pose data pose_path = os.path.join(pose_dir, 'extrinsics.npy') assert os.path.isfile(pose_path), f"Pose file not found: {pose_path}" # Load and process the pose file poses = np.load(pose_path) # Data processing steps poses = poses[::5, :, 3] # Take first row for every 5 rows max_value = np.max(poses) min_value = np.min(poses) min_abs_value = np.min(np.abs(poses)) # Normalize and convert to integers (minimize integer digits) poses = np.round(poses / (max_value - min_value) / min_abs_value).astype(int) # Keep only first 3 columns and transpose poses = poses[:, :3].T # Extract individual axes poses1, poses2, poses3 = poses[0], poses[1], poses[2] # Convert each axis to string poses1_str = ' '.join(map(str, poses1)) poses2_str = ' '.join(map(str, poses2)) poses3_str = ' '.join(map(str, poses3)) # Combine into formatted string poses_str = f'x:{poses1_str}\ny:{poses2_str}\nz:{poses3_str}' return poses_str def get_prompt(pose_dir, prompt_dir, vqa_caption, dist_level): """ Construct a prompt by combining content from prompt1.txt, prompt2.txt, VQA caption, and pose data """ # Read prompt components p1_file = os.path.join(prompt_dir, 'prompt1.txt') p2_file = os.path.join(prompt_dir, 'prompt2.txt') with open(p1_file, 'r', encoding='utf-8') as f: p1_content = f.read().strip() with open(p2_file, 'r', encoding='utf-8') as f: p2_content = f.read().strip() # Get pose data poses = get_pose(pose_dir) # Assemble final prompt prompt = (f"{p1_content}\nGiven Information:\n{vqa_caption}\n3.Camera Position Data:\n{poses}\n" f"\n4.Motion intensity:\n{dist_level}\n{p2_content}") return prompt def process_single_row(args, row): """ Process a single row of data by calling API and saving the result """ # Check if VQA file exists vqa_path = os.path.join(args.vqa_path, f"{row['id']}.txt") assert os.path.isfile(vqa_path), f"VQA file not found: {vqa_path}" # Read VQA caption with open(vqa_path, "r") as f: vqa_caption = f.read() # Skip processing if file already exists save_file = os.path.join(args.llm_path, f"{row['id']}.txt") if os.path.exists(save_file) and os.path.getsize(save_file) > 0: return # Call API with retry mechanism pose_dir = os.path.join(args.pose_load_dir, row["id"], "reconstructions") prompt_text = get_prompt(pose_dir, args.prompt_dir, vqa_caption, row["distLevel"]) llm_caption = api_call(prompt_text, args.model, args.api_key, args.base_domain) assert llm_caption is not None, f"API call failed for id {row['id']}" # Save the result with model information with open(save_file, 'w', encoding='utf-8') as f: f.write(llm_caption + f"\n\n6. Qwen model: \n{args.model}") return def worker(args, task_queue, pbar): """ Worker function to process tasks from the queue Args: task_queue: Queue containing tasks to process pbar: Progress bar object for tracking progress """ while True: try: index, row = task_queue.get(timeout=1) except queue.Empty: break # Add delay to prevent overwhelming API time.sleep(args.wait_time) # Process the single row process_single_row(args, row) # Update progress task_queue.task_done() pbar.update(1) def parse_args(): """ Parse command line arguments Returns: Parsed arguments object """ parser = ArgumentParser(description='VQA Processing Program') parser.add_argument('--csv_path', type=str, required=True, help='Path to CSV file') parser.add_argument('--pose_load_dir', type=str, required=True, help='Directory to load pose data') parser.add_argument('--output_dir', type=str, required=True, help='Directory to save results') parser.add_argument('--prompt_dir', type=str, default=os.path.join(os.path.dirname( __file__), "vqa_prompt.txt"), help='Path to prompt file') parser.add_argument('--model', type=str, default="qwen3-30b-a3b", help='Model name') parser.add_argument('--api_key', type=str, default="sk-****", help='API key') parser.add_argument('--num_workers', type=int, default=1, help='Number of worker threads') parser.add_argument('--wait_time', type=float, default=0.5, help='Time between requests in seconds') parser.add_argument('--base_domain', type=str, default="https://cn2us02.opapi.win/", help='API base domain') return parser.parse_args() def main(): """ Main processing function that handles multiple rows using parallel workers Args: group_id (str): Identifier for the group prompt_dir (str): Directory containing prompt files model_file (str): Path to file containing model names api_key_file (str): Path to file containing API keys num_workers (int): Number of worker threads wait_time (float): Time to wait between requests base_domain (str): Base domain for API calls record_time (bool): Whether to record processing time Returns: None """ args = parse_args() # Validate temporary directory exists # Create LLM directory if it doesn't exist args.llm_path = os.path.join(args.output_dir, "LLM") if not os.path.isdir(args.llm_path): os.makedirs(args.llm_path, exist_ok=True) # Validate VQA directory exists args.vqa_path = os.path.join(args.output_dir, "VQA") assert os.path.isdir( args.vqa_path), f"VQA directory not found: {args.vqa_path}" # Read CSV file containing scene information df = pd.read_csv(args.csv_path) # Initialize task queue with all rows manager = Manager() task_queue = manager.Queue() for index, row in df.iterrows(): task_queue.put((index, row)) # Start processing with progress bar with tqdm(total=len(df), desc="LLM Finished") as pbar: with ThreadPoolExecutor(max_workers=args.num_workers) as executor: # Start worker threads futures = [executor.submit(worker, args, task_queue, pbar) for _ in range(args.num_workers)] # Wait for all workers to complete for future in concurrent.futures.as_completed(futures): future.result() if __name__ == "__main__": main() ================================================ FILE: caption/LLM/prompt1.txt ================================================ You are given a video sequence with camera trajectory data representing the camera's movement through a scene. The data consists of: Camera Motion Caption: A basic description of how the camera moves. Scene Description: A detailed visual summary of the environment. Camera position data: Three lines, representing the sequence of the camera's x-coordinate, y-coordinate, and z-coordinate. These values are derived from normalized 3D pose data using the following formula: poses = np.round(poses / (max_value - min_value) / min_abs_value).astype(int); Each value is then multiplied by 1,000,000 and rounded to the nearest integer. Motion intensity: An integer that indicates the level of camera movement, where a value of 0 means the camera is static, 1 indicates slight movement, and 2 or higher represents normal or noticeable motion. In tasks such as Optimized Camera Motion Caption and Main Motion Trend Summary, this intensity value should be used to qualify the degree of motion described — for example, using "slight forward translate" when the intensity is 1. Your Tasks: 1. Optimized Camera Motion Caption Generate a refined motion caption **from the perspective of the camera itself**, using only the **camera position data** to determine movement direction and dynamics. Use the following rules to interpret motion: x increasing: camera moves right x decreasing: camera moves left y increasing: camera moves down y decreasing: camera moves up z increasing: camera moves forward z decreasing: camera moves backward Analyze the full trajectory over time to capture acceleration, deceleration, or steady motion. Integrate scene context but prioritize accuracy based on numerical data. Avoid vague phrases like "zoom out" unless it's clearly due to focal length change — here, use translation terms instead. If motion intensity is 0, describe the fixed viewpoint and what the camera observes from that vantage point, incorporating compositional or environmental elements from the original caption. If intensity is 1, reflect subtle movement in the description (e.g., "slight right translate") without exaggerating the motion. For both cases, preserve visual context while aligning with the actual movement level. Avoid mentioning data analysis or detection explicitly — let the description itself reflect the motion state. Target Length: 50–100 words 2. Scene Abstract Caption Provide a single-sentence summary that captures: - Key architectural elements - Overall atmosphere/style - Notable design features Target Length: About 50 words 3. Main Motion Trend Summary Summarize the general movement using only 1–3 short motion phrases , depending on how many are clearly present. Focus strictly on major, sustained movements — ignore minor fluctuations or brief directional changes. If only one or two movements dominate, list only those. Use directional translation terms (e.g., forward translate, left translate, upward drift) 4. Scene Keywords Extract up to 4 keywords summarizing the key aspects of the scene. Include one term that broadly describes the scene type. Use nouns/noun phrases related to weather, place, time, lighting, scene type. Avoid adjectives/gerunds except for weather. Example: sunset, foggy, marketplace, city street, village 5. Immersive Shot Summary Blend Optimized Camera Motion Caption and Scene Description evenly — do not focus more on the camera or the scene alone. Describe the visuals as if someone is watching a moving image unfold. Use descriptive, cinematic language that evokes imagery and emotion. Keep it concise but expressive — suitable for use in scripts, storyboards, or AI video/image generation. Target Length: 50–100 words ================================================ FILE: caption/LLM/prompt2.txt ================================================ Output Format: 1. Camera Motion Caption: [From the perspective of the camera holder, with the camera as the subject. Combine camera pose information to describe] 2. Scene Abstract Caption: [A concise one sentence summary of the scene] 3. Main Motion Trend Summary: [keywords separated by commas, e.g., forward translate, downward tilt] 4. Scene Keywords: [word1, word2, word3, ...] (max 5 words) 5. Immersive Shot Summary: ================================================ FILE: caption/README.md ================================================ # Semantic Information Annotation This script automates the process of generating structured text descriptions (captions) for videos through a multi-step pipeline involving Visual Question Answering (VQA), Large Language Models (LLM), result combination, and tagging. ## Captioning Workflow The video captioning process follows these sequential steps: 1. **VQA Captioning**: Uses a Visual Question Answering model to analyze visual content and generate initial captions based on predefined prompts. 2. **LLM Captioning**: Employs a Large Language Model to process pose data and generate additional descriptive captions. 3. **Result Combination**: Merges the outputs from the VQA and LLM steps into a unified structure. 4. **Tag Addition**: Enhances the combined results with relevant tags using a language model.
caption pipeline
## Script Explanation ### Configuration Parameters - `CSV`: Path to the result CSV file generated in the annotation step - `SRC_DIR`: Path to the annotation output directory containing video frames and pose data - `OUTPUT_DIR`: Path where all output files will be saved - `num_workers`: Number of parallel workers to use for processing - `wait_time`: Waiting time between API requests (in seconds) ### Step 1: VQA Captioning Generates captions by analyzing visual content using a VQA model. Parameters: - `--csv_path`: Path to the input CSV file - `--fig_load_dir`: Directory containing video frames/images - `--output_dir`: Directory to save VQA results - `--prompt_file`: Path to VQA prompt template file - `--model`: VQA model to use (default: gemini-2.0-flash) - `--api_key`: API key for accessing the VQA model service - `--base_domain`: API endpoint domain for the VQA model - `--num_workers`: Number of parallel workers - `--wait_time`: Waiting time between API requests ### Step 2: LLM Captioning Generates additional captions by processing pose data using a Large Language Model. Parameters: - `--csv_path`: Path to the input CSV file - `--pose_load_dir`: Directory containing pose data - `--output_dir`: Directory to save LLM results - `--prompt_dir`: Directory containing LLM prompt templates - `--model`: LLM model to use (default: qwen3-30b-a3b) - `--api_key`: API key for accessing the LLM service - `--base_domain`: API endpoint domain for the LLM - `--num_workers`: Number of parallel workers - `--wait_time`: Waiting time between API requests ### Step 3: Combine Results Merges the outputs from VQA and LLM steps into a unified format. Parameters: - `--csv_path`: Path to the input CSV file - `--load_dir`: Directory containing VQA and LLM results - `--output_dir`: Directory to save combined results - `--num_workers`: Number of parallel workers ### Step 4: Add Tags Enhances the combined results with relevant tags using a language model. Parameters: - `--csv_path`: Path to the input CSV file - `--json_load_dir`: Directory containing combined results - `--prompt_file`: Path to tagging prompt template file - `--model`: Model to use for tagging (default: qwen3-30b-a3b) - `--api_key`: API key for accessing the tagging model service - `--base_domain`: API endpoint domain for the tagging model - `--num_workers`: Number of parallel workers - `--wait_time`: Waiting time between API requests ## Usage 1. Replace all placeholder values (enclosed in square brackets) with your actual paths and API keys 2. Make the script executable: `chmod +x caption_pipeline.sh` 3. Run the script: `./caption_pipeline.sh` The script will execute each step sequentially, displaying start/end times and duration for each step, and save all outputs to the specified `OUTPUT_DIR`. ## results example several samples of video captions generated by the model after each step. ================================================ FILE: caption/VQA/__init__.py ================================================ ================================================ FILE: caption/VQA/inference.py ================================================ import os import concurrent.futures from multiprocessing import Manager import queue import pandas as pd from tqdm import tqdm import argparse import time import base64 import cv2 from glob import glob import sys sys.path.append(os.path.abspath(os.path.join(__file__, "../.."))) from utils.api_call import api_call def encode_image(image_path): """ Resizes an image to 640x360 and encodes it as a Base64 string with data URI prefix. """ # Read image using OpenCV image = cv2.imread(image_path) # Resize image to standard dimensions (640x360) resized_image = cv2.resize(image, (640, 360)) # Encode image as JPEG and convert to Base64 _, buffer = cv2.imencode('.jpeg', resized_image) base64_data = base64.b64encode(buffer).decode("utf-8") # Return with data URI format for API compatibility return f"data:image/jpeg;base64,{base64_data}" def get_prompt(fig_dir, prompt_text): """ Load key frames from a video, constructs a multimodal request, and calls the API. """ # Get frames from directory frames = sorted(glob(f"{fig_dir}/*.jpg"))[::5] # Construct multimodal input content messages_content = [] # Add encoded images to request content for frame in frames: try: encoded_frame = encode_image(frame) messages_content.append({ "type": "image_url", "image_url": {"url": encoded_frame} }) except Exception as e: print(f"Image processing error: {str(e)}") return None # Add text prompt to request content messages_content.append({"type": "text", "text": prompt_text}) return messages_content def process_single_row(args, row): """ Process a single row: call the VQA API and save the result for one scene. Handles retries and error logging. """ save_path = os.path.join(args.output_dir, "VQA") if not os.path.isdir(save_path): os.makedirs(save_path, exist_ok=True) save_file = os.path.join(save_path, f"{row['id']}.txt") if os.path.exists(save_file) and os.path.getsize(save_file) > 0: # Skip if already exists return # Call API fig_dir = os.path.join(args.fig_load_dir, row['id'], "img") prompt_text = get_prompt(fig_dir, args.prompt_text) vqa_caption = api_call(prompt_text, args.model, args.api_key, args.base_domain) assert vqa_caption is not None, f"API call failed for id {row['id']}" # Save result with open(save_file, 'w', encoding='utf-8') as f: f.write(vqa_caption) def worker(args, task_queue, pbar): while True: try: index, row = task_queue.get(timeout=1) except queue.Empty: break time.sleep(args.wait_time) process_single_row(args, row) task_queue.task_done() pbar.update(1) def parse_args(): """ Parse command line arguments for VQA batch processing. """ parser = argparse.ArgumentParser(description='VQA batch processing script') parser.add_argument('--csv_path', type=str, required=True, help='CSV file path') parser.add_argument('--fig_load_dir', type=str, required=True, help='Directory to load figures') parser.add_argument('--output_dir', type=str, required=True, help='Directory to save results') parser.add_argument('--prompt_file', type=str, default="vqa_prompt.txt", help='Prompt file path') parser.add_argument('--model', type=str, default="gemini-2.0-flash", help='Model name') parser.add_argument('--api_key', type=str, default="sk-****", help='API key') parser.add_argument('--num_workers', type=int, default=4, help='Number of worker threads') parser.add_argument('--wait_time', type=float, default=0.8, help='Request interval for each thread (seconds)') parser.add_argument('--base_domain', type=str, default="https://cn2us02.opapi.win/", help='API base domain') return parser.parse_args() def main(): """ Batch process all scenes in a group: call VQA API for each row in the CSV. Uses a thread pool for concurrency and supports timing. """ args = parse_args() df = pd.read_csv(args.csv_path) # Read prompt text with open(args.prompt_file, "r", encoding="utf-8") as f: args.prompt_text = f.read().strip() manager = Manager() task_queue = manager.Queue() for index, row in df.iterrows(): task_queue.put((index, row)) with tqdm(total=len(df), desc=f"VQA Finished") as pbar: with concurrent.futures.ThreadPoolExecutor(max_workers=args.num_workers) as executor: futures = [] for worker_id in range(args.num_workers): futures.append(executor.submit(worker, args, task_queue, pbar)) for future in concurrent.futures.as_completed(futures): future.result() if __name__ == "__main__": main() ================================================ FILE: caption/VQA/prompt.txt ================================================ You are given a sequence of video frames in chronological order. Analyze them carefully and generate two distinct captions based on the following instructions: 1. Camera Motion Caption: From the perspective of the camera operator, describe the entire motion trajectory of the camera throughout the clip using precise cinematography terminology (e.g., static, pan, tilt, dolly, handheld, crane, aerial, zoom, etc.). Do NOT assume the camera starts in a "static" position just because it appears stationary in the first frame.Only describe the camera as stationary if there is no visual change across multiple consecutive frames. Instead, focus on changes between frames to infer movement. Describe motion state transitions, not frame-by-frame repetition (e.g., do not say “the camera moves forward again” if it’s continuous). For example: - Starting with a dolly forward along a straight path, - Then transitioning into a slow right-hand pan, - Or shifting from handheld walking movement to a stationary pivot tilt. Include brief environmental context where relevant to clarify direction or intent (e.g., "The camera dollies forward through a narrow alleyway, then smoothly turns left at the intersection"). Keep the final caption concise, between 50–100 words, focused only on motion and its evolution over time. 2. Scene Description: Provide a rich, holistic description of the visual content. Include: - Main subjects and dynamic objects: who or what is present, and what they are doing (e.g., a cyclist rides past from left to right, a group of people gather near a bench), - Background/environment: setting (urban street, forest trail, indoor space), notable landmarks or structures, - Lighting and atmosphere: time of day, weather conditions, mood (e.g., golden-hour lighting, overcast sky casting soft shadows, neon-lit nighttime scene), - Overall tone or emotion conveyed by the scene. Avoid focusing on individual frames—describe the general impression and ongoing activity across the entire clip. Aim for around 100 words, balancing detail and conciseness. Output Format: Do not include any explanations or extra text before or after your response. Begin directly with: 1. Camera Motion Caption: ... followed by 2. Scene Description: ... ================================================ FILE: caption/__init__.py ================================================ ================================================ FILE: caption/tagging/__init__.py ================================================ ================================================ FILE: caption/tagging/inference.py ================================================ import os import time import json import queue import argparse import pandas as pd from tqdm import tqdm from multiprocessing import Manager import concurrent.futures import sys sys.path.append(os.path.abspath(os.path.join(__file__, "../.."))) from utils.api_call import api_call def parse_category_tags(tag_caption): """ Parse API response to structured category data using camelCase naming convention """ lines = [line.strip() for line in tag_caption.strip().split('\n') if line.strip()] # Initialize category data with default values category_data = { "sceneType": { "first": "Unknown", "second": "Unknown" }, "lighting": "Unknown", "timeOfDay": "Unknown", "weather": "Unknown", "crowdDensity": "Unknown" } # Parse each line to extract category information for line in lines: line_lower = line.lower() if line_lower.startswith("primary scene type:"): category_data["sceneType"]["first"] = line.split(":", 1)[1].strip() elif line_lower.startswith("secondary scene type:"): category_data["sceneType"]["second"] = line.split(":", 1)[ 1].strip() elif line_lower.startswith("lighting:"): category_data["lighting"] = line.split(":", 1)[1].strip() elif line_lower.startswith("time of day:"): category_data["timeOfDay"] = line.split(":", 1)[1].strip() elif line_lower.startswith("weather:"): category_data["weather"] = line.split(":", 1)[1].strip() elif line_lower.startswith("crowd density:"): category_data["crowdDensity"] = line.split(":", 1)[1].strip() return category_data def process_single_row(args, json_file): """ Process a single JSON file to add category tags via API call """ # Check if CategoryTag field already exists with open(json_file, 'r') as f: data = json.load(f) # Skip if CategoryTag already exists if "CategoryTag" in data: return description = data['SceneDesc'] prompt_text = args.prompt_text + description # Call API to get category tags with retry mechanism tag_caption = api_call(prompt_text, args.model, args.api_key, args.base_domain) assert tag_caption is not None, f"API call failed for file {json_file}" # Parse and add category tags to the JSON file category_tag = parse_category_tags(tag_caption) # Merge new data with existing data data["CategoryTag"] = category_tag # Overwrite file with updated data with open(json_file, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) def worker(args, task_queue, pbar): while True: try: index, json_file = task_queue.get(timeout=1) except queue.Empty: break time.sleep(args.wait_time) process_single_row(args, json_file) task_queue.task_done() pbar.update(1) def parse_args(): """Parse command line arguments""" parser = argparse.ArgumentParser( description='Category Tag Processing Program') parser.add_argument('--csv_path', type=str, required=True, help='Path to the CSV file') parser.add_argument('--json_load_dir', type=str, required=True, help='Directory containing JSON files') parser.add_argument('--prompt_file', type=str, default="prompt.txt", help='Path to prompt file') parser.add_argument('--model', type=str, default="qwen3-30b-a3b", help='Model name') parser.add_argument('--api_key', type=str, default="sk-****", help='API key') parser.add_argument('--num_workers', type=int, default=4, help='Number of worker threads') parser.add_argument('--wait_time', type=float, default=0.8, help='Time interval between requests per thread (seconds)') parser.add_argument('--base_domain', type=str, default="https://cn2us02.opapi.win/", help='API base domain') return parser.parse_args() def main(): """ Process a group of JSON files using multiple threads to add category tags """ args = parse_args() df = pd.read_csv(args.csv_path) with open(args.prompt_file, 'r', encoding='utf-8') as f: args.prompt_text = f.read().strip() # Initialize task queue and add all files to process manager = Manager() task_queue = manager.Queue() for index, row in df.iterrows(): clip_id = row['id'] json_file = os.path.join(args.json_load_dir, f"{clip_id}.json") task_queue.put((index, json_file)) # Start processing with progress bar with tqdm(total=task_queue.qsize(), desc="Tags Completed") as pbar: with concurrent.futures.ThreadPoolExecutor(max_workers=args.num_workers) as executor: # Start worker threads futures = [executor.submit(worker, args, task_queue, pbar) for _ in range(args.num_workers)] # Wait for all workers to complete for future in concurrent.futures.as_completed(futures): future.result() if __name__ == "__main__": main() ================================================ FILE: caption/tagging/prompt.txt ================================================ You are an AI assistant specialized in analyzing scene descriptions and extracting structured metadata. Your task is to read the provided scene description and infer six attributes with hierarchical classification where applicable. Output Rules: 1. Scene Type (Choose one from): - [Urban, Natural Landscape, Interior, Rural, Waterfront, Unknown] - Add a custom secondary tag (unrestricted) to further define the scene (e.g., Urban → "Street Scene", Interior → "Library") 2. Lighting: [Bright / Dim/Dark / Unknown] 3. Time of Day: [Dawn/Morning / Daytime / Dusk/Evening / Night / Unknown] 4. Weather: [Sunny / Rainy / Foggy / Cloudy / Snowy / Unknown] 5. Crowd Density: [Deserted / Sparse / Moderate / Crowded / Unknown] Deduction Guidelines: - Prioritize explicit descriptors over implied cues (e.g., "snow scattered" → Snowy; "wet surfaces + cloudy" → Cloudy) - If no evidence exists for any attribute, output 'Unknown' for that field. Maintain strict objectivity - never assume information beyond the text. Output Format (strictly follow line breaks): Primary Scene Type: [X] Secondary Scene Type: [CustomTag] Lighting: [X] Time of Day: [X] Weather: [X] Crowd Density: [X] The following is a scene description: ================================================ FILE: caption/utils/__init__.py ================================================ ================================================ FILE: caption/utils/api_call.py ================================================ import requests def api_call(prompt_text, model, api_key, base_domain): """ Make an API call to a language model with a constructed prompt, handling different API formats for different model providers. """ # Determine if using Qwen model API (Aliyun) is_qwen = "dashscope.aliyuncs.com" in base_domain # Configure API endpoint and payload based on model type if is_qwen: api_url = base_domain + "v1/chat/completions" # Payload format specific to Qwen model payload = { "model": model, "messages": [ # {"role": "system", "content": "You are a helpful assistant."}, # Optional system message {"role": "user", "content": prompt_text} ], "enable_thinking": False, "temperature": 0.1 # Low temperature for more deterministic output } else: # Payload format for other models api_url = base_domain + "v1beta/openai/" payload = { "model": model, "messages": [ {"role": "user", "content": prompt_text} ], "temperature": 0.1, "user": "User" } # Configure request headers if is_qwen: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", "Accept": "application/json" } else: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", "User-Agent": f"({base_domain})", "Accept": "application/json" } try: # Execute API request with timeout response = requests.post( api_url, headers=headers, json=payload, timeout=120 # 2-minute timeout ) response.raise_for_status() # Raise exception for HTTP errors response_data = response.json() # Optional: Uncomment to log token usage # if 'usage' in response_data: # usage = response_data.get('usage', {}) # prompt_tokens = usage.get('prompt_tokens', 0) # completion_tokens = usage.get('completion_tokens', 0) # total_tokens = usage.get('total_tokens', 0) # # print(f"Input tokens: {prompt_tokens}") # print(f"Output tokens: {completion_tokens}") # print(f"Total tokens: {total_tokens}") # else: # print("API response does not contain token usage information") # Extract and return response content based on API format if is_qwen: return response_data.get("choices", [{}])[0].get("message", {}).get("content", "") else: return response_data.get("choices", [{}])[0].get("message", {}).get("content", "") except Exception as e: print(f"API request error: {str(e)}") return None ================================================ FILE: caption/utils/combine.py ================================================ import os import json import re import queue import argparse import pandas as pd from multiprocessing import Manager from concurrent.futures import ThreadPoolExecutor, as_completed from tqdm import tqdm def parse_text_to_json(text): """ Parses text in a specific format into a JSON structure. """ # Define mapping between text labels and JSON keys labels = { "Camera Motion Caption": "OptCamMotion", "Scene Abstract Caption": "SceneSummary", "Main Motion Trend Summary": "MotionTrends", "Scene Keywords": "SceneTags", "Immersive Shot Summary": "ShotImmersion", "Qwen model": "LLM" } # Initialize result dictionary with empty values result = {key: "" for key in labels.values()} current_label = None current_content = [] # Process text line by line lines = text.split('\n') i = 0 while i < len(lines): line = lines[i].strip() # Check if current line contains any label for label, json_key in labels.items(): if label in line: # Find position of first letter after the label start_pos = line.find(label) + len(label) # Skip non-alphabet characters while start_pos < len(line) and not line[start_pos].isalpha(): start_pos += 1 content = line[start_pos:].strip() # Process content if it exists after the label if content: if json_key in ["MotionTrends", "SceneTags"]: # Split by commas (both Chinese and English), preserve spaces in phrases items = re.split(r'[,,]\s*', content) result[json_key] = [item.strip() for item in items if item.strip()] else: result[json_key] = content current_label = None current_content = [] else: # No content after label, continue reading subsequent lines current_label = json_key current_content = [] break else: # If collecting content for a label if current_label: # Check if line is not empty if line: # Find first letter in line start_pos = 0 while start_pos < len(line) and not line[start_pos].isalpha(): start_pos += 1 if start_pos < len(line): current_content.append(line[start_pos:]) else: # Empty line indicates end of current label content content = ' '.join(current_content).strip() if current_label in ["MotionTrends", "SceneTags"]: # Split by commas (both Chinese and English), preserve spaces in phrases items = re.split(r'[,,]\s*', content) result[current_label] = [item.strip() for item in items if item.strip()] else: result[current_label] = content current_label = None current_content = [] i += 1 # Handle Qwen model label which might extend to the end if current_label == "LLM" and current_content: content = ' '.join(current_content).strip() result[current_label] = content return result def vqa_parse_text_to_json(text): """ Parses text containing Camera Motion Caption and Scene Description into JSON format. """ result = { "CamMotion": "", "SceneDesc": "" } # Process Camera Motion Caption - from first letter after label to newline camera_pattern = r'Camera Motion Caption:\s*(\w[\s\S]*?)(?=\n|$)' camera_match = re.search(camera_pattern, text) if camera_match: result["CamMotion"] = camera_match.group(1).strip() # Process Scene Description - from first letter after label to end of text scene_pattern = r'Scene Description:\s*(\w[\s\S]*)$' scene_match = re.search(scene_pattern, text, re.DOTALL) if scene_match: result["SceneDesc"] = scene_match.group(1).strip() return result def process_single_row(args, clip_id): """ Processes VQA and LLM captions for a single clip and merges them into one JSON file. """ # Define file paths vqa_path = os.path.join(args.load_dir, "VQA", f"{clip_id}.txt") assert os.path.exists(vqa_path), f"VQA path does not exist: {vqa_path}" llm_path = os.path.join(args.load_dir, "LLM", f"{clip_id}.txt") assert os.path.exists(llm_path), f"LLM path does not exist: {llm_path}" output_path = os.path.join(args.output_dir, f"{clip_id}.json") # Skip if output file already exists if os.path.exists(output_path) and os.path.getsize(output_path) > 0: return # Read VQA file content with open(vqa_path, 'r', encoding='utf-8') as f: vqa_text = f.read() # Read LLM file content with open(llm_path, 'r', encoding='utf-8') as f: llm_text = f.read() # Parse text content to JSON vqa_json = vqa_parse_text_to_json(vqa_text) llm_json = parse_text_to_json(llm_text) # Merge JSON objects combined_json = {**vqa_json, **llm_json} # Save merged JSON to output file with open(output_path, 'w', encoding='utf-8') as f: json.dump(combined_json, f, ensure_ascii=False, indent=2) def worker(args, task_queue, pbar): while True: try: idx, clip_id = task_queue.get(timeout=1) except queue.Empty: break process_single_row(args, clip_id) task_queue.task_done() pbar.update(1) def parse_args(): """Parse command line arguments.""" parser = argparse.ArgumentParser( description='Merge VQA and LLM caption data') parser.add_argument('--csv_path', type=str, required=True, help='Path to the CSV file') parser.add_argument('--load_dir', type=str, required=True, help='Directory containing caption files') parser.add_argument('--output_dir', type=str, required=True, help='Directory to save merged JSON files') parser.add_argument('--num_workers', type=int, default=32, help='Number of worker threads') return parser.parse_args() def main(): """ Processes all scenes in the specified batch. """ args = parse_args() df = pd.read_csv(args.csv_path) os.makedirs(args.output_dir, exist_ok=True) # Use multiprocessing manager for thread-safe queue manager = Manager() task_queue = manager.Queue() # Add tasks to queue for index, row in df.iterrows(): task_queue.put((index, row['id'])) # Start multi-threaded processing with progress bar with tqdm(total=len(df), desc="Processing progress") as pbar: with ThreadPoolExecutor(max_workers=args.num_workers) as executor: futures = [] for _ in range(args.num_workers): futures.append(executor.submit(worker, args, task_queue, pbar)) # Wait for all futures to complete for future in as_completed(futures): future.result() if __name__ == "__main__": main() ================================================ FILE: docker-entrypoint.sh ================================================ #!/usr/bin/env bash # Simple entrypoint: activate venv if present and run provided command set -euo pipefail if [ -f "/workspace/venv/bin/activate" ]; then echo "Activating venv" # shellcheck disable=SC1091 source /workspace/venv/bin/activate fi if [ "$#" -gt 0 ]; then exec "$@" else exec bash fi ================================================ FILE: requirements/requirements.txt ================================================ torch==2.7.0 --index-url https://download.pytorch.org/whl/cu126 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu126 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu126 opencv-python==4.11.0.86 tqdm==4.67.1 imageio==2.37.0 einops==0.8.1 scipy==1.15.2 matplotlib==3.10.0 ninja==1.11.1.3 numpy==1.26.4 pandas==2.2.3 huggingface_hub ================================================ FILE: requirements/requirements_annotation.txt ================================================ wandb==0.19.8 timm==1.0.15 kornia==0.8.0 xformers==0.0.30 torch_scatter==2.1.2 gradio_imageslider==0.0.20 gradio==4.29.0 # sam2 hydra-core==1.3.2 iopath==0.1.10 OpenEXR ================================================ FILE: requirements/requirements_scoring.txt ================================================ ftfy==6.3.1 diffusers==0.29.0 accelerate==1.4.0 av==14.2.0 scenedetect==0.6.5.2 decord==0.6.0 imageio-ffmpeg==0.6.0 ffmpeg-python==0.2.0 clip @ git+https://github.com/openai/CLIP.git cpbd==1.0.7 # paddlepaddle-gpu==3.0.0 --index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/ paddleocr==3.0.0 nvidia-nccl-cu12==2.26.2 numpy==1.26.4 ================================================ FILE: scoring/README.md ================================================ # Scoring ## Aesthetic Score To evaluate the aesthetic quality of videos, we use the scoring model from [CLIP+MLP Aesthetic Score Predictor](https://github.com/christophschuhmann/improved-aesthetic-predictor). This model is trained on 176K SAC (Simulacra Aesthetic Captions) pairs, 15K LAION-Logos (Logos) pairs, and 250K AVA (The Aesthetic Visual Analysis) image-text pairs. The aesthetic score is between 1 and 10, where 5.5 can be considered as the threshold for fair aesthetics, and 6.5 for high aesthetics. Good text-to-image models can achieve a score of 7.0 or higher. First, download the scoring model to `./checkpoints/aesthetic.pth`. Skip this step if you already follow the installation instructions in [README](../README.md). ```bash wget https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/sac+logos+ava1-l14-linearMSE.pth -O checkpoints/aesthetic.pth ``` Then, run the following command to compute aesthetic scores. ```bash torchrun --nproc_per_node ${GPU_NUM} scoring/aesthetic/inference.py \ ${ROOT_META}/clips_info.csv \ --bs 16 \ --num_workers ${NUM_WORKERS} \ --fig_load_dir ${ROOT_FIG} ``` ## Luminance Score Luminance was calculated for the first, middle, and last frames using the standard formula $L = 0.2126 R + 0.7152 G + 0.0722 B$, where $R$, $G$, and $B$ are the respective channel values. Clips with average luminance outside the range [20, 140], either too dark or too bright, were excluded, ensuring that only videos with proper exposure were retained. Run the following command to compute luminance scores. ```bash torchrun --nproc_per_node ${GPU_NUM} scoring/luminance/inference.py \ ${ROOT_META}/clips_info.csv \ --bs 16 \ --num_workers ${NUM_WORKERS} \ --fig_load_dir ${ROOT_FIG} ``` ## Motion Score Conventional motion analysis using optical flow is computationally expensive and less effective for videos with complex motion patterns. Inspired by Open-Sora 2.0, we adopted a lightweight VMAF-based motion analysis method integrated into FFMPEG. This method yields a motion score between 0 and 20. Clips with scores outside the valid range of [2, 14], either too static (scores $<$ 2) or excessively chaotic (scores $>$ 14), were filtered out. Run the following command to compute motion scores. ```bash python scoring/motion/inference.py ${ROOT_META}/clips_info.csv \ --temp_save_dir ${ROOT_TEMP} \ --num_workers $((GPU_NUM * 4)) \ --gpu_num ${GPU_NUM} ``` ## OCR For text detection, we used the latest release of PaddleOCR, which offers high accuracy and robust multilingual support. We processed the first, middle, and last frames of each clip to detect text regions, computing the ratio of text area to frame size. Clips where the text area exceeded 30% were removed, as these were considered informational rather than visual. Run the following command to compute OCR scores. ```bash python scoring/ocr/inference.py ${ROOT_META}/clips_info.csv \ --fig_load_dir ${ROOT_FIG} \ --num_workers $((GPU_NUM * 4)) \ --gpu_num ${GPU_NUM} ``` ================================================ FILE: scoring/__init__.py ================================================ ================================================ FILE: scoring/aesthetic/__init__.py ================================================ ================================================ FILE: scoring/aesthetic/inference.py ================================================ """ Aesthetic scoring script for video frames using CLIP and MLP models. Adapted from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py Calculates aesthetic scores for video clips using distributed processing. """ # adapted from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py import argparse import gc import os from glob import glob from datetime import timedelta from PIL import Image import clip import numpy as np import pandas as pd import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch.utils.data import DataLoader, DistributedSampler from tqdm import tqdm def merge_scores(gathered_list: list, csv: pd.DataFrame, column): """Merge aesthetic scores from all distributed processes.""" # Reorder results from all processes indices_list = list(map(lambda x: x[0], gathered_list)) scores_list = list(map(lambda x: x[1], gathered_list)) flat_indices = [] for x in zip(*indices_list): flat_indices.extend(x) flat_scores = [] for x in zip(*scores_list): flat_scores.extend(x) flat_indices = np.array(flat_indices) flat_scores = np.array(flat_scores) # Filter duplicates from distributed processing unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True) csv.loc[unique_indices, column] = flat_scores[unique_indices_idx] # Drop indices in csv not in unique_indices csv = csv.loc[unique_indices] return csv class VideoTextDataset(torch.utils.data.Dataset): """Dataset for loading video frames for aesthetic scoring.""" def __init__(self, csv_path, fig_load_dir, transform=None): self.csv_path = csv_path self.csv = pd.read_csv(csv_path) self.transform = transform self.fig_load_dir = fig_load_dir def __getitem__(self, index): """Load and transform video frames for a single sample.""" sample = self.csv.iloc[index] # Load first 3 frames from video clip images_dir = os.path.join(self.fig_load_dir, sample["id"]) images = sorted(glob(f"{images_dir}/img/*.jpg"))[:3] # Apply CLIP preprocessing transforms images = [self.transform(Image.open(img).convert("RGB")) for img in images] # Stack images into tensor images = torch.stack(images) return dict(index=index, images=images) def __len__(self): return len(self.csv) class MLP(nn.Module): """Multi-layer perceptron for aesthetic score prediction.""" def __init__(self, input_size): super().__init__() self.input_size = input_size self.layers = nn.Sequential( nn.Linear(self.input_size, 1024), nn.Dropout(0.2), nn.Linear(1024, 128), nn.Dropout(0.2), nn.Linear(128, 64), nn.Dropout(0.1), nn.Linear(64, 16), nn.Linear(16, 1), ) def forward(self, x): return self.layers(x) class AestheticScorer(nn.Module): """Combined CLIP + MLP model for aesthetic scoring.""" def __init__(self, input_size, device): super().__init__() self.mlp = MLP(input_size) self.clip, self.preprocess = clip.load("ViT-L/14", device=device) self.eval() self.to(device) def forward(self, x): """Extract CLIP features and predict aesthetic scores.""" image_features = self.clip.encode_image(x) image_features = F.normalize(image_features, p=2, dim=-1).float() return self.mlp(image_features) def parse_args(): """Parse command line arguments for aesthetic scoring.""" parser = argparse.ArgumentParser() parser.add_argument("--csv_path", type=str, help="Path to the input CSV file") parser.add_argument( "--load_num", type=int, default=4, help="Number of frames to load" ) parser.add_argument("--bs", type=int, default=1024, help="Batch size") parser.add_argument("--num_workers", type=int, default=16, help="Number of workers") parser.add_argument( "--fig_load_dir", type=str, required=True, help="Directory to load the extracted frames", ) parser.add_argument( "--prefetch_factor", type=int, default=3, help="Prefetch factor" ) parser.add_argument("--skip_if_existing", action="store_true") args = parser.parse_args() return args def main(): args = parse_args() csv_path = args.csv_path if not os.path.exists(csv_path): print(f"CSV file '{csv_path}' not found. Exit.") exit() wo_ext, ext = os.path.splitext(csv_path) out_path = f"{wo_ext}_aes{ext}" if args.skip_if_existing and os.path.exists(out_path): print(f"Output CSV file '{out_path}' already exists. Exit.") exit() # Initialize distributed processing dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) # Build aesthetic scoring model device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") model = AestheticScorer(768, device) model.mlp.load_state_dict( torch.load("checkpoints/aesthetic.pth", map_location=device) ) preprocess = model.preprocess # Build dataset and dataloader dataset = VideoTextDataset( args.csv_path, transform=preprocess, fig_load_dir=args.fig_load_dir ) dataloader = DataLoader( dataset, batch_size=args.bs, num_workers=args.num_workers, sampler=DistributedSampler( dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False, ), ) # Compute aesthetic scores for all batches indices_list = [] scores_list = [] model.eval() for batch in tqdm( dataloader, disable=(dist.get_rank() != 0), position=dist.get_rank() ): indices = batch["index"] images = batch["images"].to(device, non_blocking=True) B = images.shape[0] images = rearrange(images, "B N C H W -> (B N) C H W") # Compute aesthetic scores using CLIP + MLP with torch.no_grad(): scores = model(images) # Average scores across frames for each video scores = rearrange(scores, "(B N) 1 -> B N", B=B) scores = scores.mean(dim=1) scores_np = scores.to(torch.float32).cpu().numpy() indices_list.extend(indices.tolist()) scores_list.extend(scores_np.tolist()) # Wait for all ranks to finish data processing dist.barrier() # Gather results from all processes and save torch.cuda.empty_cache() gc.collect() gathered_list = [None] * dist.get_world_size() dist.all_gather_object(gathered_list, (indices_list, scores_list)) if dist.get_rank() == 0: csv_new = merge_scores(gathered_list, dataset.csv, column="aesthetic score") csv_new.to_csv(out_path, index=False) print(f"New csv with aesthetic scores saved to '{out_path}'.") if __name__ == "__main__": main() ================================================ FILE: scoring/luminance/__init__.py ================================================ ================================================ FILE: scoring/luminance/inference.py ================================================ """ Luminance analysis script for video frames using distributed processing. Calculates mean, min, and max luminance scores for video clips using PyTorch distributed computing. """ import argparse import os import gc from glob import glob from datetime import timedelta from PIL import Image import numpy as np import pandas as pd import torch import torch.distributed as dist from torch.utils.data import DataLoader, DistributedSampler from torchvision.transforms.functional import pil_to_tensor from tqdm import tqdm def merge_scores(gathered_list: list, csv: pd.DataFrame): """Merge luminance scores from all distributed processes.""" # Reorder results from all processes indices_list = list(map(lambda x: x[0], gathered_list)) mean_scores_list = list(map(lambda x: x[1], gathered_list)) min_scores_list = list(map(lambda x: x[2], gathered_list)) max_scores_list = list(map(lambda x: x[3], gathered_list)) flat_indices = [] for x in zip(*indices_list): flat_indices.extend(x) flat_mean_scores = [] for x in zip(*mean_scores_list): flat_mean_scores.extend(x) flat_min_scores = [] for x in zip(*min_scores_list): flat_min_scores.extend(x) flat_max_scores = [] for x in zip(*max_scores_list): flat_max_scores.extend(x) flat_indices = np.array(flat_indices) flat_mean_scores = np.array(flat_mean_scores) flat_min_scores = np.array(flat_min_scores) flat_max_scores = np.array(flat_max_scores) # Filter duplicates from distributed processing unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True) csv.loc[unique_indices, "luminance mean"] = flat_mean_scores[unique_indices_idx] csv.loc[unique_indices, "luminance min"] = flat_min_scores[unique_indices_idx] csv.loc[unique_indices, "luminance max"] = flat_max_scores[unique_indices_idx] # Drop indices in csv not in unique_indices csv = csv.loc[unique_indices] return csv class VideoDataset(torch.utils.data.Dataset): """Dataset to handle video luminance computation.""" def __init__(self, csv_path, fig_load_dir): self.csv_path = csv_path self.csv = pd.read_csv(csv_path) self.fig_load_dir = fig_load_dir def __getitem__(self, index): """Get video frames and compute luminance for a single sample.""" sample = self.csv.iloc[index] # Load first 3 frames from video clip images_dir = os.path.join(self.fig_load_dir, sample["id"]) images = sorted(glob(f"{images_dir}/img/*.jpg"))[:3] # Transform images to tensors images = torch.stack( [pil_to_tensor(Image.open(img).convert("RGB")) for img in images] ) return {"index": index, "images": images} def __len__(self): return len(self.csv) def parse_args(): """Parse command line arguments for luminance analysis.""" parser = argparse.ArgumentParser() parser.add_argument("--csv_path", type=str, help="Path to the input CSV file") parser.add_argument("--bs", type=int, default=4, help="Batch size") parser.add_argument("--num_workers", type=int, default=16, help="Number of workers") parser.add_argument( "--fig_load_dir", type=str, required=True, help="Directory to load the extracted frames", ) parser.add_argument("--skip_if_existing", action="store_true") return parser.parse_args() def main(): args = parse_args() csv_path = args.csv_path if not os.path.exists(csv_path): print(f"csvdata file '{csv_path}' not found. Exiting.") return output_path = csv_path.replace(".csv", "_lum.csv") if args.skip_if_existing and os.path.exists(output_path): print(f"Output '{output_path}' already exists. Exiting.") return # Initialize distributed processing dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) ( torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) if torch.cuda.is_available() else None ) # Setup dataset and distributed dataloader dataset = VideoDataset(csv_path, fig_load_dir=args.fig_load_dir) dataloader = DataLoader( dataset, batch_size=args.bs, num_workers=args.num_workers, sampler=DistributedSampler( dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank() ), ) # Process batches and calculate luminance scores indices_list = [] mean_scores_list = [] max_scores_list = [] min_scores_list = [] device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") for batch in tqdm( dataloader, disable=(dist.get_rank() != 0), position=dist.get_rank() ): indices = batch["index"] images = batch["images"].to(device, non_blocking=True) # [B, N, C, H, W] # Calculate luminance using standard RGB weights R, G, B = images[:, :, 0], images[:, :, 1], images[:, :, 2] luminance = 0.2126 * R + 0.7152 * G + 0.0722 * B scores = luminance.mean(dim=[2, 3]) # Compute statistics across frames mean_scores = scores.mean(dim=1).cpu().numpy() max_scores = scores.max(dim=1)[0].cpu().numpy() min_scores = scores.min(dim=1)[0].cpu().numpy() indices_list.extend(indices.tolist()) mean_scores_list.extend(mean_scores.tolist()) max_scores_list.extend(max_scores.tolist()) min_scores_list.extend(min_scores.tolist()) # Wait for all ranks to finish data processing dist.barrier() # Gather results from all processes and save torch.cuda.empty_cache() gc.collect() gathered_list = [None] * dist.get_world_size() dist.all_gather_object( gathered_list, (indices_list, mean_scores_list, min_scores_list, max_scores_list), ) if dist.get_rank() == 0: csv_new = merge_scores(gathered_list, dataset.csv) csv_new.to_csv(output_path, index=False) print(f"New csv with luminance scores saved to '{output_path}'") if __name__ == "__main__": main() ================================================ FILE: scoring/motion/INSTALL.md ================================================ # Compiling FFmpeg with NVIDIA GPU Acceleration and VMAF on Ubuntu This guide provides a comprehensive walkthrough for compiling FFmpeg from source on an Ubuntu system equipped with an NVIDIA GPU. The resulting build will support NVIDIA's hardware encoding/decoding (NVENC/DEC), NPP filters (NVIDIA Performance Primitives), and CUDA-based VMAF (Video Multi-Method Assessment Fusion) for video quality assessment. ## Environment and Versions Before you begin, ensure your system environment is similar to the configuration below. Version matching is crucial for a successful compilation. The GPU needs to support HEVC; refer to the [NVIDIA NVDEC Support Matrix](https://en.wikipedia.org/wiki/NVIDIA_Video_Coding_Engine#NVDEC). - **GPU**: NVIDIA GeForce RTX 4090 or other compatible models - **OS**: Ubuntu 22.04 - **NVIDIA Driver Version**: A version compatible with CUDA 12.6 - **CUDA Version (from `nvidia-smi`)**: `12.x` - **CUDA Toolkit Version**: `12.6` (This is the version used for compilation) - **Target FFmpeg Version**: `6.1` **Key Tip**: The version of the `NVIDIA Codec Headers` (`ffnvcodec`) must be compatible with the `CUDA Toolkit` version installed on your system and the version of `FFmpeg` you intend to compile. ## Compilation Steps Please follow these steps in order. ### Step 1: Install System Dependencies Update system packages and install required development tools and libraries: ```bash sudo apt-get update sudo DEBIAN_FRONTEND=noninteractive apt-get install -y \ libopenjp2-7-dev \ ninja-build \ cmake \ git \ python3 \ python3-pip \ nasm \ xxd \ pkg-config \ curl \ unzip \ ca-certificates \ libnuma-dev \ libsm6 \ libxext6 \ libxrender1 \ libgl1 \ vim \ nvidia-cuda-toolkit ``` ### Step 2: Clone Required Repositories ```bash # Create a working directory (custom path allowed) mkdir -p ~/ffmpeg-build && cd ~/ffmpeg-build # Clone nv-codec-headers (NVIDIA codec headers) git clone https://github.com/FFmpeg/nv-codec-headers.git # Clone libvmaf (video quality assessment library) git clone https://github.com/Netflix/vmaf.git cd vmaf && git checkout master # Switch to master branch (modify version if needed) cd .. # Clone FFmpeg source code git clone https://github.com/FFmpeg/FFmpeg.git cd FFmpeg && git checkout master # Switch to master branch (modify version if needed) cd .. ``` ### Step 3: Install nv-codec-headers ```bash cd nv-codec-headers make sudo make install cd .. ``` ### Step 4: Compile and Install libvmaf (with CUDA Support) 1. Install the meson build tool: ```bash python3 -m pip install meson ``` 2. Compile and install libvmaf: ```bash cd vmaf meson libvmaf/build libvmaf \ -Denable_cuda=true \ -Denable_avx512=true \ --buildtype release ninja -vC libvmaf/build sudo ninja -vC libvmaf/build install cd .. ``` 3. Update system library cache: ```bash export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/x86_64-linux-gnu/ sudo ldconfig ``` ### Step 5: Compile and Install FFmpeg (with NVIDIA and libvmaf Support) ```bash cd FFmpeg # Configure compilation options (enable CUDA, NVENC, NVDEC, and libvmaf) ./configure \ --enable-libnpp \ --enable-nonfree \ --enable-nvdec \ --enable-nvenc \ --enable-cuvid \ --enable-cuda \ --enable-cuda-nvcc \ --enable-libvmaf \ --enable-ffnvcodec \ --disable-stripping \ --extra-cflags="-I/usr/local/cuda/include" \ --extra-ldflags="-L/usr/local/cuda/lib64 -L/usr/local/cuda/lib64/stubs/" # Compile (adjust the number after -j based on CPU cores for faster compilation) make -j$(nproc) # Install sudo make install cd .. ``` ### Step 6: Configure Python Environment 1. Upgrade pip and set up links: ```bash sudo ln -sf /usr/bin/python3 /usr/bin/python python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel ``` 2. Install Python dependencies (assuming project code is cloned locally; replace with actual path): ```bash # Navigate to the project root directory cd /path/to/your/project # Install dependencies python3 -m pip --no-cache-dir install -r requirements/requirements.txt python3 -m pip --no-cache-dir install -r requirements/requirements_scoring.txt || true python3 -m pip --no-cache-dir install -r requirements/requirements_annotation.txt || true ``` ### Step 7: Verify Installation 1. Check FFmpeg version and configuration: ```bash ffmpeg -version ffmpeg -encoders | grep nvenc # Verify NVENC support ffmpeg -decoders | grep nvdec # Verify NVDEC support ffmpeg -filters | grep vmaf # Verify libvmaf support ``` 2. If all the above commands output corresponding content correctly, the installation is successful. ## Troubleshooting ### Issue 1: VMAF compilation fails with `vcs_version.h: No such file or directory` - **Cause**: This error typically occurs if you downloaded the VMAF source code as a ZIP archive instead of using `git clone`. The build script relies on the `.git` directory to generate version header files. - **Solution**: Always use `git clone` to get the source code. ```bash git clone https://github.com/Netflix/vmaf.git ``` ### Issue 2: FFmpeg `configure` fails with error about Video Codec SDK version being too low - **Error Message**: Something like `ERROR: nvenc requested, but NVIDIA Video Codec SDK 12.1 or later is required.` (The version number may vary). - **Cause**: This means the version of `nv-codec-headers` you checked out is not compatible with your NVIDIA driver, CUDA Toolkit, or the version of FFmpeg you are building. - **Solution**: 1. Carefully re-check your [NVIDIA Driver](https://www.nvidia.com/Download/index.aspx) and [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit-archive) versions. 2. Go back to [Step 3: Install NVIDIA Codec Headers](#step-3-install-nvidia-codec-headers) and ensure you `git checkout` the branch that best matches your environment (e.g., `sdk/12.6`). 3. Consult the [Official NVIDIA FFmpeg Guide](https://docs.nvidia.com/video-technologies/video-codec-sdk/ffmpeg-with-nvidia-gpu/index.html) or the `nv-codec-headers` repository to confirm version compatibility. ## References - [VMAF on GitHub](https://github.com/Netflix/vmaf) - [FFmpeg Official Source](https://github.com/FFmpeg/FFmpeg/tree/release/6.1) - [NVIDIA Codec Headers Source](https://github.com/FFmpeg/nv-codec-headers/tree/sdk/12.6) - [Official NVIDIA Guide for Compiling FFmpeg](https://docs.nvidia.com/video-technologies/video-codec-sdk/ffmpeg-with-nvidia-gpu/index.html) ================================================ FILE: scoring/motion/__init__.py ================================================ ================================================ FILE: scoring/motion/inference.py ================================================ """ Motion analysis script for video quality assessment using FFmpeg and VMAF. Calculates motion scores for video clips using hardware acceleration when available. """ import os import argparse import pandas as pd import subprocess from multiprocessing import Manager import queue import concurrent.futures from tqdm import tqdm FFMPEG_PATH = "/usr/local/bin/ffmpeg" def get_ffmpeg_acceleration(): """ Auto detect the best acceleration method. Priority: NVIDIA GPU > CPU. """ try: # Get the list of ffmpeg configuration output = subprocess.check_output( [FFMPEG_PATH, "-version"], stderr=subprocess.DEVNULL ).decode("utf-8") if "--enable-cuda-nvcc" in output and "--enable-libvmaf" in output: return "nvidia" else: return "cpu" # Use CPU except Exception as e: print(f"FFmpeg acceleration detection failed: {e}") return "cpu" ACCELERATION_TYPE = get_ffmpeg_acceleration() print(f"FFmpeg acceleration type: {ACCELERATION_TYPE}") def process_single_row(video_path, args, process_id): """Process a single video to generate motion analysis CSV using FFmpeg.""" path = os.path.join( args.temp_save_dir, os.path.basename(video_path).split(".")[0] + ".csv" ) # Build FFmpeg command with appropriate acceleration command = [FFMPEG_PATH] if ACCELERATION_TYPE == "nvidia": command += [ "-hwaccel", "cuda", "-hwaccel_output_format", "cuda", "-hwaccel_device", f"{process_id % args.gpu_num}", ] command += ["-i", f"{video_path}"] if ACCELERATION_TYPE == "nvidia": command += [ "-hwaccel", "cuda", "-hwaccel_output_format", "cuda", "-hwaccel_device", f"{process_id % args.gpu_num}", ] command += ["-i", f"{video_path}"] if ACCELERATION_TYPE == "nvidia": command += [ "-filter_complex", f"[0:v]scale_cuda=format=yuv420p[dis],[1:v]scale_cuda=format=yuv420p[ref],[dis][ref]libvmaf_cuda=log_fmt=csv:log_path={path}", ] else: command += ["-lavfi", f"libvmaf=log_fmt=csv:log_path={path}"] command += ["-f", "null", "-"] try: result = subprocess.run(command, capture_output=True, text=True, check=True) except subprocess.CalledProcessError as e: print(f"Error: {e.stderr}") def calculate_score(row, args): """Calculate motion score for a specific video clip segment.""" csv_path = os.path.join(args.temp_save_dir, f'{row["id_ori"]}.csv') df = pd.read_csv(csv_path) df = df[(df["Frame"] >= row["frame_start"]) & (df["Frame"] <= row["frame_end"])] mean_value = df["integer_motion2"].mean() return mean_value def worker1(task_queue, progress_queue, args, process_id): """Worker function for processing videos in parallel.""" while True: try: video_path = task_queue.get(timeout=1) except queue.Empty: break process_single_row(video_path, args, process_id) progress_queue.put(video_path) task_queue.task_done() def worker2(task_queue, results_queue, args): """Worker function for calculating motion scores in parallel.""" while True: try: index, row = task_queue.get(timeout=1) except queue.Empty: break value = calculate_score(row, args) results_queue.put((index, value)) task_queue.task_done() def parse_args(): """Parse command line arguments for motion analysis.""" parser = argparse.ArgumentParser() parser.add_argument("--csv_path", type=str, required=True, help="Path to the CSV file") parser.add_argument( "--temp_save_dir", type=str, required=True, help="Directory to save the temporary files", ) parser.add_argument( "--num_workers", type=int, default=None, help="#workers for concurrent.futures" ) parser.add_argument( "--disable_parallel", action="store_true", help="disable parallel processing" ) parser.add_argument("--gpu_num", type=int, default=1, help="gpu number") parser.add_argument("--skip_if_existing", action="store_true") args = parser.parse_args() return args def main(): args = parse_args() wo_ext, ext = os.path.splitext(args.csv_path) out_path = f"{wo_ext}_motion{ext}" if args.skip_if_existing and os.path.exists(out_path): print(f"Output CSV file '{out_path}' already exists. Exit.") exit() df = pd.read_csv(args.csv_path) video_paths = df["video_path"].unique() if args.disable_parallel: # Sequential processing results = [] for video_path in tqdm(video_paths, desc="Processing videos"): result = process_single_row(video_path, args, 0) results.append(result) for index, row in tqdm( df.iterrows(), total=len(df), desc="Calculating scores" ): result = calculate_score(row, args) df.at[index, "motion"] = result else: # Parallel processing if args.num_workers is not None: num_workers = args.num_workers else: num_workers = os.cpu_count() or 1 # First phase: process videos to generate CSV files manager = Manager() task_queue = manager.Queue() progress_queue = manager.Queue() for video_path in video_paths: task_queue.put(video_path) with concurrent.futures.ProcessPoolExecutor( max_workers=num_workers ) as executor: futures = [] for id in range(num_workers): futures.append( executor.submit(worker1, task_queue, progress_queue, args, id) ) processed = 0 total_video_tasks = len(video_paths) with tqdm(total=total_video_tasks, desc="Processing videos") as pbar: while processed < total_video_tasks: try: progress_queue.get(timeout=1) processed += 1 pbar.update(1) except queue.Empty: if all(f.done() for f in futures) and progress_queue.empty(): break for future in futures: future.result() # Second phase: calculate motion scores result_queue = manager.Queue() task_queue = manager.Queue() for index, row in df.iterrows(): task_queue.put((index, row)) with concurrent.futures.ProcessPoolExecutor( max_workers=num_workers ) as executor: futures = [] for _ in range(num_workers): futures.append(executor.submit(worker2, task_queue, result_queue, args)) results = [] processed = 0 total_score_tasks = len(df) with tqdm(total=total_score_tasks, desc="Calculating scores") as pbar: while processed < total_score_tasks: try: results.append(result_queue.get(timeout=1)) processed += 1 pbar.update(1) except queue.Empty: if all(f.done() for f in futures) and result_queue.empty(): break for future in futures: future.result() # Collect and sort results while not result_queue.empty(): results.append(result_queue.get()) results.sort(key=lambda x: x[0]) results = list(map(lambda x: x[1], results)) df["motion score"] = results df.to_csv(out_path, index=False) print(f"New df with motion scores saved to '{out_path}'.") if __name__ == "__main__": main() ================================================ FILE: scoring/ocr/__init__.py ================================================ ================================================ FILE: scoring/ocr/inference.py ================================================ """ OCR analysis script for video frames using PaddleOCR. Calculates text area ratios for video clips using distributed processing. """ import os from glob import glob import argparse import pandas as pd from multiprocessing import Manager import queue import concurrent.futures from tqdm import tqdm import cv2 from paddleocr import PaddleOCR def process_single_row(row, args, model): """Process a single row to calculate OCR text area ratio.""" img_dir = os.path.join(args.fig_load_dir, row["id"]) img_list = sorted(glob(f"{img_dir}/img/*.jpg"))[:3] # Load images images = [cv2.imread(img_path) for img_path in img_list] images = [img for img in images if img is not None] if not images: return 0.0 result = model.predict(input=images) area = images[0].shape[0] * images[0].shape[1] # Image area area_list = [] for res in result: total_text_area = 0 # Initialize total text area for rec_box in res["rec_boxes"]: x_min, y_min, x_max, y_max = ( float(rec_box[0]), float(rec_box[1]), float(rec_box[2]), float(rec_box[3]), ) # Extract top-left and bottom-right coordinates text_area = (x_max - x_min) * (y_max - y_min) # Calculate text area total_text_area += text_area ratio = total_text_area / area area_list.append(ratio) return ( max(area_list) if area_list else 0.0 ) # Return max area ratio, 0.0 if no text detected def worker(task_queue, result_queue, args, id): """Worker function for multiprocessing OCR inference.""" gpu_id = id % args.gpu_num os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) # Bind to specific GPU device = "gpu:0" # if torch.cuda.is_available() else "cpu" # Initialize PaddleOCR model with disabled orientation and unwarping features model = PaddleOCR( device=device, use_doc_orientation_classify=False, # Disable document orientation classification use_doc_unwarping=False, # Disable text image correction use_textline_orientation=False, # Disable text line orientation classification ) while True: try: index, row = task_queue.get_nowait() except queue.Empty: break area_list = process_single_row(row, args, model) result_queue.put((index, area_list)) def parse_args(): """Parse command line arguments for OCR inference.""" parser = argparse.ArgumentParser(description="SAM2 Image Predictor") parser.add_argument("--csv_path", type=str, help="Path to the csv file") parser.add_argument( "--fig_load_dir", type=str, default="img", help="Directory containing images" ) parser.add_argument( "--num_workers", type=int, default=16, help="#workers for concurrent.futures" ) parser.add_argument("--gpu_num", type=int, default=1, help="gpu number") parser.add_argument("--skip_if_existing", action="store_true") parser.add_argument( "--disable_parallel", action="store_true", help="Disable parallel processing" ) return parser.parse_args() def main(): args = parse_args() if not os.path.exists(args.csv_path): print(f"csv file '{args.csv_path}' not found. Exit.") return wo_ext, ext = os.path.splitext(args.csv_path) out_path = f"{wo_ext}_ocr{ext}" if args.skip_if_existing and os.path.exists(out_path): print(f"Output csv file '{out_path}' already exists. Exit.") exit() df = pd.read_csv(args.csv_path) results = [] if args.disable_parallel: # Sequential processing model = PaddleOCR( device="gpu:0", # if torch.cuda.is_available() else "cpu" use_doc_orientation_classify=False, # Disable document orientation classification use_doc_unwarping=False, # Disable text image correction use_textline_orientation=False, # Disable text line orientation classification ) ocr_scores = [] for index, row in tqdm(df.iterrows(), total=len(df), desc="Processing rows"): score = process_single_row(row, args, model) ocr_scores.append(score) results.append((index, score)) else: # Set up multiprocessing queues manager = Manager() task_queue = manager.Queue() result_queue = manager.Queue() for index, row in df.iterrows(): task_queue.put((index, row)) # Process tasks with multiple workers with concurrent.futures.ProcessPoolExecutor( max_workers=args.num_workers ) as executor: futures = [] for id in range(args.num_workers): futures.append( executor.submit(worker, task_queue, result_queue, args, id) ) processed = 0 total_tasks = len(df) with tqdm(total=total_tasks, desc="Processing rows") as pbar: while processed < total_tasks: try: results.append(result_queue.get(timeout=1)) processed += 1 pbar.update(1) except queue.Empty: if all(f.done() for f in futures) and result_queue.empty(): break for future in futures: future.result() # Collect and sort results while not result_queue.empty(): index, area_list = result_queue.get() results.append((index, area_list)) results.sort(key=lambda x: x[0]) df["ocr score"] = [x[1] for x in results] df.to_csv(out_path, index=False) print(f"New csv (shape={df.shape}) with ocr results saved to '{out_path}'.") if __name__ == "__main__": main() ================================================ FILE: scripts/annotation.sh ================================================ #!/bin/bash CSV=[Replace with the path to the CSV file generated in the scoring step] OUTPUT_DIR=[Replace with the path to your output directory] mkdir -p ${OUTPUT_DIR} CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 GPU_NUM=8 ENHANCED=true # Set to true to enable enhanced instruction generation measure_time() { local step_number=$1 shift local green="\e[32m" local red="\e[31m" local no_color="\e[0m" local yellow="\e[33m" start_time=$(date +%s) echo -e "${green}Step ${step_number} started at: $(date)${no_color}" "$@" end_time=$(date +%s) echo -e "${red}Step ${step_number} finished at: $(date)${no_color}" echo -e "${yellow}Duration: $((end_time - start_time)) seconds${no_color}" echo "---------------------------------------" } # 1. Extract frames measure_time 1 python utils/extract_frames.py \ --csv_path ${CSV} \ --output_dir ${OUTPUT_DIR} \ --num_workers $((GPU_NUM * 2)) \ --target_size "1280*720" \ --backend "opencv" \ --interval 0.2 # 2.1 Depth Estimation with Depth-Anything CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 2.1 torchrun --standalone --nproc_per_node ${GPU_NUM} camera_pose_annotation/depth_estimation/Depth-Anything/inference_batch.py \ --csv_path ${CSV} \ --encoder vitl \ --checkpoints_path checkpoints \ --output_dir ${OUTPUT_DIR} \ --bs 16 \ --num_workers ${GPU_NUM} # 2.2 Depth Estimation with UniDepth CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 2.2 torchrun --standalone --nproc_per_node ${GPU_NUM} camera_pose_annotation/depth_estimation/UniDepth/inference_batch.py \ --csv_path ${CSV} \ --output_dir ${OUTPUT_DIR} \ --checkpoints_path checkpoints \ --bs 32 \ --num_workers ${GPU_NUM} # 3. Camera Tracking CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 3 python camera_pose_annotation/camera_tracking/inference_batch.py \ --csv_path ${CSV} \ --dir_path ${OUTPUT_DIR} \ --checkpoints_path checkpoints \ --gpu_id ${CUDA_VISIBLE_DEVICES} \ --num_workers $((GPU_NUM * 2)) # 4.1 CVD Optimization Preprocess CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 4.1 python camera_pose_annotation/cvd_opt/preprocess/inference_batch.py \ --csv_path ${CSV} \ --dir_path ${OUTPUT_DIR} \ --checkpoints_path checkpoints \ --gpu_id ${CUDA_VISIBLE_DEVICES} \ --num_workers $((GPU_NUM * 2)) # 4.2 CVD Optimization CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 4.2 python camera_pose_annotation/cvd_opt/inference_batch.py \ --csv_path ${CSV} \ --dir_path ${OUTPUT_DIR} \ --gpu_id ${CUDA_VISIBLE_DEVICES} \ --num_workers $((GPU_NUM * 2)) # --only_depth # 5. Dynamic Mask Prediction CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 5 python camera_pose_annotation/dynamic_mask/inference_batch.py \ --csv_path ${CSV} \ --dir_path ${OUTPUT_DIR} \ --checkpoints_path checkpoints \ --gpu_num ${GPU_NUM} \ --num_workers $((GPU_NUM * 2)) # 6. Evaluation of the results CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 6 python utils/evaluation.py \ --csv_path ${CSV} \ --dir_path ${OUTPUT_DIR} \ --gpu_num ${GPU_NUM} \ --num_workers $((GPU_NUM * 2)) \ --output_path ${OUTPUT_DIR}/final_results.csv # 7. Get motion instructions if [ "$ENHANCED" = false ] ; then CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 7 python utils/get_instructions.py \ --csv_path ${CSV} \ --dir_path ${OUTPUT_DIR} \ --interval 2 \ --num_workers $((GPU_NUM * 2)) else echo "Standard instruction generation is enabled." CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 7 python utils/get_instructions_enhanced.py \ --csv_path ${OUTPUT_DIR}/final_results.csv \ --dir_path ${OUTPUT_DIR} \ --num_workers $((GPU_NUM * 2)) fi # 8. Normalize the intrinsics CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 8 python utils/normalize_intrinsics.py \ --csv_path ${CSV} \ --dir_path ${OUTPUT_DIR} \ --num_workers $((GPU_NUM * 2)) # [Optional] Convert the output poses.npy into a c2w/w2c matrix CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 9 python utils/quat_to_mat.py \ --csv_path ${CSV} \ --format c2w \ --dir_path ${OUTPUT_DIR} \ --num_workers $((GPU_NUM * 2)) ================================================ FILE: scripts/caption.sh ================================================ #!/bin/bash CSV=[Replace with the path to the result CSV file generated in the annotation step] SRC_DIR=[Replace with the path to the annotation output directory] OUTPUT_DIR=[Replace with the path to your output directory] mkdir -p ${OUTPUT_DIR} num_workers=8 wait_time=1 # VQA vqa_prompt_file=caption/VQA/prompt.txt vqa_model=gemini-2.0-flash vqa_api_key=[Replace with your api key] vqa_base_domain=https://generativelanguage.googleapis.com/ # LLM llm_prompt_dir=caption/LLM llm_model=qwen3-30b-a3b llm_api_key=[Replace with your api key] llm_base_domain=https://dashscope.aliyuncs.com/compatible-mode/ # Tagging tag_prompt_file=caption/tagging/prompt.txt tag_model=qwen3-30b-a3b tag_api_key=[Replace with your api key] tag_base_domain=https://dashscope.aliyuncs.com/compatible-mode/ measure_time() { local step_number=$1 shift local green="\e[32m" local red="\e[31m" local no_color="\e[0m" local yellow="\e[33m" start_time=$(date +%s) echo -e "${green}Step $step_number started at: $(date)${no_color}" "$@" end_time=$(date +%s) echo -e "${red}Step $step_number finished at: $(date)${no_color}" echo -e "${yellow}Duration: $((end_time - start_time)) seconds${no_color}" echo "---------------------------------------" } # 1. VQA caption measure_time 1 python caption/VQA/inference.py \ --csv_path ${CSV} \ --fig_load_dir ${SRC_DIR} \ --output_dir ${OUTPUT_DIR} \ --prompt_file ${vqa_prompt_file} \ --model ${vqa_model} \ --api_key ${vqa_api_key} \ --base_domain ${vqa_base_domain} \ --num_workers ${num_workers} \ --wait_time ${wait_time} # 2. LLM caption measure_time 2 python caption/LLM/inference.py \ --csv_path $CSV \ --pose_load_dir $SRC_DIR \ --output_dir $OUTPUT_DIR \ --prompt_dir $llm_prompt_dir \ --model $llm_model \ --api_key $llm_api_key \ --num_workers $num_workers \ --base_domain $llm_base_domain \ --wait_time $wait_time # 3. Combine results measure_time 3 python caption/utils/combine.py \ --csv_path $CSV \ --load_dir $OUTPUT_DIR \ --output_dir $OUTPUT_DIR/results \ --num_workers $num_workers # 4. Add tags python caption/tagging/inference.py \ --csv_path $CSV \ --json_load_dir $OUTPUT_DIR/results \ --prompt_file $tag_prompt_file \ --model $tag_model \ --api_key $tag_api_key \ --num_workers $num_workers \ --base_domain $tag_base_domain \ --wait_time $wait_time ================================================ FILE: scripts/docker_prepulls.sh ================================================ #!/usr/bin/env bash # This script pre-pulls and tags GPU-related Docker images from specified registries. set -euo pipefail # Minimal script: pre-pull three images (builder/runtime/buildkit) and tag them to # canonical names so downstream scripts can rely on the expected tags. # You can override these by setting the env vars before running this script. BUILDER_IMAGE=${BUILDER_IMAGE:-swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04} RUNTIME_IMAGE=${RUNTIME_IMAGE:-swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/nvidia/cuda:12.6.3-runtime-ubuntu22.04} BUILDKIT_IMAGE=${BUILDKIT_IMAGE:-swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/moby/buildkit:buildx-stable-1} retry_pull() { local img="$1" for i in 1 2 3; do echo "pull attempt $i for ${img}..." if docker pull "${img}"; then echo "pulled ${img}" return 0 fi sleep $((i * 2)) done echo "Failed to pull ${img} after retries" >&2 return 1 } echo "Pre-pulling images..." echo "- builder: ${BUILDER_IMAGE}" echo "- runtime: ${RUNTIME_IMAGE}" echo "- buildkit: ${BUILDKIT_IMAGE}" retry_pull "${BUILDER_IMAGE}" || true retry_pull "${RUNTIME_IMAGE}" || true retry_pull "${BUILDKIT_IMAGE}" || true CANONICAL_BUILDKIT_TAG="moby/buildkit:buildx-stable-1" if docker image inspect "${BUILDKIT_IMAGE}" >/dev/null 2>&1; then echo "Tagging ${BUILDKIT_IMAGE} -> ${CANONICAL_BUILDKIT_TAG} (local only)" docker tag "${BUILDKIT_IMAGE}" "${CANONICAL_BUILDKIT_TAG}" || true fi # Also tag the mirrored CUDA images to the original docker.io names expected by # Dockerfiles and other scripts. This lets downstream tooling refer to # docker.io/nvidia/cuda:12.6.3-... even when images were pulled from a mirror. ORIG_BUILDER_TAG="docker.io/nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04" ORIG_RUNTIME_TAG="docker.io/nvidia/cuda:12.6.3-runtime-ubuntu22.04" if docker image inspect "${BUILDER_IMAGE}" >/dev/null 2>&1; then echo "Tagging ${BUILDER_IMAGE} -> ${ORIG_BUILDER_TAG}" docker tag "${BUILDER_IMAGE}" "${ORIG_BUILDER_TAG}" || true fi if docker image inspect "${RUNTIME_IMAGE}" >/dev/null 2>&1; then echo "Tagging ${RUNTIME_IMAGE} -> ${ORIG_RUNTIME_TAG}" docker tag "${RUNTIME_IMAGE}" "${ORIG_RUNTIME_TAG}" || true fi echo "Done pulling/tagging images." echo "You can now run downstream build steps that expect these images to exist locally." ================================================ FILE: scripts/download_checkpoints.sh ================================================ mkdir -p ./checkpoints/ cd ./checkpoints/ # aesthetic wget https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/sac+logos+ava1-l14-linearMSE.pth -O aesthetic.pth # megasam wget https://github.com/mega-sam/mega-sam/blob/main/checkpoints/megasam_final.pth -O megasam_final.pth # raft gdown -c https://drive.google.com/uc?id=1MqDajR89k-xLV0HIrmJ0k-n8ZpG6_suM -O raft-things.pth # depth anything huggingface-cli download --resume-download depth-anything/Depth-Anything-V2-Large --local-dir Depth-Anything # unidepth huggingface-cli download --resume-download lpiccinelli/unidepth-v2-vitl14 --local-dir UniDepth # sam huggingface-cli download --resume-download facebook/sam2.1-hiera-large --local-dir SAM2 ================================================ FILE: scripts/scoring.sh ================================================ #!/bin/bash VIDEO_DIR=[Replace with the path to your video files] OUTPUT_DIR=[Replace with the path to your output directory] mkdir -p ${OUTPUT_DIR} # Choose whether to cut the clips precisely based on the timestamps or to cut them fast based on keyframes. # The precise cutting will be slower but more accurate, while the fast cutting will be faster but may not be as accurate. FAST_CUT=False GPU_NUM=8 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NUM_WORKERS=$((GPU_NUM * 2)) ROOT_CLIPS=${OUTPUT_DIR}/clip ROOT_META=${OUTPUT_DIR}/meta ROOT_FIG=${OUTPUT_DIR}/fig ROOT_TEMP=${OUTPUT_DIR}/temp for dir in ${ROOT_CLIPS} ${ROOT_META} ${ROOT_FIG} ${ROOT_TEMP}; do if [ ! -d ${dir} ]; then mkdir -p ${dir} fi done measure_time() { local step_number=$1 shift local green="\e[32m" local red="\e[31m" local no_color="\e[0m" local yellow="\e[33m" start_time=$(date +%s) echo -e "${green}Step ${step_number} started at: $(date)${no_color}" "$@" end_time=$(date +%s) echo -e "${red}Step ${step_number} finished at: $(date)${no_color}" echo -e "${yellow}Duration: $((end_time - start_time)) seconds${no_color}" echo "---------------------------------------" } # 1.1 Create a meta file from a video folder. This should output ${ROOT_META}/meta.csv CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 1.1 python utils/convert.py \ --video_dir ${VIDEO_DIR} \ --output ${ROOT_META}/meta.csv # 1.2 Get video information and remove broken videos. This should output ${ROOT_META}/meta_info_fmin${fmin_1}.csv CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 1.2 python utils/get_info.py \ --csv_path ${ROOT_META}/meta.csv \ --csv_save_path ${ROOT_META}/meta_info.csv \ --backend "opencv" \ --num_workers 16 # 2.1 Detect scenes. This should output ${ROOT_META}/meta_info_fmin${fmin_1}_timestamp.csv # Also, you can set the params like "--start-remove-sec 0.5 --end-remove-sec 0.5" CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 2.1 python utils/scene_detect.py \ --csv_path ${ROOT_META}/meta_info.csv \ --backend "opencv" \ --num_workers 64 \ --frame_skip 2\ --start_remove_sec 0.3 \ --end_remove_sec 0.3 \ --min_seconds 3 \ --max_seconds 15 # 2.2 Get clips. This should output ${ROOT_META}/clips_info.csv CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 2.2 python utils/get_clip.py \ --csv_path ${ROOT_META}/meta_info_timestamp.csv \ --csv_save_dir ${ROOT_META} \ --num_workers $((GPU_NUM * 4)) # 2.3 Extract frames for scoring. CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 2.3 python utils/extract_frames.py \ --csv_path ${ROOT_META}/clips_info.csv \ --output_dir ${ROOT_FIG} \ --num_workers 64 \ --target_size "640*360" \ --backend "opencv" # 3.1 Predict aesthetic scores. This should output ${ROOT_META}/clips_info_aes.csv CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 3.1 torchrun --nproc_per_node ${GPU_NUM} scoring/aesthetic/inference.py \ --csv_path ${ROOT_META}/clips_info.csv \ --bs 16 \ --num_workers ${NUM_WORKERS} \ --fig_load_dir ${ROOT_FIG} # 3.2 Predict luminance scores. This should output ${ROOT_META}/clips_info_lum.csv CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 3.2 torchrun --nproc_per_node ${GPU_NUM} scoring/luminance/inference.py \ --csv_path ${ROOT_META}/clips_info.csv \ --bs 16 \ --num_workers ${NUM_WORKERS} \ --fig_load_dir ${ROOT_FIG} # 3.3 get motion score. This should output ${ROOT_META}/clips_info_motion.csv CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 3.3 python scoring/motion/inference.py \ --csv_path ${ROOT_META}/clips_info.csv \ --temp_save_dir ${ROOT_TEMP} \ --num_workers $((GPU_NUM * 4)) \ --gpu_num ${GPU_NUM} # 3.4 get text by OCR using PaddleOCR, this should output ${ROOT_META}/clips_info_ocr.csv CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 3.4 python scoring/ocr/inference.py \ --csv_path ${ROOT_META}/clips_info.csv \ --fig_load_dir ${ROOT_FIG} \ --num_workers $((GPU_NUM * 4)) \ --gpu_num ${GPU_NUM} # 4 merge all the scores. This should output ${ROOT_META}/clips_with_score.csv CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 4 python utils/merge_tables.py \ --csv_dir ${ROOT_META} \ --output ${ROOT_META}/clips_scores.csv # 5 Filter the clips. CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 5 python utils/filter.py \ --csv_path ${ROOT_META}/clips_scores.csv \ --csv_save_path ${ROOT_META}/filtered_clips.csv \ --aes_min 4 \ --lum_min 20 \ --lum_max 140 \ --motion_min 2 \ --motion_max 14 \ --ocr_max 0.3 # 6 Cut the clips. if [ "$FAST_CUT" = False ]; then echo "Using precise cutting based on timestamps." CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 6 python utils/cut.py \ --csv_path ${ROOT_META}/filtered_clips.csv \ --csv_save_path ${OUTPUT_DIR}/results.csv \ --video_save_dir ${ROOT_CLIPS} \ --num_workers $((GPU_NUM * 4)) \ --gpu_num $GPU_NUM \ # --keep_audio else echo "Using fast cutting based on keyframes." CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 6 python utils/cut_fast.py \ --csv_path ${ROOT_META}/filtered_clips.csv \ --csv_save_path ${OUTPUT_DIR}/results.csv \ --video_save_dir ${ROOT_CLIPS} \ --num_workers $((GPU_NUM * 4)) \ # --keep_audio fi ================================================ FILE: utils/README.md ================================================ # Utils - [`convert.py`](convert.py): convert all the paths of videos in a directory to a specific format, like csv. - [`cut.py`](cut.py): cut videos into clips. - [`download_SpatialVID.py`](download_SpatialVID.py): download the SpatialVID dataset. - [`download_YouTube.py`](download_YouTube.py): download videos from YouTube. - [`evaluate.py`](evaluate.py): evaluate the quality of video reconstructions. - [`expand_npz.py`](expand_npz.py): get dynamic masks compressed in a npz file. - [`extract_frames.py`](extract_frames.py): extract frames from videos. - [`filter.py`](filter.py): filter video clips based on score. - [`get_clip.py`](get_clip.py): get the clips separated from the video. - [`get_info.py`](get_info.py): get video information, such as duration and resolution. - [`get_instructions.py`](get_instructions.py): get motion instructions from camera poses. - [`get_instructions_enhanced.py`](get_instructions_enhanced.py): an enhanced version to get more detailed and accurate motion instructions from camera poses. - [`merge_tables.py`](merge_tables.py): merge multiple csv tables into one. - [`normalize_intrinsics.py`](normalize_intrinsics.py): normalize camera intrinsics. - [`pack_clip_assets.py`](pack_clip_assets.py): pack all the output files into an npz file for visualization. - [`quat_to_mat.py`](quat_to_mat.py): convert camera parameters to camera-to-world or world-to-camera matrices. - [`read_video.py`](read_video.py): read videos using opencv or av. - [`scene_detect.py`](scene_detect.py): separate videos into clips. ================================================ FILE: utils/__init__.py ================================================ ================================================ FILE: utils/convert.py ================================================ """ Video file conversion utility for the SpatialVID project. This module provides functionality to scan directories for video files, process them, and generate CSV metadata files containing video information. """ import argparse import os import time import pandas as pd # Supported video file extensions VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv", ".m2ts", ".webm") def scan_recursively(root): """ Recursively scan a directory tree and yield all entries. """ num = 0 for entry in os.scandir(root): if entry.is_file(): yield entry elif entry.is_dir(): num += 1 if num % 100 == 0: print(f"Scanned {num} directories.") yield from scan_recursively(entry.path) def get_filelist(file_path, exts=None): """ Get a list of files from a directory tree, optionally filtered by extensions. """ filelist = [] time_start = time.time() # Use recursive scanning to find all files obj = scan_recursively(file_path) for entry in obj: if entry.is_file(): ext = os.path.splitext(entry.name)[-1].lower() if exts is None or ext in exts: filelist.append(entry.path) time_end = time.time() print(f"Scanned {len(filelist)} files in {time_end - time_start:.2f} seconds.") return filelist def split_by_capital(name): """ Split a camelCase or PascalCase string by capital letters. """ new_name = "" for i in range(len(name)): if name[i].isupper() and i != 0: new_name += " " new_name += name[i] return new_name def process_general_videos(root, output): """ Process video files in a directory and generate a CSV metadata file. """ # Expand user path (e.g., ~ to home directory) root = os.path.expanduser(root) if not os.path.exists(root): return # Get list of video files with supported extensions path_list = get_filelist(root, VID_EXTENSIONS) # Note: In some cases (like realestate dataset), you might want to use: # path_list = get_filelist(root) # without extension filtering path_list = list(set(path_list)) # Remove duplicate entries # Extract filename without extension as ID fname_list = [os.path.splitext(os.path.basename(x))[0] for x in path_list] # Get relative paths from root directory relpath_list = [os.path.relpath(x, root) for x in path_list] # Create DataFrame with video metadata df = pd.DataFrame(dict(video_path=path_list, id=fname_list, relpath=relpath_list)) # Ensure output directory exists os.makedirs(os.path.dirname(output), exist_ok=True) df.to_csv(output, index=False) print(f"Saved {len(df)} samples to {output}.") if __name__ == "__main__": # Set up command line argument parser parser = argparse.ArgumentParser( description="Convert video directory structure to CSV metadata file" ) parser.add_argument("--video_dir", type=str, help="Root directory containing video files") parser.add_argument("--split", type=str, default="train", help="Dataset split name") parser.add_argument("--info", type=str, default=None, help="Additional info file") parser.add_argument( "--output", type=str, default=None, required=True, help="Output CSV file path" ) args = parser.parse_args() # Process videos and generate metadata CSV process_general_videos(args.video_dir, args.output) ================================================ FILE: utils/cut.py ================================================ """ Precise frame-level video cutting tool Strategy: Two-phase seek + forced keyframe alignment output """ import argparse import os import concurrent.futures from functools import partial import pandas as pd import subprocess from scenedetect import FrameTimecode from tqdm import tqdm FFMPEG_PATH = "/usr/local/bin/ffmpeg" def get_ffmpeg_acceleration(): try: output = subprocess.check_output( [FFMPEG_PATH, "-encoders"], stderr=subprocess.DEVNULL ).decode("utf-8") if "hevc_nvenc" in output: return "nvidia" return "cpu" except Exception as e: print(f"FFmpeg acceleration detection failed: {e}") return "cpu" ACCELERATION_TYPE = get_ffmpeg_acceleration() print(f"FFmpeg acceleration type: {ACCELERATION_TYPE}") # ════════════════════════════════════════════════════════════ # Core Utility Functions # ════════════════════════════════════════════════════════════ def seconds_to_timecode(seconds: float) -> str: """ Convert seconds to FFmpeg precise timecode string. Keep enough decimal places to ensure frame accuracy. Example: 1.033333 -> "0:00:01.033333" """ hours = int(seconds // 3600) minutes = int((seconds % 3600) // 60) secs = seconds % 60 # Keep 6 decimal places (microsecond-level precision) return f"{hours}:{minutes:02d}:{secs:09.6f}" def build_precise_cut_cmd( video_path: str, start_sec: float, end_sec: float, save_path: str, args, process_id: int, shorter_size: int | None, ) -> list[str]: """ Build frame-precise FFmpeg cut command. Strategy: Two-phase seek ┌──────────────────────────────────────────────────────────┐ │ -ss (pre, coarse seek) │ │ -> Jump to nearest keyframe before start_sec │ │ -> Avoid decoding from file start (speed optimize) │ │ │ │ -i input │ │ │ │ -ss (post, fine seek) │ │ -> Decode from coarse point to exact start_sec │ │ -> value = start_sec - coarse_seek (always positive) │ │ │ │ -t duration │ │ -> Exact duration │ │ │ │ Force re-encode (cannot use -c copy, otherwise │ │ start frame won't be precise) │ └──────────────────────────────────────────────────────────┘ """ duration = end_sec - start_sec if duration <= 0: raise ValueError(f"Invalid duration {duration:.4f}s (start={start_sec}, end={end_sec})") # ==== Phase 1: Coarse seek (pre seek) ==== # Safety margin: ensure coarse point is before start_sec keyframe # Too little -> may land after start_sec (seek ineffective) # Too much -> decode more frames (slightly slower) # Experience: max(GOP_size, 5s) covers most videos GOP_SAFETY_MARGIN = 5.0 coarse_seek = max(0.0, start_sec - GOP_SAFETY_MARGIN) # Offset for post precise seek = target time - coarse time fine_seek = start_sec - coarse_seek cmd = [FFMPEG_PATH, "-nostdin", "-y"] # ==== GPU hardware acceleration (decode phase) ==== if ACCELERATION_TYPE == "nvidia": cmd += [ "-hwaccel", "cuda", "-hwaccel_output_format", "cuda", "-hwaccel_device", str(process_id % args.gpu_num), ] # ==== Phase 1: Coarse seek (pre, fast jump to GOP boundary) ==== cmd += ["-ss", seconds_to_timecode(coarse_seek)] # ==== Input file ==== cmd += ["-i", video_path] # ==== Phase 2: Precise seek (post, decode from GOP boundary to exact frame) ==== # Only need post seek when fine_seek > 0 # When coarse_seek == 0, fine_seek == start_sec, still correct if fine_seek > 0.001: # Ignore errors less than 1ms cmd += ["-ss", seconds_to_timecode(fine_seek)] # ==== Exact duration ==== cmd += ["-t", seconds_to_timecode(duration)] # ==== Video filters (scale + fps) ==== filters = _build_video_filters(shorter_size, args, ACCELERATION_TYPE) if filters: cmd += ["-vf", ",".join(filters)] # ==== Encoder (must re-encode to ensure frame precision) ==== cmd += _build_encoder_args(ACCELERATION_TYPE) # ==== Frame rate ==== if args.target_fps is not None: cmd += ["-r", str(args.target_fps)] # ==== Audio ==== if args.keep_audio: cmd += ["-map", "0:v", "-map", "0:a?", "-c:a", "aac", "-b:a", "128k"] else: cmd += ["-map", "0:v", "-an"] # ==== Output: force keyframe at first frame for easy concatenation/playback ==== cmd += [ "-force_key_frames", "expr:gte(t,0)", # Force keyframe at second 0 save_path, ] return cmd def _build_video_filters(shorter_size, args, accel_type) -> list[str]: """Build video filter list""" filters = [] if shorter_size is not None: if accel_type == "nvidia": # CUDA scale filter scale = ( f"scale_cuda=" f"'if(gt(iw,ih),-2,{shorter_size})':" f"'if(gt(iw,ih),{shorter_size},-2)'" ) else: # Software scale: lanczos best quality, bicubic next scale = ( f"scale=" f"'if(gt(iw,ih),-2,{shorter_size})':" f"'if(gt(iw,ih),{shorter_size},-2)'" f":flags=lanczos" ) filters.append(scale) if args.target_fps is not None: # fps filter more accurate than -r parameter (-r sometimes drops frames) filters.append(f"fps={args.target_fps}") return filters def _build_encoder_args(accel_type) -> list[str]: """Build encoder arguments""" if accel_type == "nvidia": return [ "-c:v", "hevc_nvenc", "-preset", "p4", # p4=quality/speed balance, p7=slowest best "-rc", "vbr", "-cq", "24", # Quality factor, smaller is better (like CRF) "-b:v", "0", # No bitrate limit in VBR mode ] else: return [ "-c:v", "libx264", "-preset", "fast", # fast is best speed/quality for precise cutting "-crf", "18", # High quality (0=lossless, 23=default, 18=visually lossless) "-pix_fmt", "yuv420p", # Most compatible pixel format ] # ════════════════════════════════════════════════════════════ # Single Row Processing (maintains compatibility with original interface) # ════════════════════════════════════════════════════════════ def process_single_row(row, args, process_id): """ Precise frame-level cutting of a single segment. Returns: (row_values_list, valid, error_message) """ video_path = row["video_path"] save_dir = args.video_save_dir # # ==== Scale size calculation ==== shorter_size = args.shorter_size if (shorter_size is not None) and ("height" in row) and ("width" in row): min_size = min(row["height"], row["width"]) if min_size <= shorter_size: shorter_size = None # Already small enough, skip scaling (no upsample) # ==== Timestamp parsing ==== try: seg_start = FrameTimecode(timecode=row["timestamp_start"], fps=row["fps"]) seg_end = FrameTimecode(timecode=row["timestamp_end"], fps=row["fps"]) except Exception as e: error_msg = f"Invalid timestamp for id={row.get('id', '?')}: {e}" print(error_msg) return row.values.tolist(), False, error_msg start_sec = seg_start.get_seconds() end_sec = seg_end.get_seconds() duration = end_sec - start_sec if duration <= 0: error_msg = ( f"Invalid duration {duration:.4f}s for id={row.get('id','?')} " f"(start={start_sec:.4f}, end={end_sec:.4f})" ) print(error_msg) return row.values.tolist(), False, error_msg clip_id = row["id"] save_path = os.path.join(save_dir, f"{clip_id}.mp4") # ==== Skip if already exists ==== if os.path.exists(save_path) and os.path.getsize(save_path) > 0: row = row.copy() row["video_path"] = save_path return row.values.tolist(), True, "" # ==== Source file check ==== if not os.path.exists(video_path): error_msg = f"Source video not found: {video_path} (id={clip_id})" print(error_msg) return row.values.tolist(), False, error_msg # ==== Build precise cut command ==== try: cmd = build_precise_cut_cmd( video_path = video_path, start_sec = start_sec, end_sec = end_sec, save_path = save_path, args = args, process_id = process_id, shorter_size = shorter_size, ) except ValueError as e: error_msg = f"Command build failed for id={clip_id}: {e}" print(error_msg) return row.values.tolist(), False, error_msg # ==== Execute FFmpeg ==== try: subprocess.run(cmd, check=True, stderr=subprocess.PIPE) except subprocess.CalledProcessError as e: stderr_text = e.stderr.decode("utf-8", errors="replace") if e.stderr else str(e) error_msg = f"FFmpeg failed for id={clip_id}:\n{stderr_text}" print(error_msg) _cleanup(save_path) return row.values.tolist(), False, error_msg except Exception as e: error_msg = f"Unexpected error for id={clip_id}: {e}" print(error_msg) _cleanup(save_path) return row.values.tolist(), False, error_msg # ==== Basic integrity check ==== if not os.path.exists(save_path) or os.path.getsize(save_path) == 0: _cleanup(save_path) error_msg = f"FFmpeg produced empty/missing output for id={clip_id}" print(error_msg) return row.values.tolist(), False, error_msg row = row.copy() row["video_path"] = save_path return row.values.tolist(), True, "" def _cleanup(path: str): """Safely delete file""" try: if os.path.exists(path): os.remove(path) except OSError: pass # ════════════════════════════════════════════════════════════ # Argument Parsing # ════════════════════════════════════════════════════════════ def parse_args(): parser = argparse.ArgumentParser( description="Precise frame-level video cutting tool", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) # ==== Input/Output ==== parser.add_argument("--csv_path", type=str, required=True, help="Input CSV file path") parser.add_argument("--csv_save_path", type=str, required=True, help="Output CSV file path (success records)") parser.add_argument("--video_save_dir", type=str, required=True, help="Directory to save cut segments") # ==== Video parameters ==== parser.add_argument("--target_fps", type=int, default=None, help="Target frame rate (None=keep source frame rate)") parser.add_argument("--shorter_size", type=int, default=None, help="Short edge target size (maintain aspect ratio, no upsample)") parser.add_argument("--keep_audio", action="store_true", help="Keep audio track (default: discard)") # ==== Parallel control ==== parser.add_argument("--num_workers", type=int, default=None, help="Number of parallel workers (None=auto=CPU cores)") parser.add_argument("--disable_parallel", action="store_true", help="Disable parallel processing (for debugging)") parser.add_argument("--gpu_num", type=int, default=1, help="Number of available GPUs") # ==== Result handling ==== parser.add_argument("--drop_invalid_timestamps", action="store_true", help="Filter invalid timestamps and save corrected CSV") return parser.parse_args() # ════════════════════════════════════════════════════════════ # Parallel Worker # ════════════════════════════════════════════════════════════ def _worker_fn(task: tuple, args, process_id: int) -> tuple: """ Top-level worker function for ProcessPoolExecutor (must be serializable). Args: task: (index, row_dict) <- Use dict instead of Series to avoid serialization issues Returns: (index, row_values, valid, error_msg) """ index, row_dict = task # Restore dict to pandas Series (process_single_row depends on Series interface) row = pd.Series(row_dict) return (index,) + tuple(process_single_row(row, args, process_id)[0:3]) # Note: process_single_row returns (row_values, valid, error_msg) # Packed here as (index, row_values, valid, error_msg) # ════════════════════════════════════════════════════════════ # Result Saving # ════════════════════════════════════════════════════════════ def save_results(all_results: list, csv: pd.DataFrame, args): """ Save processing results to success/failure CSVs separately. Success CSV: Remove timestamp helper columns, update video_path to cut path Failure CSV: Keep all original columns, add error column """ columns = csv.columns.tolist() success_rows, failed_rows, failed_errors = [], [], [] for index, row_values, valid, error_msg in all_results: if valid: success_rows.append(row_values) else: failed_rows.append(row_values) failed_errors.append(error_msg) # ==== Save success records ==== if success_rows: success_df = pd.DataFrame(success_rows, columns=columns) # Remove cutting process helper columns (not needed by downstream) drop_cols = [ c for c in ["timestamp_start", "timestamp_end", "frame_start", "frame_end"] if c in success_df.columns ] if drop_cols: success_df = success_df.drop(columns=drop_cols) success_df.to_csv(args.csv_save_path, index=False) print(f"\n[OK] Success: {len(success_df)} records -> {args.csv_save_path}") else: print("\n[X] No success records") # ==== Save failure records ==== if failed_rows: base, ext = os.path.splitext(args.csv_save_path) failed_csv_path = f"{base}_failed{ext}" failed_df = pd.DataFrame(failed_rows, columns=columns) failed_df["error"] = failed_errors failed_df.to_csv(failed_csv_path, index=False) print(f"[X] Failed: {len(failed_df)} records -> {failed_csv_path}") # ==== Save corrected timestamps (optional) ==== if args.drop_invalid_timestamps and failed_rows: valid_indices = [r[0] for r in all_results if r[2]] filtered_csv = csv.iloc[valid_indices] assert args.csv_path.endswith("timestamp.csv"), \ "--drop_invalid_timestamps only supports *timestamp.csv files" corrected_path = args.csv_path.replace("timestamp.csv", "correct_timestamp.csv") filtered_csv.to_csv(corrected_path, index=False) print(f"[OK] Corrected timestamps -> {corrected_path}") # ════════════════════════════════════════════════════════════ # Main Function # ════════════════════════════════════════════════════════════ def main(): args = parse_args() # ==== Pre-check ==== if not os.path.exists(args.csv_path): print(f"[ERROR] CSV file does not exist: {args.csv_path}") return os.makedirs(args.video_save_dir, exist_ok=True) csv = pd.read_csv(args.csv_path) total = len(csv) print(f"Total {total} records to process") all_results = [] # ==== Serial mode ==== if args.disable_parallel: for index, row in tqdm(csv.iterrows(), total=total, desc="Cutting progress"): row_values, valid, error_msg = process_single_row(row, args, process_id=0) all_results.append((index, row_values, valid, error_msg)) # ==== Parallel mode ==== else: num_workers = args.num_workers or (os.cpu_count() or 1) num_workers = min(num_workers, total) # worker count not exceeding task count # Convert row to dict to avoid pandas Series serialization issues tasks = [ (index, row.to_dict()) for index, row in csv.iterrows() ] with concurrent.futures.ProcessPoolExecutor( max_workers=num_workers ) as executor: # Use enumerate to round-robin process_id (GPU rotation) futures = { executor.submit( _worker_fn, task, args, task_idx % max(args.gpu_num, 1), # GPU rotation ): task_idx for task_idx, task in enumerate(tasks) } with tqdm(total=total, desc="Cutting progress") as pbar: for future in concurrent.futures.as_completed(futures): try: result = future.result() # (index, row_values, valid, error_msg) all_results.append(result) except Exception as e: task_idx = futures[future] index, _ = tasks[task_idx] row_values = csv.iloc[index].values.tolist() all_results.append((index, row_values, False, str(e))) print(f"\n[ERROR] Worker exception (task_idx={task_idx}): {e}") finally: pbar.update(1) # ==== Sort by original order ==== all_results.sort(key=lambda x: x[0]) # ==== Statistics summary ==== success_count = sum(1 for r in all_results if r[2]) failed_count = total - success_count print(f"\n{'='*50}") print(f"Processing complete: Total={total}, Success={success_count}, Failed={failed_count}") print(f"{'='*50}") # ==== Save results ==== save_results(all_results, csv, args) if __name__ == "__main__": main() ================================================ FILE: utils/cut_fast.py ================================================ """ High-speed video cutting utility using FFmpeg stream copy. Features: - No re-encoding: uses `-c copy` - Optional audio: use --keep_audio to retain audio tracks - Group tasks by source video_path for better efficiency - Parallel processing by video group - Per-clip progress bar - Save successful and failed CSVs Notes: - This method is very fast, but not always frame-accurate. - Clip boundaries may align to nearby keyframes depending on source encoding. """ import argparse import os import queue import subprocess import concurrent.futures from multiprocessing import Manager import pandas as pd from scenedetect import FrameTimecode from tqdm import tqdm FFMPEG_PATH = "/usr/local/bin/ffmpeg" def process_single_row(row, save_dir, keep_audio=False): """ Cut one clip from source video using ffmpeg stream copy. Args: row: DataFrame row with clip metadata save_dir: directory to save output clips keep_audio: if True, copy audio streams; if False, drop audio Returns: (row_values_list, valid, error_message) """ video_path = row["video_path"] sample_id = row["id"] save_path = os.path.join(save_dir, f"{sample_id}.mp4") # Already exists -> treat as success if os.path.exists(save_path) and os.path.getsize(save_path) > 0: row = row.copy() row["video_path"] = save_path return row.values.tolist(), True, "" if not os.path.exists(video_path): error_msg = f"Source video not found: {video_path} (id={sample_id})" return row.values.tolist(), False, error_msg # Parse timestamps try: fps = row["fps"] seg_start = FrameTimecode(timecode=row["timestamp_start"], fps=fps) seg_end = FrameTimecode(timecode=row["timestamp_end"], fps=fps) start_sec = float(seg_start.get_seconds()) end_sec = float(seg_end.get_seconds()) duration = end_sec - start_sec if duration <= 0: error_msg = f"Non-positive duration for id={sample_id}: {duration}" return row.values.tolist(), False, error_msg except Exception as e: error_msg = f"Invalid timestamp for id={sample_id}: {e}" return row.values.tolist(), False, error_msg try: # Build stream mapping and audio arguments based on keep_audio flag. # '0:a?' uses '?' so FFmpeg silently skips if no audio track exists. if keep_audio: map_args = ["-map", "0:v:0", "-map", "0:a?"] audio_args = ["-c:a", "copy"] else: map_args = ["-map", "0:v:0"] audio_args = ["-an"] # Fast seek + stream copy; explicitly specify video codec to avoid ambiguity. cmd = [ FFMPEG_PATH, "-nostdin", "-y", "-ss", str(start_sec), "-t", str(duration), "-i", video_path, *map_args, *audio_args, "-c:v", "copy", "-avoid_negative_ts", "make_zero", save_path, ] subprocess.run( cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, ) # Verify output exists and non-empty if not os.path.exists(save_path) or os.path.getsize(save_path) == 0: if os.path.exists(save_path): os.remove(save_path) error_msg = f"FFmpeg produced empty/missing output for id={sample_id}" return row.values.tolist(), False, error_msg row = row.copy() row["video_path"] = save_path return row.values.tolist(), True, "" except subprocess.CalledProcessError as e: stderr_text = e.stderr.decode("utf-8", errors="ignore") if e.stderr else str(e) error_msg = f"FFmpeg failed for id={sample_id}: {stderr_text}" if os.path.exists(save_path): os.remove(save_path) return row.values.tolist(), False, error_msg except Exception as e: error_msg = f"Unexpected error for id={sample_id}: {e}" if os.path.exists(save_path): os.remove(save_path) return row.values.tolist(), False, error_msg def process_video_group(group_df, save_dir, keep_audio=False): """ Process all clips from the same source video. Args: group_df: DataFrame containing rows from one source video_path save_dir: output clip directory keep_audio: passed through to process_single_row Returns: list of tuples: (index, row_values, valid, error_msg) """ results = [] # Sort by start timestamp to make access pattern a bit more sequential if "timestamp_start" in group_df.columns: group_df = group_df.sort_values(by="timestamp_start") for index, row in group_df.iterrows(): row_values, valid, error_msg = process_single_row( row, save_dir, keep_audio=keep_audio ) results.append((index, row_values, valid, error_msg)) return results def worker(task_queue, results_queue, video_save_dir, keep_audio=False): """ Worker that processes one video group at a time. """ while True: try: video_path, group_df = task_queue.get(timeout=1) except queue.Empty: break try: group_results = process_video_group( group_df, video_save_dir, keep_audio=keep_audio ) for item in group_results: results_queue.put(item) finally: task_queue.task_done() def parse_args(): parser = argparse.ArgumentParser( description="Fast video cutting utility using FFmpeg stream copy" ) parser.add_argument("--csv_path", type=str, required=True, help="Input CSV path") parser.add_argument( "--csv_save_path", type=str, required=True, help="Output CSV path" ) parser.add_argument( "--video_save_dir", type=str, required=True, help="Directory to save clips" ) parser.add_argument( "--num_workers", type=int, default=None, help="Number of parallel workers (defaults to CPU count)", ) parser.add_argument( "--disable_parallel", action="store_true", help="Disable parallel processing", ) parser.add_argument( "--drop_invalid_timestamps", action="store_true", help="Drop invalid timestamp rows and save corrected CSV", ) parser.add_argument( "--keep_audio", action="store_true", help="Retain audio tracks in output clips (dropped by default)", ) return parser.parse_args() def main(): args = parse_args() if not os.path.exists(args.csv_path): print(f"csv file '{args.csv_path}' not found. Exit.") return os.makedirs(args.video_save_dir, exist_ok=True) csv = pd.read_csv(args.csv_path) if len(csv) == 0: print("Input CSV is empty. Exit.") return required_cols = ["id", "video_path", "timestamp_start", "timestamp_end", "fps"] missing_cols = [c for c in required_cols if c not in csv.columns] if missing_cols: raise ValueError(f"Missing required columns: {missing_cols}") results = [] # Group by source video grouped_items = list(csv.groupby("video_path", sort=False)) total_tasks = len(csv) if args.disable_parallel: success_cnt = 0 fail_cnt = 0 with tqdm(total=total_tasks, desc="Processing clips", dynamic_ncols=True) as pbar: for video_path, group_df in grouped_items: group_results = process_video_group( group_df, args.video_save_dir, keep_audio=args.keep_audio ) for item in group_results: results.append(item) _, _, valid, _ = item if valid: success_cnt += 1 else: fail_cnt += 1 pbar.update(1) pbar.set_postfix(success=success_cnt, fail=fail_cnt) else: manager = Manager() task_queue = manager.Queue() results_queue = manager.Queue() for video_path, group_df in grouped_items: task_queue.put((video_path, group_df)) num_workers = args.num_workers if args.num_workers else os.cpu_count() num_workers = max(1, num_workers) with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor: futures = [] for _ in range(num_workers): futures.append( executor.submit( worker, task_queue, results_queue, args.video_save_dir, args.keep_audio, # Forward keep_audio flag to each worker ) ) finished = 0 success_cnt = 0 fail_cnt = 0 with tqdm(total=total_tasks, desc="Processing clips", dynamic_ncols=True) as pbar: while finished < total_tasks: try: item = results_queue.get(timeout=1) except queue.Empty: continue results.append(item) finished += 1 _, _, valid, _ = item if valid: success_cnt += 1 else: fail_cnt += 1 pbar.update(1) pbar.set_postfix(success=success_cnt, fail=fail_cnt) for future in futures: future.result() # Sort back by original row index results.sort(key=lambda x: x[0]) # Separate successful and failed success_rows = [] failed_rows = [] failed_errors = [] for index, row_values, valid, error_msg in results: if valid: success_rows.append(row_values) else: failed_rows.append(row_values) failed_errors.append(error_msg) # Optional corrected timestamp CSV if args.drop_invalid_timestamps: valid_indices = [r[0] for r in results if r[2]] filtered_csv = csv.iloc[valid_indices] if args.csv_path.endswith("timestamp.csv"): corrected_path = args.csv_path.replace("timestamp.csv", "correct_timestamp.csv") else: base, ext = os.path.splitext(args.csv_path) corrected_path = f"{base}_corrected{ext}" filtered_csv.to_csv(corrected_path, index=False) print(f"Corrected timestamp file saved to '{corrected_path}'") columns = csv.columns # Save successful clips CSV if success_rows: success_df = pd.DataFrame(success_rows, columns=columns) for col in ["timestamp_start", "timestamp_end", "frame_start", "frame_end"]: if col in success_df.columns: success_df = success_df.drop(columns=[col]) success_df.to_csv(args.csv_save_path, index=False) print(f"Saved {len(success_df)} successful clip(s) to {args.csv_save_path}.") else: print("No successful clips were generated.") # Save failed clips CSV if failed_rows: base, ext = os.path.splitext(args.csv_save_path) failed_csv_path = f"{base}_failed{ext}" failed_df = pd.DataFrame(failed_rows, columns=columns) failed_df["error"] = failed_errors failed_df.to_csv(failed_csv_path, index=False) print(f"Saved {len(failed_df)} failed record(s) to {failed_csv_path}.") if __name__ == "__main__": main() ================================================ FILE: utils/download_SpatialVID.py ================================================ import argparse from huggingface_hub import hf_hub_download, snapshot_download def main(): # Setup command line arguments parser = argparse.ArgumentParser( description="Download SpatialVID dataset from Hugging Face Hub." ) parser.add_argument( "--repo_id", type=str, choices=["SpatialVID", "SpatialVID-HQ"], required=True, help="Dataset type to download (SpatialVID or SpatialVID-HQ)", ) parser.add_argument( "--type", type=str, choices=["videos", "annotations", "depths", "metadata", "all"], required=True, help="Type of data to download (videos, annotations, metadata, all)", ) parser.add_argument( "--group_id", type=int, help="Specific group ID to download (e.g., 'group_1'). If not provided, downloads all groups.", default=None, ) parser.add_argument( "--output_dir", type=str, help="Local directory to save dataset", default="./SpatialVID_data", ) args = parser.parse_args() repo_id = f"SpatialVID/{args.repo_id}" # Download csv metadata if args.type == "metadata": hub_path = f"data/train/{args.repo_id.replace('-', '_')}_metadata.csv" hf_hub_download( repo_id=repo_id, repo_type="dataset", filename=hub_path, local_dir=args.output_dir, resume_download=True, ) print(f"Downloaded file '{hub_path}' from {repo_id} to {args.output_dir}") # Download specific group elif args.group_id: hub_path = f"{args.type}/group_{args.group_id:04d}.tar.gz" hf_hub_download( repo_id=repo_id, repo_type="dataset", filename=hub_path, local_dir=args.output_dir, resume_download=True, ) print(f"Downloaded file '{hub_path}' from {repo_id} to {args.output_dir}") # Download entire type directory elif args.type == "all": snapshot_download( repo_id=repo_id, repo_type="dataset", local_dir=args.output_dir, resume_download=True, ) print(f"Downloaded entire dataset from {repo_id} to {args.output_dir}") if __name__ == "__main__": main() ================================================ FILE: utils/download_YouTube.py ================================================ """ Utility script to download YouTube videos using yt-dlp with support for concurrency and sharding. Adapted from https://huggingface.co/Ligeng-Zhu/panda70m-download running script: python download_YouTube.py --csv="$csv_file" # this csv file must contains 'url' column if you want to download a specific youtube video, consider using: - yt-dlp -F --list-formats https://www.youtube.com/watch\?v\=omP01s7RUSA # --proxy 127.0.0.1:xxxx --cookies cookies.txt run 'ls -l /path/to/folder/*.json | wc -l' for counting the videos already downloaded Customization Guide: For customizing download settings (such as video format, cookie configurations like automatic Chrome cookie retrieval or custom cookie file usage), refer to the official documentation at https://github.com/yt-dlp/yt-dlp-wiki. """ import sys, os, os.path as osp import yt_dlp import asyncio from concurrent.futures import ProcessPoolExecutor import fire import pandas as pd import json import time def ytb_download(url, json_info, output_dir="ytb_videos/"): """ Download a specified YouTube video using yt-dlp and save related metadata. """ os.makedirs(output_dir, exist_ok=True) uid = url.split("?v=")[-1] yt_opts = { "format": "bv[height=720][ext=mp4]" # "format": "bv[height=720]", # Download the best quality available # "format": "bv[height=720][ext=mp4][vcodec!^=av]", # "proxy": "127.0.0.1:xxxx", "outtmpl": osp.join(output_dir, f"{uid}.%(ext)s"), # Output template # "cookiesfrombrowser": "chrome", # Use Chrome's cookies automatically # "cookiefile": "cookies.txt", # Use a custom cookies file # "postprocessors": [ # { # "key": "FFmpegVideoConvertor", # "preferedformat": "mp4", # Convert video to mp4 format (slow) # } # ], # "verbose" : True, "abort-on-error": True, # Abort downloading when an error occurs "retries": 60, # Number of retries "ffmpeg_location": "/usr/bin/ffmpeg", # Path to ffmpeg "quiet": True, # Suppress output "sleep-requested": 5, # Sleep for 1.25 seconds between requests "min-sleep-interval": 60, "max-sleep-interval": 90, } video_path_mp4 = osp.join(output_dir, f"{uid}.mp4") video_path_webm = osp.join(output_dir, f"{uid}.webm") meta_path = osp.join(output_dir, f"{uid}.json") if (osp.exists(video_path_mp4) or osp.exists(video_path_webm)) and osp.exists( meta_path ): print(f"\033[91m{uid} already labeled.\033[0m") return 0 try: with yt_dlp.YoutubeDL(yt_opts) as ydl: ydl.download([url]) with open(meta_path, "w") as fp: json.dump(json_info, fp, indent=2) return 0 # exception logs except Exception as e: print(f"\033[91mError downloading {url}: {e}\033[0m") err_map = { "Requested format is not available": "z0322_dld_format_noavailable.log", "removed by": "z0322_dld_removed_by.log", "Private video": "z0322_dld_private_video.log", } for key, log_file in err_map.items(): if key in str(e): with open(osp.join(output_dir, f"{log_file}"), "a") as f: f.write(f"{url}\n") break else: with open(osp.join(output_dir, f"z0322_dld_othererr.log"), "a") as f: f.write(f"{url}, {str(e)}\n") return -1 async def main(csv_path, output_dir, max_workers=10, shards=0, total=-1, limit=False): """ Batch download YouTube videos specified in a CSV file, supporting sharding and concurrency. """ PPE = ProcessPoolExecutor(max_workers=max_workers) loop = asyncio.get_event_loop() df = pd.read_csv(csv_path) csv_path = os.path.basename(csv_path) output_dir = f'{output_dir}/{csv_path.split(".")[0]}' data_list = list(df.iterrows()) if total > 0: chunk = len(data_list) // total begin_idx = shards * chunk end_idx = (shards + 1) * chunk if shards < total - 1 else len(data_list) data_list = data_list[begin_idx:end_idx] print(f"download total {len(data_list)} videos") tasks = [] for idx, (index, row) in enumerate(data_list): video_url = row["url"] # json_info = {"caption": row["caption"]} json_info = {"caption": ''} # for file checking. tasks.append( loop.run_in_executor(PPE, ytb_download, video_url, json_info, output_dir) ) if limit and idx >= 20: break res = await asyncio.gather(*tasks) print(f"[{sum(res)} / {len(res)}]") def entry( csv="meta_data_sample_500.csv", output_dir="path/to/output", shards=0, total=-1, limit=False, max_workers=2, ): """ Command line entry function, supports fire invocation. """ print(csv, output_dir, shards, total, max_workers) start_time = time.time() print( f"\033[92mStarting execution at {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))}\033[0m" ) asyncio.run( main( csv, output_dir, max_workers=max_workers, shards=shards, total=total, limit=limit, ) ) end_time = time.time() print( f"\033[92mFinished execution at {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))}\033[0m" ) print(f"\033[92mTotal execution time: {end_time - start_time:.2f} seconds\033[0m") def add_download(csv_path): """ Download missing videos according to the new_vid_path field in the CSV file. """ data = pd.read_csv(csv_path) unique_ids = data['YouTube id'].unique() for uid in unique_ids: video_url = f"https://www.youtube.com/watch?v={uid}" ytb_download(video_url, json_info={}, output_dir="videos/") print(f"Downloaded {video_url}") if __name__ == "__main__": # Call entry function via command line arguments fire.Fire(entry) # for supplement download: add_download(csv_path='xxx.csv') ================================================ FILE: utils/evaluation.py ================================================ """ Camera trajectory evaluation utility with anomaly detection and motion analysis. """ import os import argparse import pandas as pd import numpy as np import torch import concurrent.futures import multiprocessing as mp from multiprocessing import Manager import queue from tqdm import tqdm from scipy.signal import find_peaks from scipy.ndimage import gaussian_filter # Import mask utility functions from expand_npz import expand def load_file(cam_pos_file, mask_file, device): """Load camera parameters and dynamic masks from files""" try: # Load camera parameters and split into position and rotation params = torch.from_numpy(np.load(cam_pos_file)).float().to(device) cam_pos = params[:, :3] # Position coordinates cam_rotate = params[:, 3:] # Rotation quaternions time_steps = params.shape[0] # Load and expand dynamic masks masks = torch.from_numpy(expand(np.load(mask_file))).to(device) except FileNotFoundError: print(f"Error: File not found - {cam_pos_file}") exit() except Exception as e: print(f"Error processing {cam_pos_file}: {e}") exit() return cam_pos, cam_rotate, time_steps, masks def anomaly_detection(cam_pos, time_steps, threshold, device): """Detect trajectory anomalies using linear prediction with acceleration""" if time_steps < 4: return True # Not enough data preds = torch.zeros((time_steps, 3), dtype=torch.float32, device=device) error_count = 0 # Linear prediction with acceleration for t in range(0, time_steps - 3): # Calculate velocity and acceleration v1 = cam_pos[t + 2] - cam_pos[t + 1] v2 = cam_pos[t + 1] - cam_pos[t] acceleration = v1 - v2 # Predict next position preds[t + 3] = cam_pos[t + 2] + v1 + 0.5 * acceleration # Check prediction error error = torch.sqrt(torch.sum((preds[t + 3] - cam_pos[t + 3]) ** 2)) if error > 0.03: error_count += 1 if error_count >= threshold: return True else: error_count = 0 return False def move_distance(cam_pos, time_steps, device): """Calculate total movement distance and classify into levels""" total_distance = torch.tensor(0., dtype=torch.float32, device=device) # Distance thresholds for classification thresholds = [0.08, 0.28, 0.92, 2.41] # Calculate cumulative distance for i in range(0, time_steps - 1): total_distance += torch.norm(cam_pos[i + 1] - cam_pos[i]) # Determine movement level distance_val = total_distance.item() level = sum(1 for threshold in thresholds if distance_val >= threshold) return distance_val, level def quaternion_multiply(q1, q2): """Multiply two quaternions""" # Extract components (q in [x, y, z, w] format) w1, x1, y1, z1 = q1[3], q1[0], q1[1], q1[2] w2, x2, y2, z2 = q2[3], q2[0], q2[1], q2[2] # Quaternion multiplication matrix = torch.tensor([ [w1, -z1, y1, x1], [z1, w1, -x1, y1], [-y1, x1, w1, z1], [-x1, -y1, -z1, w1] ], dtype=q1.dtype, device=q1.device) vector = torch.tensor([x2, y2, z2, w2], dtype=q2.dtype, device=q2.device) result = torch.matmul(matrix, vector) return result def rotation_angle(cam_rotate, time_steps, device): """Calculate total rotation angle between consecutive frames""" total_radians = torch.tensor(0.0, device=device) for i in range(0, time_steps - 1): q1 = cam_rotate[i] q2 = cam_rotate[i + 1] # Calculate relative rotation q1_inverse = torch.stack([-q1[0], -q1[1], -q1[2], q1[3]], dim=0) q_relative = quaternion_multiply(q2, q1_inverse) w = torch.clamp(q_relative[3], -1.0, 1.0) # Convert to angle rotation_angle_rad = 2 * torch.arccos(w) total_radians += rotation_angle_rad return total_radians.item() def trajectory_turns(cam_pos, time_steps, device, threshold=0.45): """Detect significant turns in camera trajectory""" if time_steps < 3: return [], 0 angles = [] # Calculate angles between trajectory segments for t in range(1, time_steps - 1): v1 = cam_pos[t] - cam_pos[0] v2 = cam_pos[time_steps - 1] - cam_pos[t] # Avoid division by zero v1_norm = torch.norm(v1) v2_norm = torch.norm(v2) if v1_norm < 1e-8 or v2_norm < 1e-8: continue # Calculate angle between vectors cos_theta = torch.dot(v1, v2) / (v1_norm * v2_norm) cos_theta = torch.clamp(cos_theta, -1.0, 1.0) angle = torch.arccos(cos_theta) angles.append(angle.item()) # Smooth and find peaks angles = gaussian_filter(angles, sigma=5) peaks, _ = find_peaks(angles, height=threshold, distance=5) peaks_values = [angles[i] for i in peaks] # Include maximum angle if significant max_angle = max(angles) if max_angle > threshold and max_angle not in peaks_values: peaks_values.append(max_angle) return len(peaks_values) def dynamic_ratio(masks): """Calculate ratio of dynamic pixels in video frames""" # Downsample for efficiency masks = masks[::5, :, :] dynamic_pixels = torch.sum(masks) total_pixels = masks.shape[1] * masks.shape[2] * masks.shape[0] return (dynamic_pixels / total_pixels).item() def process_single_row(row, index, args, device): """Process a single video row to extract trajectory metrics""" video_id = row['id'] rec_path = os.path.join(args.dir_path, video_id, "reconstructions") cam_pos_file = os.path.join(rec_path, "poses.npy") mask_file = os.path.join(rec_path, "dyn_masks.npz") # Check file existence if not os.path.exists(cam_pos_file) or not os.path.exists(mask_file): print(f"File not found: {cam_pos_file} or {mask_file}") return False, False, -1, -1, -1, -1, -1 # Load and process data cam_pos, cam_rotate, time_steps, masks = load_file(cam_pos_file, mask_file, device) # Calculate metrics anomaly = anomaly_detection(cam_pos, time_steps, args.anomaly_threshold, device) move_dist, dist_level = move_distance(cam_pos, time_steps, device) rot_angle = rotation_angle(cam_rotate, time_steps, device) traj_turns = trajectory_turns(cam_pos, time_steps, device) dyn_ratio = dynamic_ratio(masks) return True, anomaly, move_dist, dist_level, rot_angle, traj_turns, dyn_ratio def worker(task_queue, result_queue, args, worker_id): """Worker function for parallel processing""" # Assign GPU based on worker ID device = torch.device( f"cuda:{worker_id % args.gpu_num}" if torch.cuda.is_available() else "cpu" ) while True: try: index, row = task_queue.get(timeout=1) except queue.Empty: break result = process_single_row(row, index, args, device) result_queue.put((index, result)) task_queue.task_done() def parse_args(): """Parse command line arguments""" parser = argparse.ArgumentParser(description="Camera Trajectory Evaluation") parser.add_argument("--csv_path", type=str, help="Path to input CSV file") parser.add_argument("--dir_path", type=str, default="./outputs", help="Base directory with reconstruction data") parser.add_argument("--output_path", type=str, default="./outputs/evaluation_results.csv", help="Output CSV path") parser.add_argument("--anomaly_threshold", type=int, default=2, help="Anomaly detection threshold") parser.add_argument('--gpu_num', type=int, default=1, help='Number of GPUs to use') parser.add_argument("--num_workers", type=int, default=4, help="Number of parallel workers") parser.add_argument("--disable_parallel", action="store_true", help="Disable parallel processing") return parser.parse_args() if __name__ == "__main__": # Setup multiprocessing mp.set_start_method('spawn') args = parse_args() # Load input data df = pd.read_csv(args.csv_path) results = [] if args.disable_parallel: # Sequential processing for index, row in tqdm(df.iterrows(), total=len(df), desc="Processing videos"): device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu") result = process_single_row(row, index, args, device) results.append((index, result)) else: # Parallel processing manager = Manager() task_queue = manager.Queue() # Add tasks to queue for index, row in df.iterrows(): task_queue.put((index, row)) result_queue = manager.Queue() # Run workers with concurrent.futures.ProcessPoolExecutor(max_workers=args.num_workers) as executor: futures = [] for worker_id in range(args.num_workers): futures.append(executor.submit(worker, task_queue, result_queue, args, worker_id)) processed = 0 total_tasks = len(df) with tqdm(total=total_tasks, desc="Processing videos") as pbar: while processed < total_tasks: try: index, result = result_queue.get(timeout=1) results.append((index, result)) processed += 1 pbar.update(1) except queue.Empty: if all(f.done() for f in futures) and result_queue.empty(): break for future in futures: future.result() # Collect results while not result_queue.empty(): index, result = result_queue.get() results.append((index, result)) # Sort and save results results.sort(key=lambda x: x[0]) df['success'] = [result[1][0] for result in results] df['anomaly'] = [result[1][1] for result in results] df['moveDist'] = [result[1][2] for result in results] df['distLevel'] = [result[1][3] for result in results] df['rotAngle'] = [result[1][4] for result in results] df['trajTurns'] = [result[1][5] for result in results] df['dynamicRatio'] = [result[1][6] for result in results] df.to_csv(args.output_path, index=False) print(f"Results saved to {args.output_path}") ================================================ FILE: utils/expand_npz.py ================================================ """ Mask utility functions for processing sparse matrix data. """ import numpy as np from scipy.sparse import csr_matrix def expand(loaded_data): """ Reconstruct 3D mask from sparse matrix data. Args: loaded_data (dict): Dictionary containing sparse matrix data with keys: - 'shape': Original matrix dimensions - 'f_{i}_data': Sparse matrix data for frame i - 'f_{i}_indices': Sparse matrix indices for frame i - 'f_{i}_indptr': Sparse matrix index pointers for frame i Returns: np.ndarray: 3D array with shape (frames, height, width) """ reconstructed_sparse_matrices = [] num_frames = (len(loaded_data) - 1) // 3 # Calculate number of frames matrix_shape = loaded_data['shape'] # Get original matrix dimensions # Reconstruct sparse matrix for each frame for i in range(num_frames): data = loaded_data[f'f_{i}_data'] indices = loaded_data[f'f_{i}_indices'] indptr = loaded_data[f'f_{i}_indptr'] reconstructed_matrix = csr_matrix((data, indices, indptr), shape=matrix_shape) reconstructed_sparse_matrices.append(reconstructed_matrix) # Stack all frames into a 3D array (frames, height, width) reconstructed_mask_3d = np.stack([m.toarray() for m in reconstructed_sparse_matrices], axis=0) return reconstructed_mask_3d ================================================ FILE: utils/extract_frames.py ================================================ """ Video frame extraction utility with parallel processing support. """ import os import sys import cv2 import av import glob import argparse import pandas as pd import queue import concurrent.futures from multiprocessing import Manager from tqdm import tqdm def extract_frames_opencv( video_path, output_dir, interval, frame_start, num_frames, target_size=None ): """Extract frames from video at specified intervals""" # Create output directory if not os.path.exists(output_dir): os.makedirs(output_dir) # Open video file cap = cv2.VideoCapture(video_path) cap.set(cv2.CAP_PROP_POS_FRAMES, frame_start) if not cap.isOpened(): print(f"Error: Could not open video file {video_path}") sys.exit(1) # Extract frames for frame in range(num_frames): ret, image = cap.read() if not ret: break # Save frame at specified intervals if frame % interval == 0: frame_filename = os.path.join(output_dir, f"frame_{frame:06d}.jpg") if target_size is not None: h, w = image.shape[:2] # Adaptively adjust target size based on video orientation # For portrait videos (height > width), swap width and height of target size if h > w: # Portrait video target_w, target_h = target_size[1], target_size[0] else: # Landscape video target_w, target_h = target_size[0], target_size[1] image = cv2.resize(image, (target_w, target_h)) cv2.imwrite(frame_filename, image) cap.release() def extract_frames_av( video_path, output_dir, interval, frame_start, num_frames, target_size=None ): """ Extract frames from video at specified intervals using PyAV backend. """ # Create output directory if not os.path.exists(output_dir): os.makedirs(output_dir) try: # Open video file container = av.open(video_path) stream = container.streams.video[0] stream.thread_type = 'AUTO' except Exception as e: print(f"Error: Could not open video file {video_path}. Reason: {e}") return # Get video properties fps = float(stream.average_rate) time_base = stream.time_base target_sec = frame_start / fps # Set a small tolerance (e.g., half a frame time) to prevent frame loss due to floating-point precision issues epsilon = 0.5 / fps # Seek to the target start time if frame_start > 0: target_pts = int(target_sec / time_base) container.seek(target_pts, stream=stream, backward=True) count = 0 for packet in container.demux(stream): try: for frame in packet.decode(): if frame.pts is None: continue current_sec = frame.pts * time_base if current_sec < (target_sec - epsilon): continue if count >= num_frames: break if count % interval == 0: image = frame.to_ndarray(format='bgr24') frame_filename = os.path.join(output_dir, f"frame_{count:06d}.jpg") if target_size is not None: if isinstance(target_size, str): w, h = map(int, target_size.split('*')) target_size = (w, h) h, w = image.shape[:2] if h > w: target_w, target_h = target_size[1], target_size[0] else: target_w, target_h = target_size[0], target_size[1] image = cv2.resize(image, (target_w, target_h)) cv2.imwrite(frame_filename, image) count += 1 if count >= num_frames: break except av.error.InvalidDataError: continue # 跳过损坏的包 container.close() def _calc_expected_frames(num_frames, interval): """Calculate the expected number of output frames based on total frames and interval.""" if interval <= 0: return num_frames # Frames at indices 0, interval, 2*interval, ... that are < num_frames return (num_frames - 1) // interval + 1 def _verify_frames(img_dir, expected_frames): """Check if img_dir has enough valid (non-empty) frame files. Returns True if the directory exists and contains at least `expected_frames` non-zero-byte frame_*.jpg files. """ if not os.path.isdir(img_dir): return False frame_files = glob.glob(os.path.join(img_dir, "frame_*.jpg")) if len(frame_files) < expected_frames: return False if any(os.path.getsize(f) == 0 for f in frame_files): return False return True def process_single_row(row, row_index, args): """Process a single video row to extract frames. Returns: True if processing succeeded or was skipped (already done), False if an error occurred. """ video_path = row["video_path"] frame_start = row.get("frame_start", 0) num_frames = row["num_frames"] output_dir = os.path.join(args.output_dir, row["id"]) img_dir = os.path.join(output_dir, "img") # Calculate frame extraction interval if args.interval is None: interval = row["num_frames"] // 3 # Extract 3 frames by default elif args.interval == 0: interval = 1 # Extract every frame else: interval = int(args.interval * row["fps"]) expected_frames = _calc_expected_frames(num_frames, interval) # --- Skip logic: already has enough valid frames --- if _verify_frames(img_dir, expected_frames): return True if not os.path.exists(output_dir): os.makedirs(output_dir) try: if args.backend == "opencv": extract_frames_opencv( video_path, img_dir, interval, frame_start, num_frames, args.target_size ) elif args.backend == "av": extract_frames_av( video_path, img_dir, interval, frame_start, num_frames, args.target_size ) # Post-extraction verification if not _verify_frames(img_dir, expected_frames): actual_count = len(glob.glob(os.path.join(img_dir, "frame_*.jpg"))) print( f"[Verify FAIL] {row['id']}: expected {expected_frames} frames, " f"got {actual_count} (or contains empty files)." ) return False return True except Exception as e: print(f"Error: Could not extract frames from video {video_path}. Reason: {e}") return False def worker(task_queue, progress_queue, failed_indices, args): """Worker function for parallel frame extraction""" while True: try: index, row = task_queue.get(timeout=1) except queue.Empty: break success = process_single_row(row, index, args) if not success: failed_indices.append(index) progress_queue.put(index) task_queue.task_done() def parse_args(): """Parse command line arguments""" parser = argparse.ArgumentParser(description="Extract frames from video files") parser.add_argument( "--csv_path", type=str, help="Path to CSV file with video csvdata" ) parser.add_argument( "--output_dir", type=str, default="extract_frames", help="Output directory for extracted frames", ) parser.add_argument( "--interval", type=float, default=None, help="Frame extraction interval in seconds (set to 0 to extract every frame)", ) parser.add_argument( "--target_size", type=str, default=None, help="Resize frames to size (width*height). For portrait videos (h>w), dimensions will be automatically swapped to (height*width) to maintain correct orientation.", ) parser.add_argument( "--num_workers", type=int, default=None, help="Number of parallel workers" ) parser.add_argument( "--backend", type=str, default="opencv", choices=["opencv", "av"], help="Backend for video reading", ) parser.add_argument( "--disable_parallel", action="store_true", help="Disable parallel processing" ) return parser.parse_args() def main(): """Main function to process frame extraction""" args = parse_args() # Parse target size if provided if args.target_size is not None: args.target_size = tuple(map(int, args.target_size.split("*"))) # Create output directory if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) # Load video csvdata csv = pd.read_csv(args.csv_path) failed_indices = [] if args.disable_parallel: # Sequential processing for index, row in tqdm( csv.iterrows(), total=len(csv), desc="Processing videos" ): success = process_single_row(row, index, args) if not success: failed_indices.append(index) else: # Parallel processing num_workers = args.num_workers if args.num_workers else os.cpu_count() or 1 manager = Manager() task_queue = manager.Queue() progress_queue = manager.Queue() shared_failed_indices = manager.list() # Add tasks to queue for index, row in csv.iterrows(): task_queue.put((index, row)) # Execute workers with concurrent.futures.ProcessPoolExecutor( max_workers=num_workers ) as executor: futures = [] for _ in range(num_workers): future = executor.submit(worker, task_queue, progress_queue, shared_failed_indices, args) futures.append(future) processed = 0 total_tasks = len(csv) with tqdm(total=total_tasks, desc="Processing videos") as pbar: while processed < total_tasks: try: progress_queue.get(timeout=1) processed += 1 pbar.update(1) except queue.Empty: if all(f.done() for f in futures) and progress_queue.empty(): break for future in futures: future.result() failed_indices = list(shared_failed_indices) # Save failed rows to a separate CSV; keep only successful rows in the original CSV if failed_indices: failed_csv = csv.loc[failed_indices] base, ext = os.path.splitext(args.csv_path) failed_csv_path = f"{base}_failed{ext}" failed_csv.to_csv(failed_csv_path, index=False) csv = csv.drop(index=failed_indices) csv.to_csv(args.csv_path, index=False) print(f"\n{len(failed_indices)} video(s) failed. Saved to: {failed_csv_path}") print(f"Original CSV updated. Remaining rows: {len(csv)}") else: print("\nAll videos processed successfully.") if __name__ == "__main__": main() ================================================ FILE: utils/filter.py ================================================ """ Dataset filtering utility for video metadata with various quality metrics. """ import argparse import os import random from glob import glob import numpy as np import pandas as pd def main(args): """Apply filtering criteria to dataset""" # Load data data = pd.read_csv(args.csv_path) # Apply filters based on various metrics if args.frames_min is not None: assert "num_frames" in data.columns data = data[data["num_frames"] >= args.frames_min] if args.frames_max is not None: assert "num_frames" in data.columns data = data[data["num_frames"] <= args.frames_max] if args.fps_max is not None: assert "fps" in data.columns data = data[(data["fps"] <= args.fps_max) | np.isnan(data["fps"])] if args.fps_min is not None: assert "fps" in data.columns data = data[(data["fps"] >= args.fps_min) | np.isnan(data["fps"])] if args.resolution_max is not None: if "resolution" not in data.columns: height = data["height"] width = data["width"] data["resolution"] = height * width data = data[data["resolution"] <= args.resolution_max] if args.aes_min is not None: assert "aesthetic score" in data.columns data = data[data["aesthetic score"] >= args.aes_min] if args.ocr_max is not None: assert "ocr score" in data.columns data = data[data["ocr score"] <= args.ocr_max] if args.ocr_min is not None: assert "ocr score" in data.columns data = data[data["ocr score"] >= args.ocr_min] if args.lum_min is not None: assert "luminance mean" in data.columns data = data[data["luminance mean"] >= args.lum_min] if args.lum_max is not None: assert "luminance mean" in data.columns data = data[data["luminance mean"] <= args.lum_max] if args.motion_min is not None: assert "motion score" in data.columns data = data[data["motion score"] >= args.motion_min] if args.motion_max is not None: assert "motion score" in data.columns data = data[data["motion score"] <= args.motion_max] # Save filtered data data.to_csv(args.csv_save_path, index=False) print(f"Saved {len(data)} samples to {args.csv_save_path}.") def parse_args(): """Parse command line arguments for dataset filtering""" parser = argparse.ArgumentParser( description="Filter video dataset by quality metrics" ) parser.add_argument( "--csv_path", type=str, required=True, help="Path to input CSV file" ) parser.add_argument( "--csv_save_path", type=str, default=None, help="Path to save output CSV file" ) parser.add_argument("--seed", type=int, default=42, help="Random seed") # Video property filters parser.add_argument( "--frames_min", type=int, default=None, help="Minimum number of frames" ) parser.add_argument( "--frames_max", type=int, default=None, help="Maximum number of frames" ) parser.add_argument( "--resolution_max", type=int, default=None, help="Maximum resolution" ) parser.add_argument("--fps_max", type=float, default=None, help="Maximum FPS") parser.add_argument("--fps_min", type=float, default=None, help="Minimum FPS") # Quality metric filters parser.add_argument( "--aes_min", type=float, default=None, help="Minimum aesthetic score" ) parser.add_argument( "--flow_min", type=float, default=None, help="Minimum optical flow score" ) parser.add_argument( "--flow_max", type=float, default=None, help="Maximum optical flow score" ) parser.add_argument("--ocr_max", type=float, default=None, help="Maximum OCR score") parser.add_argument("--ocr_min", type=float, default=None, help="Minimum OCR score") parser.add_argument( "--lum_min", type=float, default=None, help="Minimum luminance score" ) parser.add_argument( "--lum_max", type=float, default=None, help="Maximum luminance score" ) parser.add_argument( "--blur_max", type=float, default=None, help="Maximum blur score" ) parser.add_argument( "--motion_min", type=float, default=None, help="Minimum motion score" ) parser.add_argument( "--motion_max", type=float, default=None, help="Maximum motion score" ) return parser.parse_args() if __name__ == "__main__": args = parse_args() # Set random seeds for reproducibility if args.seed is not None: random.seed(args.seed) np.random.seed(args.seed) main(args) ================================================ FILE: utils/get_clip.py ================================================ """ Video clip information extraction utility with timestamp parsing. """ import argparse import os import queue import concurrent.futures from functools import partial import pandas as pd from scenedetect import FrameTimecode import re from tqdm import tqdm def process_single_row(row, args): """Process a single video row to extract clip information""" video_path = row["video_path"] new_rows = [] try: if "timestamp" in row: timestamp_str = row["timestamp"] # Parse timestamps using regex timestamp_pattern = ( r"\('(\d{2}:\d{2}:\d{2}\.\d+)', '(\d{2}:\d{2}:\d{2}\.\d+)'\)" ) matches = re.findall(timestamp_pattern, timestamp_str) scene_list = [ (FrameTimecode(s, fps=row["fps"]), FrameTimecode(t, fps=row["fps"])) for s, t in matches ] else: scene_list = [None] if args.drop_invalid_timestamps: return new_rows, True except Exception as e: if args.drop_invalid_timestamps: return new_rows, False height = row["height"] width = row["width"] fps = row["fps"] # Extract clip information for each scene for idx, scene in enumerate(scene_list): if scene is not None: s, t = scene # FrameTimecode objects fname = os.path.basename(video_path) fname_wo_ext = os.path.splitext(fname)[0] # Calculate clip metrics num_frames = t.frame_num - s.frame_num aspect_ratio = width / height if height != 0 else 0 resolution = f"{width}x{height}" timestamp_start = s.get_timecode() timestamp_end = t.get_timecode() frame_start = s.frame_num frame_end = t.frame_num id_ori = row["id"] if "id" in row else "" id = f"{fname_wo_ext}_{idx}" new_rows.append( [ video_path, id, num_frames, height, width, aspect_ratio, fps, resolution, timestamp_start, timestamp_end, frame_start, frame_end, id_ori, ] ) return (new_rows, True) def worker(task_queue, results_queue, args): """Worker function for parallel processing""" while True: try: index, row = task_queue.get(timeout=1) except queue.Empty: break result = process_single_row(row, args) results_queue.put((index, result)) task_queue.task_done() def parse_args(): """Parse command line arguments""" parser = argparse.ArgumentParser( description="Extract video clip information from csvdata" ) parser.add_argument("--csv_path", type=str, help="Path to the input CSV file") parser.add_argument( "--csv_save_dir", type=str, required=True, help="Directory to save output CSV file", ) parser.add_argument( "--num_workers", type=int, default=None, help="Number of parallel workers" ) parser.add_argument( "--disable_parallel", action="store_true", help="Disable parallel processing" ) parser.add_argument( "--drop_invalid_timestamps", action="store_true", help="Drop rows with invalid timestamps", ) args = parser.parse_args() return args def main(): args = parse_args() csv_path = args.csv_path if not os.path.exists(csv_path): print(f"csv file '{csv_path}' not found. Exit.") return os.makedirs(args.csv_save_dir, exist_ok=True) # Load csvdata csv = pd.read_csv(args.csv_path) # Setup multiprocessing from multiprocessing import Manager manager = Manager() task_queue = manager.Queue() results_queue = manager.Queue() for index, row in csv.iterrows(): task_queue.put((index, row)) if args.disable_parallel: # Sequential processing results = [] for index, row in tqdm( csv.iterrows(), total=len(csv), desc="Processing rows" ): result = process_single_row(row, args) results.append((index, result)) else: # Parallel processing num_workers = args.num_workers if args.num_workers else os.cpu_count() or 1 with concurrent.futures.ProcessPoolExecutor( max_workers=num_workers ) as executor: futures = [] for _ in range(num_workers): future = executor.submit(worker, task_queue, results_queue, args) futures.append(future) # Per-row progress is more informative than per-worker completion. results = [] processed = 0 total_tasks = len(csv) with tqdm(total=total_tasks, desc="Processing rows") as pbar: while processed < total_tasks: try: results.append(results_queue.get(timeout=1)) processed += 1 pbar.update(1) except queue.Empty: if all(f.done() for f in futures) and results_queue.empty(): break for future in futures: future.result() while not results_queue.empty(): results.append(results_queue.get()) # Process results results.sort(key=lambda x: x[0]) new_rows = [] valid_rows = [] for index, (rows, valid) in results: if valid: valid_rows.append(index) new_rows.extend(rows) # Save corrected timestamps if needed if args.drop_invalid_timestamps: csv = csv[valid_rows] assert args.csv_path.endswith("timestamp.csv"), "Only support *timestamp.csv" csv.to_csv( args.csv_path.replace("timestamp.csv", "correct_timestamp.csv"), index=False, ) print( f"Corrected timestamp file saved to '{args.csv_path.replace('timestamp.csv', 'correct_timestamp.csv')}'" ) # Create and save clip information DataFrame columns = [ "video_path", "id", "num_frames", "height", "width", "aspect_ratio", "fps", "resolution", "timestamp_start", "timestamp_end", "frame_start", "frame_end", "id_ori", ] new_df = pd.DataFrame(new_rows, columns=columns) new_csv_path = os.path.join(args.csv_save_dir, "clips_info.csv") new_df.to_csv(new_csv_path, index=False) print(f"Saved {len(new_df)} clip information to {new_csv_path}.") if __name__ == "__main__": main() ================================================ FILE: utils/get_info.py ================================================ """ Video information extraction utility supporting multiple backends (OpenCV, TorchVision, AV). """ import argparse import os import random import cv2 import av import numpy as np import pandas as pd from tqdm import tqdm import concurrent.futures def get_video_length(cap, method="header"): """Get video frame count using different methods""" assert method in ["header", "set"] if method == "header": length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) else: cap.set(cv2.CAP_PROP_POS_AVI_RATIO, 1) length = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) return length def get_video_info(args): """Extract video information using specified backend""" idx, path, backend = args try: if backend == "torchvision": from tools.utils.read_video import read_video vframes, infos = read_video(path) num_frames, height, width = ( vframes.shape[0], vframes.shape[2], vframes.shape[3], ) fps = ( float(infos.get("video_fps", np.nan)) if isinstance(infos, dict) else np.nan ) elif backend == "opencv": cap = cv2.VideoCapture(path) if not cap.isOpened(): raise ValueError("Video open failed") num_frames = get_video_length(cap, method="header") height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) fps = float(cap.get(cv2.CAP_PROP_FPS)) cap.release() elif backend == "av": container = av.open(path) stream = container.streams.video[0] num_frames = int(stream.frames) height = int(stream.height) width = int(stream.width) if stream.average_rate is not None: fps = float(stream.average_rate) elif stream.guessed_rate is not None: fps = float(stream.guessed_rate) else: fps = np.nan else: raise ValueError("Unknown backend") # Calculate derived metrics hw = height * width aspect_ratio = height / width if width > 0 else np.nan return (idx, True, num_frames, height, width, aspect_ratio, hw, fps) except Exception: return (idx, False, 0, 0, 0, np.nan, np.nan, np.nan) def main(args): """Main function to extract video information""" # Load data data = pd.read_csv(args.csv_path) if data.empty: data.to_csv(args.csv_save_path, index=False) print(f"Input CSV is empty. Saved 0 samples to {args.csv_save_path}.") return tasks = [(index, row["video_path"], args.backend) for index, row in data.iterrows()] num_workers = args.num_workers if args.num_workers else os.cpu_count() or 1 # Process videos with a per-video progress bar (more intuitive than per-worker) if args.disable_parallel or num_workers <= 1: ret = [ get_video_info(task) for task in tqdm(tasks, total=len(tasks), desc="Processing videos") ] else: with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor: ret = list( tqdm( executor.map(get_video_info, tasks, chunksize=16), total=len(tasks), desc="Processing videos", ) ) ret.sort(key=lambda x: x[0]) ( _idx_list, success_list, num_frames_list, height_list, width_list, aspect_ratio_list, hw_list, fps_list, ) = zip(*ret) # Add extracted information to DataFrame data["success"] = success_list data["num_frames"] = num_frames_list data["height"] = height_list data["width"] = width_list data["aspect_ratio"] = aspect_ratio_list data["resolution"] = hw_list data["fps"] = fps_list # Filter existing files if requested if args.ext: assert "video_path" in data.columns data = data[data["video_path"].apply(os.path.exists)] # Sort by frame count if "num_frames" in data.columns: data = data.sort_values(by="num_frames", ascending=True) data.to_csv(args.csv_save_path, index=False) print(f"Saved {len(data)} samples to {args.csv_save_path}.") def parse_args(): """Parse command line arguments""" parser = argparse.ArgumentParser( description="Extract video information using multiple backends" ) parser.add_argument( "--csv_path", type=str, required=True, help="Path to input CSV file" ) parser.add_argument( "--csv_save_path", type=str, default=None, help="Path to save output CSV file" ) parser.add_argument( "--backend", type=str, default="opencv", help="Video backend", choices=["opencv", "torchvision", "av"], ) parser.add_argument( "--disable-parallel", action="store_true", help="Disable parallel processing" ) parser.add_argument( "--num_workers", type=int, default=None, help="Number of parallel workers" ) parser.add_argument("--seed", type=int, default=42, help="Random seed") # File existence checking parser.add_argument("--ext", action="store_true", help="Check if video files exist") return parser.parse_args() if __name__ == "__main__": args = parse_args() # Set random seeds for reproducibility if args.seed is not None: random.seed(args.seed) np.random.seed(args.seed) main(args) ================================================ FILE: utils/get_instructions.py ================================================ """ This module processes camera pose sequences and generates movement instructions. """ import argparse import numpy as np from scipy.spatial.transform import Rotation as R import os import pandas as pd from multiprocessing import Manager import concurrent.futures import queue from tqdm import tqdm import json def filter_poses(poses_array, alpha): """ Filter pose sequences using exponential moving average. - Position: Exponential moving average (EMA) - Orientation (quaternion): NLERP-based EMA with hemisphere flip handling Args: poses_array: Array of poses [position(3) + quaternion(4)] alpha: Smoothing factor (0 < alpha < 1) Returns: Filtered pose array with same shape as input """ positions = poses_array[:, :3] quaternions = poses_array[:, 3:] filtered_positions = np.zeros_like(positions) filtered_quaternions = np.zeros_like(quaternions) # Initialize with first frame filtered_positions[0] = positions[0] filtered_quaternions[0] = quaternions[0] for i in range(1, len(poses_array)): filtered_positions[i] = ( alpha * positions[i] + (1 - alpha) * filtered_positions[i - 1] ) # quaternion filtering with hemisphere check last_q = filtered_quaternions[i - 1] current_q = quaternions[i] # 1. Check hemisphere to ensure interpolation takes the "shortest path" if np.dot(last_q, current_q) < 0: current_q = -current_q # 2. Linear interpolation interp_q = (1 - alpha) * last_q + alpha * current_q # 3. Re-normalize to ensure unit quaternion filtered_quaternions[i] = interp_q / np.linalg.norm(interp_q) return np.hstack([filtered_positions, filtered_quaternions]) def poses_to_multi_instructions(poses_array, translation_thresh, rotation_thresh_deg): """ Convert camera pose sequence to concurrent movement instruction sequence. """ # Convert NumPy array to Scipy Rotation objects for easier computation poses = [] for row in poses_array: pos = row[:3] rot = R.from_quat(row[3:]) poses.append((pos, rot)) command_sequence = [] rotation_thresh_rad = np.deg2rad(rotation_thresh_deg) for i in range(len(poses) - 1): # Calculate local relative movement pos_t_w2c, rot_t_w2c = poses[i] pos_t1_w2c, rot_t1_w2c = poses[i+1] delta_rot = rot_t1_w2c * rot_t_w2c.inv() pos_t_c2w = -rot_t_w2c.inv().apply(pos_t_w2c) pos_t1_c2w = -rot_t1_w2c.inv().apply(pos_t1_w2c) local_delta_pos = rot_t_w2c.apply(pos_t1_c2w - pos_t_c2w) dx, dy, dz = local_delta_pos euler_angles_rad = delta_rot.as_euler( "yxz" ) # 'y' for yaw, 'x' for pitch, 'z' for roll yaw_change, pitch_change, roll_change = euler_angles_rad instructions = [] # Translation movements if dz < -translation_thresh: instructions.append("Dolly Out") elif dz > translation_thresh: instructions.append("Dolly In") if dx > translation_thresh: instructions.append("Truck Right") elif dx < -translation_thresh: instructions.append("Truck Left") if dy > translation_thresh: instructions.append("Pedestal Down") elif dy < -translation_thresh: instructions.append("Pedestal Up") # Rotation movements if yaw_change > rotation_thresh_rad: instructions.append("Pan Left") elif yaw_change < -rotation_thresh_rad: instructions.append("Pan Right") if pitch_change > rotation_thresh_rad: instructions.append("Tilt Down") elif pitch_change < -rotation_thresh_rad: instructions.append("Tilt Up") if roll_change > rotation_thresh_rad: instructions.append("Roll CCW") elif roll_change < -rotation_thresh_rad: instructions.append("Roll CW") if not instructions: instructions.append("Stay") command_sequence.append(instructions) return command_sequence def process_single_row(args, row): """Process a single video row to generate camera movement instructions.""" npy_path = os.path.join(args.dir_path, row["id"], "reconstructions", "poses.npy") # Load and subsample poses, then apply filtering raw_poses = np.load(npy_path)[:: args.interval] filtered_poses = filter_poses(raw_poses, alpha=args.alpha) # Generate movement instructions instructions = poses_to_multi_instructions( filtered_poses, args.translation_threshold, args.rotation_threshold ) json_file = os.path.join(args.dir_path, row["id"], "instructions.json") if os.path.exists(json_file) and os.path.getsize(json_file) > 0: return # Merge consecutive identical instructions merged_instructions = {} start = 0 prev_cmd = instructions[0] for i in range(1, len(instructions)): if instructions[i] == prev_cmd: continue else: key = f"{start}->{i}" merged_instructions[key] = prev_cmd start = i prev_cmd = instructions[i] # Add final segment key = f"{start}->{len(instructions)}" merged_instructions[key] = prev_cmd # Save instructions to JSON file with open(json_file, "w") as f: json.dump(merged_instructions, f, ensure_ascii=False, indent=2) def worker(task_queue, args, pbar): """Worker function for parallel processing of video rows.""" while True: try: index, row = task_queue.get(timeout=1) except queue.Empty: break process_single_row(args, row) task_queue.task_done() pbar.update(1) def args_parser(): """Parse command line arguments.""" parser = argparse.ArgumentParser() parser.add_argument( "--csv_path", type=str, default="outputs.csv", help="Path to the input CSV file" ) parser.add_argument("--dir_path", type=str, default="./outputs") parser.add_argument( "--interval", type=int, default=2, help="Interval for computing instructions" ) parser.add_argument( "--alpha", type=float, default=0.1, help="Smoothing factor for filtering (0 < alpha < 1)", ) parser.add_argument( "--translation_threshold", type=float, default=0.02, help="Translation threshold for command generation", ) parser.add_argument( "--rotation_threshold", type=float, default=0.5, help="Rotation threshold for command generation", ) parser.add_argument( "--num_workers", type=int, default=8, help="Number of parallel workers" ) parser.add_argument( "--disable_parallel", action="store_true", help="Disable parallel processing" ) return parser.parse_args() def main(): args = args_parser() csv = pd.read_csv(args.csv_path) if args.disable_parallel: # Sequential processing for index, row in tqdm(csv.iterrows(), total=len(csv)): process_single_row(args, row) else: # Parallel processing using ThreadPoolExecutor manager = Manager() task_queue = manager.Queue() for index, row in csv.iterrows(): task_queue.put((index, row)) with tqdm(total=len(csv), desc="Finished tasks") as pbar: with concurrent.futures.ThreadPoolExecutor( max_workers=args.num_workers ) as executor: futures = [] for _ in range(args.num_workers): futures.append(executor.submit(worker, task_queue, args, pbar)) for future in concurrent.futures.as_completed(futures): future.result() if __name__ == "__main__": main() ================================================ FILE: utils/get_instructions_enhanced.py ================================================ import argparse from math import sqrt import numpy as np from scipy.spatial.transform import Rotation as R import os import pandas as pd from multiprocessing import Manager import concurrent.futures import queue from tqdm import tqdm import json from collections import defaultdict, Counter import itertools def filter_poses(poses_array, alpha): """ Smooth pose sequence using Exponential Moving Average (EMA). - Position: Standard EMA - Quaternion: EMA with hemisphere check (shortest path interpolation) """ positions = poses_array[:, :3] quaternions = poses_array[:, 3:] filtered_pos = np.zeros_like(positions) filtered_quat = np.zeros_like(quaternions) # Initialize with first frame filtered_pos[0], filtered_quat[0] = positions[0], quaternions[0] for i in range(1, len(poses_array)): # Position smoothing filtered_pos[i] = alpha * positions[i] + \ (1 - alpha) * filtered_pos[i-1] # Quaternion smoothing with hemisphere correction last_q, curr_q = filtered_quat[i-1], quaternions[i] if np.dot(last_q, curr_q) < 0: # Flip to shortest interpolation path curr_q = -curr_q interp_q = (1 - alpha) * last_q + alpha * curr_q filtered_quat[i] = interp_q / \ np.linalg.norm(interp_q) # Keep unit quaternion return np.hstack([filtered_pos, filtered_quat]) def poses_to_multi_instructions(poses_array, translation_thresh, rotation_thresh_deg, interval=1): """ Convert pose sequence to motion instructions (e.g., Dolly, Pan). Calculates pose difference between frame i and i+interval (convolution-like). """ # Convert to (position, Rotation object) pairs poses = [(row[:3], R.from_quat(row[3:])) for row in poses_array] command_seq = [] # Adjust thresholds by interval (scaling for longer gaps) rotation_thresh_deg *= sqrt(interval) / 1.8 rotation_thresh_rad = np.deg2rad(rotation_thresh_deg) translation_thresh *= sqrt(interval) stride = int(sqrt(interval) + 1) i = 0 while True: if i + interval >= len(poses): # Ensure valid frame pair break # Calculate relative motion (local coordinate system) pos_t_w2c, rot_t_w2c = poses[i] pos_t1_w2c, rot_t1_w2c = poses[i+interval] delta_rot = rot_t1_w2c * rot_t_w2c.inv() pos_t_c2w = -rot_t_w2c.inv().apply(pos_t_w2c) pos_t1_c2w = -rot_t1_w2c.inv().apply(pos_t1_w2c) local_delta_pos = rot_t_w2c.apply(pos_t1_c2w - pos_t_c2w) dx, dy, dz = local_delta_pos yaw, pitch, roll = delta_rot.as_euler( "yxz") # Yaw:Pan, Pitch:Tilt, Roll:Rotate instructions = [] # Translation commands if dz < -translation_thresh: instructions.append("Dolly Out") elif dz > translation_thresh: instructions.append("Dolly In") if dx > translation_thresh: instructions.append("Truck Right") elif dx < -translation_thresh: instructions.append("Truck Left") if dy > translation_thresh: instructions.append("Pedestal Down") elif dy < -translation_thresh: instructions.append("Pedestal Up") # Rotation commands if yaw > rotation_thresh_rad: instructions.append("Pan Left") elif yaw < -rotation_thresh_rad: instructions.append("Pan Right") if pitch > rotation_thresh_rad: instructions.append("Tilt Down") elif pitch < -rotation_thresh_rad: instructions.append("Tilt Up") if roll > rotation_thresh_rad: instructions.append("Roll CCW") elif roll < -rotation_thresh_rad: instructions.append("Roll CW") command_seq.append(instructions if instructions else ["Stay"]) i += stride return command_seq def calculate_relative_scale(total_distance, num_poses, f_translation, min_threshold=0.001): """ Calculate relative translation threshold (dynamic scaling by total motion). """ if num_poses <= 1: return min_threshold base_scale = total_distance / num_poses # Base scale per frame return max(base_scale / f_translation, min_threshold) def voter(args, row, interval, alpha): """ Process single video with specific (interval, alpha) parameter pair. """ # Locate pose file npy_path = os.path.join( args.dir_path, row["id"], "reconstructions", "poses.npy" ) try: raw_poses = np.load(npy_path) filtered_poses = filter_poses(raw_poses, alpha) # Calculate dynamic thresholds translation_thresh = calculate_relative_scale( row["moveDist"], len( filtered_poses), args.f_translation, args.min_threshold_translation ) rotation_thresh = args.rotation_threshold return poses_to_multi_instructions( filtered_poses, translation_thresh, rotation_thresh, interval ) except Exception as e: print(f"Error processing {row['id']}: {e}") return None def collect_all_results(args, row, param_combinations): """Collect instruction results for all (interval, alpha) pairs.""" results = [] for interval, alpha in param_combinations: res = voter(args, row, interval, alpha) if res is not None: results.append(res) return results # ------------------------------ Voting Logic ------------------------------ def get_mutually_exclusive_groups(): """Return groups of conflicting instructions (cannot coexist).""" return [ ["Dolly In", "Dolly Out"], ["Truck Left", "Truck Right"], ["Pedestal Up", "Pedestal Down"], ["Pan Left", "Pan Right"], ["Tilt Up", "Tilt Down"], ["Roll CW", "Roll CCW"] ] def remove_conflicting_instructions(instructions, conflict_groups): """Remove conflicting instructions (keep higher-voted ones).""" selected = [] selected_set = set() for inst, count in instructions: conflict = False for group in conflict_groups: if inst in group and any(s in group for s in selected_set): conflict = True break if not conflict: selected.append((inst, count)) selected_set.add(inst) return selected def smart_instruction_selection(non_conflicting_inst): """ Smart instruction selection based on vote distribution: - Keep leading votes (3x threshold for断层) - Max 4 instructions - Prioritize non-"Stay" commands """ if not non_conflicting_inst: return ["Stay"] if len(non_conflicting_inst) == 1: return [non_conflicting_inst[0][0]] # Separate Stay and other instructions stay = [i for i in non_conflicting_inst if i[0] == "Stay"] others = [i for i in non_conflicting_inst if i[0] != "Stay"] if not others: return ["Stay"] votes = [c for _, c in others] max_vote = votes[0] selected = [] # Check for vote gap (3x threshold) if len(others) >= 2 and max_vote >= votes[1] * 3: selected = [i[0] for i in others if i[1] == max_vote] else: # Select up to 4 leading instructions gap_thresh = max_vote * 0.5 selected = [i[0] for i in others if i[1] >= gap_thresh][:4] # Ensure minimum 2 instructions if no large gap if len(selected) < 2 and len(others) >= 2 and max_vote < votes[1] * 3: selected = [i[0] for i in others[:2]] return selected if selected else ["Stay"] def collect_interval_based_votes(all_results, param_combinations): """ Vote by time interval: collect all instructions covering (start_frame->end_frame). Handles overlapping segments from different (interval, alpha) pairs. """ if not all_results: return {} # Get max frame covered by any parameter pair max_frames = 0 for index, res in enumerate(all_results): interval = param_combinations[index][0] stride = int(sqrt(interval) + 1) if res: last_start = (len(res)-1) * stride max_frames = max(max_frames, last_start + interval) interval_votes = {} for start in range(max_frames): end = start + 1 vote_counter = Counter() # Check all parameter results for coverage of (start->end) for res_index, res in enumerate(all_results): interval, _ = param_combinations[res_index] stride = int(sqrt(interval) + 1) for seg_index, seg in enumerate(res): seg_start = seg_index * stride seg_end = seg_start + interval # Check if segment covers target interval if seg_start <= start < seg_end and seg_start < end <= seg_end: for inst in seg: vote_counter[inst] += 1 interval_votes[f"{start}->{end}"] = vote_counter return interval_votes def vote_for_final_instructions(all_results, param_combinations=None): """Generate final instructions via voting (interval-based if possible).""" if not all_results: return [] conflict_groups = get_mutually_exclusive_groups() final_seq = [] # Use interval-based voting if parameters are provided if param_combinations and len(param_combinations) == len(all_results): interval_votes = collect_interval_based_votes( all_results, param_combinations) for key in sorted(interval_votes.keys(), key=lambda x: int(x.split('->')[0])): votes = interval_votes[key] if votes: sorted_inst = votes.most_common() non_conflict = remove_conflicting_instructions( sorted_inst, conflict_groups) selected = smart_instruction_selection(non_conflict) else: selected = ["Stay"] final_seq.append(selected) else: # Fallback: frame-wise voting max_len = max(len(res) for res in all_results) for frame_index in range(max_len): votes = Counter() for res in all_results: if frame_index < len(res): for inst in res[frame_index]: votes[inst] += 1 if votes: sorted_inst = votes.most_common() non_conflict = remove_conflicting_instructions( sorted_inst, conflict_groups) selected = smart_instruction_selection(non_conflict) else: selected = ["Stay"] final_seq.append(selected) return final_seq # ------------------------------ Main Workflow ------------------------------ def merge_consecutive_instructions(instructions): """Merge consecutive identical instruction lists (e.g., [A,A,A] → "0->3":[A]).""" if not instructions: return {} merged = {} start, prev = 0, instructions[0] for i in range(1, len(instructions)): if instructions[i] != prev: merged[f"{start}->{i}"] = prev start, prev = i, instructions[i] merged[f"{start}->{len(instructions)}"] = prev # Add final segment return merged def process_single_row(args, row, param_combinations): # Skip if output exists out_file = os.path.join(args.dir_path, row['id'], "instructions.json") if os.path.exists(out_file) and os.path.getsize(out_file) > 0: return # Collect results & vote all_results = collect_all_results(args, row, param_combinations) if not all_results: print(f"No valid results for {row['id']}") return final_inst = vote_for_final_instructions(all_results, param_combinations) merged_inst = merge_consecutive_instructions(final_inst) # Save to JSON with open(out_file, "w") as f: json.dump(merged_inst, f, ensure_ascii=False, indent=2) def generate_param_combinations(args): """Generate all (interval, alpha) parameter pairs for grid search.""" intervals = getattr(args, "intervals", [1, 3, 5]) alphas = getattr(args, "alphas", [0.03, 0.05, 0.1]) return list(itertools.product(intervals, alphas)) def worker(task_queue, args, param_combinations, pbar): """Parallel worker: process tasks from queue.""" while True: try: index, row = task_queue.get(timeout=1) process_single_row(args, row, param_combinations) except queue.Empty: break task_queue.task_done() pbar.update(1) def args_parser(): parser = argparse.ArgumentParser( description="Enhanced Camera Pose Instruction Generator") parser.add_argument("--csv_path", type=str, required=True, help="Input CSV path (The final_results.csv generated by evaluation.py)") parser.add_argument("--dir_path", type=str, required=True, help="Annotation directory path") parser.add_argument("--intervals", type=int, nargs="+", default=[1, 3, 5], help="Frame intervals for grid search") parser.add_argument("--alphas", type=float, nargs="+", default=[0.03, 0.05, 0.1], help="Smoothing factors for grid search") parser.add_argument("--f_translation", type=float, default=1.1, help="Translation scale factor (>1)") parser.add_argument("--min_threshold_translation", type=float, default=0.01, help="Min translation threshold") parser.add_argument("--rotation_threshold", type=float, default=1.5, help="Fixed rotation threshold (degrees)") parser.add_argument("--num_workers", type=int, default=8, help="Parallel workers count") parser.add_argument("--disable_parallel", action="store_true", help="Disable parallel processing") return parser.parse_args() def main(): args = args_parser() csv = pd.read_csv(args.csv_path) param_combinations = generate_param_combinations(args) if args.disable_parallel: # Serial processing for index, row in tqdm(csv.iterrows(), total=len(csv), desc="Processing"): process_single_row(args, row, param_combinations) else: # Parallel processing manager = Manager() task_queue = manager.Queue() for index, row in csv.iterrows(): task_queue.put((index, row)) with tqdm(total=len(csv), desc="Processing") as pbar: with concurrent.futures.ThreadPoolExecutor(max_workers=args.num_workers) as executor: for _ in range(args.num_workers): executor.submit(worker, task_queue, args, param_combinations, pbar) task_queue.join() if __name__ == "__main__": main() ================================================ FILE: utils/merge_tables.py ================================================ """ CSV table merging utility for combining multiple clip information files. """ import os import glob import argparse import pandas as pd def read_csv_file(file_path): """Read a single CSV file""" return pd.read_csv(file_path) def merge_tables_from_files(file_list, output_file, merge_on=None): """ Merge multiple CSV files using common columns as merge keys. Args: file_list: List of CSV file paths to merge output_file: Output path for merged CSV file merge_on: List of column names for merging (defaults to first 13 columns) """ if not file_list: raise ValueError("File list is empty!") # Read all CSV files dfs = [read_csv_file(f) for f in file_list] # Auto-select merge keys: first 13 columns if merge_on is None: merge_on = dfs[0].columns[:13].tolist() # Merge dataframes df_merged = dfs[0] for df in dfs[1:]: # Check if merge keys are consistent if merge_on != df.columns[:13].tolist(): raise ValueError( f"Common columns in one file are inconsistent with previous files!" ) # Merge based on specified keys df_merged = pd.merge(df_merged, df, on=merge_on) # Save merged result df_merged.to_csv(output_file, index=False) print(f"Merge completed. Saved to {output_file}") return df_merged def main(): parser = argparse.ArgumentParser( description="Merge multiple CSV files from a folder" ) parser.add_argument("--csv_dir", type=str, help="Path to folder containing CSV files") parser.add_argument( "--output", type=str, required=True, help="Output path for merged CSV file" ) args = parser.parse_args() # Match CSV files with 'clips_info_' prefix pattern = os.path.join(args.csv_dir, "clips_info_*.csv") file_list = glob.glob(pattern) file_list.sort() # Sort to ensure consistent merge order if not file_list: raise ValueError(f"No matching CSV files found in folder {args.csv_dir}!") print(f"Found {len(file_list)} CSV files:") for f in file_list: print(f" {f}") # Perform merge merge_tables_from_files(file_list, args.output) if __name__ == "__main__": main() ================================================ FILE: utils/normalize_intrinsics.py ================================================ """ Camera intrinsics normalization utility. This module provides functionality for: - Normalizing camera intrinsics to standard format - Converting focal length to normalized coordinates - Parallel processing of multiple camera files - Support for both threaded and sequential processing """ import numpy as np import os import pandas as pd import argparse import concurrent.futures import multiprocessing as mp from multiprocessing import Manager import queue from tqdm import tqdm def possess_single_row(row, args): """ Process a single row to normalize camera intrinsics. """ id = row["id"] dir_path = os.path.join(args.dir_path, id, "reconstructions") cam_intrinsics_file = os.path.join(dir_path, "intrinsics.npy") # Load and normalize intrinsics intrinsics = np.load(cam_intrinsics_file) intrinsics[:, 0] /= intrinsics[:, 2] * 2 # Normalize focal length x intrinsics[:, 1] /= intrinsics[:, 3] * 2 # Normalize focal length y intrinsics[:, 2] = 0.5 # Set principal point x to center intrinsics[:, 3] = 0.5 # Set principal point y to center # Save normalized intrinsics np.save(cam_intrinsics_file, intrinsics) def worker(task_queue, args, pbar): """ Worker function for parallel processing of intrinsics normalization. """ while True: try: index, row = task_queue.get(timeout=1) except queue.Empty: break possess_single_row(row, args) task_queue.task_done() pbar.update(1) def parse_args(): """Parse command line arguments for intrinsics normalization.""" parser = argparse.ArgumentParser(description="Normalize camera intrinsics to standard format") parser.add_argument("--csv_path", type=str, help="Path to the csv file") parser.add_argument("--dir_path", type=str, default="./outputs") parser.add_argument( "--num_workers", type=int, default=8, help="Number of workers for parallel processing", ) parser.add_argument( "--disable_parallel", action="store_true", help="Disable parallel processing" ) return parser.parse_args() if __name__ == "__main__": args = parse_args() df = pd.read_csv(args.csv_path) if args.disable_parallel: # Sequential processing for index, row in tqdm(df.iterrows(), total=len(df)): possess_single_row(row, index, args) else: # Parallel processing using thread pool manager = Manager() task_queue = manager.Queue() # Add all tasks to queue for index, row in df.iterrows(): task_queue.put((index, row)) with tqdm(total=len(df), desc="Finished tasks") as pbar: with concurrent.futures.ThreadPoolExecutor( max_workers=args.num_workers ) as executor: futures = [] for _ in range(args.num_workers): futures.append(executor.submit(worker, task_queue, args, pbar)) for future in concurrent.futures.as_completed(futures): future.result() ================================================ FILE: utils/pack_clip_assets.py ================================================ """ pack_clip_assets.py ------------------ This script unifies depth, RGB frames, intrinsics, extrinsics, etc. of a specified video clip into a single npz file for downstream 3D reconstruction or analysis. Usage example: python pack_clip_assets.py --base_dir /path/to/HQ --clip_id group_xxxx/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx --height 328 --width 584 """ import argparse import numpy as np import torch from lietorch import SE3 import cv2 from read_depth import read_depth def load_video(clip_path, indexes_path, height=720, width=1280): """ Read video frames at specified indexes and resize to (height, width). Args: clip_path (str): Path to video file indexes_path (str): Path to frame indexes txt height (int): Output frame height width (int): Output frame width Returns: np.ndarray: (N, height, width, 3) RGB frames """ indexes = [] with open(indexes_path, 'r') as f: for line in f: parts = line.strip().split() if len(parts) == 2: indexes.append(int(parts[1])) print(f"Frame indexes: {indexes}") cap = cv2.VideoCapture(clip_path) frames = [] for idx in indexes: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if not ret: raise ValueError(f"Frame at index {idx} could not be read.") frame = cv2.resize(frame, (width, height)) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame) cap.release() return np.array(frames) def load_intrinsics(intrinsics_path, tgt_width=1024, tgt_height=576): """ Read normalized intrinsics (n,4), convert to 3x3 matrix and scale to target resolution. Args: intrinsics_path (str): Path to intrinsics npy tgt_width (int): Target width tgt_height (int): Target height Returns: np.ndarray: (N, 3, 3) intrinsics matrices """ intrinsics = np.load(intrinsics_path) intrinsics_3x3 = [] for intrin in intrinsics: fx, fy, cx, cy = intrin K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) intrinsics_3x3.append(K) intrinsics_3x3 = np.array(intrinsics_3x3) intrinsics_3x3[:, 0, 0] *= tgt_width intrinsics_3x3[:, 1, 1] *= tgt_height intrinsics_3x3[:, 0, 2] *= tgt_width intrinsics_3x3[:, 1, 2] *= tgt_height return intrinsics_3x3 def main(): """ Main pipeline: load depth, RGB frames, intrinsics, extrinsics, and save as npz. """ parser = argparse.ArgumentParser(description="Pack clip assets into a single npz file.") parser.add_argument('--base_dir', type=str, required=True, help='Root directory of HQ data') parser.add_argument('--group_id', type=int, required=False, help='Group ID, e.g. group_xxxx') parser.add_argument('--clip_id', type=str, required=True, help='Clip ID, e.g. xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx') parser.add_argument('--height', type=int, default=328, help='Output image height') parser.add_argument('--width', type=int, default=584, help='Output image width') parser.add_argument('--output', type=str, default='sgd_cvd_hr.npz', help='Output npz filename') args = parser.parse_args() # Path construction annotation_dir = f'{args.base_dir}/annotations/group_{args.group_id:04d}/{args.clip_id}' depth_path = f'{args.base_dir}/depths/group_{args.group_id:04d}/{args.clip_id}.zip' clip_path = f'{args.base_dir}/videos/group_{args.group_id:04d}/{args.clip_id}.mp4' intrinsics_path = f'{annotation_dir}/intrinsics.npy' extrinsics_path = f'{annotation_dir}/poses.npy' indexes_path = f'{annotation_dir}/indexes.txt' # Load intrinsics and extrinsics intrinsics = load_intrinsics(intrinsics_path, tgt_width=args.width, tgt_height=args.height) extrinsics = np.load(extrinsics_path) # Load and resize depth depth = np.clip(read_depth(depth_path), 1e-3, 1e2) # (N, H, W) resized_depth = np.zeros((depth.shape[0], args.height, args.width), dtype=depth.dtype) for i in range(depth.shape[0]): resized_depth[i] = cv2.resize(depth[i], (args.width, args.height), interpolation=cv2.INTER_LINEAR) # Load RGB frames frames = load_video(clip_path, indexes_path, args.height, args.width) # Compute camera poses poses_th = torch.as_tensor(extrinsics, device="cpu").float() cam_c2w = SE3(poses_th).inv().matrix() K = intrinsics[0] K_o = torch.from_numpy(K).float() # Save as npz np.savez( args.output, images=frames, depths=resized_depth, intrinsic=K_o.detach().cpu().numpy(), cam_c2w=cam_c2w.detach().cpu().numpy(), ) print(f"Saved to {args.output}") if __name__ == "__main__": main() ================================================ FILE: utils/quat_to_mat.py ================================================ """ Camera pose conversion utility to camera-to-world (c2w) or world-to-camera (w2c) format. Converts quaternion representations to rotation matrices and handles pose transformations. This module provides utilities for: - Converting between quaternion and matrix representations of camera poses - Transforming between world-to-camera (w2c) and camera-to-world (c2w) coordinate systems - Parallel processing of pose conversion for large datasets """ import einops import torch import torch.nn.functional as F import numpy as np import os import pandas as pd import argparse import concurrent.futures import multiprocessing as mp from multiprocessing import Manager import queue from tqdm import tqdm class Pose: """ A class of operations on camera poses (numpy arrays with shape [...,3,4]). Each [3,4] camera pose takes the form of [R|t]. """ def __call__(self, R=None, t=None): """ Construct a camera pose from the given rotation matrix R and/or translation vector t. Args: R: Rotation matrix [...,3,3] or None t: Translation vector [...,3] or None Returns: pose: Camera pose matrix [...,3,4] """ assert R is not None or t is not None if R is None: if not isinstance(t, np.ndarray): t = np.array(t) R = np.eye(3, device=t.device).repeat(*t.shape[:-1], 1, 1) elif t is None: if not isinstance(R, np.ndarray): R = np.array(R) t = np.zeros(R.shape[:-1], device=R.device) else: if not isinstance(R, np.ndarray): R = np.array(R) if not isinstance(t, np.ndarray): t = np.tensor(t) assert R.shape[:-1] == t.shape and R.shape[-2:] == (3, 3) R = R.astype(np.float32) t = t.astype(np.float32) pose = np.concatenate([R, t[..., None]], axis=-1) # [...,3,4] assert pose.shape[-2:] == (3, 4) return pose def invert(self, pose, use_inverse=False): # c2w <==> w2c """ Invert a camera pose transformation matrix. Converts between camera-to-world (c2w) and world-to-camera (w2c) representations. For a pose [R|t], the inverse is [R^T | -R^T*t]. Args: pose: Camera pose matrix [...,3,4] with shape [R|t] use_inverse: Whether to use matrix inverse instead of transpose for rotation Returns: pose_inv: Inverted camera pose matrix [...,3,4] """ R, t = pose[..., :3], pose[..., 3:] R_inv = ( R.inverse() if use_inverse else R.transpose(0, 2, 1) ) # For orthogonal matrices, transpose equals inverse t_inv = (-R_inv @ t)[..., 0] # Apply inverse rotation to negative translation pose_inv = self(R=R_inv, t=t_inv) return pose_inv def compose(self, pose_list): """ Compose a sequence of poses together. pose_new(x) = poseN o ... o pose2 o pose1(x) Args: pose_list: List of camera poses to compose Returns: pose_new: Composed camera pose """ pose_new = pose_list[0] for pose in pose_list[1:]: pose_new = self.compose_pair(pose_new, pose) return pose_new def compose_pair(self, pose_a, pose_b): """ Compose two poses together. pose_new(x) = pose_b o pose_a(x) Args: pose_a: First camera pose pose_b: Second camera pose Returns: pose_new: Composed camera pose """ R_a, t_a = pose_a[..., :3], pose_a[..., 3:] R_b, t_b = pose_b[..., :3], pose_b[..., 3:] R_new = R_b @ R_a t_new = (R_b @ t_a + t_b)[..., 0] pose_new = self(R=R_new, t=t_new) return pose_new def scale_center(self, pose, scale): """ Scale the camera center from the origin. 0 = R@c+t --> c = -R^T@t (camera center in world coordinates) 0 = R@(sc)+t' --> t' = -R@(sc) = -R@(-R^T@st) = st Args: pose: Camera pose to scale scale: Scale factor Returns: pose_new: Scaled camera pose """ R, t = pose[..., :3], pose[..., 3:] pose_new = np.concatenate([R, t * scale], axis=-1) return pose_new def quaternion_to_matrix(quaternions, eps: float = 1e-8): """ Convert 4-dimensional quaternions to 3x3 rotation matrices. This is adapted from Pytorch3D: https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py Args: quaternions: Quaternion tensor [..., 4] (order: i, j, k, r) eps: Small value for numerical stability Returns: Rotation matrices [..., 3, 3] """ # Order changed to match scipy format! i, j, k, r = torch.unbind(quaternions, dim=-1) two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps) o = torch.stack( ( 1 - two_s * (j * j + k * k), two_s * (i * j - k * r), two_s * (i * k + j * r), two_s * (i * j + k * r), 1 - two_s * (i * i + k * k), two_s * (j * k - i * r), two_s * (i * k - j * r), two_s * (j * k + i * r), 1 - two_s * (i * i + j * j), ), -1, ) return einops.rearrange(o, "... (i j) -> ... i j", i=3, j=3) def pose_from_quaternion(pose): """ Convert pose from quaternion representation to transformation matrix. Args: pose: Pose tensor [..., 7] where first 3 elements are translation (t) and last 4 elements are quaternion rotation (r) Returns: w2c_matrix: World-to-camera transformation matrices [..., 3, 4] """ # Input is w2c, pose(n,7) or (n,v,7), output is (N,3,4) w2c matrix # Tensor format from https://github.com/pointrix-project/Geomotion/blob/6ab0c364f1b44ab4ea190085dbf068f62b42727c/geomotion/model/cameras.py#L6 if type(pose) == np.ndarray: pose = torch.tensor(pose) if len(pose.shape) == 1: pose = pose[None] quat_t = pose[..., :3] # Translation quat_r = pose[..., 3:] # Quaternion rotation w2c_matrix = torch.zeros((*list(pose.shape)[:-1], 3, 4), device=pose.device) w2c_matrix[..., :3, 3] = quat_t w2c_matrix[..., :3, :3] = quaternion_to_matrix(quat_r) return w2c_matrix def possess_single_row(row, index, args): """ Process a single row to convert camera poses to c2w/w2c format. Args: row: Data row containing video ID index: Row index args: Command line arguments """ id = row["id"] dir_path = os.path.join(args.dir_path, id, "reconstructions") cam_pos_file = os.path.join(dir_path, "poses.npy") if not os.path.exists(cam_pos_file): return output_file = os.path.join(dir_path, "extrinsics.npy") if os.path.exists(output_file): return # Load quaternion poses pose = np.load(cam_pos_file) # Convert w2c quaternion format (N,v,7) to w2c matrix format (N,v,3,4) poses = pose_from_quaternion(pose) poses = poses.cpu().numpy() # Convert w2c matrices to c2w matrices (N,v,3,4) if args.format == "c2w": poses = Pose().invert(poses) np.save(output_file, poses) def worker(task_queue, args, pbar): """Worker function for parallel pose conversion processing.""" while True: try: index, row = task_queue.get(timeout=1) except queue.Empty: break possess_single_row(row, index, args) task_queue.task_done() pbar.update(1) def parse_args(): """Parse command line arguments for camera pose conversion.""" parser = argparse.ArgumentParser(description="Convert quaternion to camera pose") parser.add_argument("--csv_path", type=str, help="Path to the csv file") parser.add_argument("--dir_path", type=str, default="./outputs") parser.add_argument("--format", type=str, default="c2w", choices=["c2w", "w2c"]) parser.add_argument( "--num_workers", type=int, default=8, help="Number of workers for parallel processing", ) parser.add_argument( "--disable_parallel", action="store_true", help="Disable parallel processing" ) return parser.parse_args() if __name__ == "__main__": args = parse_args() df = pd.read_csv(args.csv_path) if args.disable_parallel: # Sequential processing for index, row in tqdm(df.iterrows(), total=len(df)): possess_single_row(row, index, args) else: # Parallel processing with multiple workers manager = Manager() task_queue = manager.Queue() for index, row in df.iterrows(): task_queue.put((index, row)) with tqdm(total=len(df), desc="Finished tasks") as pbar: with concurrent.futures.ThreadPoolExecutor( max_workers=args.num_workers ) as executor: futures = [] for _ in range(args.num_workers): futures.append(executor.submit(worker, task_queue, args, pbar)) for future in concurrent.futures.as_completed(futures): future.result() ================================================ FILE: utils/read_depth.py ================================================ import zipfile import numpy as np import OpenEXR def read_depth(zip_file_path): """ Read depth from zipped exr files. """ valid_width, valid_height = 0, 0 depth_data_list = [] with zipfile.ZipFile(zip_file_path, "r") as z: for file_name in sorted(z.namelist()): with z.open(file_name) as f: try: exr = OpenEXR.InputFile(f) except OSError: # Sometimes EXR loader might fail, we return all nan maps. assert valid_width > 0 and valid_height > 0 depth_data_list.append( np.full((valid_height, valid_width), np.nan, dtype=np.float32)) continue header = exr.header() dw = header["dataWindow"] valid_width = width = dw.max.x - dw.min.x + 1 valid_height = height = dw.max.y - dw.min.y + 1 channels = exr.channels(["Z"]) depth_data = np.frombuffer( channels[0], dtype=np.float16).reshape((height, width)) depth_data_list.append(depth_data.astype(np.float32)) # Note that the depth with a negative value is an invalid depth. # It can be set to the farthest point or other operations. depth_array = np.array(depth_data_list) depth_array_safe = np.where(depth_array == 0, 1e-12, depth_array) return 1.0 / depth_array_safe ================================================ FILE: utils/read_video.py ================================================ """ Video reading utilities with memory optimization and multiple backend support. """ import gc import math import os import re import warnings from fractions import Fraction from typing import Any, Dict, List, Optional, Tuple, Union from tools.logger import test_lg import av import cv2 import numpy as np import torch from torchvision import get_video_backend from torchvision.io.video import _check_av_available MAX_NUM_FRAMES = 2500 def read_video_av( filename: str, start_pts: Union[float, Fraction] = 0, end_pts: Optional[Union[float, Fraction]] = None, pts_unit: str = "pts", output_format: str = "THWC", ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: """ Read video frames using PyAV backend with memory optimization. Modified from torchvision.io.video.read_video with improvements: - No audio extraction (returns empty aframes) - PyAV backend only - Added container.close() and gc.collect() to prevent memory leaks - Optimized for memory efficiency """ # Validate format output_format = output_format.upper() if output_format not in ("THWC", "TCHW"): raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.") # Check file existence if not os.path.exists(filename): raise RuntimeError(f"File not found: {filename}") # Validate backend assert get_video_backend() == "pyav", "pyav backend is required for read_video_av" _check_av_available() # Validate time range if end_pts is None: end_pts = float("inf") if end_pts < start_pts: raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}") # Extract video metadata info = {} container = av.open(filename, metadata_errors="ignore") video_fps = container.streams.video[0].average_rate if video_fps is not None: info["video_fps"] = float(video_fps) # Get frame dimensions iter_video = container.decode(**{"video": 0}) frame = next(iter_video).to_rgb().to_ndarray() height, width = frame.shape[:2] total_frames = container.streams.video[0].frames if total_frames == 0: total_frames = MAX_NUM_FRAMES warnings.warn(f"total_frames is 0, using {MAX_NUM_FRAMES} as a fallback") container.close() del container # Pre-allocate frame buffer (np.zeros doesn't actually allocate memory) video_frames = np.zeros((total_frames, height, width, 3), dtype=np.uint8) # Read video frames try: container = av.open(filename, metadata_errors="ignore") assert container.streams.video is not None video_frames = _read_from_stream( video_frames, container, start_pts, end_pts, pts_unit, container.streams.video[0], {"video": 0}, filename=filename, ) except av.AVError as e: print(f"[Warning] Error while reading video {filename}: {e}") # Convert to tensor and adjust format vframes = torch.from_numpy(video_frames).clone() del video_frames if output_format == "TCHW": # Convert [T,H,W,C] to [T,C,H,W] vframes = vframes.permute(0, 3, 1, 2) aframes = torch.empty((1, 0), dtype=torch.float32) return vframes, aframes, info def _read_from_stream( video_frames, container: "av.container.Container", start_offset: float, end_offset: float, pts_unit: str, stream: "av.stream.Stream", stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]], filename: Optional[str] = None, ) -> List["av.frame.Frame"]: """Read frames from video stream with proper buffering and seeking""" # Convert time units if pts_unit == "sec": start_offset = int(math.floor(start_offset * (1 / stream.time_base))) if end_offset != float("inf"): end_offset = int(math.ceil(end_offset * (1 / stream.time_base))) else: warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.") # Check if buffering is needed for DivX packed B-frames should_buffer = True max_buffer_size = 5 if stream.type == "video": extradata = stream.codec_context.extradata if extradata and b"DivX" in extradata: pos = extradata.find(b"DivX") d = extradata[pos:] o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d) if o is None: o = re.search(rb"DivX(\d+)b(\d+)(\w)", d) if o is not None: should_buffer = o.group(3) == b"p" # Calculate seek offset with safety margin seek_offset = start_offset seek_offset = max(seek_offset - 1, 0) # Safety margin for seeking if should_buffer: seek_offset = max(seek_offset - max_buffer_size, 0) # Seek to start position try: container.seek(seek_offset, any_frame=False, backward=True, stream=stream) except av.AVError as e: print(f"[Warning] Error while seeking video {filename}: {e}") return [] # Read frames from stream buffer_count = 0 frames_pts = [] cnt = 0 try: for _idx, frame in enumerate(container.decode(**stream_name)): frames_pts.append(frame.pts) video_frames[cnt] = frame.to_rgb().to_ndarray() cnt += 1 if cnt >= len(video_frames): break if frame.pts >= end_offset: if should_buffer and buffer_count < max_buffer_size: buffer_count += 1 continue break except av.AVError as e: print(f"[Warning] Error while reading video {filename}: {e}") # Clean up resources to prevent memory leaks container.close() del container gc.collect() # Force garbage collection for PyAV threads # ensure that the results are sorted wrt the pts # NOTE: here we assert frames_pts is sorted start_ptr = 0 end_ptr = cnt while start_ptr < end_ptr and frames_pts[start_ptr] < start_offset: start_ptr += 1 while start_ptr < end_ptr and frames_pts[end_ptr - 1] > end_offset: end_ptr -= 1 if start_offset > 0 and start_offset not in frames_pts[start_ptr:end_ptr]: # if there is no frame that exactly matches the pts of start_offset # add the last frame smaller than start_offset, to guarantee that # we will have all the necessary data. This is most useful for audio if start_ptr > 0: start_ptr -= 1 result = video_frames[start_ptr:end_ptr].copy() return result def read_video_cv2(filename, start_pts=None, end_pts=None, pts_unit="pts"): """ Read video using OpenCV backend. """ if pts_unit != "frames": warnings.warn("Using pts_unit other than 'frames' is not supported for cv2 backend") cap = cv2.VideoCapture(filename) # Get video metadata fps = cap.get(cv2.CAP_PROP_FPS) frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Calculate frame range if start_pts is None: start_pts = 0 if end_pts is None: end_pts = frame_count # Limit frame range to video bounds start_pts = max(0, start_pts) end_pts = min(frame_count, end_pts) num_frames = end_pts - start_pts if num_frames <= 0: return torch.zeros(0, 3, 0, 0), None, {"video_fps": fps} # Seek to start frame cap.set(cv2.CAP_PROP_POS_FRAMES, start_pts) # Read frames frames = [] for i in range(num_frames): ret, frame = cap.read() if not ret: break # Convert BGR to RGB and change HWC to CHW format frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = torch.from_numpy(frame).permute(2, 0, 1).float() frames.append(frame) cap.release() if frames: video_tensor = torch.stack(frames) else: video_tensor = torch.zeros(0, 3, 0, 0) metadata = {"video_fps": fps} return video_tensor, None, metadata def read_video(video_path, backend="av"): """ Read video using specified backend. """ if backend == "cv2": vframes, vinfo = read_video_cv2(video_path) elif backend == "av": vframes, _, vinfo = read_video_av(filename=video_path, pts_unit="sec", output_format="TCHW") else: raise ValueError(f"Unsupported backend: {backend}") return vframes, vinfo ================================================ FILE: utils/scene_detect.py ================================================ """ Video scene detection and timestamp processing utility. This module provides functionality for: - Scene detection using PySceneDetect library - Timestamp processing and filtering - Scene duration management - Parallel processing of video files """ import argparse import os import ast import concurrent.futures import queue import numpy as np import pandas as pd from tqdm import tqdm from scenedetect import ( AdaptiveDetector, detect, ContentDetector, SceneManager, open_video, ) from multiprocessing import Manager def timecode_to_seconds(timecode): """Convert timecode string to seconds.""" h, m, s = map(float, timecode.split(":")) return h * 3600 + m * 60 + s def seconds_to_timecode(seconds): """Convert seconds to timecode string format.""" h = int(seconds // 3600) m = int((seconds % 3600) // 60) s = seconds % 60 return f"{h:02d}:{m:02d}:{s:06.3f}" def process_single_row( row, frame_skip=0, start_remove_sec=0, end_remove_sec=0, min_seconds=2, max_seconds=15, backend="opencv", ): """ Process a single video file for scene detection. """ video_path = row["video_path"] detector1 = ContentDetector(threshold=21, min_scene_len=15) detector2 = AdaptiveDetector( adaptive_threshold=3.0, min_scene_len=15, luma_only=True ) detector = [detector1, detector2] try: if isinstance(detector, list): scene_manager = SceneManager() for i in detector: scene_manager.add_detector(i) if backend == "opencv": video = open_video(video_path) elif backend == "av": video = open_video(video_path, backend="pyav") # Get video frame rate fps = video.frame_rate scene_manager.detect_scenes(video=video, frame_skip=frame_skip) scene_list = scene_manager.get_scene_list() else: video = open_video(video_path) # Get video frame rate fps = video.frame_rate scene_list = detect(video_path, detector, start_in_scene=True) if not scene_list: # If no scenes are detected, treat the entire video as one scene video_duration = video.duration timestamp = [("00:00:00.000", seconds_to_timecode(video_duration.get_seconds()))] else: timestamp = [(s.get_timecode(), t.get_timecode()) for s, t in scene_list] # Process timestamps: remove specified seconds from start/end, filter by duration new_timestamp = [] total_remove_sec = start_remove_sec + end_remove_sec for start_timecode, end_timecode in timestamp: start_seconds = timecode_to_seconds(start_timecode) end_seconds = timecode_to_seconds(end_timecode) duration = end_seconds - start_seconds # Only record scenes longer than total removal time if duration >= total_remove_sec: new_start_seconds = start_seconds + start_remove_sec new_end_seconds = end_seconds - end_remove_sec new_duration = new_end_seconds - new_start_seconds if new_duration <= max_seconds: # Duration within max_seconds, check if meets min_seconds if min_seconds <= new_duration: new_start_timecode = seconds_to_timecode(new_start_seconds) new_end_timecode = seconds_to_timecode(new_end_seconds) new_timestamp.append((new_start_timecode, new_end_timecode)) else: # Duration exceeds max_seconds, split into segments current_start = new_start_seconds while current_start + max_seconds <= new_end_seconds: new_start_timecode = seconds_to_timecode(current_start) new_end_timecode = seconds_to_timecode( current_start + max_seconds ) new_timestamp.append((new_start_timecode, new_end_timecode)) current_start += max_seconds # Handle remaining segment last_duration = new_end_seconds - current_start if last_duration >= min_seconds: new_start_timecode = seconds_to_timecode(current_start) new_end_timecode = seconds_to_timecode(new_end_seconds) new_timestamp.append((new_start_timecode, new_end_timecode)) return True, str(new_timestamp), float(fps) except Exception as e: print(f"Video '{video_path}' with error {e}") return False, "", None def timecode_to_frames(timecode, fps): """Convert timecode to frame number using fps.""" h, m, s = map(float, timecode.split(":")) total_seconds = h * 3600 + m * 60 + s return int(total_seconds * fps) def worker(task_queue, results_queue, args): """ Worker function for parallel scene detection processing. """ while True: try: index, row = task_queue.get(timeout=1) except queue.Empty: break result = process_single_row( row, frame_skip=args.frame_skip, start_remove_sec=args.start_remove_sec, end_remove_sec=args.end_remove_sec, min_seconds=args.min_seconds, max_seconds=args.max_seconds, backend=args.backend, ) results_queue.put((index, result)) task_queue.task_done() def parse_args(): """Parse command line arguments for scene detection.""" parser = argparse.ArgumentParser() parser.add_argument( "--csv_path", type=str, required=True, help="Path to the input CSV file containing video paths.", ) parser.add_argument( "--num_workers", type=int, default=1, help="#workers for concurrent.futures" ) parser.add_argument( "--frame_skip", type=int, default=0, help="skip frame for detect_scenes" ) parser.add_argument( "--start_remove_sec", type=float, default=0, help="Seconds to remove from the start of each timestamp", ) parser.add_argument( "--end_remove_sec", type=float, default=0, help="Seconds to remove from the end of each timestamp", ) parser.add_argument( "--min_seconds", type=float, default=2, help="Minimum duration of a scene in seconds", ) parser.add_argument( "--max_seconds", type=float, default=15, help="Maximum duration of a scene in seconds", ) parser.add_argument( "--backend", type=str, default="opencv", choices=["opencv", "av"], help="Backend for video reading", ) parser.add_argument( "--disable_parallel", action="store_true", help="Disable parallel processing" ) args = parser.parse_args() return args def main(): args = parse_args() csv_path = args.csv_path if not os.path.exists(csv_path): print(f"csv file '{csv_path}' not found. Exit.") return csv = pd.read_csv(csv_path) ret = [] if args.disable_parallel: for index, row in tqdm(csv.iterrows(), total=len(csv)): succ, timestamps, fps = process_single_row( row, frame_skip=args.frame_skip, start_remove_sec=args.start_remove_sec, end_remove_sec=args.end_remove_sec, min_seconds=args.min_seconds, max_seconds=args.max_seconds, ) csv.at[index, "fps"] = fps csv.at[index, "timestamp"] = timestamps ret.append((index, (succ, timestamps, fps))) else: manager = Manager() task_queue = manager.Queue() results_queue = manager.Queue() # Add all tasks to queue for index, row in csv.iterrows(): task_queue.put((index, row)) # Set number of workers if args.num_workers is not None: num_workers = args.num_workers else: num_workers = os.cpu_count() or 1 # Process videos in parallel with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor: futures = [] for _ in range(num_workers): future = executor.submit(worker, task_queue, results_queue, args) futures.append(future) processed = 0 total_tasks = len(csv) with tqdm(total=total_tasks, desc="Processing videos") as pbar: while processed < total_tasks: try: ret.append(results_queue.get(timeout=1)) processed += 1 pbar.update(1) except queue.Empty: if all(f.done() for f in futures) and results_queue.empty(): break for future in futures: future.result() # Collect results while not results_queue.empty(): ret.append(results_queue.get()) # Sort results by index ret.sort(key=lambda x: x[0]) succ, timestamps, fps_list = list(zip(*[result for _, result in ret])) csv["fps"] = fps_list csv["timestamp"] = timestamps csv = csv[np.array(succ)] def calculate_frame_numbers(row): """Calculate frame numbers from timestamps and fps.""" timestamp = ast.literal_eval(row["timestamp"]) fps = row["fps"] frame_numbers = [ (timecode_to_frames(start, fps), timecode_to_frames(end, fps)) for start, end in timestamp ] return str(frame_numbers) csv["frame_numbers"] = csv.apply(calculate_frame_numbers, axis=1) # Save results to new CSV file wo_ext, ext = os.path.splitext(csv_path) out_path = f"{wo_ext}_timestamp{ext}" csv.to_csv(out_path, index=False) print( f"New csv (shape={csv.shape}) with timestamp and frame numbers saved to '{out_path}'." ) if __name__ == "__main__": main() ================================================ FILE: viser/.clang-format ================================================ # C++ formatting rules; used for WebAssembly code. BasedOnStyle: LLVM AlignAfterOpenBracket: BlockIndent BinPackArguments: false BinPackParameters: false IndentWidth: 4 ================================================ FILE: viser/.gitignore ================================================ *.swp *.swo *.pyc *.egg-info *.ipynb_checkpoints __pycache__ .coverage htmlcov .mypy_cache .dmypy.json .hypothesis .envrc .lvimrc .DS_Store .envrc .vite build src/viser/client/build src/viser/client/.nodeenv record3d_dance ================================================ FILE: viser/.pre-commit-config.yaml ================================================ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks default_language_version: python: python3 repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v3.2.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. rev: v0.6.2 hooks: # Run the linter. - id: ruff args: [--fix] # Run the formatter. - id: ruff-format ================================================ FILE: viser/.prettierignore ================================================ *.mjs build/ ================================================ FILE: viser/LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: viser/README.md ================================================

viser logo viser viser logo

pyright typescript-compile codecov

### This repo is a customized version of https://github.com/nerfstudio-project/viser for project MonST3R (https://monst3r-project.github.io/) `viser` is a library for interactive 3D visualization in Python. Features include: - API for visualizing 3D primitives - GUI building blocks: buttons, checkboxes, text inputs, sliders, etc. - Scene interaction tools (clicks, selection, transform gizmos) - Programmatic camera control and rendering - An entirely web-based client, for easy use over SSH! For usage and API reference, see our documentation. ## Installation You can install `viser` with `pip`: ```bash pip install viser ``` To include example dependencies: ```bash pip install viser[examples] ``` After an example script is running, you can connect by navigating to the printed URL (default: `http://localhost:8080`). See also: our [development docs](https://viser.studio/latest/development/). ## Examples **Point cloud visualization** https://github.com/nerfstudio-project/viser/assets/6992947/df35c6ee-78a3-43ad-a2c7-1dddf83f7458 Source: `./examples/07_record3d_visualizer.py` **Gaussian splatting visualization** https://github.com/nerfstudio-project/viser/assets/6992947/c51b4871-6cc8-4987-8751-2bf186bcb1ae Source: [WangFeng18/3d-gaussian-splatting](https://github.com/WangFeng18/3d-gaussian-splatting) and [heheyas/gaussian_splatting_3d](https://github.com/heheyas/gaussian_splatting_3d). **SMPLX visualizer** https://github.com/nerfstudio-project/viser/assets/6992947/78ba0e09-612d-4678-abf3-beaeeffddb01 Source: `./example/08_smpl_visualizer.py` ## Acknowledgements `viser` is heavily inspired by packages like [Pangolin](https://github.com/stevenlovegrove/Pangolin), [rviz](https://wiki.ros.org/rviz/), [meshcat](https://github.com/rdeits/meshcat), and [Gradio](https://github.com/gradio-app/gradio). It's made possible by several open-source projects. The web client is implemented using [React](https://react.dev/), with: - [Vite](https://vitejs.dev/) / [Rollup](https://rollupjs.org/) for bundling - [three.js](https://threejs.org/) via [react-three-fiber](https://github.com/pmndrs/react-three-fiber) and [drei](https://github.com/pmndrs/drei) - [Mantine](https://mantine.dev/) for UI components - [zustand](https://github.com/pmndrs/zustand) for state management - [vanilla-extract](https://vanilla-extract.style/) for stylesheets The Python API communicates via [msgpack](https://msgpack.org/index.html) and [websockets](https://websockets.readthedocs.io/en/stable/index.html). ================================================ FILE: viser/docs/.gitignore ================================================ build/ ================================================ FILE: viser/docs/Makefile ================================================ # Minimal makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build SPHINXPROJ = viser SOURCEDIR = source BUILDDIR = ./build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: viser/docs/source/_static/css/custom.css ================================================ img.sidebar-logo { width: 5em; margin: 1em 0 0 0; } ================================================ FILE: viser/docs/source/_templates/sidebar/brand.html ================================================ {%- endif %} {%- if theme_light_logo and theme_dark_logo %} {%- endif %} {% endblock brand_content %}
Version: {{ version }}
================================================ FILE: viser/docs/source/camera_handles.md ================================================ # Camera Handles .. autoclass:: viser.CameraHandle :members: :undoc-members: :inherited-members: ================================================ FILE: viser/docs/source/client_handles.md ================================================ # Client Handles .. autoclass:: viser.ClientHandle :members: :undoc-members: :inherited-members: .. autoclass:: viser.NotificationHandle :members: :undoc-members: :inherited-members: ================================================ FILE: viser/docs/source/conf.py ================================================ # -*- coding: utf-8 -*- # # Configuration file for the Sphinx documentation builder. # # This file does only contain a selection of the most common options. For a # full list see the documentation: # http://www.sphinx-doc.org/en/stable/config from pathlib import Path from typing import Dict, List import m2r2 import toml # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # # -- Project information ----------------------------------------------------- project = "viser" copyright = "2024" author = "brentyi" # The short X.Y version version: str = toml.load( Path(__file__).absolute().parent.parent.parent / "pyproject.toml" )["project"]["version"] # Formatting! # 0.1.30 => v0.1.30 # dev => dev if not version.isalpha(): version = "v" + version # The full version, including alpha/beta/rc tags release = "" # -- General configuration --------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. # # needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ "sphinx.ext.autodoc", "sphinx.ext.todo", "sphinx.ext.coverage", "sphinx.ext.mathjax", "sphinx.ext.githubpages", "sphinx.ext.napoleon", # "sphinx.ext.inheritance_diagram", "sphinx.ext.viewcode", "m2r2", "sphinxcontrib.programoutput", "sphinxcontrib.ansi", "sphinxcontrib.googleanalytics", # google analytics extension https://github.com/sphinx-contrib/googleanalytics/tree/master ] programoutput_use_ansi = True html_ansi_stylesheet = "black-on-white.css" html_static_path = ["_static"] html_theme_options = { "light_css_variables": { "color-code-background": "#f4f4f4", "color-code-foreground": "#000", }, "footer_icons": [ { "name": "GitHub", "url": "https://github.com/nerfstudio-project/viser", "html": """ """, "class": "", }, ], "light_logo": "logo.svg", "dark_logo": "logo.svg", } # Pull documentation types from hints autodoc_typehints = "both" autodoc_class_signature = "separated" autodoc_default_options = { "members": True, "member-order": "bysource", "undoc-members": True, "inherited-members": True, "exclude-members": "__init__, __post_init__", "imported-members": True, } # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # source_suffix = [".rst", ".md"] # source_suffix = ".rst" # The master toctree document. master_doc = "index" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. language: str = "en" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path . exclude_patterns: List[str] = [] # The name of the Pygments (syntax highlighting) style to use. pygments_style = "monokai" # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = "furo" html_title = "viser" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # # html_theme_options = {} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". # html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. # # The default sidebars (for documents that don't match any pattern) are # defined by theme itself. Builtin themes are using these templates by # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', # 'searchbox.html']``. # # html_sidebars = {} # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. htmlhelp_basename = "viser_doc" # -- Options for Github output ------------------------------------------------ sphinx_to_github = True sphinx_to_github_verbose = True sphinx_to_github_encoding = "utf-8" # -- Options for LaTeX output ------------------------------------------------ latex_elements: Dict[str, str] = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', # Additional stuff for the LaTeX preamble. # # 'preamble': '', # Latex figure (float) alignment # # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ ( master_doc, "viser.tex", "viser", "brentyi", "manual", ), ] # -- Options for manual page output ------------------------------------------ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [(master_doc, "viser", "viser documentation", [author], 1)] # -- Options for Texinfo output ---------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ ( master_doc, "viser", "viser", author, "viser", "viser", "Miscellaneous", ), ] # -- Extension configuration -------------------------------------------------- # Google Analytics ID googleanalytics_id = "G-RRGY51J5ZH" # -- Options for todo extension ---------------------------------------------- # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True # -- Enable Markdown -> RST conversion ---------------------------------------- def docstring(app, what, name, obj, options, lines): md = "\n".join(lines) rst = m2r2.convert(md) lines.clear() lines += rst.splitlines() def setup(app): app.connect("autodoc-process-docstring", docstring) app.add_css_file("css/custom.css") # -- Napoleon settings ------------------------------------------------------- # Settings for parsing non-sphinx style docstrings. We use Google style in this # project. napoleon_google_docstring = True napoleon_numpy_docstring = False napoleon_include_init_with_doc = False napoleon_include_private_with_doc = False napoleon_include_special_with_doc = True napoleon_use_admonition_for_examples = False napoleon_use_admonition_for_notes = False napoleon_use_admonition_for_references = False napoleon_use_ivar = False napoleon_use_param = True napoleon_use_rtype = True napoleon_preprocess_types = False napoleon_type_aliases = None napoleon_attr_annotations = True ================================================ FILE: viser/docs/source/conventions.md ================================================ # Frame Conventions In this note, we describe the coordinate frame conventions used in `viser`. ## Scene tree naming Each object that we add to the scene in viser is instantiated as a node in a scene tree. The structure of this tree is determined by the names assigned to the nodes. If we add a coordinate frame called `/base_link/shoulder/wrist`, it signifies three nodes: the `wrist` is a child of the `shoulder` which is a child of the `base_link`. If we set the transformation of a given node like `/base_link/shoulder`, both it and its child `/base_link/shoulder/wrist` will move. Its parent, `/base_link`, will be unaffected. ## Poses Poses in `viser` are defined using a pair of fields: - `wxyz`, a unit quaternion orientation term. This should always be 4D. - `position`, a translation term. This should always be 3D. These correspond to a transformation from coordinates in the local frame to the parent frame: .. math:: p_\mathrm{parent} = \begin{bmatrix} R & t \end{bmatrix}\begin{bmatrix}p_\mathrm{local} \\ 1\end{bmatrix} where `wxyz` is the quaternion form of the :math:`\mathrm{SO}(3)` matrix :math:`R` and `position` is the :math:`\mathbb{R}^3` translation term :math:`t`. ## World coordinates In the world coordinate space, +Z points upward by default. This can be overridden with :func:`viser.SceneApi.set_up_direction()`. ## Cameras In `viser`, all camera parameters exposed to the Python API use the COLMAP/OpenCV convention: - Forward: +Z - Up: -Y - Right: +X Confusingly, this is different from Nerfstudio, which adopts the OpenGL/Blender convention: - Forward: -Z - Up: +Y - Right: +X Conversion between the two is a simple 180 degree rotation around the local X-axis. ================================================ FILE: viser/docs/source/development.md ================================================ # Development In this note, we outline current practices, tools, and workflows for `viser` development. We assume that the repository is cloned to `~/viser`. ## Python install We recommend an editable install for Python development, ideally in a virtual environment (eg via conda). ```bash # Install package. cd ~/viser pip install -e . # Install example dependencies. pip install -e .[examples] ``` After installation, any of the example scripts (`~/viser/examples`) should be runnable. A few of them require downloading assets, which can be done via the scripts in `~/viser/examples/assets`. **Linting, formatting, type-checking.** First, install developer tools: ```bash # Using pip. pip install -e .[dev] pre-commit install ``` It would be hard to write unit tests for `viser`. We rely on static typing for robustness. To check your code, you can run the following: ```bash # runs linting, formatting, and type-checking viser-dev-checks ``` ## Message updates The `viser` frontend and backend communicate via a shared set of message definitions: - On the server, these are defined as Python dataclasses in `~/viser/src/viser/_messages.py`. - On the client, these are defined as TypeScript interfaces in `~/viser/src/viser/client/src/WebsocketMessages.tsx`. Note that there is a 1:1 correspondence between the dataclasses message types and the TypeScript ones. The TypeScript definitions should not be manually modified. Instead, changes should be made in Python and synchronized via the `sync_message_defs.py` script: ``` cd ~/viser python sync_message_defs.py ``` ## Client development For client development, we can start by launching a relevant Python script. The examples are a good place to start: ``` cd ~/viser/examples python 05_camera_commands.py ``` When a `viser` script is launched, two URLs will be printed: - An HTTP URL, like `http://localhost:8080`, which can be used to open a _pre-built_ version of the React frontend. - A websocket URL, like `ws://localhost:8080`, which client applications can connect to. If changes to the client source files are detected on startup, `viser` will re-build the client automatically. This is okay for quick changes, but for faster iteration we can also launch a development version of the frontend, which will reflect changes we make to the client source files (`~/viser/src/viser/client/src`) without a full build. This requires a few more steps. **Installing dependencies.** 1. [Install nodejs.](https://nodejs.dev/en/download/package-manager) 2. [Install yarn.](https://yarnpkg.com/getting-started/install) 3. Install dependencies. ``` cd ~/viser/src/viser/client yarn install ``` **Launching client.** To launch the client, we can run: ``` cd ~/viser/src/viser/client yarn start ``` from the `viser/src/viser/client` directory. After opening the client in a web browser, the websocket server address typically needs to be updated in the "Server" tab. **Formatting.** We use [prettier](https://prettier.io/docs/en/install.html). This can be run via one of: - `prettier -w .` - `npx prettier -w .` from `~/viser/src/viser/client`. ================================================ FILE: viser/docs/source/events.md ================================================ # Events We define a small set of event types, which are passed to callback functions when events like clicks or GUI updates are triggered. .. autoclass:: viser.ScenePointerEvent() .. autoclass:: viser.SceneNodePointerEvent() .. autoclass:: viser.GuiEvent() ================================================ FILE: viser/docs/source/examples/00_coordinate_frames.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. Coordinate frames ========================================== In this basic example, we visualize a set of coordinate frames. Naming for all scene nodes are hierarchical; /tree/branch, for example, is defined relative to /tree. .. code-block:: python :linenos: import random import time import viser server = viser.ViserServer() while True: # Add some coordinate frames to the scene. These will be visualized in the viewer. server.scene.add_frame( "/tree", wxyz=(1.0, 0.0, 0.0, 0.0), position=(random.random() * 2.0, 2.0, 0.2), ) server.scene.add_frame( "/tree/branch", wxyz=(1.0, 0.0, 0.0, 0.0), position=(random.random() * 2.0, 2.0, 0.2), ) leaf = server.scene.add_frame( "/tree/branch/leaf", wxyz=(1.0, 0.0, 0.0, 0.0), position=(random.random() * 2.0, 2.0, 0.2), ) time.sleep(5.0) # Remove the leaf node from the scene. leaf.remove() time.sleep(0.5) ================================================ FILE: viser/docs/source/examples/01_image.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. Images ========================================== Example for sending images to the viewer. We can send backgrond images to display behind the viewer (useful for visualizing NeRFs), or images to render as 3D textures. .. code-block:: python :linenos: import time from pathlib import Path import imageio.v3 as iio import numpy as onp import viser def main() -> None: server = viser.ViserServer() # Add a background image. server.scene.set_background_image( iio.imread(Path(__file__).parent / "assets/Cal_logo.png"), format="png", ) # Add main image. server.scene.add_image( "/img", iio.imread(Path(__file__).parent / "assets/Cal_logo.png"), 4.0, 4.0, format="png", wxyz=(1.0, 0.0, 0.0, 0.0), position=(2.0, 2.0, 0.0), ) while True: server.scene.add_image( "/noise", onp.random.randint( 0, 256, size=(400, 400, 3), dtype=onp.uint8, ), 4.0, 4.0, format="jpeg", wxyz=(1.0, 0.0, 0.0, 0.0), position=(2.0, 2.0, -1e-2), ) time.sleep(0.2) if __name__ == "__main__": main() ================================================ FILE: viser/docs/source/examples/02_gui.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. GUI basics ========================================== Examples of basic GUI elements that we can create, read from, and write to. .. code-block:: python :linenos: import time import numpy as onp import viser def main() -> None: server = viser.ViserServer() # Add some common GUI elements: number inputs, sliders, vectors, checkboxes. with server.gui.add_folder("Read-only"): gui_counter = server.gui.add_number( "Counter", initial_value=0, disabled=True, ) gui_slider = server.gui.add_slider( "Slider", min=0, max=100, step=1, initial_value=0, disabled=True, ) gui_progress = server.gui.add_progress_bar(25, animated=True) with server.gui.add_folder("Editable"): gui_vector2 = server.gui.add_vector2( "Position", initial_value=(0.0, 0.0), step=0.1, ) gui_vector3 = server.gui.add_vector3( "Size", initial_value=(1.0, 1.0, 1.0), step=0.25, ) with server.gui.add_folder("Text toggle"): gui_checkbox_hide = server.gui.add_checkbox( "Hide", initial_value=False, ) gui_text = server.gui.add_text( "Text", initial_value="Hello world", ) gui_button = server.gui.add_button("Button") gui_checkbox_disable = server.gui.add_checkbox( "Disable", initial_value=False, ) gui_rgb = server.gui.add_rgb( "Color", initial_value=(255, 255, 0), ) gui_multi_slider = server.gui.add_multi_slider( "Multi slider", min=0, max=100, step=1, initial_value=(0, 30, 100), marks=((0, "0"), (50, "5"), (70, "7"), 99), ) gui_slider_positions = server.gui.add_slider( "# sliders", min=0, max=10, step=1, initial_value=3, marks=((0, "0"), (5, "5"), (7, "7"), 10), ) gui_upload_button = server.gui.add_upload_button( "Upload", icon=viser.Icon.UPLOAD ) @gui_upload_button.on_upload def _(_) -> None: """Callback for when a file is uploaded.""" file = gui_upload_button.value print(file.name, len(file.content), "bytes") # Pre-generate a point cloud to send. point_positions = onp.random.uniform(low=-1.0, high=1.0, size=(5000, 3)) color_coeffs = onp.random.uniform(0.4, 1.0, size=(point_positions.shape[0])) counter = 0 while True: # We can set the value of an input to a particular value. Changes are # automatically reflected in connected clients. gui_counter.value = counter gui_slider.value = counter % 100 # We can set the position of a scene node with `.position`, and read the value # of a gui element with `.value`. Changes are automatically reflected in # connected clients. server.scene.add_point_cloud( "/point_cloud", points=point_positions * onp.array(gui_vector3.value, dtype=onp.float32), colors=( onp.tile(gui_rgb.value, point_positions.shape[0]).reshape((-1, 3)) * color_coeffs[:, None] ).astype(onp.uint8), position=gui_vector2.value + (0,), point_shape="circle", ) gui_progress.value = float((counter % 100)) # We can use `.visible` and `.disabled` to toggle GUI elements. gui_text.visible = not gui_checkbox_hide.value gui_button.visible = not gui_checkbox_hide.value gui_rgb.disabled = gui_checkbox_disable.value gui_button.disabled = gui_checkbox_disable.value gui_upload_button.disabled = gui_checkbox_disable.value # Update the number of handles in the multi-slider. if gui_slider_positions.value != len(gui_multi_slider.value): gui_multi_slider.value = onp.linspace( 0, 100, gui_slider_positions.value, dtype=onp.int64 ) counter += 1 time.sleep(0.01) if __name__ == "__main__": main() ================================================ FILE: viser/docs/source/examples/03_gui_callbacks.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. GUI callbacks ========================================== Asynchronous usage of GUI elements: we can attach callbacks that are called as soon as we get updates. .. code-block:: python :linenos: import time import numpy as onp from typing_extensions import assert_never import viser def main() -> None: server = viser.ViserServer() gui_reset_scene = server.gui.add_button("Reset Scene") gui_plane = server.gui.add_dropdown( "Grid plane", ("xz", "xy", "yx", "yz", "zx", "zy") ) def update_plane() -> None: server.scene.add_grid( "/grid", width=10.0, height=20.0, width_segments=10, height_segments=20, plane=gui_plane.value, ) gui_plane.on_update(lambda _: update_plane()) with server.gui.add_folder("Control"): gui_show_frame = server.gui.add_checkbox("Show Frame", initial_value=True) gui_show_everything = server.gui.add_checkbox( "Show Everything", initial_value=True ) gui_axis = server.gui.add_dropdown("Axis", ("x", "y", "z")) gui_include_z = server.gui.add_checkbox("Z in dropdown", initial_value=True) @gui_include_z.on_update def _(_) -> None: gui_axis.options = ("x", "y", "z") if gui_include_z.value else ("x", "y") with server.gui.add_folder("Sliders"): gui_location = server.gui.add_slider( "Location", min=-5.0, max=5.0, step=0.05, initial_value=0.0 ) gui_num_points = server.gui.add_slider( "# Points", min=1000, max=200_000, step=1000, initial_value=10_000 ) def draw_frame() -> None: axis = gui_axis.value if axis == "x": pos = (gui_location.value, 0.0, 0.0) elif axis == "y": pos = (0.0, gui_location.value, 0.0) elif axis == "z": pos = (0.0, 0.0, gui_location.value) else: assert_never(axis) server.scene.add_frame( "/frame", wxyz=(1.0, 0.0, 0.0, 0.0), position=pos, show_axes=gui_show_frame.value, axes_length=5.0, ) def draw_points() -> None: num_points = gui_num_points.value server.scene.add_point_cloud( "/frame/point_cloud", points=onp.random.normal(size=(num_points, 3)), colors=onp.random.randint(0, 256, size=(num_points, 3)), ) # We can (optionally) also attach callbacks! # Here, we update the point clouds + frames whenever any of the GUI items are updated. gui_show_frame.on_update(lambda _: draw_frame()) gui_show_everything.on_update( lambda _: server.scene.set_global_visibility(gui_show_everything.value) ) gui_axis.on_update(lambda _: draw_frame()) gui_location.on_update(lambda _: draw_frame()) gui_num_points.on_update(lambda _: draw_points()) @gui_reset_scene.on_click def _(_) -> None: """Reset the scene when the reset button is clicked.""" gui_show_frame.value = True gui_location.value = 0.0 gui_axis.value = "x" gui_num_points.value = 10_000 draw_frame() draw_points() # Finally, let's add the initial frame + point cloud and just loop infinitely. :) update_plane() draw_frame() draw_points() while True: time.sleep(1.0) if __name__ == "__main__": main() ================================================ FILE: viser/docs/source/examples/04_camera_poses.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. Camera poses ========================================== Example showing how we can detect new clients and read camera poses from them. .. code-block:: python :linenos: import time import viser server = viser.ViserServer() server.scene.world_axes.visible = True @server.on_client_connect def _(client: viser.ClientHandle) -> None: print("new client!") # This will run whenever we get a new camera! @client.camera.on_update def _(_: viser.CameraHandle) -> None: print(f"New camera on client {client.client_id}!") # Show the client ID in the GUI. gui_info = client.gui.add_text("Client ID", initial_value=str(client.client_id)) gui_info.disabled = True while True: # Get all currently connected clients. clients = server.get_clients() print("Connected client IDs", clients.keys()) for id, client in clients.items(): print(f"Camera pose for client {id}") print(f"\twxyz: {client.camera.wxyz}") print(f"\tposition: {client.camera.position}") print(f"\tfov: {client.camera.fov}") print(f"\taspect: {client.camera.aspect}") print(f"\tlast update: {client.camera.update_timestamp}") time.sleep(2.0) ================================================ FILE: viser/docs/source/examples/05_camera_commands.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. Camera commands ========================================== In addition to reads, camera parameters also support writes. These are synced to the corresponding client automatically. .. code-block:: python :linenos: import time import numpy as onp import viser import viser.transforms as tf server = viser.ViserServer() num_frames = 20 @server.on_client_connect def _(client: viser.ClientHandle) -> None: """For each client that connects, we create a set of random frames + a click handler for each frame. When a frame is clicked, we move the camera to the corresponding frame. """ rng = onp.random.default_rng(0) def make_frame(i: int) -> None: # Sample a random orientation + position. wxyz = rng.normal(size=4) wxyz /= onp.linalg.norm(wxyz) position = rng.uniform(-3.0, 3.0, size=(3,)) # Create a coordinate frame and label. frame = client.scene.add_frame(f"/frame_{i}", wxyz=wxyz, position=position) client.scene.add_label(f"/frame_{i}/label", text=f"Frame {i}") # Move the camera when we click a frame. @frame.on_click def _(_): T_world_current = tf.SE3.from_rotation_and_translation( tf.SO3(client.camera.wxyz), client.camera.position ) T_world_target = tf.SE3.from_rotation_and_translation( tf.SO3(frame.wxyz), frame.position ) @ tf.SE3.from_translation(onp.array([0.0, 0.0, -0.5])) T_current_target = T_world_current.inverse() @ T_world_target for j in range(20): T_world_set = T_world_current @ tf.SE3.exp( T_current_target.log() * j / 19.0 ) # We can atomically set the orientation and the position of the camera # together to prevent jitter that might happen if one was set before the # other. with client.atomic(): client.camera.wxyz = T_world_set.rotation().wxyz client.camera.position = T_world_set.translation() client.flush() # Optional! time.sleep(1.0 / 60.0) # Mouse interactions should orbit around the frame origin. client.camera.look_at = frame.position for i in range(num_frames): make_frame(i) while True: time.sleep(1.0) ================================================ FILE: viser/docs/source/examples/06_mesh.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. Meshes ========================================== Visualize a mesh. To get the demo data, see ``./assets/download_dragon_mesh.sh``. .. code-block:: python :linenos: import time from pathlib import Path import numpy as onp import trimesh import viser import viser.transforms as tf mesh = trimesh.load_mesh(str(Path(__file__).parent / "assets/dragon.obj")) assert isinstance(mesh, trimesh.Trimesh) mesh.apply_scale(0.05) vertices = mesh.vertices faces = mesh.faces print(f"Loaded mesh with {vertices.shape} vertices, {faces.shape} faces") server = viser.ViserServer() server.scene.add_mesh_simple( name="/simple", vertices=vertices, faces=faces, wxyz=tf.SO3.from_x_radians(onp.pi / 2).wxyz, position=(0.0, 0.0, 0.0), ) server.scene.add_mesh_trimesh( name="/trimesh", mesh=mesh.smoothed(), wxyz=tf.SO3.from_x_radians(onp.pi / 2).wxyz, position=(0.0, 5.0, 0.0), ) while True: time.sleep(10.0) ================================================ FILE: viser/docs/source/examples/07_record3d_visualizer.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. Record3D visualizer ========================================== Parse and stream record3d captures. To get the demo data, see ``./assets/download_record3d_dance.sh``. .. code-block:: python :linenos: import time from pathlib import Path import numpy as onp import tyro from tqdm.auto import tqdm import viser import viser.extras import viser.transforms as tf def main( data_path: Path = Path(__file__).parent / "assets/record3d_dance", downsample_factor: int = 4, max_frames: int = 100, share: bool = False, ) -> None: server = viser.ViserServer() if share: server.request_share_url() print("Loading frames!") loader = viser.extras.Record3dLoader(data_path) num_frames = min(max_frames, loader.num_frames()) # Add playback UI. with server.gui.add_folder("Playback"): gui_timestep = server.gui.add_slider( "Timestep", min=0, max=num_frames - 1, step=1, initial_value=0, disabled=True, ) gui_next_frame = server.gui.add_button("Next Frame", disabled=True) gui_prev_frame = server.gui.add_button("Prev Frame", disabled=True) gui_playing = server.gui.add_checkbox("Playing", True) gui_framerate = server.gui.add_slider( "FPS", min=1, max=60, step=0.1, initial_value=loader.fps ) gui_framerate_options = server.gui.add_button_group( "FPS options", ("10", "20", "30", "60") ) # Frame step buttons. @gui_next_frame.on_click def _(_) -> None: gui_timestep.value = (gui_timestep.value + 1) % num_frames @gui_prev_frame.on_click def _(_) -> None: gui_timestep.value = (gui_timestep.value - 1) % num_frames # Disable frame controls when we're playing. @gui_playing.on_update def _(_) -> None: gui_timestep.disabled = gui_playing.value gui_next_frame.disabled = gui_playing.value gui_prev_frame.disabled = gui_playing.value # Set the framerate when we click one of the options. @gui_framerate_options.on_click def _(_) -> None: gui_framerate.value = int(gui_framerate_options.value) prev_timestep = gui_timestep.value # Toggle frame visibility when the timestep slider changes. @gui_timestep.on_update def _(_) -> None: nonlocal prev_timestep current_timestep = gui_timestep.value with server.atomic(): frame_nodes[current_timestep].visible = True frame_nodes[prev_timestep].visible = False prev_timestep = current_timestep server.flush() # Optional! # Load in frames. server.scene.add_frame( "/frames", wxyz=tf.SO3.exp(onp.array([onp.pi / 2.0, 0.0, 0.0])).wxyz, position=(0, 0, 0), show_axes=False, ) frame_nodes: list[viser.FrameHandle] = [] for i in tqdm(range(num_frames)): frame = loader.get_frame(i) position, color = frame.get_point_cloud(downsample_factor) # Add base frame. frame_nodes.append(server.scene.add_frame(f"/frames/t{i}", show_axes=False)) # Place the point cloud in the frame. server.scene.add_point_cloud( name=f"/frames/t{i}/point_cloud", points=position, colors=color, point_size=0.01, point_shape="rounded", ) # Place the frustum. fov = 2 * onp.arctan2(frame.rgb.shape[0] / 2, frame.K[0, 0]) aspect = frame.rgb.shape[1] / frame.rgb.shape[0] server.scene.add_camera_frustum( f"/frames/t{i}/frustum", fov=fov, aspect=aspect, scale=0.15, image=frame.rgb[::downsample_factor, ::downsample_factor], wxyz=tf.SO3.from_matrix(frame.T_world_camera[:3, :3]).wxyz, position=frame.T_world_camera[:3, 3], ) # Add some axes. server.scene.add_frame( f"/frames/t{i}/frustum/axes", axes_length=0.05, axes_radius=0.005, ) # Hide all but the current frame. for i, frame_node in enumerate(frame_nodes): frame_node.visible = i == gui_timestep.value # Playback update loop. prev_timestep = gui_timestep.value while True: if gui_playing.value: gui_timestep.value = (gui_timestep.value + 1) % num_frames time.sleep(1.0 / gui_framerate.value) if __name__ == "__main__": tyro.cli(main) ================================================ FILE: viser/docs/source/examples/08_smpl_visualizer.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. SMPL model visualizer ========================================== Visualizer for SMPL human body models. Requires a .npz model file. See here for download instructions: https://github.com/vchoutas/smplx?tab=readme-ov-file#downloading-the-model .. code-block:: python :linenos: from __future__ import annotations import time from dataclasses import dataclass from pathlib import Path import numpy as np import numpy as onp import tyro import viser import viser.transforms as tf @dataclass(frozen=True) class SmplOutputs: vertices: np.ndarray faces: np.ndarray T_world_joint: np.ndarray # (num_joints, 4, 4) T_parent_joint: np.ndarray # (num_joints, 4, 4) class SmplHelper: """Helper for models in the SMPL family, implemented in numpy.""" def __init__(self, model_path: Path) -> None: assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!" body_dict = dict(**onp.load(model_path, allow_pickle=True)) self._J_regressor = body_dict["J_regressor"] self._weights = body_dict["weights"] self._v_template = body_dict["v_template"] self._posedirs = body_dict["posedirs"] self._shapedirs = body_dict["shapedirs"] self._faces = body_dict["f"] self.num_joints: int = self._weights.shape[-1] self.num_betas: int = self._shapedirs.shape[-1] self.parent_idx: np.ndarray = body_dict["kintree_table"][0] def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutputs: # Get shaped vertices + joint positions, when all local poses are identity. v_tpose = self._v_template + np.einsum("vxb,b->vx", self._shapedirs, betas) j_tpose = np.einsum("jv,vx->jx", self._J_regressor, v_tpose) # Local SE(3) transforms. T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4) T_parent_joint[:, :3, :3] = joint_rotmats T_parent_joint[0, :3, 3] = j_tpose[0] T_parent_joint[1:, :3, 3] = j_tpose[1:] - j_tpose[self.parent_idx[1:]] # Forward kinematics. T_world_joint = T_parent_joint.copy() for i in range(1, self.num_joints): T_world_joint[i] = T_world_joint[self.parent_idx[i]] @ T_parent_joint[i] # Linear blend skinning. pose_delta = (joint_rotmats[1:, ...] - np.eye(3)).flatten() v_blend = v_tpose + np.einsum("byn,n->by", self._posedirs, pose_delta) v_delta = np.ones((v_blend.shape[0], self.num_joints, 4)) v_delta[:, :, :3] = v_blend[:, None, :] - j_tpose[None, :, :] v_posed = np.einsum( "jxy,vj,vjy->vx", T_world_joint[:, :3, :], self._weights, v_delta ) return SmplOutputs(v_posed, self._faces, T_world_joint, T_parent_joint) def main(model_path: Path) -> None: server = viser.ViserServer() server.scene.set_up_direction("+y") server.gui.configure_theme(control_layout="collapsible") # Main loop. We'll read pose/shape from the GUI elements, compute the mesh, # and then send the updated mesh in a loop. model = SmplHelper(model_path) gui_elements = make_gui_elements( server, num_betas=model.num_betas, num_joints=model.num_joints, parent_idx=model.parent_idx, ) while True: # Do nothing if no change. time.sleep(0.02) if not gui_elements.changed: continue gui_elements.changed = False # Compute SMPL outputs. smpl_outputs = model.get_outputs( betas=np.array([x.value for x in gui_elements.gui_betas]), joint_rotmats=tf.SO3.exp( # (num_joints, 3) np.array([x.value for x in gui_elements.gui_joints]) ).as_matrix(), ) server.scene.add_mesh_simple( "/human", smpl_outputs.vertices, smpl_outputs.faces, wireframe=gui_elements.gui_wireframe.value, color=gui_elements.gui_rgb.value, ) # Match transform control gizmos to joint positions. for i, control in enumerate(gui_elements.transform_controls): control.position = smpl_outputs.T_parent_joint[i, :3, 3] @dataclass class GuiElements: """Structure containing handles for reading from GUI elements.""" gui_rgb: viser.GuiInputHandle[tuple[int, int, int]] gui_wireframe: viser.GuiInputHandle[bool] gui_betas: list[viser.GuiInputHandle[float]] gui_joints: list[viser.GuiInputHandle[tuple[float, float, float]]] transform_controls: list[viser.TransformControlsHandle] changed: bool """This flag will be flipped to True whenever the mesh needs to be re-generated.""" def make_gui_elements( server: viser.ViserServer, num_betas: int, num_joints: int, parent_idx: np.ndarray, ) -> GuiElements: """Make GUI elements for interacting with the model.""" tab_group = server.gui.add_tab_group() def set_changed(_) -> None: out.changed = True # out is define later! # GUI elements: mesh settings + visibility. with tab_group.add_tab("View", viser.Icon.VIEWFINDER): gui_rgb = server.gui.add_rgb("Color", initial_value=(90, 200, 255)) gui_wireframe = server.gui.add_checkbox("Wireframe", initial_value=False) gui_show_controls = server.gui.add_checkbox("Handles", initial_value=False) gui_rgb.on_update(set_changed) gui_wireframe.on_update(set_changed) @gui_show_controls.on_update def _(_): for control in transform_controls: control.visible = gui_show_controls.value # GUI elements: shape parameters. with tab_group.add_tab("Shape", viser.Icon.BOX): gui_reset_shape = server.gui.add_button("Reset Shape") gui_random_shape = server.gui.add_button("Random Shape") @gui_reset_shape.on_click def _(_): for beta in gui_betas: beta.value = 0.0 @gui_random_shape.on_click def _(_): for beta in gui_betas: beta.value = onp.random.normal(loc=0.0, scale=1.0) gui_betas = [] for i in range(num_betas): beta = server.gui.add_slider( f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0 ) gui_betas.append(beta) beta.on_update(set_changed) # GUI elements: joint angles. with tab_group.add_tab("Joints", viser.Icon.ANGLE): gui_reset_joints = server.gui.add_button("Reset Joints") gui_random_joints = server.gui.add_button("Random Joints") @gui_reset_joints.on_click def _(_): for joint in gui_joints: joint.value = (0.0, 0.0, 0.0) @gui_random_joints.on_click def _(_): for joint in gui_joints: # It's hard to uniformly sample orientations directly in so(3), so we # first sample on S^3 and then convert. quat = onp.random.normal(loc=0.0, scale=1.0, size=(4,)) quat /= onp.linalg.norm(quat) joint.value = tf.SO3(wxyz=quat).log() gui_joints: list[viser.GuiInputHandle[tuple[float, float, float]]] = [] for i in range(num_joints): gui_joint = server.gui.add_vector3( label=f"Joint {i}", initial_value=(0.0, 0.0, 0.0), step=0.05, ) gui_joints.append(gui_joint) def set_callback_in_closure(i: int) -> None: @gui_joint.on_update def _(_): transform_controls[i].wxyz = tf.SO3.exp( np.array(gui_joints[i].value) ).wxyz out.changed = True set_callback_in_closure(i) # Transform control gizmos on joints. transform_controls: list[viser.TransformControlsHandle] = [] prefixed_joint_names = [] # Joint names, but prefixed with parents. for i in range(num_joints): prefixed_joint_name = f"joint_{i}" if i > 0: prefixed_joint_name = ( prefixed_joint_names[parent_idx[i]] + "/" + prefixed_joint_name ) prefixed_joint_names.append(prefixed_joint_name) controls = server.scene.add_transform_controls( f"/smpl/{prefixed_joint_name}", depth_test=False, scale=0.2 * (0.75 ** prefixed_joint_name.count("/")), disable_axes=True, disable_sliders=True, visible=gui_show_controls.value, ) transform_controls.append(controls) def set_callback_in_closure(i: int) -> None: @controls.on_update def _(_) -> None: axisangle = tf.SO3(transform_controls[i].wxyz).log() gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2]) set_callback_in_closure(i) out = GuiElements( gui_rgb, gui_wireframe, gui_betas, gui_joints, transform_controls=transform_controls, changed=True, ) return out if __name__ == "__main__": tyro.cli(main, description=__doc__) ================================================ FILE: viser/docs/source/examples/09_urdf_visualizer.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. Robot URDF visualizer ========================================== Requires yourdfpy and robot_descriptions. Any URDF supported by yourdfpy should work. * https://github.com/robot-descriptions/robot_descriptions.py * https://github.com/clemense/yourdfpy The :class:`viser.extras.ViserUrdf` is a lightweight interface between yourdfpy and viser. It can also take a path to a local URDF file as input. .. code-block:: python :linenos: from __future__ import annotations import time from typing import Literal import numpy as onp import tyro import viser from robot_descriptions.loaders.yourdfpy import load_robot_description from viser.extras import ViserUrdf def create_robot_control_sliders( server: viser.ViserServer, viser_urdf: ViserUrdf ) -> tuple[list[viser.GuiInputHandle[float]], list[float]]: """Create slider for each joint of the robot. We also update robot model when slider moves.""" slider_handles: list[viser.GuiInputHandle[float]] = [] initial_config: list[float] = [] for joint_name, ( lower, upper, ) in viser_urdf.get_actuated_joint_limits().items(): lower = lower if lower is not None else -onp.pi upper = upper if upper is not None else onp.pi initial_pos = 0.0 if lower < 0 and upper > 0 else (lower + upper) / 2.0 slider = server.gui.add_slider( label=joint_name, min=lower, max=upper, step=1e-3, initial_value=initial_pos, ) slider.on_update( # When sliders move, we update the URDF configuration. lambda _: viser_urdf.update_cfg( onp.array([slider.value for slider in slider_handles]) ) ) slider_handles.append(slider) initial_config.append(initial_pos) return slider_handles, initial_config def main( robot_type: Literal[ "panda", "ur10", "cassie", "allegro_hand", "barrett_hand", "robotiq_2f85", "atlas_drc", "g1", "h1", "anymal_c", "go2", ] = "panda", ) -> None: # Start viser server. server = viser.ViserServer() # Load URDF. # # This takes either a yourdfpy.URDF object or a path to a .urdf file. viser_urdf = ViserUrdf( server, urdf_or_path=load_robot_description(robot_type + "_description"), ) # Create sliders in GUI that help us move the robot joints. with server.gui.add_folder("Joint position control"): (slider_handles, initial_config) = create_robot_control_sliders( server, viser_urdf ) # Set initial robot configuration. viser_urdf.update_cfg(onp.array(initial_config)) # Create joint reset button. reset_button = server.gui.add_button("Reset") @reset_button.on_click def _(_): for s, init_q in zip(slider_handles, initial_config): s.value = init_q # Sleep forever. while True: time.sleep(10.0) if __name__ == "__main__": tyro.cli(main) ================================================ FILE: viser/docs/source/examples/10_realsense.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. RealSense visualizer ========================================== Connect to a RealSense camera, then visualize RGB-D readings as a point clouds. Requires pyrealsense2. .. code-block:: python :linenos: from __future__ import annotations import contextlib import numpy as np import numpy.typing as npt import pyrealsense2 as rs # type: ignore from tqdm.auto import tqdm import viser @contextlib.contextmanager def realsense_pipeline(fps: int = 30): """Context manager that yields a RealSense pipeline.""" # Configure depth and color streams. pipeline = rs.pipeline() # type: ignore config = rs.config() # type: ignore pipeline_wrapper = rs.pipeline_wrapper(pipeline) # type: ignore config.resolve(pipeline_wrapper) config.enable_stream(rs.stream.depth, rs.format.z16, fps) # type: ignore config.enable_stream(rs.stream.color, rs.format.rgb8, fps) # type: ignore # Start streaming. pipeline.start(config) yield pipeline # Close pipeline when done. pipeline.close() def point_cloud_arrays_from_frames( depth_frame, color_frame ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.uint8]]: """Maps realsense frames to two arrays. Returns: - A point position array: (N, 3) float32. - A point color array: (N, 3) uint8. """ # Processing blocks. Could be tuned. point_cloud = rs.pointcloud() # type: ignore decimate = rs.decimation_filter() # type: ignore decimate.set_option(rs.option.filter_magnitude, 3) # type: ignore # Downsample depth frame. depth_frame = decimate.process(depth_frame) # Map texture and calculate points from frames. Uses frame intrinsics. point_cloud.map_to(color_frame) points = point_cloud.calculate(depth_frame) # Get color coordinates. texture_uv = ( np.asanyarray(points.get_texture_coordinates()) .view(np.float32) .reshape((-1, 2)) ) color_image = np.asanyarray(color_frame.get_data()) color_h, color_w, _ = color_image.shape # Note: for points that aren't in the view of our RGB camera, we currently clamp to # the closes available RGB pixel. We could also just remove these points. texture_uv = texture_uv.clip(0.0, 1.0) # Get positions and colors. positions = np.asanyarray(points.get_vertices()).view(np.float32) positions = positions.reshape((-1, 3)) colors = color_image[ (texture_uv[:, 1] * (color_h - 1.0)).astype(np.int32), (texture_uv[:, 0] * (color_w - 1.0)).astype(np.int32), :, ] N = positions.shape[0] assert positions.shape == (N, 3) assert positions.dtype == np.float32 assert colors.shape == (N, 3) assert colors.dtype == np.uint8 return positions, colors def main(): # Start visualization server. server = viser.ViserServer() with realsense_pipeline() as pipeline: for i in tqdm(range(10000000)): # Wait for a coherent pair of frames: depth and color frames = pipeline.wait_for_frames() depth_frame = frames.get_depth_frame() color_frame = frames.get_color_frame() # Compute point cloud from frames. positions, colors = point_cloud_arrays_from_frames(depth_frame, color_frame) R = np.array( [ [1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, -1.0, 0.0], ], dtype=np.float32, ) positions = positions @ R.T # Visualize. server.scene.add_point_cloud( "/realsense", points=positions * 10.0, colors=colors, point_size=0.1, ) if __name__ == "__main__": main() ================================================ FILE: viser/docs/source/examples/11_colmap_visualizer.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. COLMAP visualizer ========================================== Visualize COLMAP sparse reconstruction outputs. To get demo data, see ``./assets/download_colmap_garden.sh``. .. code-block:: python :linenos: import random import time from pathlib import Path import imageio.v3 as iio import numpy as onp import tyro from tqdm.auto import tqdm import viser import viser.transforms as tf from viser.extras.colmap import ( read_cameras_binary, read_images_binary, read_points3d_binary, ) def main( colmap_path: Path = Path(__file__).parent / "assets/colmap_garden/sparse/0", images_path: Path = Path(__file__).parent / "assets/colmap_garden/images_8", downsample_factor: int = 2, ) -> None: """Visualize COLMAP sparse reconstruction outputs. Args: colmap_path: Path to the COLMAP reconstruction directory. images_path: Path to the COLMAP images directory. downsample_factor: Downsample factor for the images. """ server = viser.ViserServer() server.gui.configure_theme(titlebar_content=None, control_layout="collapsible") # Load the colmap info. cameras = read_cameras_binary(colmap_path / "cameras.bin") images = read_images_binary(colmap_path / "images.bin") points3d = read_points3d_binary(colmap_path / "points3D.bin") gui_reset_up = server.gui.add_button( "Reset up direction", hint="Set the camera control 'up' direction to the current camera's 'up'.", ) @gui_reset_up.on_click def _(event: viser.GuiEvent) -> None: client = event.client assert client is not None client.camera.up_direction = tf.SO3(client.camera.wxyz) @ onp.array( [0.0, -1.0, 0.0] ) gui_points = server.gui.add_slider( "Max points", min=1, max=len(points3d), step=1, initial_value=min(len(points3d), 50_000), ) gui_frames = server.gui.add_slider( "Max frames", min=1, max=len(images), step=1, initial_value=min(len(images), 100), ) gui_point_size = server.gui.add_number("Point size", initial_value=0.05) def visualize_colmap() -> None: """Send all COLMAP elements to viser for visualization. This could be optimized a ton!""" # Set the point cloud. points = onp.array([points3d[p_id].xyz for p_id in points3d]) colors = onp.array([points3d[p_id].rgb for p_id in points3d]) points_selection = onp.random.choice( points.shape[0], gui_points.value, replace=False ) points = points[points_selection] colors = colors[points_selection] server.scene.add_point_cloud( name="/colmap/pcd", points=points, colors=colors, point_size=gui_point_size.value, ) # Interpret the images and cameras. img_ids = [im.id for im in images.values()] random.shuffle(img_ids) img_ids = sorted(img_ids[: gui_frames.value]) def attach_callback( frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle ) -> None: @frustum.on_click def _(_) -> None: for client in server.get_clients().values(): client.camera.wxyz = frame.wxyz client.camera.position = frame.position for img_id in tqdm(img_ids): img = images[img_id] cam = cameras[img.camera_id] # Skip images that don't exist. image_filename = images_path / img.name if not image_filename.exists(): continue T_world_camera = tf.SE3.from_rotation_and_translation( tf.SO3(img.qvec), img.tvec ).inverse() frame = server.scene.add_frame( f"/colmap/frame_{img_id}", wxyz=T_world_camera.rotation().wxyz, position=T_world_camera.translation(), axes_length=0.1, axes_radius=0.005, ) # For pinhole cameras, cam.params will be (fx, fy, cx, cy). if cam.model != "PINHOLE": print(f"Expected pinhole camera, but got {cam.model}") H, W = cam.height, cam.width fy = cam.params[1] image = iio.imread(image_filename) image = image[::downsample_factor, ::downsample_factor] frustum = server.scene.add_camera_frustum( f"/colmap/frame_{img_id}/frustum", fov=2 * onp.arctan2(H / 2, fy), aspect=W / H, scale=0.15, image=image, ) attach_callback(frustum, frame) need_update = True @gui_points.on_update def _(_) -> None: nonlocal need_update need_update = True @gui_frames.on_update def _(_) -> None: nonlocal need_update need_update = True @gui_point_size.on_update def _(_) -> None: nonlocal need_update need_update = True while True: if need_update: need_update = False server.scene.reset() visualize_colmap() time.sleep(1e-3) if __name__ == "__main__": tyro.cli(main) ================================================ FILE: viser/docs/source/examples/12_click_meshes.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. Mesh click events ========================================== Click on meshes to select them. The index of the last clicked mesh is displayed in the GUI. .. code-block:: python :linenos: import time import matplotlib import viser def main() -> None: grid_shape = (4, 5) server = viser.ViserServer() with server.gui.add_folder("Last clicked"): x_value = server.gui.add_number( label="x", initial_value=0, disabled=True, hint="x coordinate of the last clicked mesh", ) y_value = server.gui.add_number( label="y", initial_value=0, disabled=True, hint="y coordinate of the last clicked mesh", ) def add_swappable_mesh(i: int, j: int) -> None: """Simple callback that swaps between: - a gray box - a colored box - a colored sphere Color is chosen based on the position (i, j) of the mesh in the grid. """ colormap = matplotlib.colormaps["tab20"] def create_mesh(counter: int) -> None: if counter == 0: color = (0.8, 0.8, 0.8) else: index = (i * grid_shape[1] + j) / (grid_shape[0] * grid_shape[1]) color = colormap(index)[:3] if counter in (0, 1): handle = server.scene.add_box( name=f"/sphere_{i}_{j}", position=(i, j, 0.0), color=color, dimensions=(0.5, 0.5, 0.5), ) else: handle = server.scene.add_icosphere( name=f"/sphere_{i}_{j}", radius=0.4, color=color, position=(i, j, 0.0), ) @handle.on_click def _(_) -> None: x_value.value = i y_value.value = j # The new mesh will replace the old one because the names # /sphere_{i}_{j} are the same. create_mesh((counter + 1) % 3) create_mesh(0) for i in range(grid_shape[0]): for j in range(grid_shape[1]): add_swappable_mesh(i, j) while True: time.sleep(10.0) if __name__ == "__main__": main() ================================================ FILE: viser/docs/source/examples/13_theming.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. Theming ========================================== Viser includes support for light theming. .. code-block:: python :linenos: import time import viser from viser.theme import TitlebarButton, TitlebarConfig, TitlebarImage def main(): server = viser.ViserServer(label="Viser Theming") buttons = ( TitlebarButton( text="Getting Started", icon=None, href="https://nerf.studio", ), TitlebarButton( text="Github", icon="GitHub", href="https://github.com/nerfstudio-project/nerfstudio", ), TitlebarButton( text="Documentation", icon="Description", href="https://docs.nerf.studio", ), ) image = TitlebarImage( image_url_light="https://docs.nerf.studio/_static/imgs/logo.png", image_url_dark="https://docs.nerf.studio/_static/imgs/logo-dark.png", image_alt="NerfStudio Logo", href="https://docs.nerf.studio/", ) titlebar_theme = TitlebarConfig(buttons=buttons, image=image) server.gui.add_markdown( "Viser includes support for light theming via the `.configure_theme()` method." ) gui_theme_code = server.gui.add_markdown("no theme applied yet") # GUI elements for controllable values. titlebar = server.gui.add_checkbox("Titlebar", initial_value=True) dark_mode = server.gui.add_checkbox("Dark mode", initial_value=True) show_logo = server.gui.add_checkbox("Show logo", initial_value=True) show_share_button = server.gui.add_checkbox("Show share button", initial_value=True) brand_color = server.gui.add_rgb("Brand color", (230, 180, 30)) control_layout = server.gui.add_dropdown( "Control layout", ("floating", "fixed", "collapsible") ) control_width = server.gui.add_dropdown( "Control width", ("small", "medium", "large"), initial_value="medium" ) synchronize = server.gui.add_button("Apply theme", icon=viser.Icon.CHECK) def synchronize_theme() -> None: server.gui.configure_theme( titlebar_content=titlebar_theme if titlebar.value else None, control_layout=control_layout.value, control_width=control_width.value, dark_mode=dark_mode.value, show_logo=show_logo.value, show_share_button=show_share_button.value, brand_color=brand_color.value, ) gui_theme_code.content = f""" ### Current applied theme ``` server.gui.configure_theme( titlebar_content={"titlebar_content" if titlebar.value else None}, control_layout="{control_layout.value}", control_width="{control_width.value}", dark_mode={dark_mode.value}, show_logo={show_logo.value}, show_share_button={show_share_button.value}, brand_color={brand_color.value}, ) ``` """ synchronize.on_click(lambda _: synchronize_theme()) synchronize_theme() while True: time.sleep(10.0) # main() if __name__ == "__main__": main() ================================================ FILE: viser/docs/source/examples/14_markdown.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. Markdown demonstration ========================================== Viser GUI has MDX 2 support. .. code-block:: python :linenos: import time from pathlib import Path import viser server = viser.ViserServer() server.scene.world_axes.visible = True markdown_counter = server.gui.add_markdown("Counter: 0") here = Path(__file__).absolute().parent button = server.gui.add_button("Remove blurb") checkbox = server.gui.add_checkbox("Visibility", initial_value=True) markdown_source = (here / "./assets/mdx_example.mdx").read_text() markdown_blurb = server.gui.add_markdown( content=markdown_source, image_root=here, ) @button.on_click def _(_): markdown_blurb.remove() @checkbox.on_update def _(_): markdown_blurb.visible = checkbox.value counter = 0 while True: markdown_counter.content = f"Counter: {counter}" counter += 1 time.sleep(0.1) ================================================ FILE: viser/docs/source/examples/15_gui_in_scene.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. 3D GUI elements ========================================== ``add_3d_gui_container()`` allows standard GUI elements to be incorporated directly into a 3D scene. In this example, we click on coordinate frames to show actions that can be performed on them. .. code-block:: python :linenos: import time from typing import Optional import numpy as onp import viser import viser.transforms as tf server = viser.ViserServer() server.gui.configure_theme(dark_mode=True) num_frames = 20 @server.on_client_connect def _(client: viser.ClientHandle) -> None: """For each client that connects, we create a set of random frames + a click handler for each frame. When a frame is clicked, we display a 3D gui node. """ rng = onp.random.default_rng(0) displayed_3d_container: Optional[viser.Gui3dContainerHandle] = None def make_frame(i: int) -> None: # Sample a random orientation + position. wxyz = rng.normal(size=4) wxyz /= onp.linalg.norm(wxyz) position = rng.uniform(-3.0, 3.0, size=(3,)) # Create a coordinate frame and label. frame = client.scene.add_frame(f"/frame_{i}", wxyz=wxyz, position=position) # Move the camera when we click a frame. @frame.on_click def _(_): nonlocal displayed_3d_container # Close previously opened GUI. if displayed_3d_container is not None: displayed_3d_container.remove() displayed_3d_container = client.scene.add_3d_gui_container( f"/frame_{i}/gui" ) with displayed_3d_container: go_to = client.gui.add_button("Go to") randomize_orientation = client.gui.add_button("Randomize orientation") close = client.gui.add_button("Close GUI") @go_to.on_click def _(_) -> None: T_world_current = tf.SE3.from_rotation_and_translation( tf.SO3(client.camera.wxyz), client.camera.position ) T_world_target = tf.SE3.from_rotation_and_translation( tf.SO3(frame.wxyz), frame.position ) @ tf.SE3.from_translation(onp.array([0.0, 0.0, -0.5])) T_current_target = T_world_current.inverse() @ T_world_target for j in range(20): T_world_set = T_world_current @ tf.SE3.exp( T_current_target.log() * j / 19.0 ) # Important bit: we atomically set both the orientation and the position # of the camera. with client.atomic(): client.camera.wxyz = T_world_set.rotation().wxyz client.camera.position = T_world_set.translation() time.sleep(1.0 / 60.0) # Mouse interactions should orbit around the frame origin. client.camera.look_at = frame.position @randomize_orientation.on_click def _(_) -> None: wxyz = rng.normal(size=4) wxyz /= onp.linalg.norm(wxyz) frame.wxyz = wxyz @close.on_click def _(_) -> None: nonlocal displayed_3d_container if displayed_3d_container is None: return displayed_3d_container.remove() displayed_3d_container = None for i in range(num_frames): make_frame(i) while True: time.sleep(1.0) ================================================ FILE: viser/docs/source/examples/16_modal.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. Modal basics ========================================== Examples of using modals in Viser. .. code-block:: python :linenos: import time import viser def main(): server = viser.ViserServer() @server.on_client_connect def _(client: viser.ClientHandle) -> None: with client.gui.add_modal("Modal example"): client.gui.add_markdown( "**The input below determines the title of the modal...**" ) gui_title = client.gui.add_text( "Title", initial_value="My Modal", ) modal_button = client.gui.add_button("Show more modals") @modal_button.on_click def _(_) -> None: with client.gui.add_modal(gui_title.value) as modal: client.gui.add_markdown("This is content inside the modal!") client.gui.add_button("Close").on_click(lambda _: modal.close()) while True: time.sleep(0.15) if __name__ == "__main__": main() ================================================ FILE: viser/docs/source/examples/17_background_composite.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. Depth compositing ========================================== In this example, we show how to use a background image with depth compositing. This can be useful when we want a 2D image to occlude 3D geometry, such as for NeRF rendering. .. code-block:: python :linenos: import time import numpy as onp import trimesh import trimesh.creation import viser server = viser.ViserServer() img = onp.random.randint(0, 255, size=(1000, 1000, 3), dtype=onp.uint8) depth = onp.ones((1000, 1000, 1), dtype=onp.float32) # Make a square middle portal. depth[250:750, 250:750, :] = 10.0 img[250:750, 250:750, :] = 255 mesh = trimesh.creation.box((0.5, 0.5, 0.5)) server.scene.add_mesh_trimesh( name="/cube", mesh=mesh, position=(0, 0, 0.0), ) server.scene.set_background_image(img, depth=depth) while True: time.sleep(1.0) ================================================ FILE: viser/docs/source/examples/18_splines.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. Splines ========================================== Make a ball with some random splines. .. code-block:: python :linenos: import time import numpy as onp import viser def main() -> None: server = viser.ViserServer() for i in range(10): positions = onp.random.normal(size=(30, 3)) * 3.0 server.scene.add_spline_catmull_rom( f"/catmull_{i}", positions, tension=0.5, line_width=3.0, color=onp.random.uniform(size=3), segments=100, ) control_points = onp.random.normal(size=(30 * 2 - 2, 3)) * 3.0 server.scene.add_spline_cubic_bezier( f"/cubic_bezier_{i}", positions, control_points, line_width=3.0, color=onp.random.uniform(size=3), segments=100, ) while True: time.sleep(10.0) if __name__ == "__main__": main() ================================================ FILE: viser/docs/source/examples/19_get_renders.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. Get renders ========================================== Example for getting renders from a client's viewport to the Python API. .. code-block:: python :linenos: import time import imageio.v3 as iio import numpy as onp import viser def main(): server = viser.ViserServer() button = server.gui.add_button("Render a GIF") @button.on_click def _(event: viser.GuiEvent) -> None: client = event.client assert client is not None client.scene.reset() images = [] for i in range(20): positions = onp.random.normal(size=(30, 3)) * 3.0 client.scene.add_spline_catmull_rom( f"/catmull_{i}", positions, tension=0.5, line_width=3.0, color=onp.random.uniform(size=3), ) images.append(client.camera.get_render(height=720, width=1280)) print("Generating and sending GIF...") client.send_file_download( "image.gif", iio.imwrite("", images, extension=".gif") ) print("Done!") while True: time.sleep(10.0) if __name__ == "__main__": main() ================================================ FILE: viser/docs/source/examples/20_scene_pointer.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. Scene pointer events. ========================================== This example shows how to use scene pointer events to specify rays, and how they can be used to interact with the scene (e.g., ray-mesh intersections). To get the demo data, see ``./assets/download_dragon_mesh.sh``. .. code-block:: python :linenos: from __future__ import annotations import time from pathlib import Path from typing import cast import numpy as onp import trimesh import trimesh.creation import trimesh.ray import viser import viser.transforms as tf from viser.theme import TitlebarConfig server = viser.ViserServer() server.gui.configure_theme( brand_color=(130, 0, 150), titlebar_content=TitlebarConfig(buttons=(), image=None), ) server.scene.set_up_direction("+y") mesh = cast( trimesh.Trimesh, trimesh.load_mesh(str(Path(__file__).parent / "assets/dragon.obj")) ) mesh.apply_scale(0.05) mesh_handle = server.scene.add_mesh_trimesh( name="/mesh", mesh=mesh, position=(0.0, 0.0, 0.0), ) hit_pos_handles: list[viser.GlbHandle] = [] # Buttons + callbacks will operate on a per-client basis, but will modify the global scene! :) @server.on_client_connect def _(client: viser.ClientHandle) -> None: # Set up the camera -- this gives a nice view of the full mesh. client.camera.position = onp.array([0.0, 0.0, -10.0]) client.camera.wxyz = onp.array([0.0, 0.0, 0.0, 1.0]) # Tests "click" scenepointerevent. click_button_handle = client.gui.add_button("Add sphere", icon=viser.Icon.POINTER) @click_button_handle.on_click def _(_): click_button_handle.disabled = True @client.scene.on_pointer_event(event_type="click") def _(event: viser.ScenePointerEvent) -> None: # Check for intersection with the mesh, using trimesh's ray-mesh intersection. # Note that mesh is in the mesh frame, so we need to transform the ray. R_world_mesh = tf.SO3(mesh_handle.wxyz) R_mesh_world = R_world_mesh.inverse() origin = (R_mesh_world @ onp.array(event.ray_origin)).reshape(1, 3) direction = (R_mesh_world @ onp.array(event.ray_direction)).reshape(1, 3) intersector = trimesh.ray.ray_triangle.RayMeshIntersector(mesh) hit_pos, _, _ = intersector.intersects_location(origin, direction) if len(hit_pos) == 0: return client.scene.remove_pointer_callback() # Get the first hit position (based on distance from the ray origin). hit_pos = hit_pos[onp.argmin(onp.sum((hit_pos - origin) ** 2, axis=-1))] # Create a sphere at the hit location. hit_pos_mesh = trimesh.creation.icosphere(radius=0.1) hit_pos_mesh.vertices += R_world_mesh @ hit_pos hit_pos_mesh.visual.vertex_colors = (0.5, 0.0, 0.7, 1.0) # type: ignore hit_pos_handle = server.scene.add_mesh_trimesh( name=f"/hit_pos_{len(hit_pos_handles)}", mesh=hit_pos_mesh ) hit_pos_handles.append(hit_pos_handle) @client.scene.on_pointer_callback_removed def _(): click_button_handle.disabled = False # Tests "rect-select" scenepointerevent. paint_button_handle = client.gui.add_button("Paint mesh", icon=viser.Icon.PAINT) @paint_button_handle.on_click def _(_): paint_button_handle.disabled = True @client.scene.on_pointer_event(event_type="rect-select") def _(message: viser.ScenePointerEvent) -> None: client.scene.remove_pointer_callback() global mesh_handle camera = message.client.camera # Put the mesh in the camera frame. R_world_mesh = tf.SO3(mesh_handle.wxyz) R_mesh_world = R_world_mesh.inverse() R_camera_world = tf.SE3.from_rotation_and_translation( tf.SO3(camera.wxyz), camera.position ).inverse() vertices = cast(onp.ndarray, mesh.vertices) vertices = (R_mesh_world.as_matrix() @ vertices.T).T vertices = ( R_camera_world.as_matrix() @ onp.hstack([vertices, onp.ones((vertices.shape[0], 1))]).T ).T[:, :3] # Get the camera intrinsics, and project the vertices onto the image plane. fov, aspect = camera.fov, camera.aspect vertices_proj = vertices[:, :2] / vertices[:, 2].reshape(-1, 1) vertices_proj /= onp.tan(fov / 2) vertices_proj[:, 0] /= aspect # Move the origin to the upper-left corner, and scale to [0, 1]. # ... make sure to match the OpenCV's image coordinates! vertices_proj = (1 + vertices_proj) / 2 # Select the vertices that lie inside the 2D selected box, once projected. mask = ( (vertices_proj > onp.array(message.screen_pos[0])) & (vertices_proj < onp.array(message.screen_pos[1])) ).all(axis=1)[..., None] # Update the mesh color based on whether the vertices are inside the box mesh.visual.vertex_colors = onp.where( # type: ignore mask, (0.5, 0.0, 0.7, 1.0), (0.9, 0.9, 0.9, 1.0) ) mesh_handle = server.scene.add_mesh_trimesh( name="/mesh", mesh=mesh, position=(0.0, 0.0, 0.0), ) @client.scene.on_pointer_callback_removed def _(): paint_button_handle.disabled = False # Button to clear spheres. clear_button_handle = client.gui.add_button("Clear scene", icon=viser.Icon.X) @clear_button_handle.on_click def _(_): """Reset the mesh color and remove all click-generated spheres.""" global mesh_handle for handle in hit_pos_handles: handle.remove() hit_pos_handles.clear() mesh.visual.vertex_colors = (0.9, 0.9, 0.9, 1.0) # type: ignore mesh_handle = server.scene.add_mesh_trimesh( name="/mesh", mesh=mesh, position=(0.0, 0.0, 0.0), ) while True: time.sleep(10.0) ================================================ FILE: viser/docs/source/examples/21_set_up_direction.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. Set up direction ========================================== ``.set_up_direction()`` can help us set the global up direction. .. code-block:: python :linenos: import time import viser def main() -> None: server = viser.ViserServer() server.scene.world_axes.visible = True gui_up = server.gui.add_vector3( "Up Direction", initial_value=(0.0, 0.0, 1.0), step=0.01, ) @gui_up.on_update def _(_) -> None: server.scene.set_up_direction(gui_up.value) while True: time.sleep(1.0) if __name__ == "__main__": main() ================================================ FILE: viser/docs/source/examples/22_games.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. Games ========================================== Some two-player games implemented using scene click events. .. code-block:: python :linenos: import time from typing import Literal import numpy as onp import trimesh.creation from typing_extensions import assert_never import viser import viser.transforms as tf def main() -> None: server = viser.ViserServer() server.gui.configure_theme(dark_mode=True) play_connect_4(server) server.gui.add_button("Tic-Tac-Toe").on_click(lambda _: play_tic_tac_toe(server)) server.gui.add_button("Connect 4").on_click(lambda _: play_connect_4(server)) while True: time.sleep(10.0) def play_connect_4(server: viser.ViserServer) -> None: """Play a game of Connect 4.""" server.scene.reset() num_rows = 6 num_cols = 7 whose_turn: Literal["red", "yellow"] = "red" pieces_in_col = [0] * num_cols # Create the board frame. for col in range(num_cols): for row in range(num_rows): server.scene.add_mesh_trimesh( f"/structure/{row}_{col}", trimesh.creation.annulus(0.45, 0.55, 0.125), position=(0.0, col, row), wxyz=tf.SO3.from_y_radians(onp.pi / 2.0).wxyz, ) # Create a sphere to click on for each column. def setup_column(col: int) -> None: sphere = server.scene.add_icosphere( f"/spheres/{col}", radius=0.25, position=(0, col, num_rows - 0.25), color=(255, 255, 255), ) # Drop piece into the column. @sphere.on_click def _(_) -> None: nonlocal whose_turn whose_turn = "red" if whose_turn != "red" else "yellow" row = pieces_in_col[col] if row == num_rows - 1: sphere.remove() pieces_in_col[col] += 1 cylinder = trimesh.creation.cylinder(radius=0.4, height=0.125) piece = server.scene.add_mesh_simple( f"/game_pieces/{row}_{col}", cylinder.vertices, cylinder.faces, wxyz=tf.SO3.from_y_radians(onp.pi / 2.0).wxyz, color={"red": (255, 0, 0), "yellow": (255, 255, 0)}[whose_turn], ) for row_anim in onp.linspace(num_rows - 1, row, num_rows - row + 1): piece.position = ( 0, col, row_anim, ) time.sleep(1.0 / 30.0) for col in range(num_cols): setup_column(col) def play_tic_tac_toe(server: viser.ViserServer) -> None: """Play a game of tic-tac-toe.""" server.scene.reset() whose_turn: Literal["x", "o"] = "x" for i in range(4): server.scene.add_spline_catmull_rom( f"/gridlines/{i}", ((-0.5, -1.5, 0), (-0.5, 1.5, 0)), color=(127, 127, 127), position=(1, 1, 0), wxyz=tf.SO3.from_z_radians(onp.pi / 2 * i).wxyz, ) def draw_symbol(symbol: Literal["x", "o"], i: int, j: int) -> None: """Draw an X or O in the given cell.""" for scale in onp.linspace(0.01, 1.0, 5): if symbol == "x": for k in range(2): server.scene.add_box( f"/symbols/{i}_{j}/{k}", dimensions=(0.7 * scale, 0.125 * scale, 0.125), position=(i, j, 0), color=(0, 0, 255), wxyz=tf.SO3.from_z_radians( onp.pi / 2.0 * k + onp.pi / 4.0 ).wxyz, ) elif symbol == "o": mesh = trimesh.creation.annulus(0.25 * scale, 0.35 * scale, 0.125) server.scene.add_mesh_simple( f"/symbols/{i}_{j}", mesh.vertices, mesh.faces, position=(i, j, 0), color=(255, 0, 0), ) else: assert_never(symbol) server.flush() time.sleep(1.0 / 30.0) def setup_cell(i: int, j: int) -> None: """Create a clickable sphere in a given cell.""" sphere = server.scene.add_icosphere( f"/spheres/{i}_{j}", radius=0.25, position=(i, j, 0), color=(255, 255, 255), ) @sphere.on_click def _(_) -> None: nonlocal whose_turn whose_turn = "x" if whose_turn != "x" else "o" sphere.remove() draw_symbol(whose_turn, i, j) for i in range(3): for j in range(3): setup_cell(i, j) if __name__ == "__main__": main() ================================================ FILE: viser/docs/source/examples/23_plotly.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. Plotly ========================================== Examples of visualizing plotly plots in Viser. .. code-block:: python :linenos: import time import numpy as onp import plotly.express as px import plotly.graph_objects as go from PIL import Image import viser def create_sinusoidal_wave(t: float) -> go.Figure: """Create a sinusoidal wave plot, starting at time t.""" x_data = onp.linspace(t, t + 6 * onp.pi, 50) y_data = onp.sin(x_data) * 10 fig = px.line( x=list(x_data), y=list(y_data), labels={"x": "x", "y": "sin(x)"}, title="Sinusoidal Wave", ) # this sets the margins to be tight around the title. fig.layout.title.automargin = True # type: ignore fig.update_layout( margin=dict(l=20, r=20, t=20, b=20), ) # Reduce plot margins. return fig def main() -> None: server = viser.ViserServer() # Plot type 1: Line plot. line_plot_time = 0.0 line_plot = server.gui.add_plotly(figure=create_sinusoidal_wave(line_plot_time)) # Plot type 2: Image plot. fig = px.imshow(Image.open("assets/Cal_logo.png")) fig.update_layout( margin=dict(l=20, r=20, t=20, b=20), ) server.gui.add_plotly(figure=fig, aspect=1.0) # Plot type 3: 3D Scatter plot. fig = px.scatter_3d( px.data.iris(), x="sepal_length", y="sepal_width", z="petal_width", color="species", ) fig.update_layout(legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01)) fig.update_layout( margin=dict(l=20, r=20, t=20, b=20), ) server.gui.add_plotly(figure=fig, aspect=1.0) while True: # Update the line plot. line_plot_time += 0.1 line_plot.figure = create_sinusoidal_wave(line_plot_time) time.sleep(0.01) if __name__ == "__main__": main() ================================================ FILE: viser/docs/source/examples/24_notification.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. Notifications ========================================== Examples of adding notifications per client in Viser. .. code-block:: python :linenos: import time import viser def main() -> None: server = viser.ViserServer() persistent_notif_button = server.gui.add_button( "Show persistent notification (default)" ) timed_notif_button = server.gui.add_button("Show timed notification") controlled_notif_button = server.gui.add_button("Show controlled notification") loading_notif_button = server.gui.add_button("Show loading notification") remove_controlled_notif = server.gui.add_button("Remove controlled notification") @persistent_notif_button.on_click def _(event: viser.GuiEvent) -> None: """Show persistent notification when the button is clicked.""" client = event.client assert client is not None client.add_notification( title="Persistent notification", body="This can be closed manually and does not disappear on its own!", loading=False, with_close_button=True, auto_close=False, ) @timed_notif_button.on_click def _(event: viser.GuiEvent) -> None: """Show timed notification when the button is clicked.""" client = event.client assert client is not None client.add_notification( title="Timed notification", body="This disappears automatically after 5 seconds!", loading=False, with_close_button=True, auto_close=5000, ) @controlled_notif_button.on_click def _(event: viser.GuiEvent) -> None: """Show controlled notification when the button is clicked.""" client = event.client assert client is not None controlled_notif = client.add_notification( title="Controlled notification", body="This cannot be closed by the user and is controlled in code only!", loading=False, with_close_button=False, auto_close=False, ) @remove_controlled_notif.on_click def _(_) -> None: """Remove controlled notification.""" controlled_notif.remove() @loading_notif_button.on_click def _(event: viser.GuiEvent) -> None: """Show loading notification when the button is clicked.""" client = event.client assert client is not None loading_notif = client.add_notification( title="Loading notification", body="This indicates that some action is in progress! It will be updated in 3 seconds.", loading=True, with_close_button=False, auto_close=False, ) time.sleep(3.0) loading_notif.title = "Updated notification" loading_notif.body = "This notification has been updated!" loading_notif.loading = False loading_notif.with_close_button = True loading_notif.auto_close = 5000 loading_notif.color = "green" while True: time.sleep(1.0) if __name__ == "__main__": main() ================================================ FILE: viser/docs/source/examples/25_smpl_visualizer_skinned.rst ================================================ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. SMPL visualizer (Skinned Mesh) ========================================== Requires a .npz model file. See here for download instructions: https://github.com/vchoutas/smplx?tab=readme-ov-file#downloading-the-model .. code-block:: python :linenos: from __future__ import annotations import time from dataclasses import dataclass from pathlib import Path from typing import List, Tuple import numpy as np import numpy as onp import tyro import viser import viser.transforms as tf @dataclass(frozen=True) class SmplOutputs: vertices: np.ndarray faces: np.ndarray T_world_joint: np.ndarray # (num_joints, 4, 4) T_parent_joint: np.ndarray # (num_joints, 4, 4) class SmplHelper: """Helper for models in the SMPL family, implemented in numpy.""" def __init__(self, model_path: Path) -> None: assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!" body_dict = dict(**onp.load(model_path, allow_pickle=True)) self._J_regressor = body_dict["J_regressor"] self._weights = body_dict["weights"] self._v_template = body_dict["v_template"] self._posedirs = body_dict["posedirs"] self._shapedirs = body_dict["shapedirs"] self._faces = body_dict["f"] self.num_joints: int = self._weights.shape[-1] self.num_betas: int = self._shapedirs.shape[-1] self.parent_idx: np.ndarray = body_dict["kintree_table"][0] def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutputs: # Get shaped vertices + joint positions, when all local poses are identity. v_tpose = self._v_template + np.einsum("vxb,b->vx", self._shapedirs, betas) j_tpose = np.einsum("jv,vx->jx", self._J_regressor, v_tpose) # Local SE(3) transforms. T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4) T_parent_joint[:, :3, :3] = joint_rotmats T_parent_joint[0, :3, 3] = j_tpose[0] T_parent_joint[1:, :3, 3] = j_tpose[1:] - j_tpose[self.parent_idx[1:]] # Forward kinematics. T_world_joint = T_parent_joint.copy() for i in range(1, self.num_joints): T_world_joint[i] = T_world_joint[self.parent_idx[i]] @ T_parent_joint[i] # Linear blend skinning. pose_delta = (joint_rotmats[1:, ...] - np.eye(3)).flatten() v_blend = v_tpose + np.einsum("byn,n->by", self._posedirs, pose_delta) v_delta = np.ones((v_blend.shape[0], self.num_joints, 4)) v_delta[:, :, :3] = v_blend[:, None, :] - j_tpose[None, :, :] v_posed = np.einsum( "jxy,vj,vjy->vx", T_world_joint[:, :3, :], self._weights, v_delta ) return SmplOutputs(v_posed, self._faces, T_world_joint, T_parent_joint) def main(model_path: Path) -> None: server = viser.ViserServer() server.scene.set_up_direction("+y") server.gui.configure_theme(control_layout="collapsible") # Main loop. We'll read pose/shape from the GUI elements, compute the mesh, # and then send the updated mesh in a loop. model = SmplHelper(model_path) gui_elements = make_gui_elements( server, num_betas=model.num_betas, num_joints=model.num_joints, parent_idx=model.parent_idx, ) smpl_outputs = model.get_outputs( betas=np.array([x.value for x in gui_elements.gui_betas]), joint_rotmats=onp.zeros((model.num_joints, 3, 3)) + onp.eye(3), ) bone_wxyzs = np.array( [tf.SO3.from_matrix(R).wxyz for R in smpl_outputs.T_world_joint[:, :3, :3]] ) bone_positions = smpl_outputs.T_world_joint[:, :3, 3] skinned_handle = server.scene.add_mesh_skinned( "/human", smpl_outputs.vertices, smpl_outputs.faces, bone_wxyzs=bone_wxyzs, bone_positions=bone_positions, skin_weights=model._weights, wireframe=gui_elements.gui_wireframe.value, color=gui_elements.gui_rgb.value, ) while True: # Do nothing if no change. time.sleep(0.02) if not gui_elements.changed: continue gui_elements.changed = False # Compute SMPL outputs. smpl_outputs = model.get_outputs( betas=np.array([x.value for x in gui_elements.gui_betas]), joint_rotmats=np.stack( [ tf.SO3.exp(np.array(x.value)).as_matrix() for x in gui_elements.gui_joints ], axis=0, ), ) # Match transform control gizmos to joint positions. for i, control in enumerate(gui_elements.transform_controls): control.position = smpl_outputs.T_parent_joint[i, :3, 3] skinned_handle.bones[i].wxyz = tf.SO3.from_matrix( smpl_outputs.T_world_joint[i, :3, :3] ).wxyz skinned_handle.bones[i].position = smpl_outputs.T_world_joint[i, :3, 3] @dataclass class GuiElements: """Structure containing handles for reading from GUI elements.""" gui_rgb: viser.GuiInputHandle[Tuple[int, int, int]] gui_wireframe: viser.GuiInputHandle[bool] gui_betas: List[viser.GuiInputHandle[float]] gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] transform_controls: List[viser.TransformControlsHandle] changed: bool """This flag will be flipped to True whenever the mesh needs to be re-generated.""" def make_gui_elements( server: viser.ViserServer, num_betas: int, num_joints: int, parent_idx: np.ndarray, ) -> GuiElements: """Make GUI elements for interacting with the model.""" tab_group = server.gui.add_tab_group() def set_changed(_) -> None: out.changed = True # out is define later! # GUI elements: mesh settings + visibility. with tab_group.add_tab("View", viser.Icon.VIEWFINDER): gui_rgb = server.gui.add_rgb("Color", initial_value=(90, 200, 255)) gui_wireframe = server.gui.add_checkbox("Wireframe", initial_value=False) gui_show_controls = server.gui.add_checkbox("Handles", initial_value=True) gui_rgb.on_update(set_changed) gui_wireframe.on_update(set_changed) @gui_show_controls.on_update def _(_): for control in transform_controls: control.visible = gui_show_controls.value # GUI elements: shape parameters. with tab_group.add_tab("Shape", viser.Icon.BOX): gui_reset_shape = server.gui.add_button("Reset Shape") gui_random_shape = server.gui.add_button("Random Shape") @gui_reset_shape.on_click def _(_): for beta in gui_betas: beta.value = 0.0 @gui_random_shape.on_click def _(_): for beta in gui_betas: beta.value = onp.random.normal(loc=0.0, scale=1.0) gui_betas = [] for i in range(num_betas): beta = server.gui.add_slider( f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0 ) gui_betas.append(beta) beta.on_update(set_changed) # GUI elements: joint angles. with tab_group.add_tab("Joints", viser.Icon.ANGLE): gui_reset_joints = server.gui.add_button("Reset Joints") gui_random_joints = server.gui.add_button("Random Joints") @gui_reset_joints.on_click def _(_): for joint in gui_joints: joint.value = (0.0, 0.0, 0.0) @gui_random_joints.on_click def _(_): for joint in gui_joints: # It's hard to uniformly sample orientations directly in so(3), so we # first sample on S^3 and then convert. quat = onp.random.normal(loc=0.0, scale=1.0, size=(4,)) quat /= onp.linalg.norm(quat) joint.value = tf.SO3(wxyz=quat).log() gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] = [] for i in range(num_joints): gui_joint = server.gui.add_vector3( label=f"Joint {i}", initial_value=(0.0, 0.0, 0.0), step=0.05, ) gui_joints.append(gui_joint) def set_callback_in_closure(i: int) -> None: @gui_joint.on_update def _(_): transform_controls[i].wxyz = tf.SO3.exp( np.array(gui_joints[i].value) ).wxyz out.changed = True set_callback_in_closure(i) # Transform control gizmos on joints. transform_controls: List[viser.TransformControlsHandle] = [] prefixed_joint_names = [] # Joint names, but prefixed with parents. for i in range(num_joints): prefixed_joint_name = f"joint_{i}" if i > 0: prefixed_joint_name = ( prefixed_joint_names[parent_idx[i]] + "/" + prefixed_joint_name ) prefixed_joint_names.append(prefixed_joint_name) controls = server.scene.add_transform_controls( f"/smpl/{prefixed_joint_name}", depth_test=False, scale=0.2 * (0.75 ** prefixed_joint_name.count("/")), disable_axes=True, disable_sliders=True, visible=gui_show_controls.value, ) transform_controls.append(controls) def set_callback_in_closure(i: int) -> None: @controls.on_update def _(_) -> None: axisangle = tf.SO3(transform_controls[i].wxyz).log() gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2]) set_callback_in_closure(i) out = GuiElements( gui_rgb, gui_wireframe, gui_betas, gui_joints, transform_controls=transform_controls, changed=True, ) return out if __name__ == "__main__": tyro.cli(main, description=__doc__) ================================================ FILE: viser/docs/source/extras.md ================================================ # Record3D + URDF Helpers .. automodule:: viser.extras ================================================ FILE: viser/docs/source/gui_api.md ================================================ # GUI API .. autoclass:: viser.GuiApi :members: :undoc-members: :inherited-members: ================================================ FILE: viser/docs/source/gui_handles.md ================================================ # GUI Handles .. autoclass:: viser.GuiInputHandle() .. autoclass:: viser.GuiButtonHandle() .. autoclass:: viser.GuiButtonGroupHandle() .. autoclass:: viser.GuiDropdownHandle() .. autoclass:: viser.GuiFolderHandle() .. autoclass:: viser.GuiMarkdownHandle() .. autoclass:: viser.GuiPlotlyHandle() .. autoclass:: viser.GuiTabGroupHandle() .. autoclass:: viser.GuiTabHandle() ================================================ FILE: viser/docs/source/icons.md ================================================ # Icons Icons for GUI elements (such as :meth:`GuiApi.add_button()`) can be specified using the :class:`viser.Icon` enum. .. autoclass:: viser.IconName .. autoclass:: viser.Icon ================================================ FILE: viser/docs/source/index.md ================================================ # viser |mypy| |nbsp| |pyright| |nbsp| |typescript| |nbsp| |versions| **viser** is a library for interactive 3D visualization in Python. Features include: - API for visualizing 3D primitives - GUI building blocks: buttons, checkboxes, text inputs, sliders, etc. - Scene interaction tools (clicks, selection, transform gizmos) - Programmatic camera control and rendering - An entirely web-based client, for easy use over SSH! ## Installation You can install `viser` with `pip`: ```bash pip install viser ``` To include example dependencies: ```bash pip install viser[examples] ``` After an example script is running, you can connect by navigating to the printed URL (default: `http://localhost:8080`). .. toctree:: :caption: Notes :hidden: :maxdepth: 1 :titlesonly: ./conventions.md ./development.md .. toctree:: :caption: API (Basics) :hidden: :maxdepth: 1 :titlesonly: ./server.md ./scene_api.md ./gui_api.md .. toctree:: :caption: API (Advanced) :hidden: :maxdepth: 1 :titlesonly: ./client_handles.md ./camera_handles.md ./gui_handles.md ./scene_handles.md ./events.md ./icons.md .. toctree:: :caption: API (Auxiliary) :hidden: :maxdepth: 1 :titlesonly: ./transforms.md ./infrastructure.md ./extras.md .. toctree:: :caption: Examples :hidden: :maxdepth: 1 :titlesonly: :glob: examples/* .. |build| image:: https://github.com/nerfstudio-project/viser/workflows/build/badge.svg :alt: Build status icon :target: https://github.com/nerfstudio-project/viser .. |mypy| image:: https://github.com/nerfstudio-project/viser/workflows/mypy/badge.svg?branch=main :alt: Mypy status icon :target: https://github.com/nerfstudio-project/viser .. |pyright| image:: https://github.com/nerfstudio-project/viser/workflows/pyright/badge.svg?branch=main :alt: Mypy status icon :target: https://github.com/nerfstudio-project/viser .. |typescript| image:: https://github.com/nerfstudio-project/viser/workflows/typescript-compile/badge.svg :alt: TypeScript status icon :target: https://github.com/nerfstudio-project/viser .. |versions| image:: https://img.shields.io/pypi/pyversions/viser :alt: Version icon :target: https://pypi.org/project/viser/ .. |nbsp| unicode:: 0xA0 :trim: ================================================ FILE: viser/docs/source/infrastructure.md ================================================ # Communication .. automodule:: viser.infra :show-inheritance: ================================================ FILE: viser/docs/source/scene_api.md ================================================ # Scene API .. autoclass:: viser.SceneApi :members: :undoc-members: :inherited-members: ================================================ FILE: viser/docs/source/scene_handles.md ================================================ # Scene Handles A handle is created for each object that is added to the scene. These can be used to read and set state, as well as detect clicks. When a scene node is added to a server (for example, via :func:`viser.ViserServer.add_frame()`), state is synchronized between all connected clients. When a scene node is added to a client (for example, via :func:`viser.ClientHandle.add_frame()`), state is local to a specific client. .. autoclass:: viser.SceneNodeHandle .. autoclass:: viser.CameraFrustumHandle .. autoclass:: viser.FrameHandle .. autoclass:: viser.BatchedAxesHandle .. autoclass:: viser.GlbHandle .. autoclass:: viser.Gui3dContainerHandle .. autoclass:: viser.ImageHandle .. autoclass:: viser.LabelHandle .. autoclass:: viser.MeshHandle .. autoclass:: viser.MeshSkinnedHandle .. autoclass:: viser.MeshSkinnedBoneHandle .. autoclass:: viser.PointCloudHandle .. autoclass:: viser.TransformControlsHandle .. autoclass:: viser.GaussianSplatHandle ================================================ FILE: viser/docs/source/server.md ================================================ # Viser Server .. autoclass:: viser.ViserServer ================================================ FILE: viser/docs/source/transforms.md ================================================ # Transforms .. automodule:: viser.transforms :show-inheritance: ================================================ FILE: viser/docs/update_example_docs.py ================================================ """Helper script for updating the auto-generated examples pages in the documentation.""" from __future__ import annotations import dataclasses import pathlib import shutil from typing import Iterable import m2r2 import tyro @dataclasses.dataclass class ExampleMetadata: index: str index_with_zero: str source: str title: str description: str @staticmethod def from_path(path: pathlib.Path) -> ExampleMetadata: # 01_functions -> 01, _, functions. index, _, _ = path.stem.partition("_") # 01 -> 1. index_with_zero = index index = str(int(index)) source = path.read_text().strip() docstring = source.split('"""')[1].strip() title, _, description = docstring.partition("\n") return ExampleMetadata( index=index, index_with_zero=index_with_zero, source=source.partition('"""')[2].partition('"""')[2].strip(), title=title, description=description.strip(), ) def get_example_paths(examples_dir: pathlib.Path) -> Iterable[pathlib.Path]: return filter( lambda p: not p.name.startswith("_"), sorted(examples_dir.glob("*.py")) ) REPO_ROOT = pathlib.Path(__file__).absolute().parent.parent def main( examples_dir: pathlib.Path = REPO_ROOT / "examples", sphinx_source_dir: pathlib.Path = REPO_ROOT / "docs" / "source", ) -> None: example_doc_dir = sphinx_source_dir / "examples" shutil.rmtree(example_doc_dir) example_doc_dir.mkdir() for path in get_example_paths(examples_dir): ex = ExampleMetadata.from_path(path) relative_dir = path.parent.relative_to(examples_dir) target_dir = example_doc_dir / relative_dir target_dir.mkdir(exist_ok=True, parents=True) (target_dir / f"{path.stem}.rst").write_text( "\n".join( [ ( ".. Comment: this file is automatically generated by" " `update_example_docs.py`." ), " It should not be modified manually.", "", f"{ex.title}", "==========================================", "", m2r2.convert(ex.description), "", "", ".. code-block:: python", " :linenos:", "", "", "\n".join( f" {line}".rstrip() for line in ex.source.split("\n") ), "", ] ) ) if __name__ == "__main__": tyro.cli(main, description=__doc__) ================================================ FILE: viser/examples/00_coordinate_frames.py ================================================ """Coordinate frames In this basic example, we visualize a set of coordinate frames. Naming for all scene nodes are hierarchical; /tree/branch, for example, is defined relative to /tree. """ import random import time import viser server = viser.ViserServer() while True: # Add some coordinate frames to the scene. These will be visualized in the viewer. server.scene.add_frame( "/tree", wxyz=(1.0, 0.0, 0.0, 0.0), position=(random.random() * 2.0, 2.0, 0.2), ) server.scene.add_frame( "/tree/branch", wxyz=(1.0, 0.0, 0.0, 0.0), position=(random.random() * 2.0, 2.0, 0.2), ) leaf = server.scene.add_frame( "/tree/branch/leaf", wxyz=(1.0, 0.0, 0.0, 0.0), position=(random.random() * 2.0, 2.0, 0.2), ) time.sleep(5.0) # Remove the leaf node from the scene. leaf.remove() time.sleep(0.5) ================================================ FILE: viser/examples/01_image.py ================================================ """Images Example for sending images to the viewer. We can send backgrond images to display behind the viewer (useful for visualizing NeRFs), or images to render as 3D textures. """ import time from pathlib import Path import imageio.v3 as iio import numpy as onp import viser def main() -> None: server = viser.ViserServer() # Add a background image. server.scene.set_background_image( iio.imread(Path(__file__).parent / "assets/Cal_logo.png"), format="png", ) # Add main image. server.scene.add_image( "/img", iio.imread(Path(__file__).parent / "assets/Cal_logo.png"), 4.0, 4.0, format="png", wxyz=(1.0, 0.0, 0.0, 0.0), position=(2.0, 2.0, 0.0), ) while True: server.scene.add_image( "/noise", onp.random.randint( 0, 256, size=(400, 400, 3), dtype=onp.uint8, ), 4.0, 4.0, format="jpeg", wxyz=(1.0, 0.0, 0.0, 0.0), position=(2.0, 2.0, -1e-2), ) time.sleep(0.2) if __name__ == "__main__": main() ================================================ FILE: viser/examples/02_gui.py ================================================ """GUI basics Examples of basic GUI elements that we can create, read from, and write to.""" import time import numpy as onp import viser def main() -> None: server = viser.ViserServer() # Add some common GUI elements: number inputs, sliders, vectors, checkboxes. with server.gui.add_folder("Read-only"): gui_counter = server.gui.add_number( "Counter", initial_value=0, disabled=True, ) gui_slider = server.gui.add_slider( "Slider", min=0, max=100, step=1, initial_value=0, disabled=True, ) gui_progress = server.gui.add_progress_bar(25, animated=True) with server.gui.add_folder("Editable"): gui_vector2 = server.gui.add_vector2( "Position", initial_value=(0.0, 0.0), step=0.1, ) gui_vector3 = server.gui.add_vector3( "Size", initial_value=(1.0, 1.0, 1.0), step=0.25, ) with server.gui.add_folder("Text toggle"): gui_checkbox_hide = server.gui.add_checkbox( "Hide", initial_value=False, ) gui_text = server.gui.add_text( "Text", initial_value="Hello world", ) gui_button = server.gui.add_button("Button") gui_checkbox_disable = server.gui.add_checkbox( "Disable", initial_value=False, ) gui_rgb = server.gui.add_rgb( "Color", initial_value=(255, 255, 0), ) gui_multi_slider = server.gui.add_multi_slider( "Multi slider", min=0, max=100, step=1, initial_value=(0, 30, 100), marks=((0, "0"), (50, "5"), (70, "7"), 99), ) gui_slider_positions = server.gui.add_slider( "# sliders", min=0, max=10, step=1, initial_value=3, marks=((0, "0"), (5, "5"), (7, "7"), 10), ) gui_upload_button = server.gui.add_upload_button( "Upload", icon=viser.Icon.UPLOAD ) @gui_upload_button.on_upload def _(_) -> None: """Callback for when a file is uploaded.""" file = gui_upload_button.value print(file.name, len(file.content), "bytes") # Pre-generate a point cloud to send. point_positions = onp.random.uniform(low=-1.0, high=1.0, size=(5000, 3)) color_coeffs = onp.random.uniform(0.4, 1.0, size=(point_positions.shape[0])) counter = 0 while True: # We can set the value of an input to a particular value. Changes are # automatically reflected in connected clients. gui_counter.value = counter gui_slider.value = counter % 100 # We can set the position of a scene node with `.position`, and read the value # of a gui element with `.value`. Changes are automatically reflected in # connected clients. server.scene.add_point_cloud( "/point_cloud", points=point_positions * onp.array(gui_vector3.value, dtype=onp.float32), colors=( onp.tile(gui_rgb.value, point_positions.shape[0]).reshape((-1, 3)) * color_coeffs[:, None] ).astype(onp.uint8), position=gui_vector2.value + (0,), point_shape="circle", ) gui_progress.value = float((counter % 100)) # We can use `.visible` and `.disabled` to toggle GUI elements. gui_text.visible = not gui_checkbox_hide.value gui_button.visible = not gui_checkbox_hide.value gui_rgb.disabled = gui_checkbox_disable.value gui_button.disabled = gui_checkbox_disable.value gui_upload_button.disabled = gui_checkbox_disable.value # Update the number of handles in the multi-slider. if gui_slider_positions.value != len(gui_multi_slider.value): gui_multi_slider.value = onp.linspace( 0, 100, gui_slider_positions.value, dtype=onp.int64 ) counter += 1 time.sleep(0.01) if __name__ == "__main__": main() ================================================ FILE: viser/examples/03_gui_callbacks.py ================================================ """GUI callbacks Asynchronous usage of GUI elements: we can attach callbacks that are called as soon as we get updates.""" import time import numpy as onp from typing_extensions import assert_never import viser def main() -> None: server = viser.ViserServer() gui_reset_scene = server.gui.add_button("Reset Scene") gui_plane = server.gui.add_dropdown( "Grid plane", ("xz", "xy", "yx", "yz", "zx", "zy") ) def update_plane() -> None: server.scene.add_grid( "/grid", width=10.0, height=20.0, width_segments=10, height_segments=20, plane=gui_plane.value, ) gui_plane.on_update(lambda _: update_plane()) with server.gui.add_folder("Control"): gui_show_frame = server.gui.add_checkbox("Show Frame", initial_value=True) gui_show_everything = server.gui.add_checkbox( "Show Everything", initial_value=True ) gui_axis = server.gui.add_dropdown("Axis", ("x", "y", "z")) gui_include_z = server.gui.add_checkbox("Z in dropdown", initial_value=True) @gui_include_z.on_update def _(_) -> None: gui_axis.options = ("x", "y", "z") if gui_include_z.value else ("x", "y") with server.gui.add_folder("Sliders"): gui_location = server.gui.add_slider( "Location", min=-5.0, max=5.0, step=0.05, initial_value=0.0 ) gui_num_points = server.gui.add_slider( "# Points", min=1000, max=200_000, step=1000, initial_value=10_000 ) def draw_frame() -> None: axis = gui_axis.value if axis == "x": pos = (gui_location.value, 0.0, 0.0) elif axis == "y": pos = (0.0, gui_location.value, 0.0) elif axis == "z": pos = (0.0, 0.0, gui_location.value) else: assert_never(axis) server.scene.add_frame( "/frame", wxyz=(1.0, 0.0, 0.0, 0.0), position=pos, show_axes=gui_show_frame.value, axes_length=5.0, ) def draw_points() -> None: num_points = gui_num_points.value server.scene.add_point_cloud( "/frame/point_cloud", points=onp.random.normal(size=(num_points, 3)), colors=onp.random.randint(0, 256, size=(num_points, 3)), ) # We can (optionally) also attach callbacks! # Here, we update the point clouds + frames whenever any of the GUI items are updated. gui_show_frame.on_update(lambda _: draw_frame()) gui_show_everything.on_update( lambda _: server.scene.set_global_visibility(gui_show_everything.value) ) gui_axis.on_update(lambda _: draw_frame()) gui_location.on_update(lambda _: draw_frame()) gui_num_points.on_update(lambda _: draw_points()) @gui_reset_scene.on_click def _(_) -> None: """Reset the scene when the reset button is clicked.""" gui_show_frame.value = True gui_location.value = 0.0 gui_axis.value = "x" gui_num_points.value = 10_000 draw_frame() draw_points() # Finally, let's add the initial frame + point cloud and just loop infinitely. :) update_plane() draw_frame() draw_points() while True: time.sleep(1.0) if __name__ == "__main__": main() ================================================ FILE: viser/examples/04_camera_poses.py ================================================ """Camera poses Example showing how we can detect new clients and read camera poses from them. """ import time import viser server = viser.ViserServer() server.scene.world_axes.visible = True @server.on_client_connect def _(client: viser.ClientHandle) -> None: print("new client!") # This will run whenever we get a new camera! @client.camera.on_update def _(_: viser.CameraHandle) -> None: print(f"New camera on client {client.client_id}!") # Show the client ID in the GUI. gui_info = client.gui.add_text("Client ID", initial_value=str(client.client_id)) gui_info.disabled = True while True: # Get all currently connected clients. clients = server.get_clients() print("Connected client IDs", clients.keys()) for id, client in clients.items(): print(f"Camera pose for client {id}") print(f"\twxyz: {client.camera.wxyz}") print(f"\tposition: {client.camera.position}") print(f"\tfov: {client.camera.fov}") print(f"\taspect: {client.camera.aspect}") print(f"\tlast update: {client.camera.update_timestamp}") time.sleep(2.0) ================================================ FILE: viser/examples/05_camera_commands.py ================================================ """Camera commands In addition to reads, camera parameters also support writes. These are synced to the corresponding client automatically. """ import time import numpy as onp import viser import viser.transforms as tf server = viser.ViserServer() num_frames = 20 @server.on_client_connect def _(client: viser.ClientHandle) -> None: """For each client that connects, we create a set of random frames + a click handler for each frame. When a frame is clicked, we move the camera to the corresponding frame. """ rng = onp.random.default_rng(0) def make_frame(i: int) -> None: # Sample a random orientation + position. wxyz = rng.normal(size=4) wxyz /= onp.linalg.norm(wxyz) position = rng.uniform(-3.0, 3.0, size=(3,)) # Create a coordinate frame and label. frame = client.scene.add_frame(f"/frame_{i}", wxyz=wxyz, position=position) client.scene.add_label(f"/frame_{i}/label", text=f"Frame {i}") # Move the camera when we click a frame. @frame.on_click def _(_): T_world_current = tf.SE3.from_rotation_and_translation( tf.SO3(client.camera.wxyz), client.camera.position ) T_world_target = tf.SE3.from_rotation_and_translation( tf.SO3(frame.wxyz), frame.position ) @ tf.SE3.from_translation(onp.array([0.0, 0.0, -0.5])) T_current_target = T_world_current.inverse() @ T_world_target for j in range(20): T_world_set = T_world_current @ tf.SE3.exp( T_current_target.log() * j / 19.0 ) # We can atomically set the orientation and the position of the camera # together to prevent jitter that might happen if one was set before the # other. with client.atomic(): client.camera.wxyz = T_world_set.rotation().wxyz client.camera.position = T_world_set.translation() client.flush() # Optional! time.sleep(1.0 / 60.0) # Mouse interactions should orbit around the frame origin. client.camera.look_at = frame.position for i in range(num_frames): make_frame(i) while True: time.sleep(1.0) ================================================ FILE: viser/examples/06_mesh.py ================================================ """Meshes Visualize a mesh. To get the demo data, see `./assets/download_dragon_mesh.sh`. """ import time from pathlib import Path import numpy as onp import trimesh import viser import viser.transforms as tf mesh = trimesh.load_mesh(str(Path(__file__).parent / "assets/dragon.obj")) assert isinstance(mesh, trimesh.Trimesh) mesh.apply_scale(0.05) vertices = mesh.vertices faces = mesh.faces print(f"Loaded mesh with {vertices.shape} vertices, {faces.shape} faces") server = viser.ViserServer() server.scene.add_mesh_simple( name="/simple", vertices=vertices, faces=faces, wxyz=tf.SO3.from_x_radians(onp.pi / 2).wxyz, position=(0.0, 0.0, 0.0), ) server.scene.add_mesh_trimesh( name="/trimesh", mesh=mesh.smoothed(), wxyz=tf.SO3.from_x_radians(onp.pi / 2).wxyz, position=(0.0, 5.0, 0.0), ) while True: time.sleep(10.0) ================================================ FILE: viser/examples/07_record3d_visualizer.py ================================================ """Record3D visualizer Parse and stream record3d captures. To get the demo data, see `./assets/download_record3d_dance.sh`. """ import time from pathlib import Path import numpy as onp import tyro from tqdm.auto import tqdm import viser import viser.extras import viser.transforms as tf def main( data_path: Path = Path(__file__).parent / "record3d_dance", downsample_factor: int = 4, max_frames: int = 100, share: bool = False, ) -> None: server = viser.ViserServer() if share: server.request_share_url() print("Loading frames!") loader = viser.extras.Record3dLoader(data_path) num_frames = min(max_frames, loader.num_frames()) # Add playback UI. with server.gui.add_folder("Playback"): gui_timestep = server.gui.add_slider( "Timestep", min=0, max=num_frames - 1, step=1, initial_value=0, disabled=True, ) gui_next_frame = server.gui.add_button("Next Frame", disabled=True) gui_prev_frame = server.gui.add_button("Prev Frame", disabled=True) gui_playing = server.gui.add_checkbox("Playing", True) gui_framerate = server.gui.add_slider( "FPS", min=1, max=60, step=0.1, initial_value=loader.fps ) gui_framerate_options = server.gui.add_button_group( "FPS options", ("10", "20", "30", "60") ) # Frame step buttons. @gui_next_frame.on_click def _(_) -> None: gui_timestep.value = (gui_timestep.value + 1) % num_frames @gui_prev_frame.on_click def _(_) -> None: gui_timestep.value = (gui_timestep.value - 1) % num_frames # Disable frame controls when we're playing. @gui_playing.on_update def _(_) -> None: gui_timestep.disabled = gui_playing.value gui_next_frame.disabled = gui_playing.value gui_prev_frame.disabled = gui_playing.value # Set the framerate when we click one of the options. @gui_framerate_options.on_click def _(_) -> None: gui_framerate.value = int(gui_framerate_options.value) prev_timestep = gui_timestep.value # Toggle frame visibility when the timestep slider changes. @gui_timestep.on_update def _(_) -> None: nonlocal prev_timestep current_timestep = gui_timestep.value with server.atomic(): frame_nodes[current_timestep].visible = True frame_nodes[prev_timestep].visible = False prev_timestep = current_timestep server.flush() # Optional! # Load in frames. server.scene.add_frame( "/frames", wxyz=tf.SO3.exp(onp.array([onp.pi / 2.0, 0.0, 0.0])).wxyz, position=(0, 0, 0), show_axes=False, ) frame_nodes: list[viser.FrameHandle] = [] for i in tqdm(range(num_frames)): frame = loader.get_frame(i) position, color = frame.get_point_cloud(downsample_factor) # Add base frame. frame_nodes.append(server.scene.add_frame(f"/frames/t{i}", show_axes=False)) # Place the point cloud in the frame. server.scene.add_point_cloud( name=f"/frames/t{i}/point_cloud", points=position, colors=color, point_size=0.01, point_shape="rounded", ) # Place the frustum. fov = 2 * onp.arctan2(frame.rgb.shape[0] / 2, frame.K[0, 0]) aspect = frame.rgb.shape[1] / frame.rgb.shape[0] server.scene.add_camera_frustum( f"/frames/t{i}/frustum", fov=fov, aspect=aspect, scale=0.15, image=frame.rgb[::downsample_factor, ::downsample_factor], wxyz=tf.SO3.from_matrix(frame.T_world_camera[:3, :3]).wxyz, position=frame.T_world_camera[:3, 3], ) # Add some axes. server.scene.add_frame( f"/frames/t{i}/frustum/axes", axes_length=0.05, axes_radius=0.005, ) # Hide all but the current frame. for i, frame_node in enumerate(frame_nodes): frame_node.visible = i == gui_timestep.value # Playback update loop. prev_timestep = gui_timestep.value while True: if gui_playing.value: gui_timestep.value = (gui_timestep.value + 1) % num_frames time.sleep(1.0 / gui_framerate.value) if __name__ == "__main__": tyro.cli(main) ================================================ FILE: viser/examples/08_smpl_visualizer.py ================================================ """SMPL model visualizer Visualizer for SMPL human body models. Requires a .npz model file. See here for download instructions: https://github.com/vchoutas/smplx?tab=readme-ov-file#downloading-the-model """ from __future__ import annotations import time from dataclasses import dataclass from pathlib import Path import numpy as np import numpy as onp import tyro import viser import viser.transforms as tf @dataclass(frozen=True) class SmplOutputs: vertices: np.ndarray faces: np.ndarray T_world_joint: np.ndarray # (num_joints, 4, 4) T_parent_joint: np.ndarray # (num_joints, 4, 4) class SmplHelper: """Helper for models in the SMPL family, implemented in numpy.""" def __init__(self, model_path: Path) -> None: assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!" body_dict = dict(**onp.load(model_path, allow_pickle=True)) self._J_regressor = body_dict["J_regressor"] self._weights = body_dict["weights"] self._v_template = body_dict["v_template"] self._posedirs = body_dict["posedirs"] self._shapedirs = body_dict["shapedirs"] self._faces = body_dict["f"] self.num_joints: int = self._weights.shape[-1] self.num_betas: int = self._shapedirs.shape[-1] self.parent_idx: np.ndarray = body_dict["kintree_table"][0] def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutputs: # Get shaped vertices + joint positions, when all local poses are identity. v_tpose = self._v_template + np.einsum("vxb,b->vx", self._shapedirs, betas) j_tpose = np.einsum("jv,vx->jx", self._J_regressor, v_tpose) # Local SE(3) transforms. T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4) T_parent_joint[:, :3, :3] = joint_rotmats T_parent_joint[0, :3, 3] = j_tpose[0] T_parent_joint[1:, :3, 3] = j_tpose[1:] - j_tpose[self.parent_idx[1:]] # Forward kinematics. T_world_joint = T_parent_joint.copy() for i in range(1, self.num_joints): T_world_joint[i] = T_world_joint[self.parent_idx[i]] @ T_parent_joint[i] # Linear blend skinning. pose_delta = (joint_rotmats[1:, ...] - np.eye(3)).flatten() v_blend = v_tpose + np.einsum("byn,n->by", self._posedirs, pose_delta) v_delta = np.ones((v_blend.shape[0], self.num_joints, 4)) v_delta[:, :, :3] = v_blend[:, None, :] - j_tpose[None, :, :] v_posed = np.einsum( "jxy,vj,vjy->vx", T_world_joint[:, :3, :], self._weights, v_delta ) return SmplOutputs(v_posed, self._faces, T_world_joint, T_parent_joint) def main(model_path: Path) -> None: server = viser.ViserServer() server.scene.set_up_direction("+y") server.gui.configure_theme(control_layout="collapsible") # Main loop. We'll read pose/shape from the GUI elements, compute the mesh, # and then send the updated mesh in a loop. model = SmplHelper(model_path) gui_elements = make_gui_elements( server, num_betas=model.num_betas, num_joints=model.num_joints, parent_idx=model.parent_idx, ) while True: # Do nothing if no change. time.sleep(0.02) if not gui_elements.changed: continue gui_elements.changed = False # Compute SMPL outputs. smpl_outputs = model.get_outputs( betas=np.array([x.value for x in gui_elements.gui_betas]), joint_rotmats=tf.SO3.exp( # (num_joints, 3) np.array([x.value for x in gui_elements.gui_joints]) ).as_matrix(), ) server.scene.add_mesh_simple( "/human", smpl_outputs.vertices, smpl_outputs.faces, wireframe=gui_elements.gui_wireframe.value, color=gui_elements.gui_rgb.value, ) # Match transform control gizmos to joint positions. for i, control in enumerate(gui_elements.transform_controls): control.position = smpl_outputs.T_parent_joint[i, :3, 3] @dataclass class GuiElements: """Structure containing handles for reading from GUI elements.""" gui_rgb: viser.GuiInputHandle[tuple[int, int, int]] gui_wireframe: viser.GuiInputHandle[bool] gui_betas: list[viser.GuiInputHandle[float]] gui_joints: list[viser.GuiInputHandle[tuple[float, float, float]]] transform_controls: list[viser.TransformControlsHandle] changed: bool """This flag will be flipped to True whenever the mesh needs to be re-generated.""" def make_gui_elements( server: viser.ViserServer, num_betas: int, num_joints: int, parent_idx: np.ndarray, ) -> GuiElements: """Make GUI elements for interacting with the model.""" tab_group = server.gui.add_tab_group() def set_changed(_) -> None: out.changed = True # out is define later! # GUI elements: mesh settings + visibility. with tab_group.add_tab("View", viser.Icon.VIEWFINDER): gui_rgb = server.gui.add_rgb("Color", initial_value=(90, 200, 255)) gui_wireframe = server.gui.add_checkbox("Wireframe", initial_value=False) gui_show_controls = server.gui.add_checkbox("Handles", initial_value=False) gui_rgb.on_update(set_changed) gui_wireframe.on_update(set_changed) @gui_show_controls.on_update def _(_): for control in transform_controls: control.visible = gui_show_controls.value # GUI elements: shape parameters. with tab_group.add_tab("Shape", viser.Icon.BOX): gui_reset_shape = server.gui.add_button("Reset Shape") gui_random_shape = server.gui.add_button("Random Shape") @gui_reset_shape.on_click def _(_): for beta in gui_betas: beta.value = 0.0 @gui_random_shape.on_click def _(_): for beta in gui_betas: beta.value = onp.random.normal(loc=0.0, scale=1.0) gui_betas = [] for i in range(num_betas): beta = server.gui.add_slider( f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0 ) gui_betas.append(beta) beta.on_update(set_changed) # GUI elements: joint angles. with tab_group.add_tab("Joints", viser.Icon.ANGLE): gui_reset_joints = server.gui.add_button("Reset Joints") gui_random_joints = server.gui.add_button("Random Joints") @gui_reset_joints.on_click def _(_): for joint in gui_joints: joint.value = (0.0, 0.0, 0.0) @gui_random_joints.on_click def _(_): for joint in gui_joints: # It's hard to uniformly sample orientations directly in so(3), so we # first sample on S^3 and then convert. quat = onp.random.normal(loc=0.0, scale=1.0, size=(4,)) quat /= onp.linalg.norm(quat) joint.value = tf.SO3(wxyz=quat).log() gui_joints: list[viser.GuiInputHandle[tuple[float, float, float]]] = [] for i in range(num_joints): gui_joint = server.gui.add_vector3( label=f"Joint {i}", initial_value=(0.0, 0.0, 0.0), step=0.05, ) gui_joints.append(gui_joint) def set_callback_in_closure(i: int) -> None: @gui_joint.on_update def _(_): transform_controls[i].wxyz = tf.SO3.exp( np.array(gui_joints[i].value) ).wxyz out.changed = True set_callback_in_closure(i) # Transform control gizmos on joints. transform_controls: list[viser.TransformControlsHandle] = [] prefixed_joint_names = [] # Joint names, but prefixed with parents. for i in range(num_joints): prefixed_joint_name = f"joint_{i}" if i > 0: prefixed_joint_name = ( prefixed_joint_names[parent_idx[i]] + "/" + prefixed_joint_name ) prefixed_joint_names.append(prefixed_joint_name) controls = server.scene.add_transform_controls( f"/smpl/{prefixed_joint_name}", depth_test=False, scale=0.2 * (0.75 ** prefixed_joint_name.count("/")), disable_axes=True, disable_sliders=True, visible=gui_show_controls.value, ) transform_controls.append(controls) def set_callback_in_closure(i: int) -> None: @controls.on_update def _(_) -> None: axisangle = tf.SO3(transform_controls[i].wxyz).log() gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2]) set_callback_in_closure(i) out = GuiElements( gui_rgb, gui_wireframe, gui_betas, gui_joints, transform_controls=transform_controls, changed=True, ) return out if __name__ == "__main__": tyro.cli(main, description=__doc__) ================================================ FILE: viser/examples/09_urdf_visualizer.py ================================================ """Robot URDF visualizer Requires yourdfpy and robot_descriptions. Any URDF supported by yourdfpy should work. - https://github.com/robot-descriptions/robot_descriptions.py - https://github.com/clemense/yourdfpy The :class:`viser.extras.ViserUrdf` is a lightweight interface between yourdfpy and viser. It can also take a path to a local URDF file as input. """ from __future__ import annotations import time from typing import Literal import numpy as onp import tyro from robot_descriptions.loaders.yourdfpy import load_robot_description import viser from viser.extras import ViserUrdf def create_robot_control_sliders( server: viser.ViserServer, viser_urdf: ViserUrdf ) -> tuple[list[viser.GuiInputHandle[float]], list[float]]: """Create slider for each joint of the robot. We also update robot model when slider moves.""" slider_handles: list[viser.GuiInputHandle[float]] = [] initial_config: list[float] = [] for joint_name, ( lower, upper, ) in viser_urdf.get_actuated_joint_limits().items(): lower = lower if lower is not None else -onp.pi upper = upper if upper is not None else onp.pi initial_pos = 0.0 if lower < 0 and upper > 0 else (lower + upper) / 2.0 slider = server.gui.add_slider( label=joint_name, min=lower, max=upper, step=1e-3, initial_value=initial_pos, ) slider.on_update( # When sliders move, we update the URDF configuration. lambda _: viser_urdf.update_cfg( onp.array([slider.value for slider in slider_handles]) ) ) slider_handles.append(slider) initial_config.append(initial_pos) return slider_handles, initial_config def main( robot_type: Literal[ "panda", "ur10", "cassie", "allegro_hand", "barrett_hand", "robotiq_2f85", "atlas_drc", "g1", "h1", "anymal_c", "go2", ] = "panda", ) -> None: # Start viser server. server = viser.ViserServer() # Load URDF. # # This takes either a yourdfpy.URDF object or a path to a .urdf file. viser_urdf = ViserUrdf( server, urdf_or_path=load_robot_description(robot_type + "_description"), ) # Create sliders in GUI that help us move the robot joints. with server.gui.add_folder("Joint position control"): (slider_handles, initial_config) = create_robot_control_sliders( server, viser_urdf ) # Set initial robot configuration. viser_urdf.update_cfg(onp.array(initial_config)) # Create joint reset button. reset_button = server.gui.add_button("Reset") @reset_button.on_click def _(_): for s, init_q in zip(slider_handles, initial_config): s.value = init_q # Sleep forever. while True: time.sleep(10.0) if __name__ == "__main__": tyro.cli(main) ================================================ FILE: viser/examples/10_realsense.py ================================================ """RealSense visualizer Connect to a RealSense camera, then visualize RGB-D readings as a point clouds. Requires pyrealsense2. """ from __future__ import annotations import contextlib import numpy as np import numpy.typing as npt import pyrealsense2 as rs # type: ignore from tqdm.auto import tqdm import viser @contextlib.contextmanager def realsense_pipeline(fps: int = 30): """Context manager that yields a RealSense pipeline.""" # Configure depth and color streams. pipeline = rs.pipeline() # type: ignore config = rs.config() # type: ignore pipeline_wrapper = rs.pipeline_wrapper(pipeline) # type: ignore config.resolve(pipeline_wrapper) config.enable_stream(rs.stream.depth, rs.format.z16, fps) # type: ignore config.enable_stream(rs.stream.color, rs.format.rgb8, fps) # type: ignore # Start streaming. pipeline.start(config) yield pipeline # Close pipeline when done. pipeline.close() def point_cloud_arrays_from_frames( depth_frame, color_frame ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.uint8]]: """Maps realsense frames to two arrays. Returns: - A point position array: (N, 3) float32. - A point color array: (N, 3) uint8. """ # Processing blocks. Could be tuned. point_cloud = rs.pointcloud() # type: ignore decimate = rs.decimation_filter() # type: ignore decimate.set_option(rs.option.filter_magnitude, 3) # type: ignore # Downsample depth frame. depth_frame = decimate.process(depth_frame) # Map texture and calculate points from frames. Uses frame intrinsics. point_cloud.map_to(color_frame) points = point_cloud.calculate(depth_frame) # Get color coordinates. texture_uv = ( np.asanyarray(points.get_texture_coordinates()) .view(np.float32) .reshape((-1, 2)) ) color_image = np.asanyarray(color_frame.get_data()) color_h, color_w, _ = color_image.shape # Note: for points that aren't in the view of our RGB camera, we currently clamp to # the closes available RGB pixel. We could also just remove these points. texture_uv = texture_uv.clip(0.0, 1.0) # Get positions and colors. positions = np.asanyarray(points.get_vertices()).view(np.float32) positions = positions.reshape((-1, 3)) colors = color_image[ (texture_uv[:, 1] * (color_h - 1.0)).astype(np.int32), (texture_uv[:, 0] * (color_w - 1.0)).astype(np.int32), :, ] N = positions.shape[0] assert positions.shape == (N, 3) assert positions.dtype == np.float32 assert colors.shape == (N, 3) assert colors.dtype == np.uint8 return positions, colors def main(): # Start visualization server. server = viser.ViserServer() with realsense_pipeline() as pipeline: for i in tqdm(range(10000000)): # Wait for a coherent pair of frames: depth and color frames = pipeline.wait_for_frames() depth_frame = frames.get_depth_frame() color_frame = frames.get_color_frame() # Compute point cloud from frames. positions, colors = point_cloud_arrays_from_frames(depth_frame, color_frame) R = np.array( [ [1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, -1.0, 0.0], ], dtype=np.float32, ) positions = positions @ R.T # Visualize. server.scene.add_point_cloud( "/realsense", points=positions * 10.0, colors=colors, point_size=0.1, ) if __name__ == "__main__": main() ================================================ FILE: viser/examples/11_colmap_visualizer.py ================================================ """COLMAP visualizer Visualize COLMAP sparse reconstruction outputs. To get demo data, see `./assets/download_colmap_garden.sh`. """ import random import time from pathlib import Path import imageio.v3 as iio import numpy as onp import tyro from tqdm.auto import tqdm import viser import viser.transforms as tf from viser.extras.colmap import ( read_cameras_binary, read_images_binary, read_points3d_binary, ) def main( colmap_path: Path = Path(__file__).parent / "assets/colmap_garden/sparse/0", images_path: Path = Path(__file__).parent / "assets/colmap_garden/images_8", downsample_factor: int = 2, ) -> None: """Visualize COLMAP sparse reconstruction outputs. Args: colmap_path: Path to the COLMAP reconstruction directory. images_path: Path to the COLMAP images directory. downsample_factor: Downsample factor for the images. """ server = viser.ViserServer() server.gui.configure_theme(titlebar_content=None, control_layout="collapsible") # Load the colmap info. cameras = read_cameras_binary(colmap_path / "cameras.bin") images = read_images_binary(colmap_path / "images.bin") points3d = read_points3d_binary(colmap_path / "points3D.bin") gui_reset_up = server.gui.add_button( "Reset up direction", hint="Set the camera control 'up' direction to the current camera's 'up'.", ) @gui_reset_up.on_click def _(event: viser.GuiEvent) -> None: client = event.client assert client is not None client.camera.up_direction = tf.SO3(client.camera.wxyz) @ onp.array( [0.0, -1.0, 0.0] ) gui_points = server.gui.add_slider( "Max points", min=1, max=len(points3d), step=1, initial_value=min(len(points3d), 50_000), ) gui_frames = server.gui.add_slider( "Max frames", min=1, max=len(images), step=1, initial_value=min(len(images), 100), ) gui_point_size = server.gui.add_number("Point size", initial_value=0.05) def visualize_colmap() -> None: """Send all COLMAP elements to viser for visualization. This could be optimized a ton!""" # Set the point cloud. points = onp.array([points3d[p_id].xyz for p_id in points3d]) colors = onp.array([points3d[p_id].rgb for p_id in points3d]) points_selection = onp.random.choice( points.shape[0], gui_points.value, replace=False ) points = points[points_selection] colors = colors[points_selection] server.scene.add_point_cloud( name="/colmap/pcd", points=points, colors=colors, point_size=gui_point_size.value, ) # Interpret the images and cameras. img_ids = [im.id for im in images.values()] random.shuffle(img_ids) img_ids = sorted(img_ids[: gui_frames.value]) def attach_callback( frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle ) -> None: @frustum.on_click def _(_) -> None: for client in server.get_clients().values(): client.camera.wxyz = frame.wxyz client.camera.position = frame.position for img_id in tqdm(img_ids): img = images[img_id] cam = cameras[img.camera_id] # Skip images that don't exist. image_filename = images_path / img.name if not image_filename.exists(): continue T_world_camera = tf.SE3.from_rotation_and_translation( tf.SO3(img.qvec), img.tvec ).inverse() frame = server.scene.add_frame( f"/colmap/frame_{img_id}", wxyz=T_world_camera.rotation().wxyz, position=T_world_camera.translation(), axes_length=0.1, axes_radius=0.005, ) # For pinhole cameras, cam.params will be (fx, fy, cx, cy). if cam.model != "PINHOLE": print(f"Expected pinhole camera, but got {cam.model}") H, W = cam.height, cam.width fy = cam.params[1] image = iio.imread(image_filename) image = image[::downsample_factor, ::downsample_factor] frustum = server.scene.add_camera_frustum( f"/colmap/frame_{img_id}/frustum", fov=2 * onp.arctan2(H / 2, fy), aspect=W / H, scale=0.15, image=image, ) attach_callback(frustum, frame) need_update = True @gui_points.on_update def _(_) -> None: nonlocal need_update need_update = True @gui_frames.on_update def _(_) -> None: nonlocal need_update need_update = True @gui_point_size.on_update def _(_) -> None: nonlocal need_update need_update = True while True: if need_update: need_update = False server.scene.reset() visualize_colmap() time.sleep(1e-3) if __name__ == "__main__": tyro.cli(main) ================================================ FILE: viser/examples/12_click_meshes.py ================================================ """Mesh click events Click on meshes to select them. The index of the last clicked mesh is displayed in the GUI. """ import time import matplotlib import viser def main() -> None: grid_shape = (4, 5) server = viser.ViserServer() with server.gui.add_folder("Last clicked"): x_value = server.gui.add_number( label="x", initial_value=0, disabled=True, hint="x coordinate of the last clicked mesh", ) y_value = server.gui.add_number( label="y", initial_value=0, disabled=True, hint="y coordinate of the last clicked mesh", ) def add_swappable_mesh(i: int, j: int) -> None: """Simple callback that swaps between: - a gray box - a colored box - a colored sphere Color is chosen based on the position (i, j) of the mesh in the grid. """ colormap = matplotlib.colormaps["tab20"] def create_mesh(counter: int) -> None: if counter == 0: color = (0.8, 0.8, 0.8) else: index = (i * grid_shape[1] + j) / (grid_shape[0] * grid_shape[1]) color = colormap(index)[:3] if counter in (0, 1): handle = server.scene.add_box( name=f"/sphere_{i}_{j}", position=(i, j, 0.0), color=color, dimensions=(0.5, 0.5, 0.5), ) else: handle = server.scene.add_icosphere( name=f"/sphere_{i}_{j}", radius=0.4, color=color, position=(i, j, 0.0), ) @handle.on_click def _(_) -> None: x_value.value = i y_value.value = j # The new mesh will replace the old one because the names # /sphere_{i}_{j} are the same. create_mesh((counter + 1) % 3) create_mesh(0) for i in range(grid_shape[0]): for j in range(grid_shape[1]): add_swappable_mesh(i, j) while True: time.sleep(10.0) if __name__ == "__main__": main() ================================================ FILE: viser/examples/13_theming.py ================================================ """Theming Viser includes support for light theming. """ import time import viser from viser.theme import TitlebarButton, TitlebarConfig, TitlebarImage def main(): server = viser.ViserServer(label="Viser Theming") buttons = ( TitlebarButton( text="Getting Started", icon=None, href="https://nerf.studio", ), TitlebarButton( text="Github", icon="GitHub", href="https://github.com/nerfstudio-project/nerfstudio", ), TitlebarButton( text="Documentation", icon="Description", href="https://docs.nerf.studio", ), ) image = TitlebarImage( image_url_light="https://docs.nerf.studio/_static/imgs/logo.png", image_url_dark="https://docs.nerf.studio/_static/imgs/logo-dark.png", image_alt="NerfStudio Logo", href="https://docs.nerf.studio/", ) titlebar_theme = TitlebarConfig(buttons=buttons, image=image) server.gui.add_markdown( "Viser includes support for light theming via the `.configure_theme()` method." ) gui_theme_code = server.gui.add_markdown("no theme applied yet") # GUI elements for controllable values. titlebar = server.gui.add_checkbox("Titlebar", initial_value=True) dark_mode = server.gui.add_checkbox("Dark mode", initial_value=True) show_logo = server.gui.add_checkbox("Show logo", initial_value=True) show_share_button = server.gui.add_checkbox("Show share button", initial_value=True) brand_color = server.gui.add_rgb("Brand color", (230, 180, 30)) control_layout = server.gui.add_dropdown( "Control layout", ("floating", "fixed", "collapsible") ) control_width = server.gui.add_dropdown( "Control width", ("small", "medium", "large"), initial_value="medium" ) synchronize = server.gui.add_button("Apply theme", icon=viser.Icon.CHECK) def synchronize_theme() -> None: server.gui.configure_theme( titlebar_content=titlebar_theme if titlebar.value else None, control_layout=control_layout.value, control_width=control_width.value, dark_mode=dark_mode.value, show_logo=show_logo.value, show_share_button=show_share_button.value, brand_color=brand_color.value, ) gui_theme_code.content = f""" ### Current applied theme ``` server.gui.configure_theme( titlebar_content={"titlebar_content" if titlebar.value else None}, control_layout="{control_layout.value}", control_width="{control_width.value}", dark_mode={dark_mode.value}, show_logo={show_logo.value}, show_share_button={show_share_button.value}, brand_color={brand_color.value}, ) ``` """ synchronize.on_click(lambda _: synchronize_theme()) synchronize_theme() while True: time.sleep(10.0) # main() if __name__ == "__main__": main() ================================================ FILE: viser/examples/14_markdown.py ================================================ """Markdown demonstration Viser GUI has MDX 2 support. """ import time from pathlib import Path import viser server = viser.ViserServer() server.scene.world_axes.visible = True markdown_counter = server.gui.add_markdown("Counter: 0") here = Path(__file__).absolute().parent button = server.gui.add_button("Remove blurb") checkbox = server.gui.add_checkbox("Visibility", initial_value=True) markdown_source = (here / "./assets/mdx_example.mdx").read_text() markdown_blurb = server.gui.add_markdown( content=markdown_source, image_root=here, ) @button.on_click def _(_): markdown_blurb.remove() @checkbox.on_update def _(_): markdown_blurb.visible = checkbox.value counter = 0 while True: markdown_counter.content = f"Counter: {counter}" counter += 1 time.sleep(0.1) ================================================ FILE: viser/examples/15_gui_in_scene.py ================================================ """3D GUI elements `add_3d_gui_container()` allows standard GUI elements to be incorporated directly into a 3D scene. In this example, we click on coordinate frames to show actions that can be performed on them. """ import time from typing import Optional import numpy as onp import viser import viser.transforms as tf server = viser.ViserServer() server.gui.configure_theme(dark_mode=True) num_frames = 20 @server.on_client_connect def _(client: viser.ClientHandle) -> None: """For each client that connects, we create a set of random frames + a click handler for each frame. When a frame is clicked, we display a 3D gui node. """ rng = onp.random.default_rng(0) displayed_3d_container: Optional[viser.Gui3dContainerHandle] = None def make_frame(i: int) -> None: # Sample a random orientation + position. wxyz = rng.normal(size=4) wxyz /= onp.linalg.norm(wxyz) position = rng.uniform(-3.0, 3.0, size=(3,)) # Create a coordinate frame and label. frame = client.scene.add_frame(f"/frame_{i}", wxyz=wxyz, position=position) # Move the camera when we click a frame. @frame.on_click def _(_): nonlocal displayed_3d_container # Close previously opened GUI. if displayed_3d_container is not None: displayed_3d_container.remove() displayed_3d_container = client.scene.add_3d_gui_container( f"/frame_{i}/gui" ) with displayed_3d_container: go_to = client.gui.add_button("Go to") randomize_orientation = client.gui.add_button("Randomize orientation") close = client.gui.add_button("Close GUI") @go_to.on_click def _(_) -> None: T_world_current = tf.SE3.from_rotation_and_translation( tf.SO3(client.camera.wxyz), client.camera.position ) T_world_target = tf.SE3.from_rotation_and_translation( tf.SO3(frame.wxyz), frame.position ) @ tf.SE3.from_translation(onp.array([0.0, 0.0, -0.5])) T_current_target = T_world_current.inverse() @ T_world_target for j in range(20): T_world_set = T_world_current @ tf.SE3.exp( T_current_target.log() * j / 19.0 ) # Important bit: we atomically set both the orientation and the position # of the camera. with client.atomic(): client.camera.wxyz = T_world_set.rotation().wxyz client.camera.position = T_world_set.translation() time.sleep(1.0 / 60.0) # Mouse interactions should orbit around the frame origin. client.camera.look_at = frame.position @randomize_orientation.on_click def _(_) -> None: wxyz = rng.normal(size=4) wxyz /= onp.linalg.norm(wxyz) frame.wxyz = wxyz @close.on_click def _(_) -> None: nonlocal displayed_3d_container if displayed_3d_container is None: return displayed_3d_container.remove() displayed_3d_container = None for i in range(num_frames): make_frame(i) while True: time.sleep(1.0) ================================================ FILE: viser/examples/16_modal.py ================================================ """Modal basics Examples of using modals in Viser.""" import time import viser def main(): server = viser.ViserServer() @server.on_client_connect def _(client: viser.ClientHandle) -> None: with client.gui.add_modal("Modal example"): client.gui.add_markdown( "**The input below determines the title of the modal...**" ) gui_title = client.gui.add_text( "Title", initial_value="My Modal", ) modal_button = client.gui.add_button("Show more modals") @modal_button.on_click def _(_) -> None: with client.gui.add_modal(gui_title.value) as modal: client.gui.add_markdown("This is content inside the modal!") client.gui.add_button("Close").on_click(lambda _: modal.close()) while True: time.sleep(0.15) if __name__ == "__main__": main() ================================================ FILE: viser/examples/17_background_composite.py ================================================ """Depth compositing In this example, we show how to use a background image with depth compositing. This can be useful when we want a 2D image to occlude 3D geometry, such as for NeRF rendering. """ import time import numpy as onp import trimesh import trimesh.creation import viser server = viser.ViserServer() img = onp.random.randint(0, 255, size=(1000, 1000, 3), dtype=onp.uint8) depth = onp.ones((1000, 1000, 1), dtype=onp.float32) # Make a square middle portal. depth[250:750, 250:750, :] = 10.0 img[250:750, 250:750, :] = 255 mesh = trimesh.creation.box((0.5, 0.5, 0.5)) server.scene.add_mesh_trimesh( name="/cube", mesh=mesh, position=(0, 0, 0.0), ) server.scene.set_background_image(img, depth=depth) while True: time.sleep(1.0) ================================================ FILE: viser/examples/18_splines.py ================================================ """Splines Make a ball with some random splines. """ import time import numpy as onp import viser def main() -> None: server = viser.ViserServer() for i in range(10): positions = onp.random.normal(size=(30, 3)) * 3.0 server.scene.add_spline_catmull_rom( f"/catmull_{i}", positions, tension=0.5, line_width=3.0, color=onp.random.uniform(size=3), segments=100, ) control_points = onp.random.normal(size=(30 * 2 - 2, 3)) * 3.0 server.scene.add_spline_cubic_bezier( f"/cubic_bezier_{i}", positions, control_points, line_width=3.0, color=onp.random.uniform(size=3), segments=100, ) while True: time.sleep(10.0) if __name__ == "__main__": main() ================================================ FILE: viser/examples/19_get_renders.py ================================================ """Get renders Example for getting renders from a client's viewport to the Python API.""" import time import imageio.v3 as iio import numpy as onp import viser def main(): server = viser.ViserServer() button = server.gui.add_button("Render a GIF") @button.on_click def _(event: viser.GuiEvent) -> None: client = event.client assert client is not None client.scene.reset() images = [] for i in range(20): positions = onp.random.normal(size=(30, 3)) * 3.0 client.scene.add_spline_catmull_rom( f"/catmull_{i}", positions, tension=0.5, line_width=3.0, color=onp.random.uniform(size=3), ) images.append(client.camera.get_render(height=720, width=1280)) print("Generating and sending GIF...") client.send_file_download( "image.gif", iio.imwrite("", images, extension=".gif") ) print("Done!") while True: time.sleep(10.0) if __name__ == "__main__": main() ================================================ FILE: viser/examples/20_scene_pointer.py ================================================ """Scene pointer events. This example shows how to use scene pointer events to specify rays, and how they can be used to interact with the scene (e.g., ray-mesh intersections). To get the demo data, see `./assets/download_dragon_mesh.sh`. """ from __future__ import annotations import time from pathlib import Path from typing import cast import numpy as onp import trimesh import trimesh.creation import trimesh.ray import viser import viser.transforms as tf from viser.theme import TitlebarConfig server = viser.ViserServer() server.gui.configure_theme( brand_color=(130, 0, 150), titlebar_content=TitlebarConfig(buttons=(), image=None), ) server.scene.set_up_direction("+y") mesh = cast( trimesh.Trimesh, trimesh.load_mesh(str(Path(__file__).parent / "assets/dragon.obj")) ) mesh.apply_scale(0.05) mesh_handle = server.scene.add_mesh_trimesh( name="/mesh", mesh=mesh, position=(0.0, 0.0, 0.0), ) hit_pos_handles: list[viser.GlbHandle] = [] # Buttons + callbacks will operate on a per-client basis, but will modify the global scene! :) @server.on_client_connect def _(client: viser.ClientHandle) -> None: # Set up the camera -- this gives a nice view of the full mesh. client.camera.position = onp.array([0.0, 0.0, -10.0]) client.camera.wxyz = onp.array([0.0, 0.0, 0.0, 1.0]) # Tests "click" scenepointerevent. click_button_handle = client.gui.add_button("Add sphere", icon=viser.Icon.POINTER) @click_button_handle.on_click def _(_): click_button_handle.disabled = True @client.scene.on_pointer_event(event_type="click") def _(event: viser.ScenePointerEvent) -> None: # Check for intersection with the mesh, using trimesh's ray-mesh intersection. # Note that mesh is in the mesh frame, so we need to transform the ray. R_world_mesh = tf.SO3(mesh_handle.wxyz) R_mesh_world = R_world_mesh.inverse() origin = (R_mesh_world @ onp.array(event.ray_origin)).reshape(1, 3) direction = (R_mesh_world @ onp.array(event.ray_direction)).reshape(1, 3) intersector = trimesh.ray.ray_triangle.RayMeshIntersector(mesh) hit_pos, _, _ = intersector.intersects_location(origin, direction) if len(hit_pos) == 0: return client.scene.remove_pointer_callback() # Get the first hit position (based on distance from the ray origin). hit_pos = hit_pos[onp.argmin(onp.sum((hit_pos - origin) ** 2, axis=-1))] # Create a sphere at the hit location. hit_pos_mesh = trimesh.creation.icosphere(radius=0.1) hit_pos_mesh.vertices += R_world_mesh @ hit_pos hit_pos_mesh.visual.vertex_colors = (0.5, 0.0, 0.7, 1.0) # type: ignore hit_pos_handle = server.scene.add_mesh_trimesh( name=f"/hit_pos_{len(hit_pos_handles)}", mesh=hit_pos_mesh ) hit_pos_handles.append(hit_pos_handle) @client.scene.on_pointer_callback_removed def _(): click_button_handle.disabled = False # Tests "rect-select" scenepointerevent. paint_button_handle = client.gui.add_button("Paint mesh", icon=viser.Icon.PAINT) @paint_button_handle.on_click def _(_): paint_button_handle.disabled = True @client.scene.on_pointer_event(event_type="rect-select") def _(message: viser.ScenePointerEvent) -> None: client.scene.remove_pointer_callback() global mesh_handle camera = message.client.camera # Put the mesh in the camera frame. R_world_mesh = tf.SO3(mesh_handle.wxyz) R_mesh_world = R_world_mesh.inverse() R_camera_world = tf.SE3.from_rotation_and_translation( tf.SO3(camera.wxyz), camera.position ).inverse() vertices = cast(onp.ndarray, mesh.vertices) vertices = (R_mesh_world.as_matrix() @ vertices.T).T vertices = ( R_camera_world.as_matrix() @ onp.hstack([vertices, onp.ones((vertices.shape[0], 1))]).T ).T[:, :3] # Get the camera intrinsics, and project the vertices onto the image plane. fov, aspect = camera.fov, camera.aspect vertices_proj = vertices[:, :2] / vertices[:, 2].reshape(-1, 1) vertices_proj /= onp.tan(fov / 2) vertices_proj[:, 0] /= aspect # Move the origin to the upper-left corner, and scale to [0, 1]. # ... make sure to match the OpenCV's image coordinates! vertices_proj = (1 + vertices_proj) / 2 # Select the vertices that lie inside the 2D selected box, once projected. mask = ( (vertices_proj > onp.array(message.screen_pos[0])) & (vertices_proj < onp.array(message.screen_pos[1])) ).all(axis=1)[..., None] # Update the mesh color based on whether the vertices are inside the box mesh.visual.vertex_colors = onp.where( # type: ignore mask, (0.5, 0.0, 0.7, 1.0), (0.9, 0.9, 0.9, 1.0) ) mesh_handle = server.scene.add_mesh_trimesh( name="/mesh", mesh=mesh, position=(0.0, 0.0, 0.0), ) @client.scene.on_pointer_callback_removed def _(): paint_button_handle.disabled = False # Button to clear spheres. clear_button_handle = client.gui.add_button("Clear scene", icon=viser.Icon.X) @clear_button_handle.on_click def _(_): """Reset the mesh color and remove all click-generated spheres.""" global mesh_handle for handle in hit_pos_handles: handle.remove() hit_pos_handles.clear() mesh.visual.vertex_colors = (0.9, 0.9, 0.9, 1.0) # type: ignore mesh_handle = server.scene.add_mesh_trimesh( name="/mesh", mesh=mesh, position=(0.0, 0.0, 0.0), ) while True: time.sleep(10.0) ================================================ FILE: viser/examples/21_set_up_direction.py ================================================ """Set up direction `.set_up_direction()` can help us set the global up direction.""" import time import viser def main() -> None: server = viser.ViserServer() server.scene.world_axes.visible = True gui_up = server.gui.add_vector3( "Up Direction", initial_value=(0.0, 0.0, 1.0), step=0.01, ) @gui_up.on_update def _(_) -> None: server.scene.set_up_direction(gui_up.value) while True: time.sleep(1.0) if __name__ == "__main__": main() ================================================ FILE: viser/examples/22_games.py ================================================ """Games Some two-player games implemented using scene click events.""" import time from typing import Literal import numpy as onp import trimesh.creation from typing_extensions import assert_never import viser import viser.transforms as tf def main() -> None: server = viser.ViserServer() server.gui.configure_theme(dark_mode=True) play_connect_4(server) server.gui.add_button("Tic-Tac-Toe").on_click(lambda _: play_tic_tac_toe(server)) server.gui.add_button("Connect 4").on_click(lambda _: play_connect_4(server)) while True: time.sleep(10.0) def play_connect_4(server: viser.ViserServer) -> None: """Play a game of Connect 4.""" server.scene.reset() num_rows = 6 num_cols = 7 whose_turn: Literal["red", "yellow"] = "red" pieces_in_col = [0] * num_cols # Create the board frame. for col in range(num_cols): for row in range(num_rows): server.scene.add_mesh_trimesh( f"/structure/{row}_{col}", trimesh.creation.annulus(0.45, 0.55, 0.125), position=(0.0, col, row), wxyz=tf.SO3.from_y_radians(onp.pi / 2.0).wxyz, ) # Create a sphere to click on for each column. def setup_column(col: int) -> None: sphere = server.scene.add_icosphere( f"/spheres/{col}", radius=0.25, position=(0, col, num_rows - 0.25), color=(255, 255, 255), ) # Drop piece into the column. @sphere.on_click def _(_) -> None: nonlocal whose_turn whose_turn = "red" if whose_turn != "red" else "yellow" row = pieces_in_col[col] if row == num_rows - 1: sphere.remove() pieces_in_col[col] += 1 cylinder = trimesh.creation.cylinder(radius=0.4, height=0.125) piece = server.scene.add_mesh_simple( f"/game_pieces/{row}_{col}", cylinder.vertices, cylinder.faces, wxyz=tf.SO3.from_y_radians(onp.pi / 2.0).wxyz, color={"red": (255, 0, 0), "yellow": (255, 255, 0)}[whose_turn], ) for row_anim in onp.linspace(num_rows - 1, row, num_rows - row + 1): piece.position = ( 0, col, row_anim, ) time.sleep(1.0 / 30.0) for col in range(num_cols): setup_column(col) def play_tic_tac_toe(server: viser.ViserServer) -> None: """Play a game of tic-tac-toe.""" server.scene.reset() whose_turn: Literal["x", "o"] = "x" for i in range(4): server.scene.add_spline_catmull_rom( f"/gridlines/{i}", ((-0.5, -1.5, 0), (-0.5, 1.5, 0)), color=(127, 127, 127), position=(1, 1, 0), wxyz=tf.SO3.from_z_radians(onp.pi / 2 * i).wxyz, ) def draw_symbol(symbol: Literal["x", "o"], i: int, j: int) -> None: """Draw an X or O in the given cell.""" for scale in onp.linspace(0.01, 1.0, 5): if symbol == "x": for k in range(2): server.scene.add_box( f"/symbols/{i}_{j}/{k}", dimensions=(0.7 * scale, 0.125 * scale, 0.125), position=(i, j, 0), color=(0, 0, 255), wxyz=tf.SO3.from_z_radians( onp.pi / 2.0 * k + onp.pi / 4.0 ).wxyz, ) elif symbol == "o": mesh = trimesh.creation.annulus(0.25 * scale, 0.35 * scale, 0.125) server.scene.add_mesh_simple( f"/symbols/{i}_{j}", mesh.vertices, mesh.faces, position=(i, j, 0), color=(255, 0, 0), ) else: assert_never(symbol) server.flush() time.sleep(1.0 / 30.0) def setup_cell(i: int, j: int) -> None: """Create a clickable sphere in a given cell.""" sphere = server.scene.add_icosphere( f"/spheres/{i}_{j}", radius=0.25, position=(i, j, 0), color=(255, 255, 255), ) @sphere.on_click def _(_) -> None: nonlocal whose_turn whose_turn = "x" if whose_turn != "x" else "o" sphere.remove() draw_symbol(whose_turn, i, j) for i in range(3): for j in range(3): setup_cell(i, j) if __name__ == "__main__": main() ================================================ FILE: viser/examples/23_plotly.py ================================================ """Plotly Examples of visualizing plotly plots in Viser.""" import time import numpy as onp import plotly.express as px import plotly.graph_objects as go from PIL import Image import viser def create_sinusoidal_wave(t: float) -> go.Figure: """Create a sinusoidal wave plot, starting at time t.""" x_data = onp.linspace(t, t + 6 * onp.pi, 50) y_data = onp.sin(x_data) * 10 fig = px.line( x=list(x_data), y=list(y_data), labels={"x": "x", "y": "sin(x)"}, title="Sinusoidal Wave", ) # this sets the margins to be tight around the title. fig.layout.title.automargin = True # type: ignore fig.update_layout( margin=dict(l=20, r=20, t=20, b=20), ) # Reduce plot margins. return fig def main() -> None: server = viser.ViserServer() # Plot type 1: Line plot. line_plot_time = 0.0 line_plot = server.gui.add_plotly(figure=create_sinusoidal_wave(line_plot_time)) # Plot type 2: Image plot. fig = px.imshow(Image.open("assets/Cal_logo.png")) fig.update_layout( margin=dict(l=20, r=20, t=20, b=20), ) server.gui.add_plotly(figure=fig, aspect=1.0) # Plot type 3: 3D Scatter plot. fig = px.scatter_3d( px.data.iris(), x="sepal_length", y="sepal_width", z="petal_width", color="species", ) fig.update_layout(legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01)) fig.update_layout( margin=dict(l=20, r=20, t=20, b=20), ) server.gui.add_plotly(figure=fig, aspect=1.0) while True: # Update the line plot. line_plot_time += 0.1 line_plot.figure = create_sinusoidal_wave(line_plot_time) time.sleep(0.01) if __name__ == "__main__": main() ================================================ FILE: viser/examples/24_notification.py ================================================ """Notifications Examples of adding notifications per client in Viser.""" import time import viser def main() -> None: server = viser.ViserServer() persistent_notif_button = server.gui.add_button( "Show persistent notification (default)" ) timed_notif_button = server.gui.add_button("Show timed notification") controlled_notif_button = server.gui.add_button("Show controlled notification") loading_notif_button = server.gui.add_button("Show loading notification") remove_controlled_notif = server.gui.add_button("Remove controlled notification") @persistent_notif_button.on_click def _(event: viser.GuiEvent) -> None: """Show persistent notification when the button is clicked.""" client = event.client assert client is not None client.add_notification( title="Persistent notification", body="This can be closed manually and does not disappear on its own!", loading=False, with_close_button=True, auto_close=False, ) @timed_notif_button.on_click def _(event: viser.GuiEvent) -> None: """Show timed notification when the button is clicked.""" client = event.client assert client is not None client.add_notification( title="Timed notification", body="This disappears automatically after 5 seconds!", loading=False, with_close_button=True, auto_close=5000, ) @controlled_notif_button.on_click def _(event: viser.GuiEvent) -> None: """Show controlled notification when the button is clicked.""" client = event.client assert client is not None controlled_notif = client.add_notification( title="Controlled notification", body="This cannot be closed by the user and is controlled in code only!", loading=False, with_close_button=False, auto_close=False, ) @remove_controlled_notif.on_click def _(_) -> None: """Remove controlled notification.""" controlled_notif.remove() @loading_notif_button.on_click def _(event: viser.GuiEvent) -> None: """Show loading notification when the button is clicked.""" client = event.client assert client is not None loading_notif = client.add_notification( title="Loading notification", body="This indicates that some action is in progress! It will be updated in 3 seconds.", loading=True, with_close_button=False, auto_close=False, ) time.sleep(3.0) loading_notif.title = "Updated notification" loading_notif.body = "This notification has been updated!" loading_notif.loading = False loading_notif.with_close_button = True loading_notif.auto_close = 5000 loading_notif.color = "green" while True: time.sleep(1.0) if __name__ == "__main__": main() ================================================ FILE: viser/examples/25_smpl_visualizer_skinned.py ================================================ # mypy: disable-error-code="assignment" # # Asymmetric properties are supported in Pyright, but not yet in mypy. # - https://github.com/python/mypy/issues/3004 # - https://github.com/python/mypy/pull/11643 """SMPL visualizer (Skinned Mesh) Requires a .npz model file. See here for download instructions: https://github.com/vchoutas/smplx?tab=readme-ov-file#downloading-the-model """ from __future__ import annotations import time from dataclasses import dataclass from pathlib import Path from typing import List, Tuple import numpy as np import numpy as onp import tyro import viser import viser.transforms as tf @dataclass(frozen=True) class SmplOutputs: vertices: np.ndarray faces: np.ndarray T_world_joint: np.ndarray # (num_joints, 4, 4) T_parent_joint: np.ndarray # (num_joints, 4, 4) class SmplHelper: """Helper for models in the SMPL family, implemented in numpy.""" def __init__(self, model_path: Path) -> None: assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!" body_dict = dict(**onp.load(model_path, allow_pickle=True)) self._J_regressor = body_dict["J_regressor"] self._weights = body_dict["weights"] self._v_template = body_dict["v_template"] self._posedirs = body_dict["posedirs"] self._shapedirs = body_dict["shapedirs"] self._faces = body_dict["f"] self.num_joints: int = self._weights.shape[-1] self.num_betas: int = self._shapedirs.shape[-1] self.parent_idx: np.ndarray = body_dict["kintree_table"][0] def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutputs: # Get shaped vertices + joint positions, when all local poses are identity. v_tpose = self._v_template + np.einsum("vxb,b->vx", self._shapedirs, betas) j_tpose = np.einsum("jv,vx->jx", self._J_regressor, v_tpose) # Local SE(3) transforms. T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4) T_parent_joint[:, :3, :3] = joint_rotmats T_parent_joint[0, :3, 3] = j_tpose[0] T_parent_joint[1:, :3, 3] = j_tpose[1:] - j_tpose[self.parent_idx[1:]] # Forward kinematics. T_world_joint = T_parent_joint.copy() for i in range(1, self.num_joints): T_world_joint[i] = T_world_joint[self.parent_idx[i]] @ T_parent_joint[i] # Linear blend skinning. pose_delta = (joint_rotmats[1:, ...] - np.eye(3)).flatten() v_blend = v_tpose + np.einsum("byn,n->by", self._posedirs, pose_delta) v_delta = np.ones((v_blend.shape[0], self.num_joints, 4)) v_delta[:, :, :3] = v_blend[:, None, :] - j_tpose[None, :, :] v_posed = np.einsum( "jxy,vj,vjy->vx", T_world_joint[:, :3, :], self._weights, v_delta ) return SmplOutputs(v_posed, self._faces, T_world_joint, T_parent_joint) def main(model_path: Path) -> None: server = viser.ViserServer() server.scene.set_up_direction("+y") server.gui.configure_theme(control_layout="collapsible") # Main loop. We'll read pose/shape from the GUI elements, compute the mesh, # and then send the updated mesh in a loop. model = SmplHelper(model_path) gui_elements = make_gui_elements( server, num_betas=model.num_betas, num_joints=model.num_joints, parent_idx=model.parent_idx, ) smpl_outputs = model.get_outputs( betas=np.array([x.value for x in gui_elements.gui_betas]), joint_rotmats=onp.zeros((model.num_joints, 3, 3)) + onp.eye(3), ) bone_wxyzs = np.array( [tf.SO3.from_matrix(R).wxyz for R in smpl_outputs.T_world_joint[:, :3, :3]] ) bone_positions = smpl_outputs.T_world_joint[:, :3, 3] skinned_handle = server.scene.add_mesh_skinned( "/human", smpl_outputs.vertices, smpl_outputs.faces, bone_wxyzs=bone_wxyzs, bone_positions=bone_positions, skin_weights=model._weights, wireframe=gui_elements.gui_wireframe.value, color=gui_elements.gui_rgb.value, ) while True: # Do nothing if no change. time.sleep(0.02) if not gui_elements.changed: continue gui_elements.changed = False # Compute SMPL outputs. smpl_outputs = model.get_outputs( betas=np.array([x.value for x in gui_elements.gui_betas]), joint_rotmats=np.stack( [ tf.SO3.exp(np.array(x.value)).as_matrix() for x in gui_elements.gui_joints ], axis=0, ), ) # Match transform control gizmos to joint positions. for i, control in enumerate(gui_elements.transform_controls): control.position = smpl_outputs.T_parent_joint[i, :3, 3] skinned_handle.bones[i].wxyz = tf.SO3.from_matrix( smpl_outputs.T_world_joint[i, :3, :3] ).wxyz skinned_handle.bones[i].position = smpl_outputs.T_world_joint[i, :3, 3] @dataclass class GuiElements: """Structure containing handles for reading from GUI elements.""" gui_rgb: viser.GuiInputHandle[Tuple[int, int, int]] gui_wireframe: viser.GuiInputHandle[bool] gui_betas: List[viser.GuiInputHandle[float]] gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] transform_controls: List[viser.TransformControlsHandle] changed: bool """This flag will be flipped to True whenever the mesh needs to be re-generated.""" def make_gui_elements( server: viser.ViserServer, num_betas: int, num_joints: int, parent_idx: np.ndarray, ) -> GuiElements: """Make GUI elements for interacting with the model.""" tab_group = server.gui.add_tab_group() def set_changed(_) -> None: out.changed = True # out is define later! # GUI elements: mesh settings + visibility. with tab_group.add_tab("View", viser.Icon.VIEWFINDER): gui_rgb = server.gui.add_rgb("Color", initial_value=(90, 200, 255)) gui_wireframe = server.gui.add_checkbox("Wireframe", initial_value=False) gui_show_controls = server.gui.add_checkbox("Handles", initial_value=True) gui_rgb.on_update(set_changed) gui_wireframe.on_update(set_changed) @gui_show_controls.on_update def _(_): for control in transform_controls: control.visible = gui_show_controls.value # GUI elements: shape parameters. with tab_group.add_tab("Shape", viser.Icon.BOX): gui_reset_shape = server.gui.add_button("Reset Shape") gui_random_shape = server.gui.add_button("Random Shape") @gui_reset_shape.on_click def _(_): for beta in gui_betas: beta.value = 0.0 @gui_random_shape.on_click def _(_): for beta in gui_betas: beta.value = onp.random.normal(loc=0.0, scale=1.0) gui_betas = [] for i in range(num_betas): beta = server.gui.add_slider( f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0 ) gui_betas.append(beta) beta.on_update(set_changed) # GUI elements: joint angles. with tab_group.add_tab("Joints", viser.Icon.ANGLE): gui_reset_joints = server.gui.add_button("Reset Joints") gui_random_joints = server.gui.add_button("Random Joints") @gui_reset_joints.on_click def _(_): for joint in gui_joints: joint.value = (0.0, 0.0, 0.0) @gui_random_joints.on_click def _(_): for joint in gui_joints: # It's hard to uniformly sample orientations directly in so(3), so we # first sample on S^3 and then convert. quat = onp.random.normal(loc=0.0, scale=1.0, size=(4,)) quat /= onp.linalg.norm(quat) joint.value = tf.SO3(wxyz=quat).log() gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] = [] for i in range(num_joints): gui_joint = server.gui.add_vector3( label=f"Joint {i}", initial_value=(0.0, 0.0, 0.0), step=0.05, ) gui_joints.append(gui_joint) def set_callback_in_closure(i: int) -> None: @gui_joint.on_update def _(_): transform_controls[i].wxyz = tf.SO3.exp( np.array(gui_joints[i].value) ).wxyz out.changed = True set_callback_in_closure(i) # Transform control gizmos on joints. transform_controls: List[viser.TransformControlsHandle] = [] prefixed_joint_names = [] # Joint names, but prefixed with parents. for i in range(num_joints): prefixed_joint_name = f"joint_{i}" if i > 0: prefixed_joint_name = ( prefixed_joint_names[parent_idx[i]] + "/" + prefixed_joint_name ) prefixed_joint_names.append(prefixed_joint_name) controls = server.scene.add_transform_controls( f"/smpl/{prefixed_joint_name}", depth_test=False, scale=0.2 * (0.75 ** prefixed_joint_name.count("/")), disable_axes=True, disable_sliders=True, visible=gui_show_controls.value, ) transform_controls.append(controls) def set_callback_in_closure(i: int) -> None: @controls.on_update def _(_) -> None: axisangle = tf.SO3(transform_controls[i].wxyz).log() gui_joints[i].value = (axisangle[0], axisangle[1], axisangle[2]) set_callback_in_closure(i) out = GuiElements( gui_rgb, gui_wireframe, gui_betas, gui_joints, transform_controls=transform_controls, changed=True, ) return out if __name__ == "__main__": tyro.cli(main, description=__doc__) ================================================ FILE: viser/examples/assets/.gitignore ================================================ dragon.obj /record3d_dance/ /colmap_garden/ ================================================ FILE: viser/examples/assets/download_colmap_garden.sh ================================================ # This downloads the COLMAP model for the MIP-NeRF garden dataset # with the images that are downscaled by a factor of 8. # The full dataset is available at https://jonbarron.info/mipnerf360/. set -e -x gdown "https://drive.google.com/uc?id=1wYHdrgwXPHtREdCjItvt4gqRQGISMade" mkdir -p colmap_garden # shellcheck disable=SC2035 unzip *.zip && rm *.zip ================================================ FILE: viser/examples/assets/download_dragon_mesh.sh ================================================ set -e -x gdown "https://drive.google.com/uc?id=1uRDvoS_l2Or8g8YDDPYV79K6_RfFYBeF" ================================================ FILE: viser/examples/assets/download_record3d_dance.sh ================================================ set -e -x gdown "https://drive.google.com/uc?id=1_vd5bK_MhtlfisA6BkK1IgiJNfDbIntq" mkdir -p record3d_dance # shellcheck disable=SC2035 unzip *.r3d -d record3d_dance && rm *.r3d ================================================ FILE: viser/examples/assets/mdx_example.mdx ================================================ ## Markdown in Viser --- Viser has full support for the GFM markdown spec, including **bold**, _italics_, ~~strikethrough~~, and many other features. Here's a [masked link](https://github.com/nerfstudio-project/viser). Not a fan? Here's a normal one: https://pypi.org/project/viser/ Anywhere where you can insert GUI elements, you can also insert `images`, `blockquotes`, `lists`, `tables`, `task lists`, and `(unstyled) code blocks`. In inline code blocks, you can show off colors with color chips: `#FED363` `hsl(0, 0%, 82%)` `rgb(255, 255, 255)` Adding images from a remote origin is simple. ![Viser Logo](https://viser.studio/latest/_static/logo.svg) For local images with relative paths, you can either directly use a [data URL](https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/Data_URLs) or set the `image_root` argument to the `Path` object that you'd like your paths relative to. If no such `image_root` is provided, the file system will be scoped to the directory that Viser is installed in. ![Cal Logo](../examples/assets/Cal_logo.png) Tables follow the standard markdown spec: | Application | Description | | ---------------------------------------------------- | -------------------------------------------------- | | [NS](https://nerf.studio) | A collaboration friendly studio for NeRFs | | [Viser](https://nerfstudio-project.github.io/viser/) | An interactive 3D visualization toolbox for Python | Code blocks, while being not nearly as exciting as some of the things presented, work as expected. Currently, while you can specify a language and metadata in your blocks, they will remain unused by the Markdown renderer. ```python """Markdown Demonstration Viser GUI has MDX 2 support. """ import time from pathlib import Path import viser server = viser.ViserServer() server.world_axes.visible = True @server.on_client_connect def _(client: viser.ClientHandle) -> None: with open("./assets/mdx_example.mdx", "r") as mkdn: markdown = client.gui.add_markdown( markdown=mkdn.read(), image_root=Path(__file__).parent ) button = client.gui.add_button("Remove Markdown") @button.on_click def _(_): markdown.remove() while True: time.sleep(10.0) ``` As a bonus, MDX is extensible and JS capable. This means that you have the freedom to do things like: This page loaded on {(new Date()).toString()} Or: > Oh yes, mdx PR would be exciting > > — Brent Yi **Note**: Be careful when playing with JSX, it's very easy to break markdown. So that's MDX in Viser. It has support for: - [x] CommonMark and GFM standards - bold, italics, strikethrough, images, blockquotes, tables, task lists, code blocks, inline code - [x] Color chips - [x] JSX enhanced components ================================================ FILE: viser/examples/experimental/gaussian_splats.py ================================================ """WebGL-based Gaussian splat rendering. This is still under developmentt.""" from __future__ import annotations import time from pathlib import Path from typing import TypedDict import numpy as onp import numpy.typing as onpt import tyro from plyfile import PlyData import viser from viser import transforms as tf class SplatFile(TypedDict): """Data loaded from an antimatter15-style splat file.""" centers: onpt.NDArray[onp.floating] """(N, 3).""" rgbs: onpt.NDArray[onp.floating] """(N, 3). Range [0, 1].""" opacities: onpt.NDArray[onp.floating] """(N, 1). Range [0, 1].""" covariances: onpt.NDArray[onp.floating] """(N, 3, 3).""" def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile: """Load an antimatter15-style splat file.""" start_time = time.time() splat_buffer = splat_path.read_bytes() bytes_per_gaussian = ( # Each Gaussian is serialized as: # - position (vec3, float32) 3 * 4 # - xyz (vec3, float32) + 3 * 4 # - rgba (vec4, uint8) + 4 # - ijkl (vec4, uint8), where 0 => -1, 255 => 1. + 4 ) assert len(splat_buffer) % bytes_per_gaussian == 0 num_gaussians = len(splat_buffer) // bytes_per_gaussian # Reinterpret cast to dtypes that we want to extract. splat_uint8 = onp.frombuffer(splat_buffer, dtype=onp.uint8).reshape( (num_gaussians, bytes_per_gaussian) ) scales = splat_uint8[:, 12:24].copy().view(onp.float32) wxyzs = splat_uint8[:, 28:32] / 255.0 * 2.0 - 1.0 Rs = tf.SO3(wxyzs).as_matrix() covariances = onp.einsum( "nij,njk,nlk->nil", Rs, onp.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs ) centers = splat_uint8[:, 0:12].copy().view(onp.float32) if center: centers -= onp.mean(centers, axis=0, keepdims=True) print( f"Splat file with {num_gaussians=} loaded in {time.time() - start_time} seconds" ) return { "centers": centers, # Colors should have shape (N, 3). "rgbs": splat_uint8[:, 24:27] / 255.0, "opacities": splat_uint8[:, 27:28] / 255.0, # Covariances should have shape (N, 3, 3). "covariances": covariances, } def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile: """Load Gaussians stored in a PLY file.""" start_time = time.time() SH_C0 = 0.28209479177387814 plydata = PlyData.read(ply_file_path) v = plydata["vertex"] positions = onp.stack([v["x"], v["y"], v["z"]], axis=-1) scales = onp.exp(onp.stack([v["scale_0"], v["scale_1"], v["scale_2"]], axis=-1)) wxyzs = onp.stack([v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], axis=1) colors = 0.5 + SH_C0 * onp.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]], axis=1) opacities = 1.0 / (1.0 + onp.exp(-v["opacity"][:, None])) Rs = tf.SO3(wxyzs).as_matrix() covariances = onp.einsum( "nij,njk,nlk->nil", Rs, onp.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs ) if center: positions -= onp.mean(positions, axis=0, keepdims=True) num_gaussians = len(v) print( f"PLY file with {num_gaussians=} loaded in {time.time() - start_time} seconds" ) return { "centers": positions, "rgbs": colors, "opacities": opacities, "covariances": covariances, } def main(splat_paths: tuple[Path, ...]) -> None: server = viser.ViserServer() server.gui.configure_theme(dark_mode=True) gui_reset_up = server.gui.add_button( "Reset up direction", hint="Set the camera control 'up' direction to the current camera's 'up'.", ) @gui_reset_up.on_click def _(event: viser.GuiEvent) -> None: client = event.client assert client is not None client.camera.up_direction = tf.SO3(client.camera.wxyz) @ onp.array( [0.0, -1.0, 0.0] ) for i, splat_path in enumerate(splat_paths): if splat_path.suffix == ".splat": splat_data = load_splat_file(splat_path, center=True) elif splat_path.suffix == ".ply": splat_data = load_ply_file(splat_path, center=True) else: raise SystemExit("Please provide a filepath to a .splat or .ply file.") server.scene.add_transform_controls(f"/{i}") gs_handle = server.scene._add_gaussian_splats( f"/{i}/gaussian_splats", centers=splat_data["centers"], rgbs=splat_data["rgbs"], opacities=splat_data["opacities"], covariances=splat_data["covariances"], ) remove_button = server.gui.add_button(f"Remove splat object {i}") @remove_button.on_click def _(_, gs_handle=gs_handle, remove_button=remove_button) -> None: gs_handle.remove() remove_button.remove() while True: time.sleep(10.0) if __name__ == "__main__": tyro.cli(main) ================================================ FILE: viser/examples/quick_save.py ================================================ """Record3D visualizer Batch process Record3D captures to generate recordings for multiple data folders. """ import time import sys import os import argparse from pathlib import Path import numpy as onp from tqdm.auto import tqdm import viser import viser.extras import viser.transforms as tf import matplotlib.cm as cm # For colormap def process_folder( data_path: Path, downsample_factor: int, max_frames: int, conf_threshold: float, foreground_conf_threshold: float, point_size: float, camera_frustum_scale: float, no_mask: bool, xyzw: bool, axes_scale: float, bg_downsample_factor: int, output_dir: Path, ) -> None: print(f"Processing folder: {data_path}") server = viser.ViserServer() server.scene.set_up_direction('-z') loader = viser.extras.Record3dLoader_Customized( data_path, conf_threshold=conf_threshold, foreground_conf_threshold=foreground_conf_threshold, no_mask=no_mask, xyzw=xyzw, init_conf=True, ) num_frames = min(max_frames, loader.num_frames()) # Load frames server.scene.add_frame( "/frames", wxyz=tf.SO3.exp(onp.array([onp.pi / 2.0, 0.0, 0.0])).wxyz, position=(0, 0, 0), show_axes=False, ) frame_nodes: list[viser.FrameHandle] = [] bg_positions = [] bg_colors = [] for i in tqdm(range(num_frames)): frame = loader.get_frame(i) position, color, bg_position, bg_color = frame.get_point_cloud(downsample_factor, bg_downsample_factor) bg_positions.append(bg_position) bg_colors.append(bg_color) # Add base frame frame_nodes.append(server.scene.add_frame(f"/frames/t{i}", show_axes=False)) # Place the point cloud in the frame server.scene.add_point_cloud( name=f"/frames/t{i}/point_cloud", points=position, colors=color, point_size=point_size, point_shape="rounded", ) # Compute color for frustum based on frame index norm_i = i / (num_frames - 1) if num_frames > 1 else 0 # Normalize index to [0, 1] color_rgba = cm.viridis(norm_i) # Get RGBA color from colormap color_rgb = color_rgba[:3] # Use RGB components # Place the frustum with the computed color fov = 2 * onp.arctan2(frame.rgb.shape[0] / 2, frame.K[0, 0]) aspect = frame.rgb.shape[1] / frame.rgb.shape[0] server.scene.add_camera_frustum( f"/frames/t{i}/frustum", fov=fov, aspect=aspect, scale=camera_frustum_scale, image=frame.rgb[::downsample_factor, ::downsample_factor], wxyz=tf.SO3.from_matrix(frame.T_world_camera[:3, :3]).wxyz, position=frame.T_world_camera[:3, 3], color=color_rgb, # Set the color for the frustum ) # Add axes server.scene.add_frame( f"/frames/t{i}/frustum/axes", axes_length=camera_frustum_scale * axes_scale * 10, axes_radius=camera_frustum_scale * axes_scale, ) # Add background frame bg_positions = onp.concatenate(bg_positions, axis=0) bg_colors = onp.concatenate(bg_colors, axis=0) server.scene.add_point_cloud( name=f"/frames/background", points=bg_positions, colors=bg_colors, point_size=point_size, point_shape="rounded", ) # Automatically play through frames and record the scene rec = server._start_scene_recording() rec.set_loop_start() sleep_duration = 1.0 / loader.fps if loader.fps > 0 else 0.033 # Default to ~30 FPS for t in range(num_frames): # Update the scene to show frame t with server.atomic(): for i, frame_node in enumerate(frame_nodes): frame_node.visible = (i == t) server.flush() rec.insert_sleep(sleep_duration) # Set all frames invisible with server.atomic(): for frame_node in frame_nodes: frame_node.visible = False server.flush() # Finish recording bs = rec.end_and_serialize() # Save the recording to a file output_path = output_dir / f"recording_{data_path.name}.viser" # Ensure the output directory exists output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_bytes(bs) print(f"Recording saved to {output_path.resolve()}") def main( data_paths: list[Path], output_dir: Path = Path("./viser_result"), downsample_factor: int = 1, max_frames: int = 100, conf_threshold: float = 1.0, foreground_conf_threshold: float = 0.1, point_size: float = 0.001, camera_frustum_scale: float = 0.02, no_mask: bool = False, xyzw: bool = True, axes_scale: float = 0.25, bg_downsample_factor: int = 1, ) -> None: # if data_path[0] has subfolders, process each subfolder if data_paths[0].is_dir(): new_data_paths = sorted([subfolder for subfolder in data_paths[0].iterdir() if subfolder.is_dir()]) if len(new_data_paths) > 0: data_paths = new_data_paths for data_path in data_paths: process_folder( data_path, downsample_factor, max_frames, conf_threshold, foreground_conf_threshold, point_size, camera_frustum_scale, no_mask, xyzw, axes_scale, bg_downsample_factor, output_dir, ) if __name__ == "__main__": # Initialize parser parser = argparse.ArgumentParser(description="Process input arguments.") # Define arguments parser.add_argument( "--data", type=Path, nargs="+", required=True, help="Paths to the data folders (can specify multiple paths)", ) parser.add_argument( "--output_dir", type=Path, default=Path("./viser_result"), help="Output directory for recordings", ) parser.add_argument( "--conf_thre", type=float, default=0.1, help="Confidence threshold, default is 0.1", ) parser.add_argument( "--fg_conf_thre", type=float, default=0.5, help="Foreground confidence threshold, default is 0.0", ) parser.add_argument( "--point_size", type=float, default=0.001, help="Point size, default is 0.001", ) parser.add_argument( "--camera_size", type=float, default=0.015, help="Camera frustum scale, default is 0.015", ) parser.add_argument( "--no_mask", action="store_true", help="Don't use mask to filter out points", ) parser.add_argument( "--wxyz", action="store_true", help="Use wxyz for SO3 representation", ) parser.add_argument( "--axes_scale", type=float, default=0.1, help="Scale of axes", ) parser.add_argument( "--bg_downsample", type=int, default=1, help="Background downsample factor", ) parser.add_argument( "--downsample", type=int, default=2, help="Downsample factor", ) parser.add_argument( "--max_frames", type=int, default=100, help="Maximum number of frames to process", ) # Parse arguments args = parser.parse_args() # Call the main function with the parsed arguments main( data_paths=args.data, output_dir=args.output_dir, conf_threshold=args.conf_thre, foreground_conf_threshold=args.fg_conf_thre, point_size=args.point_size, camera_frustum_scale=args.camera_size, no_mask=args.no_mask, xyzw=not args.wxyz, axes_scale=args.axes_scale, bg_downsample_factor=args.bg_downsample, downsample_factor=args.downsample, max_frames=args.max_frames, ) ================================================ FILE: viser/pyproject.toml ================================================ [build-system] requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" [project] name = "viser" version = "0.2.7" description = "3D visualization + Python" readme = "README.md" license = { text="MIT" } requires-python = ">=3.8" classifiers = [ "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent" ] dependencies = [ "websockets>=10.4", "numpy>=1.0.0", "msgspec>=0.18.6", "imageio>=2.0.0", "pyliblzfse>=0.4.1; platform_system!='Windows'", "scikit-image>=0.18.0", "scipy>=1.7.3", "tqdm>=4.0.0", "tyro>=0.2.0", "rich>=13.3.3", "trimesh>=3.21.7", "nodeenv>=1.8.0", "psutil>=5.9.5", "yourdfpy>=0.0.53", "plyfile>=1.0.2" ] [project.optional-dependencies] dev = [ "pyright>=1.1.308", "ruff==0.6.2", "pre-commit==3.3.2", ] examples = [ "torch>=1.13.1", "matplotlib>=3.7.1", "plotly>=5.21.0", "robot_descriptions>=1.10.0", "gdown>=4.6.6", "plyfile", ] [project.urls] "GitHub" = "https://github.com/nerfstudio-project/viser" # <> # Important: in the ./.github/workflows/publish.yml action, we have sed # commands that assume the `viser = ...` line below directly follows # `[tool.setuptools.package-data]`. We use this to remove the client source # from PyPI builds. # # We should make sure that any modifications to the package-data list remain # compatible with the sed commands! # # We keep the client source in by default to support things like pip # installation via the Git URL, because build artifacts aren't # version-controlled. [tool.setuptools.package-data] viser = ["py.typed", "*.pyi", "_icons/tabler-icons.tar", "client/**/*", "client/**/.*"] # [tool.setuptools.exclude-package-data] # We exclude node_modules to prevent long build times for wheels when # installing from source, eg via `pip install .`. # # https://github.com/nerfstudio-project/viser/issues/271 viser = ["**/node_modules/**"] [project.scripts] viser-dev-checks = "viser.scripts.dev_checks:entrypoint" [tool.pyright] exclude = ["./docs/**/*", "./examples/assets/**/*", "./src/viser/client/.nodeenv", "./build"] [tool.ruff] lint.select = [ "E", # pycodestyle errors. "F", # Pyflakes rules. "PLC", # Pylint convention warnings. "PLE", # Pylint errors. "PLR", # Pylint refactor recommendations. "PLW", # Pylint warnings. "I", # Import sorting. ] lint.ignore = [ "E741", # Ambiguous variable name. (l, O, or I) "E501", # Line too long. "E721", # Do not compare types, use `isinstance()`. "F722", # Forward annotation false positive from jaxtyping. Should be caught by pyright. "F821", # Forward annotation false positive from jaxtyping. Should be caught by pyright. "PLR2004", # Magic value used in comparison. "PLR0915", # Too many statements. "PLR0913", # Too many arguments. "PLC0414", # Import alias does not rename variable. (this is used for exporting names) "PLC1901", # Use falsey strings. "PLR5501", # Use `elif` instead of `else if`. "PLR0911", # Too many return statements. "PLR0912", # Too many branches. "PLW0603", # Globa statement updates are discouraged. "PLW2901", # For loop variable overwritten. "PLW0642", # Reassigned self in instance method. ] exclude = [ ".nodeenv" ] ================================================ FILE: viser/src/viser/__init__.py ================================================ from ._gui_api import GuiApi as GuiApi from ._gui_handles import GuiButtonGroupHandle as GuiButtonGroupHandle from ._gui_handles import GuiButtonHandle as GuiButtonHandle from ._gui_handles import GuiDropdownHandle as GuiDropdownHandle from ._gui_handles import GuiEvent as GuiEvent from ._gui_handles import GuiFolderHandle as GuiFolderHandle from ._gui_handles import GuiInputHandle as GuiInputHandle from ._gui_handles import GuiMarkdownHandle as GuiMarkdownHandle from ._gui_handles import GuiPlotlyHandle as GuiPlotlyHandle from ._gui_handles import GuiTabGroupHandle as GuiTabGroupHandle from ._gui_handles import GuiTabHandle as GuiTabHandle from ._icons_enum import Icon as Icon from ._icons_enum import IconName as IconName from ._notification_handle import NotificationHandle as NotificationHandle from ._scene_api import SceneApi as SceneApi from ._scene_handles import BatchedAxesHandle as BatchedAxesHandle from ._scene_handles import CameraFrustumHandle as CameraFrustumHandle from ._scene_handles import FrameHandle as FrameHandle from ._scene_handles import GaussianSplatHandle as GaussianSplatHandle from ._scene_handles import GlbHandle as GlbHandle from ._scene_handles import Gui3dContainerHandle as Gui3dContainerHandle from ._scene_handles import ImageHandle as ImageHandle from ._scene_handles import LabelHandle as LabelHandle from ._scene_handles import MeshHandle as MeshHandle from ._scene_handles import MeshSkinnedBoneHandle as MeshSkinnedBoneHandle from ._scene_handles import MeshSkinnedHandle as MeshSkinnedHandle from ._scene_handles import PointCloudHandle as PointCloudHandle from ._scene_handles import SceneNodeHandle as SceneNodeHandle from ._scene_handles import SceneNodePointerEvent as SceneNodePointerEvent from ._scene_handles import ScenePointerEvent as ScenePointerEvent from ._scene_handles import TransformControlsHandle as TransformControlsHandle from ._viser import CameraHandle as CameraHandle from ._viser import ClientHandle as ClientHandle from ._viser import ViserServer as ViserServer ================================================ FILE: viser/src/viser/_client_autobuild.py ================================================ import os import subprocess import sys from pathlib import Path import rich client_dir = Path(__file__).absolute().parent / "client" build_dir = client_dir / "build" def _check_viser_yarn_running() -> bool: """Returns True if the viewer client has been launched via `yarn start`.""" import psutil for process in psutil.process_iter(): try: if Path(process.cwd()).as_posix().endswith("viser/client") and any( [part.endswith("yarn") for part in process.cmdline()] + [part.endswith("yarn.js") for part in process.cmdline()] ): return True except (psutil.AccessDenied, psutil.ZombieProcess): pass return False def ensure_client_is_built() -> None: """Ensure that the client is built or already running.""" if not (client_dir / "src").exists(): # Can't build client. assert (build_dir / "index.html").exists(), ( "Something went wrong! At least one of the client source or build" " directories should be present." ) return # Do we need to re-trigger a build? build = False if _check_viser_yarn_running(): # Don't run `yarn build` if `yarn start` is already running. rich.print( "[bold](viser)[/bold] The Viser viewer looks like it has been launched via" " `yarn start`. Skipping build check..." ) build = False elif not (build_dir / "index.html").exists(): rich.print("[bold](viser)[/bold] No client build found. Building now...") build = True elif ( # We should be at least 10 seconds newer than the last build. # This buffer is important when we install from pip, and the src/ + # build/ directories have very similar timestamps. _modified_time_recursive(client_dir / "src") > _modified_time_recursive(build_dir) + 10.0 ): rich.print( "[bold](viser)[/bold] Client build looks out of date. Building now..." ) build = True # Install nodejs and build if necessary. We assume bash is installed. if build: node_bin_dir = _install_sandboxed_node() npx_path = node_bin_dir / "npx" subprocess_env = os.environ.copy() subprocess_env["NODE_VIRTUAL_ENV"] = str(node_bin_dir.parent) subprocess_env["PATH"] = ( str(node_bin_dir) + (";" if sys.platform == "win32" else ":") + subprocess_env["PATH"] ) subprocess.run( args=f"{npx_path} --yes yarn install", env=subprocess_env, cwd=client_dir, shell=True, check=False, ) subprocess.run( args=f"{npx_path} --yes yarn run build", env=subprocess_env, cwd=client_dir, shell=True, check=False, ) def _install_sandboxed_node() -> Path: """Install a sandboxed copy of nodejs using nodeenv, and return a path to the environment's bin directory (`.nodeenv/bin` or `.nodeenv/Scripts`). On Windows, the `.nodeenv/bin` does not exist. Instead, executables are installed to `.nodeenv/Scripts`.""" def get_node_bin_dir() -> Path: env_dir = client_dir / ".nodeenv" node_bin_dir = env_dir / "bin" if not node_bin_dir.exists(): node_bin_dir = env_dir / "Scripts" return node_bin_dir node_bin_dir = get_node_bin_dir() if (node_bin_dir / "npx").exists(): rich.print("[bold](viser)[/bold] nodejs is set up!") return node_bin_dir env_dir = client_dir / ".nodeenv" subprocess.run( [sys.executable, "-m", "nodeenv", "--node=20.4.0", env_dir], check=False ) node_bin_dir = get_node_bin_dir() assert (node_bin_dir / "npx").exists() return node_bin_dir def _modified_time_recursive(dir: Path) -> float: """Recursively get the last time a file was modified in a directory.""" return max([f.stat().st_mtime for f in dir.glob("**/*")]) ================================================ FILE: viser/src/viser/_gui_api.py ================================================ from __future__ import annotations import builtins import colorsys import dataclasses import functools import threading import time from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import TYPE_CHECKING, Any, Sequence, Tuple, TypeVar, cast, overload import numpy as onp from typing_extensions import ( Literal, LiteralString, TypeAlias, TypedDict, get_type_hints, ) from viser import theme from . import _messages from ._gui_handles import ( GuiButtonGroupHandle, GuiButtonHandle, GuiContainerProtocol, GuiDropdownHandle, GuiEvent, GuiFolderHandle, GuiInputHandle, GuiMarkdownHandle, GuiModalHandle, GuiPlotlyHandle, GuiProgressBarHandle, GuiTabGroupHandle, GuiUploadButtonHandle, SupportsRemoveProtocol, UploadedFile, _GuiHandleState, _GuiInputHandle, _make_unique_id, ) from ._icons import svg_from_icon from ._icons_enum import IconName from ._messages import FileTransferPartAck from ._scene_api import cast_vector if TYPE_CHECKING: import plotly.graph_objects as go from ._viser import ClientHandle, ViserServer from .infra import ClientId IntOrFloat = TypeVar("IntOrFloat", int, float) TString = TypeVar("TString", bound=str) TLiteralString = TypeVar("TLiteralString", bound=LiteralString) T = TypeVar("T") LengthTenStrTuple: TypeAlias = Tuple[str, str, str, str, str, str, str, str, str, str] Color: TypeAlias = Literal[ "dark", "gray", "red", "pink", "grape", "violet", "indigo", "blue", "cyan", "green", "lime", "yellow", "orange", "teal", ] def _hex_from_hls(h: float, l: float, s: float) -> str: """Converts HLS values in [0.0, 1.0] to a hex-formatted string, eg 0xffffff.""" return "#" + "".join( [ int(min(255, max(0, channel * 255.0)) + 0.5).to_bytes(1, "little").hex() for channel in colorsys.hls_to_rgb(h, l, s) ] ) def _compute_step(x: float | None) -> float: # type: ignore """For number inputs: compute an increment size from some number. Example inputs/outputs: 100 => 1 12 => 1 12.1 => 0.1 12.02 => 0.01 0.004 => 0.001 """ return 1 if x is None else 10 ** (-_compute_precision_digits(x)) def _compute_precision_digits(x: float) -> int: """For number inputs: compute digits of precision from some number. Example inputs/outputs: 100 => 0 12 => 0 12.1 => 1 10.2 => 1 0.007 => 3 """ digits = 0 while x != round(x, ndigits=digits) and digits < 7: digits += 1 return digits @dataclasses.dataclass class _RootGuiContainer: _children: dict[str, SupportsRemoveProtocol] _global_order_counter = 0 def _apply_default_order(order: float | None) -> float: """Apply default ordering logic for GUI elements. If `order` is set to a float, this function is a no-op and returns it back. Otherwise, we increment and return the value of a global counter. """ if order is not None: return order global _global_order_counter _global_order_counter += 1 return _global_order_counter @functools.lru_cache(maxsize=None) def get_type_hints_cached(cls: type[Any]) -> dict[str, Any]: return get_type_hints(cls) # type: ignore class _FileUploadState(TypedDict): filename: str mime_type: str part_count: int parts: dict[int, bytes] total_bytes: int transferred_bytes: int lock: threading.Lock class GuiApi: """Interface for working with the 2D GUI in viser. Used by both our global server object, for sharing the same GUI elements with all clients, and by individual client handles.""" _target_container_from_thread_id: dict[int, str] = {} """ID of container to put GUI elements into.""" def __init__( self, owner: ViserServer | ClientHandle, # Who do I belong to? thread_executor: ThreadPoolExecutor, ) -> None: from ._viser import ViserServer self._owner = owner """Entity that owns this API.""" self._thread_executor = thread_executor self._websock_interface = ( owner._websock_server if isinstance(owner, ViserServer) else owner._websock_connection ) """Interface for sending and listening to messages.""" self._gui_input_handle_from_id: dict[str, _GuiInputHandle[Any]] = {} self._container_handle_from_id: dict[str, GuiContainerProtocol] = { "root": _RootGuiContainer({}) } self._current_file_upload_states: dict[str, _FileUploadState] = {} # Set to True when plotly.min.js has been sent to client. self._setup_plotly_js: bool = False self._websock_interface.register_handler( _messages.GuiUpdateMessage, self._handle_gui_updates ) self._websock_interface.register_handler( _messages.FileTransferStart, self._handle_file_transfer_start ) self._websock_interface.register_handler( _messages.FileTransferPart, self._handle_file_transfer_part, ) def _handle_gui_updates( self, client_id: ClientId, message: _messages.GuiUpdateMessage ) -> None: """Callback for handling GUI messages.""" handle = self._gui_input_handle_from_id.get(message.id, None) if handle is None: return handle_state = handle._impl has_changed = False updates_cast = {} for prop_name, prop_value in message.updates.items(): assert hasattr(handle_state, prop_name) current_value = getattr(handle_state, prop_name) # Do some type casting. This is brittle, but necessary (1) when we # expect floats but the Javascript side gives us integers or (2) # when we expect tuples but the Javascript side gives us lists. if prop_name == "value": if isinstance(handle_state.value, tuple): # We currently assume all tuple types have length >0, and # contents are all the same type. assert len(handle_state.value) > 0 typ = type(handle_state.value[0]) assert all([type(x) == typ for x in handle_state.value]) prop_value = tuple([typ(new) for new in prop_value]) else: prop_value = type(handle_state.value)(prop_value) # Update handle property. if current_value != prop_value: has_changed = True setattr(handle_state, prop_name, prop_value) # Save value, which might have been cast. updates_cast[prop_name] = prop_value # Only call update when value has actually changed. if not handle_state.is_button and not has_changed: return # GUI element has been updated! handle_state.update_timestamp = time.time() for cb in handle_state.update_cb: from ._viser import ClientHandle, ViserServer # Get the handle of the client that triggered this event. if isinstance(self._owner, ClientHandle): client = self._owner elif isinstance(self._owner, ViserServer): client = self._owner.get_clients()[client_id] else: assert False cb(GuiEvent(client, client_id, handle)) if handle_state.sync_cb is not None: handle_state.sync_cb(client_id, updates_cast) def _handle_file_transfer_start( self, client_id: ClientId, message: _messages.FileTransferStart ) -> None: if message.source_component_id not in self._gui_input_handle_from_id: return self._current_file_upload_states[message.transfer_uuid] = { "filename": message.filename, "mime_type": message.mime_type, "part_count": message.part_count, "parts": {}, "total_bytes": message.size_bytes, "transferred_bytes": 0, "lock": threading.Lock(), } def _handle_file_transfer_part( self, client_id: ClientId, message: _messages.FileTransferPart ) -> None: if message.transfer_uuid not in self._current_file_upload_states: return assert message.source_component_id in self._gui_input_handle_from_id state = self._current_file_upload_states[message.transfer_uuid] state["parts"][message.part] = message.content total_bytes = state["total_bytes"] with state["lock"]: state["transferred_bytes"] += len(message.content) # Send ack to the server. self._websock_interface.queue_message( FileTransferPartAck( source_component_id=message.source_component_id, transfer_uuid=message.transfer_uuid, transferred_bytes=state["transferred_bytes"], total_bytes=total_bytes, ) ) if state["transferred_bytes"] < total_bytes: return # Finish the upload. assert state["transferred_bytes"] == total_bytes state = self._current_file_upload_states.pop(message.transfer_uuid) handle = self._gui_input_handle_from_id.get(message.source_component_id, None) if handle is None: return handle_state = handle._impl value = UploadedFile( name=state["filename"], content=b"".join(state["parts"][i] for i in range(state["part_count"])), ) # Update state. with self._owner.atomic(): handle_state.value = value handle_state.update_timestamp = time.time() # Trigger callbacks. for cb in handle_state.update_cb: from ._viser import ClientHandle, ViserServer # Get the handle of the client that triggered this event. if isinstance(self._owner, ClientHandle): client = self._owner elif isinstance(self._owner, ViserServer): client = self._owner.get_clients()[client_id] else: assert False cb(GuiEvent(client, client_id, handle)) def _get_container_id(self) -> str: """Get container ID associated with the current thread.""" return self._target_container_from_thread_id.get(threading.get_ident(), "root") def _set_container_id(self, container_id: str) -> None: """Set container ID associated with the current thread.""" self._target_container_from_thread_id[threading.get_ident()] = container_id def reset(self) -> None: """Reset the GUI.""" self._websock_interface.queue_message(_messages.ResetGuiMessage()) def set_panel_label(self, label: str | None) -> None: """Set the main label that appears in the GUI panel. Args: label: The new label. """ self._websock_interface.queue_message(_messages.SetGuiPanelLabelMessage(label)) def configure_theme( self, *, titlebar_content: theme.TitlebarConfig | None = None, control_layout: Literal["floating", "collapsible", "fixed"] = "floating", control_width: Literal["small", "medium", "large"] = "medium", dark_mode: bool = False, show_logo: bool = True, show_share_button: bool = True, brand_color: tuple[int, int, int] | None = None, ) -> None: """Configures the visual appearance of the viser front-end. Args: titlebar_content: Optional configuration for the title bar. control_layout: The layout of control elements, options are "floating", "collapsible", or "fixed". control_width: The width of control elements, options are "small", "medium", or "large". dark_mode: A boolean indicating if dark mode should be enabled. show_logo: A boolean indicating if the logo should be displayed. show_share_button: A boolean indicating if the share button should be displayed. brand_color: An optional tuple of integers (RGB) representing the brand color. """ colors_cast: LengthTenStrTuple | None = None if brand_color is not None: assert len(brand_color) in (3, 10) if len(brand_color) == 3: assert all( map(lambda val: isinstance(val, int), brand_color) ), "All channels should be integers." # RGB => HLS. h, l, s = colorsys.rgb_to_hls( brand_color[0] / 255.0, brand_color[1] / 255.0, brand_color[2] / 255.0, ) # Automatically generate a 10-color palette. min_l = max(l - 0.08, 0.0) max_l = min(0.8 + 0.5, 0.9) l = max(min_l, min(max_l, l)) primary_index = 8 ls = tuple( onp.interp( x=onp.arange(10), xp=onp.array([0, primary_index, 9]), fp=onp.array([max_l, l, min_l]), ) ) colors_cast = cast( LengthTenStrTuple, tuple(_hex_from_hls(h, ls[i], s) for i in range(10)), ) assert colors_cast is None or all( [isinstance(val, str) and val.startswith("#") for val in colors_cast] ), "All string colors should be in hexadecimal + prefixed with #, eg #ffffff." self._websock_interface.queue_message( _messages.ThemeConfigurationMessage( titlebar_content=titlebar_content, control_layout=control_layout, control_width=control_width, dark_mode=dark_mode, show_logo=show_logo, show_share_button=show_share_button, colors=colors_cast, ), ) def add_folder( self, label: str, order: float | None = None, expand_by_default: bool = True, visible: bool = True, ) -> GuiFolderHandle: """Add a folder, and return a handle that can be used to populate it. Args: label: Label to display on the folder. order: Optional ordering, smallest values will be displayed first. expand_by_default: Open the folder by default. Set to False to collapse it by default. visible: Whether the component is visible. Returns: A handle that can be used as a context to populate the folder. """ folder_container_id = _make_unique_id() order = _apply_default_order(order) self._websock_interface.queue_message( _messages.GuiAddFolderMessage( order=order, id=folder_container_id, label=label, container_id=self._get_container_id(), expand_by_default=expand_by_default, visible=visible, ) ) return GuiFolderHandle( _gui_api=self, _id=folder_container_id, _parent_container_id=self._get_container_id(), _order=order, ) def add_modal( self, title: str, order: float | None = None, ) -> GuiModalHandle: """Show a modal window, which can be useful for popups and messages, then return a handle that can be used to populate it. Args: title: Title to display on the modal. order: Optional ordering, smallest values will be displayed first. Returns: A handle that can be used as a context to populate the modal. """ modal_container_id = _make_unique_id() order = _apply_default_order(order) self._websock_interface.queue_message( _messages.GuiModalMessage( order=order, id=modal_container_id, title=title, ) ) return GuiModalHandle( _gui_api=self, _id=modal_container_id, ) def add_tab_group( self, order: float | None = None, visible: bool = True, ) -> GuiTabGroupHandle: """Add a tab group. Args: order: Optional ordering, smallest values will be displayed first. visible: Whether the component is visible. Returns: A handle that can be used as a context to populate the tab group. """ tab_group_id = _make_unique_id() order = _apply_default_order(order) self._websock_interface.queue_message( _messages.GuiAddTabGroupMessage( order=order, id=tab_group_id, container_id=self._get_container_id(), tab_labels=(), visible=visible, tab_icons_html=(), tab_container_ids=(), ) ) return GuiTabGroupHandle( _tab_group_id=tab_group_id, _labels=[], _icons_html=[], _tabs=[], _gui_api=self, _parent_container_id=self._get_container_id(), _order=order, ) def add_markdown( self, content: str, image_root: Path | None = None, order: float | None = None, visible: bool = True, ) -> GuiMarkdownHandle: """Add markdown to the GUI. Args: content: Markdown content to display. image_root: Optional root directory to resolve relative image paths. order: Optional ordering, smallest values will be displayed first. visible: Whether the component is visible. Returns: A handle that can be used to interact with the GUI element. """ handle = GuiMarkdownHandle( _gui_api=self, _id=_make_unique_id(), _visible=visible, _parent_container_id=self._get_container_id(), _order=_apply_default_order(order), _image_root=image_root, _content=None, ) self._websock_interface.queue_message( _messages.GuiAddMarkdownMessage( order=handle._order, id=handle._id, markdown="", container_id=handle._parent_container_id, visible=visible, ) ) # Logic for processing markdown, handling images, etc is all in the # `.content` setter, which should send a GuiUpdateMessage. handle.content = content return handle def add_plotly( self, figure: go.Figure, aspect: float = 1.0, order: float | None = None, visible: bool = True, ) -> GuiPlotlyHandle: """Add a Plotly figure to the GUI. Requires the `plotly` package to be installed. Args: figure: Plotly figure to display. aspect: Aspect ratio of the plot in the control panel (width/height). order: Optional ordering, smallest values will be displayed first. visible: Whether the component is visible. Returns: A handle that can be used to interact with the GUI element. """ handle = GuiPlotlyHandle( _gui_api=self, _id=_make_unique_id(), _visible=visible, _parent_container_id=self._get_container_id(), _order=_apply_default_order(order), _figure=None, _aspect=None, ) # If plotly.min.js hasn't been sent to the client yet, the client won't be able # to render the plot. Send this large file now! (~3MB) if not self._setup_plotly_js: # Check if plotly is installed. try: import plotly except ImportError: raise ImportError( "You must have the `plotly` package installed to use the Plotly GUI element." ) # Check that plotly.min.js exists. plotly_path = ( Path(plotly.__file__).parent / "package_data" / "plotly.min.js" ) assert ( plotly_path.exists() ), f"Could not find plotly.min.js at {plotly_path}." # Send it over! plotly_js = plotly_path.read_text(encoding="utf-8") self._websock_interface.queue_message( _messages.RunJavascriptMessage(source=plotly_js) ) # Update the flag so we don't send it again. self._setup_plotly_js = True # After plotly.min.js has been sent, we can send the plotly figure. # Empty string for `plotly_json_str` is a signal to the client to render nothing. self._websock_interface.queue_message( _messages.GuiAddPlotlyMessage( order=handle._order, id=handle._id, plotly_json_str="", aspect=1.0, container_id=handle._parent_container_id, visible=visible, ) ) # Set the plotly handle properties. handle.figure = figure handle.aspect = aspect return handle def add_button( self, label: str, disabled: bool = False, visible: bool = True, hint: str | None = None, color: Color | None = None, icon: IconName | None = None, order: float | None = None, ) -> GuiButtonHandle: """Add a button to the GUI. The value of this input is set to `True` every time it is clicked; to detect clicks, we can manually set it back to `False`. Args: label: Label to display on the button. visible: Whether the button is visible. disabled: Whether the button is disabled. hint: Optional hint to display on hover. color: Optional color to use for the button. icon: Optional icon to display on the button. order: Optional ordering, smallest values will be displayed first. Returns: A handle that can be used to interact with the GUI element. """ # Re-wrap the GUI handle with a button interface. id = _make_unique_id() order = _apply_default_order(order) return GuiButtonHandle( self._create_gui_input( value=False, message=_messages.GuiAddButtonMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, value=False, color=color, icon_html=None if icon is None else svg_from_icon(icon), disabled=disabled, visible=visible, ), is_button=True, )._impl ) def add_upload_button( self, label: str, disabled: bool = False, visible: bool = True, hint: str | None = None, color: Color | None = None, icon: IconName | None = None, mime_type: str = "*/*", order: float | None = None, ) -> GuiUploadButtonHandle: """Add a button to the GUI. The value of this input is set to `True` every time it is clicked; to detect clicks, we can manually set it back to `False`. Args: label: Label to display on the button. visible: Whether the button is visible. disabled: Whether the button is disabled. hint: Optional hint to display on hover. color: Optional color to use for the button. icon: Optional icon to display on the button. mime_type: Optional MIME type to filter the files that can be uploaded. order: Optional ordering, smallest values will be displayed first. Returns: A handle that can be used to interact with the GUI element. """ # Re-wrap the GUI handle with a button interface. id = _make_unique_id() order = _apply_default_order(order) return GuiUploadButtonHandle( self._create_gui_input( value=UploadedFile("", b""), message=_messages.GuiAddUploadButtonMessage( value=None, disabled=disabled, visible=visible, order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, color=color, mime_type=mime_type, icon_html=None if icon is None else svg_from_icon(icon), ), is_button=True, )._impl ) # The TLiteralString overload tells pyright to resolve the value type to a Literal # whenever possible. # # TString is helpful when the input types are generic (could be str, could be # Literal). @overload def add_button_group( self, label: str, options: Sequence[TLiteralString], visible: bool = True, disabled: bool = False, hint: str | None = None, order: float | None = None, ) -> GuiButtonGroupHandle[TLiteralString]: ... @overload def add_button_group( self, label: str, options: Sequence[TString], visible: bool = True, disabled: bool = False, hint: str | None = None, order: float | None = None, ) -> GuiButtonGroupHandle[TString]: ... def add_button_group( self, label: str, options: Sequence[TLiteralString] | Sequence[TString], visible: bool = True, disabled: bool = False, hint: str | None = None, order: float | None = None, ) -> GuiButtonGroupHandle[Any]: # Return types are specified in overloads. """Add a button group to the GUI. Args: label: Label to display on the button group. options: Sequence of options to display as buttons. visible: Whether the button group is visible. disabled: Whether the button group is disabled. hint: Optional hint to display on hover. order: Optional ordering, smallest values will be displayed first. Returns: A handle that can be used to interact with the GUI element. """ value = options[0] id = _make_unique_id() order = _apply_default_order(order) return GuiButtonGroupHandle( self._create_gui_input( value, message=_messages.GuiAddButtonGroupMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, value=value, options=tuple(options), disabled=disabled, visible=visible, ), )._impl, ) def add_checkbox( self, label: str, initial_value: bool, disabled: bool = False, visible: bool = True, hint: str | None = None, order: float | None = None, ) -> GuiInputHandle[bool]: """Add a checkbox to the GUI. Args: label: Label to display on the checkbox. initial_value: Initial value of the checkbox. disabled: Whether the checkbox is disabled. visible: Whether the checkbox is visible. hint: Optional hint to display on hover. order: Optional ordering, smallest values will be displayed first. Returns: A handle that can be used to interact with the GUI element. """ value = initial_value assert isinstance(value, bool) id = _make_unique_id() order = _apply_default_order(order) return self._create_gui_input( value, message=_messages.GuiAddCheckboxMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, value=value, disabled=disabled, visible=visible, ), ) def add_text( self, label: str, initial_value: str, disabled: bool = False, visible: bool = True, hint: str | None = None, order: float | None = None, ) -> GuiInputHandle[str]: """Add a text input to the GUI. Args: label: Label to display on the text input. initial_value: Initial value of the text input. disabled: Whether the text input is disabled. visible: Whether the text input is visible. hint: Optional hint to display on hover. order: Optional ordering, smallest values will be displayed first. Returns: A handle that can be used to interact with the GUI element. """ value = initial_value assert isinstance(value, str) id = _make_unique_id() order = _apply_default_order(order) return self._create_gui_input( value, message=_messages.GuiAddTextMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, value=value, disabled=disabled, visible=visible, ), ) def add_number( self, label: str, initial_value: IntOrFloat, min: IntOrFloat | None = None, max: IntOrFloat | None = None, step: IntOrFloat | None = None, disabled: bool = False, visible: bool = True, hint: str | None = None, order: float | None = None, ) -> GuiInputHandle[IntOrFloat]: """Add a number input to the GUI, with user-specifiable bound and precision parameters. Args: label: Label to display on the number input. initial_value: Initial value of the number input. min: Optional minimum value of the number input. max: Optional maximum value of the number input. step: Optional step size of the number input. Computed automatically if not specified. disabled: Whether the number input is disabled. visible: Whether the number input is visible. hint: Optional hint to display on hover. order: Optional ordering, smallest values will be displayed first. Returns: A handle that can be used to interact with the GUI element. """ value = initial_value assert isinstance(value, (int, float)) if step is None: # It's ok that `step` is always a float, even if the value is an integer, # because things all become `number` types after serialization. step = float( # type: ignore onp.min( [ _compute_step(value), _compute_step(min), _compute_step(max), ] ) ) assert step is not None id = _make_unique_id() order = _apply_default_order(order) return self._create_gui_input( value, message=_messages.GuiAddNumberMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, value=value, min=min, max=max, precision=_compute_precision_digits(step), step=step, disabled=disabled, visible=visible, ), is_button=False, ) def add_vector2( self, label: str, initial_value: tuple[float, float] | onp.ndarray, min: tuple[float, float] | onp.ndarray | None = None, max: tuple[float, float] | onp.ndarray | None = None, step: float | None = None, disabled: bool = False, visible: bool = True, hint: str | None = None, order: float | None = None, ) -> GuiInputHandle[tuple[float, float]]: """Add a length-2 vector input to the GUI. Args: label: Label to display on the vector input. initial_value: Initial value of the vector input. min: Optional minimum value of the vector input. max: Optional maximum value of the vector input. step: Optional step size of the vector input. Computed automatically if not disabled: Whether the vector input is disabled. visible: Whether the vector input is visible. hint: Optional hint to display on hover. order: Optional ordering, smallest values will be displayed first. Returns: A handle that can be used to interact with the GUI element. """ value = initial_value value = cast_vector(value, 2) min = cast_vector(min, 2) if min is not None else None max = cast_vector(max, 2) if max is not None else None id = _make_unique_id() order = _apply_default_order(order) if step is None: possible_steps: list[float] = [] possible_steps.extend([_compute_step(x) for x in value]) if min is not None: possible_steps.extend([_compute_step(x) for x in min]) if max is not None: possible_steps.extend([_compute_step(x) for x in max]) step = float(onp.min(possible_steps)) return self._create_gui_input( value, message=_messages.GuiAddVector2Message( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, value=value, min=min, max=max, step=step, precision=_compute_precision_digits(step), disabled=disabled, visible=visible, ), ) def add_vector3( self, label: str, initial_value: tuple[float, float, float] | onp.ndarray, min: tuple[float, float, float] | onp.ndarray | None = None, max: tuple[float, float, float] | onp.ndarray | None = None, step: float | None = None, disabled: bool = False, visible: bool = True, hint: str | None = None, order: float | None = None, ) -> GuiInputHandle[tuple[float, float, float]]: """Add a length-3 vector input to the GUI. Args: label: Label to display on the vector input. initial_value: Initial value of the vector input. min: Optional minimum value of the vector input. max: Optional maximum value of the vector input. step: Optional step size of the vector input. Computed automatically if not disabled: Whether the vector input is disabled. visible: Whether the vector input is visible. hint: Optional hint to display on hover. order: Optional ordering, smallest values will be displayed first. Returns: A handle that can be used to interact with the GUI element. """ value = initial_value value = cast_vector(value, 3) min = cast_vector(min, 3) if min is not None else None max = cast_vector(max, 3) if max is not None else None id = _make_unique_id() order = _apply_default_order(order) if step is None: possible_steps: list[float] = [] possible_steps.extend([_compute_step(x) for x in value]) if min is not None: possible_steps.extend([_compute_step(x) for x in min]) if max is not None: possible_steps.extend([_compute_step(x) for x in max]) step = float(onp.min(possible_steps)) return self._create_gui_input( value, message=_messages.GuiAddVector3Message( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, value=value, min=min, max=max, step=step, precision=_compute_precision_digits(step), disabled=disabled, visible=visible, ), ) # See add_dropdown for notes on overloads. @overload def add_dropdown( self, label: str, options: Sequence[TLiteralString], initial_value: TLiteralString | None = None, disabled: bool = False, visible: bool = True, hint: str | None = None, order: float | None = None, ) -> GuiDropdownHandle[TLiteralString]: ... @overload def add_dropdown( self, label: str, options: Sequence[TString], initial_value: TString | None = None, disabled: bool = False, visible: bool = True, hint: str | None = None, order: float | None = None, ) -> GuiDropdownHandle[TString]: ... def add_dropdown( self, label: str, options: Sequence[TLiteralString] | Sequence[TString], initial_value: TLiteralString | TString | None = None, disabled: bool = False, visible: bool = True, hint: str | None = None, order: float | None = None, ) -> GuiDropdownHandle[Any]: # Output type is specified in overloads. """Add a dropdown to the GUI. Args: label: Label to display on the dropdown. options: Sequence of options to display in the dropdown. initial_value: Initial value of the dropdown. disabled: Whether the dropdown is disabled. visible: Whether the dropdown is visible. hint: Optional hint to display on hover. order: Optional ordering, smallest values will be displayed first. Returns: A handle that can be used to interact with the GUI element. """ value = initial_value if value is None: value = options[0] id = _make_unique_id() order = _apply_default_order(order) return GuiDropdownHandle( self._create_gui_input( value, message=_messages.GuiAddDropdownMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, value=value, options=tuple(options), disabled=disabled, visible=visible, ), )._impl, _impl_options=tuple(options), ) def add_progress_bar( self, value: float, visible: bool = True, animated: bool = False, color: Color | None = None, order: float | None = None, ) -> GuiProgressBarHandle: """Add a progress bar to the GUI. Args: value: Value of the progress bar. (0 - 100) visible: Whether the progress bar is visible. animated: Whether the progress bar is in a loading state (animated, striped). color: The color of the progress bar. order: Optional ordering, smallest values will be displayed first. Returns: A handle that can be used to interact with the GUI element. """ assert value >= 0 and value <= 100 handle = GuiProgressBarHandle( _gui_api=self, _id=_make_unique_id(), _visible=visible, _animated=animated, _parent_container_id=self._get_container_id(), _order=_apply_default_order(order), _value=value, ) self._websock_interface.queue_message( _messages.GuiAddProgressBarMessage( order=handle._order, id=handle._id, value=value, animated=animated, color=color, container_id=handle._parent_container_id, visible=visible, ) ) return handle def add_slider( self, label: str, min: IntOrFloat, max: IntOrFloat, step: IntOrFloat, initial_value: IntOrFloat, marks: tuple[IntOrFloat | tuple[IntOrFloat, str], ...] | None = None, disabled: bool = False, visible: bool = True, hint: str | None = None, order: float | None = None, ) -> GuiInputHandle[IntOrFloat]: """Add a slider to the GUI. Types of the min, max, step, and initial value should match. Args: label: Label to display on the slider. min: Minimum value of the slider. max: Maximum value of the slider. step: Step size of the slider. initial_value: Initial value of the slider. marks: tuple of marks to display below the slider. Each mark should either be a numerical or a (number, label) tuple, where the label is provided as a string. disabled: Whether the slider is disabled. visible: Whether the slider is visible. hint: Optional hint to display on hover. order: Optional ordering, smallest values will be displayed first. Returns: A handle that can be used to interact with the GUI element. """ value: IntOrFloat = initial_value assert max >= min step = builtins.min(step, max - min) assert max >= value >= min # GUI callbacks cast incoming values to match the type of the initial value. If # the min, max, or step is a float, we should cast to a float. # # This should also match what the IntOrFloat TypeVar resolves to. if type(value) is int and ( type(min) is float or type(max) is float or type(step) is float ): value = float(value) # type: ignore # TODO: as of 6/5/2023, this assert will break something in nerfstudio. (at # least LERF) # # assert type(min) == type(max) == type(step) == type(value) id = _make_unique_id() order = _apply_default_order(order) return self._create_gui_input( value, message=_messages.GuiAddSliderMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, min=min, max=max, step=step, value=value, precision=_compute_precision_digits(step), visible=visible, disabled=disabled, marks=tuple( {"value": float(x[0]), "label": x[1]} if isinstance(x, tuple) else {"value": float(x)} for x in marks ) if marks is not None else None, ), is_button=False, ) def add_multi_slider( self, label: str, min: IntOrFloat, max: IntOrFloat, step: IntOrFloat, initial_value: tuple[IntOrFloat, ...], min_range: IntOrFloat | None = None, fixed_endpoints: bool = False, marks: tuple[IntOrFloat | tuple[IntOrFloat, str], ...] | None = None, disabled: bool = False, visible: bool = True, hint: str | None = None, order: float | None = None, ) -> GuiInputHandle[tuple[IntOrFloat, ...]]: """Add a multi slider to the GUI. Types of the min, max, step, and initial value should match. Args: label: Label to display on the slider. min: Minimum value of the slider. max: Maximum value of the slider. step: Step size of the slider. initial_value: Initial values of the slider. min_range: Optional minimum difference between two values of the slider. fixed_endpoints: Whether the endpoints of the slider are fixed. marks: tuple of marks to display below the slider. Each mark should either be a numerical or a (number, label) tuple, where the label is provided as a string. disabled: Whether the slider is disabled. visible: Whether the slider is visible. hint: Optional hint to display on hover. order: Optional ordering, smallest values will be displayed first. Returns: A handle that can be used to interact with the GUI element. """ assert max >= min step = builtins.min(step, max - min) assert all(max >= x >= min for x in initial_value) # GUI callbacks cast incoming values to match the type of the initial value. If # any of the arguments are floats, we should always use a float value. # # This should also match what the IntOrFloat TypeVar resolves to. if ( type(min) is float or type(max) is float or type(step) is float or type(min_range) is float ): initial_value = tuple(float(x) for x in initial_value) # type: ignore id = _make_unique_id() order = _apply_default_order(order) return self._create_gui_input( value=initial_value, message=_messages.GuiAddMultiSliderMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, min=min, min_range=min_range, max=max, step=step, value=initial_value, visible=visible, disabled=disabled, fixed_endpoints=fixed_endpoints, precision=_compute_precision_digits(step), marks=tuple( {"value": float(x[0]), "label": x[1]} if isinstance(x, tuple) else {"value": float(x)} for x in marks ) if marks is not None else None, ), is_button=False, ) def add_rgb( self, label: str, initial_value: tuple[int, int, int], disabled: bool = False, visible: bool = True, hint: str | None = None, order: float | None = None, ) -> GuiInputHandle[tuple[int, int, int]]: """Add an RGB picker to the GUI. Args: label: Label to display on the RGB picker. initial_value: Initial value of the RGB picker. disabled: Whether the RGB picker is disabled. visible: Whether the RGB picker is visible. hint: Optional hint to display on hover. order: Optional ordering, smallest values will be displayed first. Returns: A handle that can be used to interact with the GUI element. """ value = initial_value id = _make_unique_id() order = _apply_default_order(order) return self._create_gui_input( value, message=_messages.GuiAddRgbMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, value=value, disabled=disabled, visible=visible, ), ) def add_rgba( self, label: str, initial_value: tuple[int, int, int, int], disabled: bool = False, visible: bool = True, hint: str | None = None, order: float | None = None, ) -> GuiInputHandle[tuple[int, int, int, int]]: """Add an RGBA picker to the GUI. Args: label: Label to display on the RGBA picker. initial_value: Initial value of the RGBA picker. disabled: Whether the RGBA picker is disabled. visible: Whether the RGBA picker is visible. hint: Optional hint to display on hover. order: Optional ordering, smallest values will be displayed first. Returns: A handle that can be used to interact with the GUI element. """ value = initial_value id = _make_unique_id() order = _apply_default_order(order) return self._create_gui_input( value, message=_messages.GuiAddRgbaMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, value=value, disabled=disabled, visible=visible, ), ) def _create_gui_input( self, value: T, message: _messages._GuiAddInputBase, is_button: bool = False, ) -> GuiInputHandle[T]: """Private helper for adding a simple GUI element.""" # Send add GUI input message. self._websock_interface.queue_message(message) # Construct handle. handle_state = _GuiHandleState( label=message.label, message_type=type(message), gui_api=self, value=value, update_timestamp=time.time(), parent_container_id=self._get_container_id(), update_cb=[], is_button=is_button, sync_cb=None, disabled=message.disabled, visible=message.visible, id=message.id, order=message.order, hint=message.hint, ) # For broadcasted GUI handles, we should synchronize all clients. # This will be a no-op for client handles. if not is_button: def sync_other_clients( client_id: ClientId, updates: dict[str, Any] ) -> None: message = _messages.GuiUpdateMessage(handle_state.id, updates) message.excluded_self_client = client_id self._websock_interface.queue_message(message) handle_state.sync_cb = sync_other_clients handle = GuiInputHandle(handle_state) return handle ================================================ FILE: viser/src/viser/_gui_handles.py ================================================ from __future__ import annotations import base64 import dataclasses import re import time import uuid import warnings from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, TypeVar import imageio.v3 as iio import numpy as onp from typing_extensions import Protocol from ._icons import svg_from_icon from ._icons_enum import IconName from ._messages import GuiCloseModalMessage, GuiRemoveMessage, GuiUpdateMessage, Message from ._scene_api import _encode_image_binary from .infra import ClientId if TYPE_CHECKING: import plotly.graph_objects as go from ._gui_api import GuiApi from ._viser import ClientHandle T = TypeVar("T") TGuiHandle = TypeVar("TGuiHandle", bound="_GuiInputHandle") def _make_unique_id() -> str: """Return a unique ID for referencing GUI elements.""" return str(uuid.uuid4()) class GuiContainerProtocol(Protocol): _children: dict[str, SupportsRemoveProtocol] = dataclasses.field( default_factory=dict ) class SupportsRemoveProtocol(Protocol): def remove(self) -> None: ... @dataclasses.dataclass class _GuiHandleState(Generic[T]): """Internal API for GUI elements.""" label: str gui_api: GuiApi value: T update_timestamp: float parent_container_id: str """Container that this GUI input was placed into.""" update_cb: list[Callable[[GuiEvent], None]] """Registered functions to call when this input is updated.""" is_button: bool """Indicates a button element, which requires special handling.""" sync_cb: Callable[[ClientId, dict[str, Any]], None] | None """Callback for synchronizing inputs across clients.""" disabled: bool visible: bool order: float id: str hint: str | None message_type: type[Message] @dataclasses.dataclass class _GuiInputHandle(Generic[T]): # Let's shove private implementation details in here... _impl: _GuiHandleState[T] # Should we use @property for get_value / set_value, set_hidden, etc? # # Benefits: # @property is syntactically very nice. # `gui.value = ...` is really tempting! # Feels a bit more magical. # # Downsides: # Consistency: not everything that can be written can be read, and not everything # that can be read can be written. `get_`/`set_` makes this really clear. # Clarity: some things that we read (like client mappings) are copied before # they're returned. An attribute access obfuscates the overhead here. # Flexibility: getter/setter types should match. https://github.com/python/mypy/issues/3004 # Feels a bit more magical. # # Is this worth the tradeoff? @property def order(self) -> float: """Read-only order value, which dictates the position of the GUI element.""" return self._impl.order @property def value(self) -> T: """Value of the GUI input. Synchronized automatically when assigned.""" return self._impl.value @value.setter def value(self, value: T | onp.ndarray) -> None: if isinstance(value, onp.ndarray): assert len(value.shape) <= 1, f"{value.shape} should be at most 1D!" value = tuple(map(float, value)) # type: ignore # Send to client, except for buttons. if not self._impl.is_button: self._impl.gui_api._websock_interface.queue_message( GuiUpdateMessage(self._impl.id, {"value": value}) ) # Set internal state. We automatically convert numpy arrays to the expected # internal type. (eg 1D arrays to tuples) self._impl.value = type(self._impl.value)(value) # type: ignore self._impl.update_timestamp = time.time() # Call update callbacks. for cb in self._impl.update_cb: # Pushing callbacks into separate threads helps prevent deadlocks when we # have a lock in a callback. TODO: revisit other callbacks. self._impl.gui_api._thread_executor.submit( lambda: cb( GuiEvent( client_id=None, client=None, target=self, ) ) ) @property def update_timestamp(self) -> float: """Read-only timestamp when this input was last updated.""" return self._impl.update_timestamp @property def disabled(self) -> bool: """Allow/disallow user interaction with the input. Synchronized automatically when assigned.""" return self._impl.disabled @disabled.setter def disabled(self, disabled: bool) -> None: if disabled == self.disabled: return self._impl.gui_api._websock_interface.queue_message( GuiUpdateMessage(self._impl.id, {"disabled": disabled}) ) self._impl.disabled = disabled @property def visible(self) -> bool: """Temporarily show or hide this GUI element from the visualizer. Synchronized automatically when assigned.""" return self._impl.visible @visible.setter def visible(self, visible: bool) -> None: if visible == self.visible: return self._impl.gui_api._websock_interface.queue_message( GuiUpdateMessage(self._impl.id, {"visible": visible}) ) self._impl.visible = visible def __post_init__(self) -> None: """We need to register ourself after construction for callbacks to work.""" gui_api = self._impl.gui_api # TODO: the current way we track GUI handles and children is very manual + # error-prone. We should revist this design. gui_api._gui_input_handle_from_id[self._impl.id] = self parent = gui_api._container_handle_from_id[self._impl.parent_container_id] parent._children[self._impl.id] = self def remove(self) -> None: """Permanently remove this GUI element from the visualizer.""" gui_api = self._impl.gui_api gui_api._websock_interface.queue_message(GuiRemoveMessage(self._impl.id)) gui_api._gui_input_handle_from_id.pop(self._impl.id) parent = gui_api._container_handle_from_id[self._impl.parent_container_id] parent._children.pop(self._impl.id) StringType = TypeVar("StringType", bound=str) # GuiInputHandle[T] is used for all inputs except for buttons. # # We inherit from _GuiInputHandle to special-case buttons because the usage semantics # are slightly different: we have `on_click()` instead of `on_update()`. @dataclasses.dataclass class GuiInputHandle(_GuiInputHandle[T], Generic[T]): """A handle is created for each GUI element that is added in `viser`. Handles can be used to read and write state. When a GUI element is added via :attr:`ViserServer.gui`, state is synchronized between all connected clients. When a GUI element is added via :attr:`ClientHandle.gui`, state is local to a specific client. """ def on_update( self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], None] ) -> Callable[[GuiEvent[TGuiHandle]], None]: """Attach a function to call when a GUI input is updated. Happens in a thread.""" self._impl.update_cb.append(func) return func @dataclasses.dataclass(frozen=True) class GuiEvent(Generic[TGuiHandle]): """Information associated with a GUI event, such as an update or click. Passed as input to callback functions.""" client: ClientHandle | None """Client that triggered this event.""" client_id: int | None """ID of client that triggered this event.""" target: TGuiHandle """GUI element that was affected.""" @dataclasses.dataclass class GuiButtonHandle(_GuiInputHandle[bool]): """Handle for a button input in our visualizer. Lets us detect clicks.""" def on_click( self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], None] ) -> Callable[[GuiEvent[TGuiHandle]], None]: """Attach a function to call when a button is pressed. Happens in a thread.""" self._impl.update_cb.append(func) return func @dataclasses.dataclass class UploadedFile: """Result of a file upload.""" name: str """Name of the file.""" content: bytes """Contents of the file.""" @dataclasses.dataclass class GuiUploadButtonHandle(_GuiInputHandle[UploadedFile]): """Handle for an upload file button in our visualizer. The `.value` attribute will be updated with the contents of uploaded files. """ def on_upload( self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], None] ) -> Callable[[GuiEvent[TGuiHandle]], None]: """Attach a function to call when a button is pressed. Happens in a thread.""" self._impl.update_cb.append(func) return func @dataclasses.dataclass class GuiButtonGroupHandle(_GuiInputHandle[StringType], Generic[StringType]): """Handle for a button group input in our visualizer. Lets us detect clicks.""" def on_click( self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], None] ) -> Callable[[GuiEvent[TGuiHandle]], None]: """Attach a function to call when a button is pressed. Happens in a thread.""" self._impl.update_cb.append(func) return func @property def disabled(self) -> bool: """Button groups cannot be disabled.""" return False @disabled.setter def disabled(self, disabled: bool) -> None: """Button groups cannot be disabled.""" assert not disabled, "Button groups cannot be disabled." @dataclasses.dataclass class GuiDropdownHandle(GuiInputHandle[StringType], Generic[StringType]): """Handle for a dropdown-style GUI input in our visualizer. Lets us get values, set values, and detect updates.""" _impl_options: tuple[StringType, ...] @property def options(self) -> tuple[StringType, ...]: """Options for our dropdown. Synchronized automatically when assigned. For projects that care about typing: the static type of `options` should be consistent with the `StringType` associated with a handle. Literal types will be inferred where possible when handles are instantiated; for the most flexibility, we can declare handles as `GuiDropdownHandle[str]`. """ return self._impl_options @options.setter def options(self, options: Iterable[StringType]) -> None: self._impl_options = tuple(options) need_to_overwrite_value = self.value not in self._impl_options if need_to_overwrite_value: self._impl.gui_api._websock_interface.queue_message( GuiUpdateMessage( self._impl.id, {"options": self._impl_options, "value": self._impl_options[0]}, ) ) self._impl.value = self._impl_options[0] else: self._impl.gui_api._websock_interface.queue_message( GuiUpdateMessage( self._impl.id, {"options": self._impl_options}, ) ) @dataclasses.dataclass(frozen=True) class GuiTabGroupHandle: """Handle for a tab group. Call :meth:`add_tab()` to add a tab.""" _tab_group_id: str _labels: list[str] _icons_html: list[str | None] _tabs: list[GuiTabHandle] _gui_api: GuiApi _parent_container_id: str _order: float @property def order(self) -> float: """Read-only order value, which dictates the position of the GUI element.""" return self._order def add_tab(self, label: str, icon: IconName | None = None) -> GuiTabHandle: """Add a tab. Returns a handle we can use to add GUI elements to it.""" id = _make_unique_id() # We may want to make this thread-safe in the future. out = GuiTabHandle(_parent=self, _id=id) self._labels.append(label) self._icons_html.append(None if icon is None else svg_from_icon(icon)) self._tabs.append(out) self._sync_with_client() return out def __post_init__(self) -> None: parent = self._gui_api._container_handle_from_id[self._parent_container_id] parent._children[self._tab_group_id] = self def remove(self) -> None: """Remove this tab group and all contained GUI elements.""" for tab in tuple(self._tabs): tab.remove() gui_api = self._gui_api gui_api._websock_interface.queue_message(GuiRemoveMessage(self._tab_group_id)) parent = gui_api._container_handle_from_id[self._parent_container_id] parent._children.pop(self._tab_group_id) def _sync_with_client(self) -> None: """Send messages for syncing tab state with the client.""" self._gui_api._websock_interface.queue_message( GuiUpdateMessage( self._tab_group_id, { "tab_labels": tuple(self._labels), "tab_icons_html": tuple(self._icons_html), "tab_container_ids": tuple(tab._id for tab in self._tabs), }, ) ) @dataclasses.dataclass class GuiFolderHandle: """Use as a context to place GUI elements into a folder.""" _gui_api: GuiApi _id: str # Used as container ID for children. _order: float _parent_container_id: str # Container ID of parent. _container_id_restore: str | None = None _children: dict[str, SupportsRemoveProtocol] = dataclasses.field( default_factory=dict ) @property def order(self) -> float: """Read-only order value, which dictates the position of the GUI element.""" return self._order def __enter__(self) -> GuiFolderHandle: self._container_id_restore = self._gui_api._get_container_id() self._gui_api._set_container_id(self._id) return self def __exit__(self, *args) -> None: del args assert self._container_id_restore is not None self._gui_api._set_container_id(self._container_id_restore) self._container_id_restore = None def __post_init__(self) -> None: self._gui_api._container_handle_from_id[self._id] = self parent = self._gui_api._container_handle_from_id[self._parent_container_id] parent._children[self._id] = self def remove(self) -> None: """Permanently remove this folder and all contained GUI elements from the visualizer.""" self._gui_api._websock_interface.queue_message(GuiRemoveMessage(self._id)) for child in tuple(self._children.values()): child.remove() parent = self._gui_api._container_handle_from_id[self._parent_container_id] parent._children.pop(self._id) self._gui_api._container_handle_from_id.pop(self._id) @dataclasses.dataclass class GuiModalHandle: """Use as a context to place GUI elements into a modal.""" _gui_api: GuiApi _id: str # Used as container ID of children. _container_id_restore: str | None = None _children: dict[str, SupportsRemoveProtocol] = dataclasses.field( default_factory=dict ) def __enter__(self) -> GuiModalHandle: self._container_id_restore = self._gui_api._get_container_id() self._gui_api._set_container_id(self._id) return self def __exit__(self, *args) -> None: del args assert self._container_id_restore is not None self._gui_api._set_container_id(self._container_id_restore) self._container_id_restore = None def __post_init__(self) -> None: self._gui_api._container_handle_from_id[self._id] = self def close(self) -> None: """Close this modal and permananently remove all contained GUI elements.""" self._gui_api._websock_interface.queue_message( GuiCloseModalMessage(self._id), ) for child in tuple(self._children.values()): child.remove() self._gui_api._container_handle_from_id.pop(self._id) @dataclasses.dataclass class GuiTabHandle: """Use as a context to place GUI elements into a tab.""" _parent: GuiTabGroupHandle _id: str # Used as container ID of children. _container_id_restore: str | None = None _children: dict[str, SupportsRemoveProtocol] = dataclasses.field( default_factory=dict ) def __enter__(self) -> GuiTabHandle: self._container_id_restore = self._parent._gui_api._get_container_id() self._parent._gui_api._set_container_id(self._id) return self def __exit__(self, *args) -> None: del args assert self._container_id_restore is not None self._parent._gui_api._set_container_id(self._container_id_restore) self._container_id_restore = None def __post_init__(self) -> None: self._parent._gui_api._container_handle_from_id[self._id] = self def remove(self) -> None: """Permanently remove this tab and all contained GUI elements from the visualizer.""" # We may want to make this thread-safe in the future. container_index = -1 for i, tab in enumerate(self._parent._tabs): if tab is self: container_index = i break assert container_index != -1, "Tab already removed!" self._parent._labels.pop(container_index) self._parent._icons_html.pop(container_index) self._parent._tabs.pop(container_index) self._parent._sync_with_client() for child in tuple(self._children.values()): child.remove() self._parent._gui_api._container_handle_from_id.pop(self._id) def _get_data_url(url: str, image_root: Path | None) -> str: if not url.startswith("http") and not image_root: warnings.warn( ( "No `image_root` provided. All relative paths will be scoped to viser's" " installation path." ), stacklevel=2, ) if url.startswith("http") or url.startswith("data:"): return url if image_root is None: image_root = Path(__file__).parent try: image = iio.imread(image_root / url) media_type, binary = _encode_image_binary(image, "png") url = base64.b64encode(binary).decode("utf-8") return f"data:{media_type};base64,{url}" except (IOError, FileNotFoundError): warnings.warn( f"Failed to read image {url}, with image_root set to {image_root}.", stacklevel=2, ) return url def _parse_markdown(markdown: str, image_root: Path | None) -> str: markdown = re.sub( r"\!\[([^]]*)\]\(([^]]*)\)", lambda match: ( f"![{match.group(1)}]({_get_data_url(match.group(2), image_root)})" ), markdown, ) return markdown @dataclasses.dataclass class GuiProgressBarHandle: """Use to remove markdown.""" _gui_api: GuiApi _id: str _visible: bool _animated: bool _parent_container_id: str _order: float _value: float @property def value(self) -> float: """Current content of this progress bar element, 0 - 100. Synchronized automatically when assigned.""" return self._value @value.setter def value(self, value: float) -> None: assert value >= 0 and value <= 100 self._value = value self._gui_api._websock_interface.queue_message( GuiUpdateMessage( self._id, {"value": value}, ) ) @property def animated(self) -> bool: """Show this progress bar as loading (animated, striped).""" return self._animated @animated.setter def animated(self, animated: bool) -> None: self._animated = animated self._gui_api._websock_interface.queue_message( GuiUpdateMessage( self._id, {"animated": animated}, ) ) @property def order(self) -> float: """Read-only order value, which dictates the position of the GUI element.""" return self._order @property def visible(self) -> bool: """Temporarily show or hide this GUI element from the visualizer. Synchronized automatically when assigned.""" return self._visible @visible.setter def visible(self, visible: bool) -> None: if visible == self.visible: return self._gui_api._websock_interface.queue_message( GuiUpdateMessage(self._id, {"visible": visible}) ) self._visible = visible def __post_init__(self) -> None: """We need to register ourself after construction for callbacks to work.""" parent = self._gui_api._container_handle_from_id[self._parent_container_id] parent._children[self._id] = self def remove(self) -> None: """Permanently remove this progress bar from the visualizer.""" self._gui_api._websock_interface.queue_message(GuiRemoveMessage(self._id)) parent = self._gui_api._container_handle_from_id[self._parent_container_id] parent._children.pop(self._id) @dataclasses.dataclass class GuiMarkdownHandle: """Use to remove markdown.""" _gui_api: GuiApi _id: str _visible: bool _parent_container_id: str _order: float _image_root: Path | None _content: str | None @property def content(self) -> str: """Current content of this markdown element. Synchronized automatically when assigned.""" assert self._content is not None return self._content @content.setter def content(self, content: str) -> None: self._content = content self._gui_api._websock_interface.queue_message( GuiUpdateMessage( self._id, {"markdown": _parse_markdown(content, self._image_root)}, ) ) @property def order(self) -> float: """Read-only order value, which dictates the position of the GUI element.""" return self._order @property def visible(self) -> bool: """Temporarily show or hide this GUI element from the visualizer. Synchronized automatically when assigned.""" return self._visible @visible.setter def visible(self, visible: bool) -> None: if visible == self.visible: return self._gui_api._websock_interface.queue_message( GuiUpdateMessage(self._id, {"visible": visible}) ) self._visible = visible def __post_init__(self) -> None: """We need to register ourself after construction for callbacks to work.""" parent = self._gui_api._container_handle_from_id[self._parent_container_id] parent._children[self._id] = self def remove(self) -> None: """Permanently remove this markdown from the visualizer.""" self._gui_api._websock_interface.queue_message(GuiRemoveMessage(self._id)) parent = self._gui_api._container_handle_from_id[self._parent_container_id] parent._children.pop(self._id) @dataclasses.dataclass class GuiPlotlyHandle: """Use to update or remove markdown elements.""" _gui_api: GuiApi _id: str _visible: bool _parent_container_id: str _order: float _figure: go.Figure | None _aspect: float | None @property def figure(self) -> go.Figure: """Current content of this markdown element. Synchronized automatically when assigned.""" assert self._figure is not None return self._figure @figure.setter def figure(self, figure: go.Figure) -> None: self._figure = figure json_str = figure.to_json() assert isinstance(json_str, str) self._gui_api._websock_interface.queue_message( GuiUpdateMessage( self._id, {"plotly_json_str": json_str}, ) ) @property def aspect(self) -> float: """Aspect ratio of the plotly figure, in the control panel.""" assert self._aspect is not None return self._aspect @aspect.setter def aspect(self, aspect: float) -> None: self._aspect = aspect self._gui_api._websock_interface.queue_message( GuiUpdateMessage( self._id, {"aspect": aspect}, ) ) @property def order(self) -> float: """Read-only order value, which dictates the position of the GUI element.""" return self._order @property def visible(self) -> bool: """Temporarily show or hide this GUI element from the visualizer. Synchronized automatically when assigned.""" return self._visible @visible.setter def visible(self, visible: bool) -> None: if visible == self.visible: return self._gui_api._websock_interface.queue_message( GuiUpdateMessage(self._id, {"visible": visible}) ) self._visible = visible def __post_init__(self) -> None: """We need to register ourself after construction for callbacks to work.""" parent = self._gui_api._container_handle_from_id[self._parent_container_id] parent._children[self._id] = self def remove(self) -> None: """Permanently remove this figure from the visualizer.""" self._gui_api._websock_interface.queue_message(GuiRemoveMessage(self._id)) parent = self._gui_api._container_handle_from_id[self._parent_container_id] parent._children.pop(self._id) ================================================ FILE: viser/src/viser/_icons.py ================================================ import tarfile from pathlib import Path from ._icons_enum import IconName ICONS_DIR = Path(__file__).absolute().parent / "_icons" def svg_from_icon(icon_name: IconName) -> str: """Read an icon and return it as a UTF string; we expect this to be an tag.""" assert isinstance(icon_name, str) icons_tarball = ICONS_DIR / "tabler-icons.tar" with tarfile.open(icons_tarball) as tar: icon_file = tar.extractfile(f"{icon_name}.svg") assert icon_file is not None out = icon_file.read() return out.decode("utf-8") ================================================ FILE: viser/src/viser/_icons_enum.py ================================================ # Automatically generated by `_icons_generate_enum.py` # See https://tabler-icons.io/ from typing import NewType IconName = NewType("IconName", str) """Name of an icon. Should be generated via `viser.Icon.*`.""" class _IconStringConverter(type): def __getattr__(self, __name: str) -> IconName: if not __name.startswith("_"): return IconName(__name.lower().replace("_", "-")) else: raise AttributeError() class Icon(metaclass=_IconStringConverter): """'Enum' class for referencing Tabler icons. We don't subclass enum.Enum for performance reasons -- importing an enum with thousands of names can result in import times in the hundreds of milliseconds. Attributes: ICON_123 (IconName): The :code:`123` icon. ICON_24_HOURS (IconName): The :code:`24-hours` icon. ICON_2FA (IconName): The :code:`2fa` icon. ICON_360 (IconName): The :code:`360` icon. ICON_360_VIEW (IconName): The :code:`360-view` icon. ICON_3D_CUBE_SPHERE (IconName): The :code:`3d-cube-sphere` icon. ICON_3D_CUBE_SPHERE_OFF (IconName): The :code:`3d-cube-sphere-off` icon. ICON_3D_ROTATE (IconName): The :code:`3d-rotate` icon. A_B (IconName): The :code:`a-b` icon. A_B_2 (IconName): The :code:`a-b-2` icon. A_B_OFF (IconName): The :code:`a-b-off` icon. ABACUS (IconName): The :code:`abacus` icon. ABACUS_OFF (IconName): The :code:`abacus-off` icon. ABC (IconName): The :code:`abc` icon. ACCESS_POINT (IconName): The :code:`access-point` icon. ACCESS_POINT_OFF (IconName): The :code:`access-point-off` icon. ACCESSIBLE (IconName): The :code:`accessible` icon. ACCESSIBLE_OFF (IconName): The :code:`accessible-off` icon. ACCESSIBLE_OFF_FILLED (IconName): The :code:`accessible-off-filled` icon. ACTIVITY (IconName): The :code:`activity` icon. ACTIVITY_HEARTBEAT (IconName): The :code:`activity-heartbeat` icon. AD (IconName): The :code:`ad` icon. AD_2 (IconName): The :code:`ad-2` icon. AD_CIRCLE (IconName): The :code:`ad-circle` icon. AD_CIRCLE_FILLED (IconName): The :code:`ad-circle-filled` icon. AD_CIRCLE_OFF (IconName): The :code:`ad-circle-off` icon. AD_FILLED (IconName): The :code:`ad-filled` icon. AD_OFF (IconName): The :code:`ad-off` icon. ADDRESS_BOOK (IconName): The :code:`address-book` icon. ADDRESS_BOOK_OFF (IconName): The :code:`address-book-off` icon. ADJUSTMENTS (IconName): The :code:`adjustments` icon. ADJUSTMENTS_ALT (IconName): The :code:`adjustments-alt` icon. ADJUSTMENTS_BOLT (IconName): The :code:`adjustments-bolt` icon. ADJUSTMENTS_CANCEL (IconName): The :code:`adjustments-cancel` icon. ADJUSTMENTS_CHECK (IconName): The :code:`adjustments-check` icon. ADJUSTMENTS_CODE (IconName): The :code:`adjustments-code` icon. ADJUSTMENTS_COG (IconName): The :code:`adjustments-cog` icon. ADJUSTMENTS_DOLLAR (IconName): The :code:`adjustments-dollar` icon. ADJUSTMENTS_DOWN (IconName): The :code:`adjustments-down` icon. ADJUSTMENTS_EXCLAMATION (IconName): The :code:`adjustments-exclamation` icon. ADJUSTMENTS_FILLED (IconName): The :code:`adjustments-filled` icon. ADJUSTMENTS_HEART (IconName): The :code:`adjustments-heart` icon. ADJUSTMENTS_HORIZONTAL (IconName): The :code:`adjustments-horizontal` icon. ADJUSTMENTS_MINUS (IconName): The :code:`adjustments-minus` icon. ADJUSTMENTS_OFF (IconName): The :code:`adjustments-off` icon. ADJUSTMENTS_PAUSE (IconName): The :code:`adjustments-pause` icon. ADJUSTMENTS_PIN (IconName): The :code:`adjustments-pin` icon. ADJUSTMENTS_PLUS (IconName): The :code:`adjustments-plus` icon. ADJUSTMENTS_QUESTION (IconName): The :code:`adjustments-question` icon. ADJUSTMENTS_SEARCH (IconName): The :code:`adjustments-search` icon. ADJUSTMENTS_SHARE (IconName): The :code:`adjustments-share` icon. ADJUSTMENTS_STAR (IconName): The :code:`adjustments-star` icon. ADJUSTMENTS_UP (IconName): The :code:`adjustments-up` icon. ADJUSTMENTS_X (IconName): The :code:`adjustments-x` icon. AERIAL_LIFT (IconName): The :code:`aerial-lift` icon. AFFILIATE (IconName): The :code:`affiliate` icon. AFFILIATE_FILLED (IconName): The :code:`affiliate-filled` icon. AIR_BALLOON (IconName): The :code:`air-balloon` icon. AIR_CONDITIONING (IconName): The :code:`air-conditioning` icon. AIR_CONDITIONING_DISABLED (IconName): The :code:`air-conditioning-disabled` icon. ALARM (IconName): The :code:`alarm` icon. ALARM_FILLED (IconName): The :code:`alarm-filled` icon. ALARM_MINUS (IconName): The :code:`alarm-minus` icon. ALARM_MINUS_FILLED (IconName): The :code:`alarm-minus-filled` icon. ALARM_OFF (IconName): The :code:`alarm-off` icon. ALARM_PLUS (IconName): The :code:`alarm-plus` icon. ALARM_PLUS_FILLED (IconName): The :code:`alarm-plus-filled` icon. ALARM_SNOOZE (IconName): The :code:`alarm-snooze` icon. ALARM_SNOOZE_FILLED (IconName): The :code:`alarm-snooze-filled` icon. ALBUM (IconName): The :code:`album` icon. ALBUM_OFF (IconName): The :code:`album-off` icon. ALERT_CIRCLE (IconName): The :code:`alert-circle` icon. ALERT_CIRCLE_FILLED (IconName): The :code:`alert-circle-filled` icon. ALERT_HEXAGON (IconName): The :code:`alert-hexagon` icon. ALERT_HEXAGON_FILLED (IconName): The :code:`alert-hexagon-filled` icon. ALERT_OCTAGON (IconName): The :code:`alert-octagon` icon. ALERT_OCTAGON_FILLED (IconName): The :code:`alert-octagon-filled` icon. ALERT_SMALL (IconName): The :code:`alert-small` icon. ALERT_SQUARE (IconName): The :code:`alert-square` icon. ALERT_SQUARE_FILLED (IconName): The :code:`alert-square-filled` icon. ALERT_SQUARE_ROUNDED (IconName): The :code:`alert-square-rounded` icon. ALERT_SQUARE_ROUNDED_FILLED (IconName): The :code:`alert-square-rounded-filled` icon. ALERT_TRIANGLE (IconName): The :code:`alert-triangle` icon. ALERT_TRIANGLE_FILLED (IconName): The :code:`alert-triangle-filled` icon. ALIEN (IconName): The :code:`alien` icon. ALIEN_FILLED (IconName): The :code:`alien-filled` icon. ALIGN_BOX_BOTTOM_CENTER (IconName): The :code:`align-box-bottom-center` icon. ALIGN_BOX_BOTTOM_CENTER_FILLED (IconName): The :code:`align-box-bottom-center-filled` icon. ALIGN_BOX_BOTTOM_LEFT (IconName): The :code:`align-box-bottom-left` icon. ALIGN_BOX_BOTTOM_LEFT_FILLED (IconName): The :code:`align-box-bottom-left-filled` icon. ALIGN_BOX_BOTTOM_RIGHT (IconName): The :code:`align-box-bottom-right` icon. ALIGN_BOX_BOTTOM_RIGHT_FILLED (IconName): The :code:`align-box-bottom-right-filled` icon. ALIGN_BOX_CENTER_BOTTOM (IconName): The :code:`align-box-center-bottom` icon. ALIGN_BOX_CENTER_MIDDLE (IconName): The :code:`align-box-center-middle` icon. ALIGN_BOX_CENTER_MIDDLE_FILLED (IconName): The :code:`align-box-center-middle-filled` icon. ALIGN_BOX_CENTER_STRETCH (IconName): The :code:`align-box-center-stretch` icon. ALIGN_BOX_CENTER_TOP (IconName): The :code:`align-box-center-top` icon. ALIGN_BOX_LEFT_BOTTOM (IconName): The :code:`align-box-left-bottom` icon. ALIGN_BOX_LEFT_BOTTOM_FILLED (IconName): The :code:`align-box-left-bottom-filled` icon. ALIGN_BOX_LEFT_MIDDLE (IconName): The :code:`align-box-left-middle` icon. ALIGN_BOX_LEFT_MIDDLE_FILLED (IconName): The :code:`align-box-left-middle-filled` icon. ALIGN_BOX_LEFT_STRETCH (IconName): The :code:`align-box-left-stretch` icon. ALIGN_BOX_LEFT_TOP (IconName): The :code:`align-box-left-top` icon. ALIGN_BOX_LEFT_TOP_FILLED (IconName): The :code:`align-box-left-top-filled` icon. ALIGN_BOX_RIGHT_BOTTOM (IconName): The :code:`align-box-right-bottom` icon. ALIGN_BOX_RIGHT_BOTTOM_FILLED (IconName): The :code:`align-box-right-bottom-filled` icon. ALIGN_BOX_RIGHT_MIDDLE (IconName): The :code:`align-box-right-middle` icon. ALIGN_BOX_RIGHT_MIDDLE_FILLED (IconName): The :code:`align-box-right-middle-filled` icon. ALIGN_BOX_RIGHT_STRETCH (IconName): The :code:`align-box-right-stretch` icon. ALIGN_BOX_RIGHT_TOP (IconName): The :code:`align-box-right-top` icon. ALIGN_BOX_RIGHT_TOP_FILLED (IconName): The :code:`align-box-right-top-filled` icon. ALIGN_BOX_TOP_CENTER (IconName): The :code:`align-box-top-center` icon. ALIGN_BOX_TOP_CENTER_FILLED (IconName): The :code:`align-box-top-center-filled` icon. ALIGN_BOX_TOP_LEFT (IconName): The :code:`align-box-top-left` icon. ALIGN_BOX_TOP_LEFT_FILLED (IconName): The :code:`align-box-top-left-filled` icon. ALIGN_BOX_TOP_RIGHT (IconName): The :code:`align-box-top-right` icon. ALIGN_BOX_TOP_RIGHT_FILLED (IconName): The :code:`align-box-top-right-filled` icon. ALIGN_CENTER (IconName): The :code:`align-center` icon. ALIGN_JUSTIFIED (IconName): The :code:`align-justified` icon. ALIGN_LEFT (IconName): The :code:`align-left` icon. ALIGN_RIGHT (IconName): The :code:`align-right` icon. ALPHA (IconName): The :code:`alpha` icon. ALPHABET_CYRILLIC (IconName): The :code:`alphabet-cyrillic` icon. ALPHABET_GREEK (IconName): The :code:`alphabet-greek` icon. ALPHABET_LATIN (IconName): The :code:`alphabet-latin` icon. AMBULANCE (IconName): The :code:`ambulance` icon. AMPERSAND (IconName): The :code:`ampersand` icon. ANALYZE (IconName): The :code:`analyze` icon. ANALYZE_FILLED (IconName): The :code:`analyze-filled` icon. ANALYZE_OFF (IconName): The :code:`analyze-off` icon. ANCHOR (IconName): The :code:`anchor` icon. ANCHOR_OFF (IconName): The :code:`anchor-off` icon. ANGLE (IconName): The :code:`angle` icon. ANKH (IconName): The :code:`ankh` icon. ANTENNA (IconName): The :code:`antenna` icon. ANTENNA_BARS_1 (IconName): The :code:`antenna-bars-1` icon. ANTENNA_BARS_2 (IconName): The :code:`antenna-bars-2` icon. ANTENNA_BARS_3 (IconName): The :code:`antenna-bars-3` icon. ANTENNA_BARS_4 (IconName): The :code:`antenna-bars-4` icon. ANTENNA_BARS_5 (IconName): The :code:`antenna-bars-5` icon. ANTENNA_BARS_OFF (IconName): The :code:`antenna-bars-off` icon. ANTENNA_OFF (IconName): The :code:`antenna-off` icon. APERTURE (IconName): The :code:`aperture` icon. APERTURE_OFF (IconName): The :code:`aperture-off` icon. API (IconName): The :code:`api` icon. API_APP (IconName): The :code:`api-app` icon. API_APP_OFF (IconName): The :code:`api-app-off` icon. API_OFF (IconName): The :code:`api-off` icon. APP_WINDOW (IconName): The :code:`app-window` icon. APP_WINDOW_FILLED (IconName): The :code:`app-window-filled` icon. APPLE (IconName): The :code:`apple` icon. APPS (IconName): The :code:`apps` icon. APPS_FILLED (IconName): The :code:`apps-filled` icon. APPS_OFF (IconName): The :code:`apps-off` icon. ARCHIVE (IconName): The :code:`archive` icon. ARCHIVE_FILLED (IconName): The :code:`archive-filled` icon. ARCHIVE_OFF (IconName): The :code:`archive-off` icon. ARMCHAIR (IconName): The :code:`armchair` icon. ARMCHAIR_2 (IconName): The :code:`armchair-2` icon. ARMCHAIR_2_OFF (IconName): The :code:`armchair-2-off` icon. ARMCHAIR_OFF (IconName): The :code:`armchair-off` icon. ARROW_AUTOFIT_CONTENT (IconName): The :code:`arrow-autofit-content` icon. ARROW_AUTOFIT_CONTENT_FILLED (IconName): The :code:`arrow-autofit-content-filled` icon. ARROW_AUTOFIT_DOWN (IconName): The :code:`arrow-autofit-down` icon. ARROW_AUTOFIT_HEIGHT (IconName): The :code:`arrow-autofit-height` icon. ARROW_AUTOFIT_LEFT (IconName): The :code:`arrow-autofit-left` icon. ARROW_AUTOFIT_RIGHT (IconName): The :code:`arrow-autofit-right` icon. ARROW_AUTOFIT_UP (IconName): The :code:`arrow-autofit-up` icon. ARROW_AUTOFIT_WIDTH (IconName): The :code:`arrow-autofit-width` icon. ARROW_BACK (IconName): The :code:`arrow-back` icon. ARROW_BACK_UP (IconName): The :code:`arrow-back-up` icon. ARROW_BACK_UP_DOUBLE (IconName): The :code:`arrow-back-up-double` icon. ARROW_BADGE_DOWN (IconName): The :code:`arrow-badge-down` icon. ARROW_BADGE_DOWN_FILLED (IconName): The :code:`arrow-badge-down-filled` icon. ARROW_BADGE_LEFT (IconName): The :code:`arrow-badge-left` icon. ARROW_BADGE_LEFT_FILLED (IconName): The :code:`arrow-badge-left-filled` icon. ARROW_BADGE_RIGHT (IconName): The :code:`arrow-badge-right` icon. ARROW_BADGE_RIGHT_FILLED (IconName): The :code:`arrow-badge-right-filled` icon. ARROW_BADGE_UP (IconName): The :code:`arrow-badge-up` icon. ARROW_BADGE_UP_FILLED (IconName): The :code:`arrow-badge-up-filled` icon. ARROW_BAR_BOTH (IconName): The :code:`arrow-bar-both` icon. ARROW_BAR_DOWN (IconName): The :code:`arrow-bar-down` icon. ARROW_BAR_LEFT (IconName): The :code:`arrow-bar-left` icon. ARROW_BAR_RIGHT (IconName): The :code:`arrow-bar-right` icon. ARROW_BAR_TO_DOWN (IconName): The :code:`arrow-bar-to-down` icon. ARROW_BAR_TO_LEFT (IconName): The :code:`arrow-bar-to-left` icon. ARROW_BAR_TO_RIGHT (IconName): The :code:`arrow-bar-to-right` icon. ARROW_BAR_TO_UP (IconName): The :code:`arrow-bar-to-up` icon. ARROW_BAR_UP (IconName): The :code:`arrow-bar-up` icon. ARROW_BEAR_LEFT (IconName): The :code:`arrow-bear-left` icon. ARROW_BEAR_LEFT_2 (IconName): The :code:`arrow-bear-left-2` icon. ARROW_BEAR_RIGHT (IconName): The :code:`arrow-bear-right` icon. ARROW_BEAR_RIGHT_2 (IconName): The :code:`arrow-bear-right-2` icon. ARROW_BIG_DOWN (IconName): The :code:`arrow-big-down` icon. ARROW_BIG_DOWN_FILLED (IconName): The :code:`arrow-big-down-filled` icon. ARROW_BIG_DOWN_LINE (IconName): The :code:`arrow-big-down-line` icon. ARROW_BIG_DOWN_LINE_FILLED (IconName): The :code:`arrow-big-down-line-filled` icon. ARROW_BIG_DOWN_LINES (IconName): The :code:`arrow-big-down-lines` icon. ARROW_BIG_DOWN_LINES_FILLED (IconName): The :code:`arrow-big-down-lines-filled` icon. ARROW_BIG_LEFT (IconName): The :code:`arrow-big-left` icon. ARROW_BIG_LEFT_FILLED (IconName): The :code:`arrow-big-left-filled` icon. ARROW_BIG_LEFT_LINE (IconName): The :code:`arrow-big-left-line` icon. ARROW_BIG_LEFT_LINE_FILLED (IconName): The :code:`arrow-big-left-line-filled` icon. ARROW_BIG_LEFT_LINES (IconName): The :code:`arrow-big-left-lines` icon. ARROW_BIG_LEFT_LINES_FILLED (IconName): The :code:`arrow-big-left-lines-filled` icon. ARROW_BIG_RIGHT (IconName): The :code:`arrow-big-right` icon. ARROW_BIG_RIGHT_FILLED (IconName): The :code:`arrow-big-right-filled` icon. ARROW_BIG_RIGHT_LINE (IconName): The :code:`arrow-big-right-line` icon. ARROW_BIG_RIGHT_LINE_FILLED (IconName): The :code:`arrow-big-right-line-filled` icon. ARROW_BIG_RIGHT_LINES (IconName): The :code:`arrow-big-right-lines` icon. ARROW_BIG_RIGHT_LINES_FILLED (IconName): The :code:`arrow-big-right-lines-filled` icon. ARROW_BIG_UP (IconName): The :code:`arrow-big-up` icon. ARROW_BIG_UP_FILLED (IconName): The :code:`arrow-big-up-filled` icon. ARROW_BIG_UP_LINE (IconName): The :code:`arrow-big-up-line` icon. ARROW_BIG_UP_LINE_FILLED (IconName): The :code:`arrow-big-up-line-filled` icon. ARROW_BIG_UP_LINES (IconName): The :code:`arrow-big-up-lines` icon. ARROW_BIG_UP_LINES_FILLED (IconName): The :code:`arrow-big-up-lines-filled` icon. ARROW_BOUNCE (IconName): The :code:`arrow-bounce` icon. ARROW_CAPSULE (IconName): The :code:`arrow-capsule` icon. ARROW_CURVE_LEFT (IconName): The :code:`arrow-curve-left` icon. ARROW_CURVE_RIGHT (IconName): The :code:`arrow-curve-right` icon. ARROW_DOWN (IconName): The :code:`arrow-down` icon. ARROW_DOWN_BAR (IconName): The :code:`arrow-down-bar` icon. ARROW_DOWN_CIRCLE (IconName): The :code:`arrow-down-circle` icon. ARROW_DOWN_LEFT (IconName): The :code:`arrow-down-left` icon. ARROW_DOWN_LEFT_CIRCLE (IconName): The :code:`arrow-down-left-circle` icon. ARROW_DOWN_RHOMBUS (IconName): The :code:`arrow-down-rhombus` icon. ARROW_DOWN_RIGHT (IconName): The :code:`arrow-down-right` icon. ARROW_DOWN_RIGHT_CIRCLE (IconName): The :code:`arrow-down-right-circle` icon. ARROW_DOWN_SQUARE (IconName): The :code:`arrow-down-square` icon. ARROW_DOWN_TAIL (IconName): The :code:`arrow-down-tail` icon. ARROW_ELBOW_LEFT (IconName): The :code:`arrow-elbow-left` icon. ARROW_ELBOW_RIGHT (IconName): The :code:`arrow-elbow-right` icon. ARROW_FORK (IconName): The :code:`arrow-fork` icon. ARROW_FORWARD (IconName): The :code:`arrow-forward` icon. ARROW_FORWARD_UP (IconName): The :code:`arrow-forward-up` icon. ARROW_FORWARD_UP_DOUBLE (IconName): The :code:`arrow-forward-up-double` icon. ARROW_GUIDE (IconName): The :code:`arrow-guide` icon. ARROW_ITERATION (IconName): The :code:`arrow-iteration` icon. ARROW_LEFT (IconName): The :code:`arrow-left` icon. ARROW_LEFT_BAR (IconName): The :code:`arrow-left-bar` icon. ARROW_LEFT_CIRCLE (IconName): The :code:`arrow-left-circle` icon. ARROW_LEFT_RHOMBUS (IconName): The :code:`arrow-left-rhombus` icon. ARROW_LEFT_RIGHT (IconName): The :code:`arrow-left-right` icon. ARROW_LEFT_SQUARE (IconName): The :code:`arrow-left-square` icon. ARROW_LEFT_TAIL (IconName): The :code:`arrow-left-tail` icon. ARROW_LOOP_LEFT (IconName): The :code:`arrow-loop-left` icon. ARROW_LOOP_LEFT_2 (IconName): The :code:`arrow-loop-left-2` icon. ARROW_LOOP_RIGHT (IconName): The :code:`arrow-loop-right` icon. ARROW_LOOP_RIGHT_2 (IconName): The :code:`arrow-loop-right-2` icon. ARROW_MERGE (IconName): The :code:`arrow-merge` icon. ARROW_MERGE_BOTH (IconName): The :code:`arrow-merge-both` icon. ARROW_MERGE_LEFT (IconName): The :code:`arrow-merge-left` icon. ARROW_MERGE_RIGHT (IconName): The :code:`arrow-merge-right` icon. ARROW_MOVE_DOWN (IconName): The :code:`arrow-move-down` icon. ARROW_MOVE_LEFT (IconName): The :code:`arrow-move-left` icon. ARROW_MOVE_RIGHT (IconName): The :code:`arrow-move-right` icon. ARROW_MOVE_UP (IconName): The :code:`arrow-move-up` icon. ARROW_NARROW_DOWN (IconName): The :code:`arrow-narrow-down` icon. ARROW_NARROW_LEFT (IconName): The :code:`arrow-narrow-left` icon. ARROW_NARROW_RIGHT (IconName): The :code:`arrow-narrow-right` icon. ARROW_NARROW_UP (IconName): The :code:`arrow-narrow-up` icon. ARROW_RAMP_LEFT (IconName): The :code:`arrow-ramp-left` icon. ARROW_RAMP_LEFT_2 (IconName): The :code:`arrow-ramp-left-2` icon. ARROW_RAMP_LEFT_3 (IconName): The :code:`arrow-ramp-left-3` icon. ARROW_RAMP_RIGHT (IconName): The :code:`arrow-ramp-right` icon. ARROW_RAMP_RIGHT_2 (IconName): The :code:`arrow-ramp-right-2` icon. ARROW_RAMP_RIGHT_3 (IconName): The :code:`arrow-ramp-right-3` icon. ARROW_RIGHT (IconName): The :code:`arrow-right` icon. ARROW_RIGHT_BAR (IconName): The :code:`arrow-right-bar` icon. ARROW_RIGHT_CIRCLE (IconName): The :code:`arrow-right-circle` icon. ARROW_RIGHT_RHOMBUS (IconName): The :code:`arrow-right-rhombus` icon. ARROW_RIGHT_SQUARE (IconName): The :code:`arrow-right-square` icon. ARROW_RIGHT_TAIL (IconName): The :code:`arrow-right-tail` icon. ARROW_ROTARY_FIRST_LEFT (IconName): The :code:`arrow-rotary-first-left` icon. ARROW_ROTARY_FIRST_RIGHT (IconName): The :code:`arrow-rotary-first-right` icon. ARROW_ROTARY_LAST_LEFT (IconName): The :code:`arrow-rotary-last-left` icon. ARROW_ROTARY_LAST_RIGHT (IconName): The :code:`arrow-rotary-last-right` icon. ARROW_ROTARY_LEFT (IconName): The :code:`arrow-rotary-left` icon. ARROW_ROTARY_RIGHT (IconName): The :code:`arrow-rotary-right` icon. ARROW_ROTARY_STRAIGHT (IconName): The :code:`arrow-rotary-straight` icon. ARROW_ROUNDABOUT_LEFT (IconName): The :code:`arrow-roundabout-left` icon. ARROW_ROUNDABOUT_RIGHT (IconName): The :code:`arrow-roundabout-right` icon. ARROW_SHARP_TURN_LEFT (IconName): The :code:`arrow-sharp-turn-left` icon. ARROW_SHARP_TURN_RIGHT (IconName): The :code:`arrow-sharp-turn-right` icon. ARROW_UP (IconName): The :code:`arrow-up` icon. ARROW_UP_BAR (IconName): The :code:`arrow-up-bar` icon. ARROW_UP_CIRCLE (IconName): The :code:`arrow-up-circle` icon. ARROW_UP_LEFT (IconName): The :code:`arrow-up-left` icon. ARROW_UP_LEFT_CIRCLE (IconName): The :code:`arrow-up-left-circle` icon. ARROW_UP_RHOMBUS (IconName): The :code:`arrow-up-rhombus` icon. ARROW_UP_RIGHT (IconName): The :code:`arrow-up-right` icon. ARROW_UP_RIGHT_CIRCLE (IconName): The :code:`arrow-up-right-circle` icon. ARROW_UP_SQUARE (IconName): The :code:`arrow-up-square` icon. ARROW_UP_TAIL (IconName): The :code:`arrow-up-tail` icon. ARROW_WAVE_LEFT_DOWN (IconName): The :code:`arrow-wave-left-down` icon. ARROW_WAVE_LEFT_UP (IconName): The :code:`arrow-wave-left-up` icon. ARROW_WAVE_RIGHT_DOWN (IconName): The :code:`arrow-wave-right-down` icon. ARROW_WAVE_RIGHT_UP (IconName): The :code:`arrow-wave-right-up` icon. ARROW_ZIG_ZAG (IconName): The :code:`arrow-zig-zag` icon. ARROWS_CROSS (IconName): The :code:`arrows-cross` icon. ARROWS_DIAGONAL (IconName): The :code:`arrows-diagonal` icon. ARROWS_DIAGONAL_2 (IconName): The :code:`arrows-diagonal-2` icon. ARROWS_DIAGONAL_MINIMIZE (IconName): The :code:`arrows-diagonal-minimize` icon. ARROWS_DIAGONAL_MINIMIZE_2 (IconName): The :code:`arrows-diagonal-minimize-2` icon. ARROWS_DIFF (IconName): The :code:`arrows-diff` icon. ARROWS_DOUBLE_NE_SW (IconName): The :code:`arrows-double-ne-sw` icon. ARROWS_DOUBLE_NW_SE (IconName): The :code:`arrows-double-nw-se` icon. ARROWS_DOUBLE_SE_NW (IconName): The :code:`arrows-double-se-nw` icon. ARROWS_DOUBLE_SW_NE (IconName): The :code:`arrows-double-sw-ne` icon. ARROWS_DOWN (IconName): The :code:`arrows-down` icon. ARROWS_DOWN_UP (IconName): The :code:`arrows-down-up` icon. ARROWS_EXCHANGE (IconName): The :code:`arrows-exchange` icon. ARROWS_EXCHANGE_2 (IconName): The :code:`arrows-exchange-2` icon. ARROWS_HORIZONTAL (IconName): The :code:`arrows-horizontal` icon. ARROWS_JOIN (IconName): The :code:`arrows-join` icon. ARROWS_JOIN_2 (IconName): The :code:`arrows-join-2` icon. ARROWS_LEFT (IconName): The :code:`arrows-left` icon. ARROWS_LEFT_DOWN (IconName): The :code:`arrows-left-down` icon. ARROWS_LEFT_RIGHT (IconName): The :code:`arrows-left-right` icon. ARROWS_MAXIMIZE (IconName): The :code:`arrows-maximize` icon. ARROWS_MINIMIZE (IconName): The :code:`arrows-minimize` icon. ARROWS_MOVE (IconName): The :code:`arrows-move` icon. ARROWS_MOVE_HORIZONTAL (IconName): The :code:`arrows-move-horizontal` icon. ARROWS_MOVE_VERTICAL (IconName): The :code:`arrows-move-vertical` icon. ARROWS_RANDOM (IconName): The :code:`arrows-random` icon. ARROWS_RIGHT (IconName): The :code:`arrows-right` icon. ARROWS_RIGHT_DOWN (IconName): The :code:`arrows-right-down` icon. ARROWS_RIGHT_LEFT (IconName): The :code:`arrows-right-left` icon. ARROWS_SHUFFLE (IconName): The :code:`arrows-shuffle` icon. ARROWS_SHUFFLE_2 (IconName): The :code:`arrows-shuffle-2` icon. ARROWS_SORT (IconName): The :code:`arrows-sort` icon. ARROWS_SPLIT (IconName): The :code:`arrows-split` icon. ARROWS_SPLIT_2 (IconName): The :code:`arrows-split-2` icon. ARROWS_TRANSFER_DOWN (IconName): The :code:`arrows-transfer-down` icon. ARROWS_TRANSFER_UP (IconName): The :code:`arrows-transfer-up` icon. ARROWS_UP (IconName): The :code:`arrows-up` icon. ARROWS_UP_DOWN (IconName): The :code:`arrows-up-down` icon. ARROWS_UP_LEFT (IconName): The :code:`arrows-up-left` icon. ARROWS_UP_RIGHT (IconName): The :code:`arrows-up-right` icon. ARROWS_VERTICAL (IconName): The :code:`arrows-vertical` icon. ARTBOARD (IconName): The :code:`artboard` icon. ARTBOARD_FILLED (IconName): The :code:`artboard-filled` icon. ARTBOARD_OFF (IconName): The :code:`artboard-off` icon. ARTICLE (IconName): The :code:`article` icon. ARTICLE_FILLED_FILLED (IconName): The :code:`article-filled-filled` icon. ARTICLE_OFF (IconName): The :code:`article-off` icon. ASPECT_RATIO (IconName): The :code:`aspect-ratio` icon. ASPECT_RATIO_FILLED (IconName): The :code:`aspect-ratio-filled` icon. ASPECT_RATIO_OFF (IconName): The :code:`aspect-ratio-off` icon. ASSEMBLY (IconName): The :code:`assembly` icon. ASSEMBLY_OFF (IconName): The :code:`assembly-off` icon. ASSET (IconName): The :code:`asset` icon. ASTERISK (IconName): The :code:`asterisk` icon. ASTERISK_SIMPLE (IconName): The :code:`asterisk-simple` icon. AT (IconName): The :code:`at` icon. AT_OFF (IconName): The :code:`at-off` icon. ATOM (IconName): The :code:`atom` icon. ATOM_2 (IconName): The :code:`atom-2` icon. ATOM_2_FILLED (IconName): The :code:`atom-2-filled` icon. ATOM_OFF (IconName): The :code:`atom-off` icon. AUGMENTED_REALITY (IconName): The :code:`augmented-reality` icon. AUGMENTED_REALITY_2 (IconName): The :code:`augmented-reality-2` icon. AUGMENTED_REALITY_OFF (IconName): The :code:`augmented-reality-off` icon. AWARD (IconName): The :code:`award` icon. AWARD_FILLED (IconName): The :code:`award-filled` icon. AWARD_OFF (IconName): The :code:`award-off` icon. AXE (IconName): The :code:`axe` icon. AXIS_X (IconName): The :code:`axis-x` icon. AXIS_Y (IconName): The :code:`axis-y` icon. BABY_BOTTLE (IconName): The :code:`baby-bottle` icon. BABY_CARRIAGE (IconName): The :code:`baby-carriage` icon. BACKHOE (IconName): The :code:`backhoe` icon. BACKPACK (IconName): The :code:`backpack` icon. BACKPACK_OFF (IconName): The :code:`backpack-off` icon. BACKSLASH (IconName): The :code:`backslash` icon. BACKSPACE (IconName): The :code:`backspace` icon. BACKSPACE_FILLED (IconName): The :code:`backspace-filled` icon. BADGE (IconName): The :code:`badge` icon. BADGE_3D (IconName): The :code:`badge-3d` icon. BADGE_4K (IconName): The :code:`badge-4k` icon. BADGE_8K (IconName): The :code:`badge-8k` icon. BADGE_AD (IconName): The :code:`badge-ad` icon. BADGE_AR (IconName): The :code:`badge-ar` icon. BADGE_CC (IconName): The :code:`badge-cc` icon. BADGE_FILLED (IconName): The :code:`badge-filled` icon. BADGE_HD (IconName): The :code:`badge-hd` icon. BADGE_OFF (IconName): The :code:`badge-off` icon. BADGE_SD (IconName): The :code:`badge-sd` icon. BADGE_TM (IconName): The :code:`badge-tm` icon. BADGE_VO (IconName): The :code:`badge-vo` icon. BADGE_VR (IconName): The :code:`badge-vr` icon. BADGE_WC (IconName): The :code:`badge-wc` icon. BADGES (IconName): The :code:`badges` icon. BADGES_FILLED (IconName): The :code:`badges-filled` icon. BADGES_OFF (IconName): The :code:`badges-off` icon. BAGUETTE (IconName): The :code:`baguette` icon. BALL_AMERICAN_FOOTBALL (IconName): The :code:`ball-american-football` icon. BALL_AMERICAN_FOOTBALL_OFF (IconName): The :code:`ball-american-football-off` icon. BALL_BASEBALL (IconName): The :code:`ball-baseball` icon. BALL_BASKETBALL (IconName): The :code:`ball-basketball` icon. BALL_BOWLING (IconName): The :code:`ball-bowling` icon. BALL_FOOTBALL (IconName): The :code:`ball-football` icon. BALL_FOOTBALL_OFF (IconName): The :code:`ball-football-off` icon. BALL_TENNIS (IconName): The :code:`ball-tennis` icon. BALL_VOLLEYBALL (IconName): The :code:`ball-volleyball` icon. BALLOON (IconName): The :code:`balloon` icon. BALLOON_FILLED (IconName): The :code:`balloon-filled` icon. BALLOON_OFF (IconName): The :code:`balloon-off` icon. BALLPEN (IconName): The :code:`ballpen` icon. BALLPEN_FILLED (IconName): The :code:`ballpen-filled` icon. BALLPEN_OFF (IconName): The :code:`ballpen-off` icon. BAN (IconName): The :code:`ban` icon. BANDAGE (IconName): The :code:`bandage` icon. BANDAGE_FILLED (IconName): The :code:`bandage-filled` icon. BANDAGE_OFF (IconName): The :code:`bandage-off` icon. BARBELL (IconName): The :code:`barbell` icon. BARBELL_OFF (IconName): The :code:`barbell-off` icon. BARCODE (IconName): The :code:`barcode` icon. BARCODE_OFF (IconName): The :code:`barcode-off` icon. BARREL (IconName): The :code:`barrel` icon. BARREL_OFF (IconName): The :code:`barrel-off` icon. BARRIER_BLOCK (IconName): The :code:`barrier-block` icon. BARRIER_BLOCK_OFF (IconName): The :code:`barrier-block-off` icon. BASELINE (IconName): The :code:`baseline` icon. BASELINE_DENSITY_LARGE (IconName): The :code:`baseline-density-large` icon. BASELINE_DENSITY_MEDIUM (IconName): The :code:`baseline-density-medium` icon. BASELINE_DENSITY_SMALL (IconName): The :code:`baseline-density-small` icon. BASKET (IconName): The :code:`basket` icon. BASKET_FILLED (IconName): The :code:`basket-filled` icon. BASKET_OFF (IconName): The :code:`basket-off` icon. BAT (IconName): The :code:`bat` icon. BATH (IconName): The :code:`bath` icon. BATH_FILLED (IconName): The :code:`bath-filled` icon. BATH_OFF (IconName): The :code:`bath-off` icon. BATTERY (IconName): The :code:`battery` icon. BATTERY_1 (IconName): The :code:`battery-1` icon. BATTERY_1_FILLED (IconName): The :code:`battery-1-filled` icon. BATTERY_2 (IconName): The :code:`battery-2` icon. BATTERY_2_FILLED (IconName): The :code:`battery-2-filled` icon. BATTERY_3 (IconName): The :code:`battery-3` icon. BATTERY_3_FILLED (IconName): The :code:`battery-3-filled` icon. BATTERY_4 (IconName): The :code:`battery-4` icon. BATTERY_4_FILLED (IconName): The :code:`battery-4-filled` icon. BATTERY_AUTOMOTIVE (IconName): The :code:`battery-automotive` icon. BATTERY_CHARGING (IconName): The :code:`battery-charging` icon. BATTERY_CHARGING_2 (IconName): The :code:`battery-charging-2` icon. BATTERY_ECO (IconName): The :code:`battery-eco` icon. BATTERY_FILLED (IconName): The :code:`battery-filled` icon. BATTERY_OFF (IconName): The :code:`battery-off` icon. BEACH (IconName): The :code:`beach` icon. BEACH_OFF (IconName): The :code:`beach-off` icon. BED (IconName): The :code:`bed` icon. BED_FILLED (IconName): The :code:`bed-filled` icon. BED_OFF (IconName): The :code:`bed-off` icon. BEER (IconName): The :code:`beer` icon. BEER_FILLED (IconName): The :code:`beer-filled` icon. BEER_OFF (IconName): The :code:`beer-off` icon. BELL (IconName): The :code:`bell` icon. BELL_BOLT (IconName): The :code:`bell-bolt` icon. BELL_CANCEL (IconName): The :code:`bell-cancel` icon. BELL_CHECK (IconName): The :code:`bell-check` icon. BELL_CODE (IconName): The :code:`bell-code` icon. BELL_COG (IconName): The :code:`bell-cog` icon. BELL_DOLLAR (IconName): The :code:`bell-dollar` icon. BELL_DOWN (IconName): The :code:`bell-down` icon. BELL_EXCLAMATION (IconName): The :code:`bell-exclamation` icon. BELL_FILLED (IconName): The :code:`bell-filled` icon. BELL_HEART (IconName): The :code:`bell-heart` icon. BELL_MINUS (IconName): The :code:`bell-minus` icon. BELL_MINUS_FILLED (IconName): The :code:`bell-minus-filled` icon. BELL_OFF (IconName): The :code:`bell-off` icon. BELL_PAUSE (IconName): The :code:`bell-pause` icon. BELL_PIN (IconName): The :code:`bell-pin` icon. BELL_PLUS (IconName): The :code:`bell-plus` icon. BELL_PLUS_FILLED (IconName): The :code:`bell-plus-filled` icon. BELL_QUESTION (IconName): The :code:`bell-question` icon. BELL_RINGING (IconName): The :code:`bell-ringing` icon. BELL_RINGING_2 (IconName): The :code:`bell-ringing-2` icon. BELL_RINGING_2_FILLED (IconName): The :code:`bell-ringing-2-filled` icon. BELL_RINGING_FILLED (IconName): The :code:`bell-ringing-filled` icon. BELL_SCHOOL (IconName): The :code:`bell-school` icon. BELL_SEARCH (IconName): The :code:`bell-search` icon. BELL_SHARE (IconName): The :code:`bell-share` icon. BELL_STAR (IconName): The :code:`bell-star` icon. BELL_UP (IconName): The :code:`bell-up` icon. BELL_X (IconName): The :code:`bell-x` icon. BELL_X_FILLED (IconName): The :code:`bell-x-filled` icon. BELL_Z (IconName): The :code:`bell-z` icon. BELL_Z_FILLED (IconName): The :code:`bell-z-filled` icon. BETA (IconName): The :code:`beta` icon. BIBLE (IconName): The :code:`bible` icon. BIKE (IconName): The :code:`bike` icon. BIKE_OFF (IconName): The :code:`bike-off` icon. BINARY (IconName): The :code:`binary` icon. BINARY_OFF (IconName): The :code:`binary-off` icon. BINARY_TREE (IconName): The :code:`binary-tree` icon. BINARY_TREE_2 (IconName): The :code:`binary-tree-2` icon. BIOHAZARD (IconName): The :code:`biohazard` icon. BIOHAZARD_OFF (IconName): The :code:`biohazard-off` icon. BLADE (IconName): The :code:`blade` icon. BLADE_FILLED (IconName): The :code:`blade-filled` icon. BLEACH (IconName): The :code:`bleach` icon. BLEACH_CHLORINE (IconName): The :code:`bleach-chlorine` icon. BLEACH_NO_CHLORINE (IconName): The :code:`bleach-no-chlorine` icon. BLEACH_OFF (IconName): The :code:`bleach-off` icon. BLOCKQUOTE (IconName): The :code:`blockquote` icon. BLUETOOTH (IconName): The :code:`bluetooth` icon. BLUETOOTH_CONNECTED (IconName): The :code:`bluetooth-connected` icon. BLUETOOTH_OFF (IconName): The :code:`bluetooth-off` icon. BLUETOOTH_X (IconName): The :code:`bluetooth-x` icon. BLUR (IconName): The :code:`blur` icon. BLUR_OFF (IconName): The :code:`blur-off` icon. BMP (IconName): The :code:`bmp` icon. BOLD (IconName): The :code:`bold` icon. BOLD_OFF (IconName): The :code:`bold-off` icon. BOLT (IconName): The :code:`bolt` icon. BOLT_OFF (IconName): The :code:`bolt-off` icon. BOMB (IconName): The :code:`bomb` icon. BOMB_FILLED (IconName): The :code:`bomb-filled` icon. BONE (IconName): The :code:`bone` icon. BONE_OFF (IconName): The :code:`bone-off` icon. BONG (IconName): The :code:`bong` icon. BONG_OFF (IconName): The :code:`bong-off` icon. BOOK (IconName): The :code:`book` icon. BOOK_2 (IconName): The :code:`book-2` icon. BOOK_DOWNLOAD (IconName): The :code:`book-download` icon. BOOK_FILLED (IconName): The :code:`book-filled` icon. BOOK_OFF (IconName): The :code:`book-off` icon. BOOK_UPLOAD (IconName): The :code:`book-upload` icon. BOOKMARK (IconName): The :code:`bookmark` icon. BOOKMARK_EDIT (IconName): The :code:`bookmark-edit` icon. BOOKMARK_FILLED (IconName): The :code:`bookmark-filled` icon. BOOKMARK_MINUS (IconName): The :code:`bookmark-minus` icon. BOOKMARK_OFF (IconName): The :code:`bookmark-off` icon. BOOKMARK_PLUS (IconName): The :code:`bookmark-plus` icon. BOOKMARK_QUESTION (IconName): The :code:`bookmark-question` icon. BOOKMARKS (IconName): The :code:`bookmarks` icon. BOOKMARKS_OFF (IconName): The :code:`bookmarks-off` icon. BOOKS (IconName): The :code:`books` icon. BOOKS_OFF (IconName): The :code:`books-off` icon. BORDER_ALL (IconName): The :code:`border-all` icon. BORDER_BOTTOM (IconName): The :code:`border-bottom` icon. BORDER_CORNERS (IconName): The :code:`border-corners` icon. BORDER_HORIZONTAL (IconName): The :code:`border-horizontal` icon. BORDER_INNER (IconName): The :code:`border-inner` icon. BORDER_LEFT (IconName): The :code:`border-left` icon. BORDER_NONE (IconName): The :code:`border-none` icon. BORDER_OUTER (IconName): The :code:`border-outer` icon. BORDER_RADIUS (IconName): The :code:`border-radius` icon. BORDER_RIGHT (IconName): The :code:`border-right` icon. BORDER_SIDES (IconName): The :code:`border-sides` icon. BORDER_STYLE (IconName): The :code:`border-style` icon. BORDER_STYLE_2 (IconName): The :code:`border-style-2` icon. BORDER_TOP (IconName): The :code:`border-top` icon. BORDER_VERTICAL (IconName): The :code:`border-vertical` icon. BOTTLE (IconName): The :code:`bottle` icon. BOTTLE_FILLED (IconName): The :code:`bottle-filled` icon. BOTTLE_OFF (IconName): The :code:`bottle-off` icon. BOUNCE_LEFT (IconName): The :code:`bounce-left` icon. BOUNCE_RIGHT (IconName): The :code:`bounce-right` icon. BOW (IconName): The :code:`bow` icon. BOWL (IconName): The :code:`bowl` icon. BOX (IconName): The :code:`box` icon. BOX_ALIGN_BOTTOM (IconName): The :code:`box-align-bottom` icon. BOX_ALIGN_BOTTOM_FILLED (IconName): The :code:`box-align-bottom-filled` icon. BOX_ALIGN_BOTTOM_LEFT (IconName): The :code:`box-align-bottom-left` icon. BOX_ALIGN_BOTTOM_LEFT_FILLED (IconName): The :code:`box-align-bottom-left-filled` icon. BOX_ALIGN_BOTTOM_RIGHT (IconName): The :code:`box-align-bottom-right` icon. BOX_ALIGN_BOTTOM_RIGHT_FILLED (IconName): The :code:`box-align-bottom-right-filled` icon. BOX_ALIGN_LEFT (IconName): The :code:`box-align-left` icon. BOX_ALIGN_LEFT_FILLED (IconName): The :code:`box-align-left-filled` icon. BOX_ALIGN_RIGHT (IconName): The :code:`box-align-right` icon. BOX_ALIGN_RIGHT_FILLED (IconName): The :code:`box-align-right-filled` icon. BOX_ALIGN_TOP (IconName): The :code:`box-align-top` icon. BOX_ALIGN_TOP_FILLED (IconName): The :code:`box-align-top-filled` icon. BOX_ALIGN_TOP_LEFT (IconName): The :code:`box-align-top-left` icon. BOX_ALIGN_TOP_LEFT_FILLED (IconName): The :code:`box-align-top-left-filled` icon. BOX_ALIGN_TOP_RIGHT (IconName): The :code:`box-align-top-right` icon. BOX_ALIGN_TOP_RIGHT_FILLED (IconName): The :code:`box-align-top-right-filled` icon. BOX_MARGIN (IconName): The :code:`box-margin` icon. BOX_MODEL (IconName): The :code:`box-model` icon. BOX_MODEL_2 (IconName): The :code:`box-model-2` icon. BOX_MODEL_2_OFF (IconName): The :code:`box-model-2-off` icon. BOX_MODEL_OFF (IconName): The :code:`box-model-off` icon. BOX_MULTIPLE (IconName): The :code:`box-multiple` icon. BOX_MULTIPLE_0 (IconName): The :code:`box-multiple-0` icon. BOX_MULTIPLE_1 (IconName): The :code:`box-multiple-1` icon. BOX_MULTIPLE_2 (IconName): The :code:`box-multiple-2` icon. BOX_MULTIPLE_3 (IconName): The :code:`box-multiple-3` icon. BOX_MULTIPLE_4 (IconName): The :code:`box-multiple-4` icon. BOX_MULTIPLE_5 (IconName): The :code:`box-multiple-5` icon. BOX_MULTIPLE_6 (IconName): The :code:`box-multiple-6` icon. BOX_MULTIPLE_7 (IconName): The :code:`box-multiple-7` icon. BOX_MULTIPLE_8 (IconName): The :code:`box-multiple-8` icon. BOX_MULTIPLE_9 (IconName): The :code:`box-multiple-9` icon. BOX_OFF (IconName): The :code:`box-off` icon. BOX_PADDING (IconName): The :code:`box-padding` icon. BOX_SEAM (IconName): The :code:`box-seam` icon. BRACES (IconName): The :code:`braces` icon. BRACES_OFF (IconName): The :code:`braces-off` icon. BRACKETS (IconName): The :code:`brackets` icon. BRACKETS_CONTAIN (IconName): The :code:`brackets-contain` icon. BRACKETS_CONTAIN_END (IconName): The :code:`brackets-contain-end` icon. BRACKETS_CONTAIN_START (IconName): The :code:`brackets-contain-start` icon. BRACKETS_OFF (IconName): The :code:`brackets-off` icon. BRAILLE (IconName): The :code:`braille` icon. BRAIN (IconName): The :code:`brain` icon. BRAND_4CHAN (IconName): The :code:`brand-4chan` icon. BRAND_ABSTRACT (IconName): The :code:`brand-abstract` icon. BRAND_ADOBE (IconName): The :code:`brand-adobe` icon. BRAND_ADONIS_JS (IconName): The :code:`brand-adonis-js` icon. BRAND_AIRBNB (IconName): The :code:`brand-airbnb` icon. BRAND_AIRTABLE (IconName): The :code:`brand-airtable` icon. BRAND_ALGOLIA (IconName): The :code:`brand-algolia` icon. BRAND_ALIPAY (IconName): The :code:`brand-alipay` icon. BRAND_ALPINE_JS (IconName): The :code:`brand-alpine-js` icon. BRAND_AMAZON (IconName): The :code:`brand-amazon` icon. BRAND_AMD (IconName): The :code:`brand-amd` icon. BRAND_AMIGO (IconName): The :code:`brand-amigo` icon. BRAND_AMONG_US (IconName): The :code:`brand-among-us` icon. BRAND_ANDROID (IconName): The :code:`brand-android` icon. BRAND_ANGULAR (IconName): The :code:`brand-angular` icon. BRAND_ANSIBLE (IconName): The :code:`brand-ansible` icon. BRAND_AO3 (IconName): The :code:`brand-ao3` icon. BRAND_APPGALLERY (IconName): The :code:`brand-appgallery` icon. BRAND_APPLE (IconName): The :code:`brand-apple` icon. BRAND_APPLE_ARCADE (IconName): The :code:`brand-apple-arcade` icon. BRAND_APPLE_PODCAST (IconName): The :code:`brand-apple-podcast` icon. BRAND_APPSTORE (IconName): The :code:`brand-appstore` icon. BRAND_ASANA (IconName): The :code:`brand-asana` icon. BRAND_AWS (IconName): The :code:`brand-aws` icon. BRAND_AZURE (IconName): The :code:`brand-azure` icon. BRAND_BACKBONE (IconName): The :code:`brand-backbone` icon. BRAND_BADOO (IconName): The :code:`brand-badoo` icon. BRAND_BAIDU (IconName): The :code:`brand-baidu` icon. BRAND_BANDCAMP (IconName): The :code:`brand-bandcamp` icon. BRAND_BANDLAB (IconName): The :code:`brand-bandlab` icon. BRAND_BEATS (IconName): The :code:`brand-beats` icon. BRAND_BEHANCE (IconName): The :code:`brand-behance` icon. BRAND_BILIBILI (IconName): The :code:`brand-bilibili` icon. BRAND_BINANCE (IconName): The :code:`brand-binance` icon. BRAND_BING (IconName): The :code:`brand-bing` icon. BRAND_BITBUCKET (IconName): The :code:`brand-bitbucket` icon. BRAND_BLACKBERRY (IconName): The :code:`brand-blackberry` icon. BRAND_BLENDER (IconName): The :code:`brand-blender` icon. BRAND_BLOGGER (IconName): The :code:`brand-blogger` icon. BRAND_BOOKING (IconName): The :code:`brand-booking` icon. BRAND_BOOTSTRAP (IconName): The :code:`brand-bootstrap` icon. BRAND_BULMA (IconName): The :code:`brand-bulma` icon. BRAND_BUMBLE (IconName): The :code:`brand-bumble` icon. BRAND_BUNPO (IconName): The :code:`brand-bunpo` icon. BRAND_C_SHARP (IconName): The :code:`brand-c-sharp` icon. BRAND_CAKE (IconName): The :code:`brand-cake` icon. BRAND_CAKEPHP (IconName): The :code:`brand-cakephp` icon. BRAND_CAMPAIGNMONITOR (IconName): The :code:`brand-campaignmonitor` icon. BRAND_CARBON (IconName): The :code:`brand-carbon` icon. BRAND_CASHAPP (IconName): The :code:`brand-cashapp` icon. BRAND_CHROME (IconName): The :code:`brand-chrome` icon. BRAND_CINEMA_4D (IconName): The :code:`brand-cinema-4d` icon. BRAND_CITYMAPPER (IconName): The :code:`brand-citymapper` icon. BRAND_CLOUDFLARE (IconName): The :code:`brand-cloudflare` icon. BRAND_CODECOV (IconName): The :code:`brand-codecov` icon. BRAND_CODEPEN (IconName): The :code:`brand-codepen` icon. BRAND_CODESANDBOX (IconName): The :code:`brand-codesandbox` icon. BRAND_COHOST (IconName): The :code:`brand-cohost` icon. BRAND_COINBASE (IconName): The :code:`brand-coinbase` icon. BRAND_COMEDY_CENTRAL (IconName): The :code:`brand-comedy-central` icon. BRAND_COREOS (IconName): The :code:`brand-coreos` icon. BRAND_COUCHDB (IconName): The :code:`brand-couchdb` icon. BRAND_COUCHSURFING (IconName): The :code:`brand-couchsurfing` icon. BRAND_CPP (IconName): The :code:`brand-cpp` icon. BRAND_CRAFT (IconName): The :code:`brand-craft` icon. BRAND_CRUNCHBASE (IconName): The :code:`brand-crunchbase` icon. BRAND_CSS3 (IconName): The :code:`brand-css3` icon. BRAND_CTEMPLAR (IconName): The :code:`brand-ctemplar` icon. BRAND_CUCUMBER (IconName): The :code:`brand-cucumber` icon. BRAND_CUPRA (IconName): The :code:`brand-cupra` icon. BRAND_CYPRESS (IconName): The :code:`brand-cypress` icon. BRAND_D3 (IconName): The :code:`brand-d3` icon. BRAND_DAYS_COUNTER (IconName): The :code:`brand-days-counter` icon. BRAND_DCOS (IconName): The :code:`brand-dcos` icon. BRAND_DEBIAN (IconName): The :code:`brand-debian` icon. BRAND_DEEZER (IconName): The :code:`brand-deezer` icon. BRAND_DELIVEROO (IconName): The :code:`brand-deliveroo` icon. BRAND_DENO (IconName): The :code:`brand-deno` icon. BRAND_DENODO (IconName): The :code:`brand-denodo` icon. BRAND_DEVIANTART (IconName): The :code:`brand-deviantart` icon. BRAND_DIGG (IconName): The :code:`brand-digg` icon. BRAND_DINGTALK (IconName): The :code:`brand-dingtalk` icon. BRAND_DISCORD (IconName): The :code:`brand-discord` icon. BRAND_DISCORD_FILLED (IconName): The :code:`brand-discord-filled` icon. BRAND_DISNEY (IconName): The :code:`brand-disney` icon. BRAND_DISQUS (IconName): The :code:`brand-disqus` icon. BRAND_DJANGO (IconName): The :code:`brand-django` icon. BRAND_DOCKER (IconName): The :code:`brand-docker` icon. BRAND_DOCTRINE (IconName): The :code:`brand-doctrine` icon. BRAND_DOLBY_DIGITAL (IconName): The :code:`brand-dolby-digital` icon. BRAND_DOUBAN (IconName): The :code:`brand-douban` icon. BRAND_DRIBBBLE (IconName): The :code:`brand-dribbble` icon. BRAND_DRIBBBLE_FILLED (IconName): The :code:`brand-dribbble-filled` icon. BRAND_DROPS (IconName): The :code:`brand-drops` icon. BRAND_DRUPAL (IconName): The :code:`brand-drupal` icon. BRAND_EDGE (IconName): The :code:`brand-edge` icon. BRAND_ELASTIC (IconName): The :code:`brand-elastic` icon. BRAND_ELECTRONIC_ARTS (IconName): The :code:`brand-electronic-arts` icon. BRAND_EMBER (IconName): The :code:`brand-ember` icon. BRAND_ENVATO (IconName): The :code:`brand-envato` icon. BRAND_ETSY (IconName): The :code:`brand-etsy` icon. BRAND_EVERNOTE (IconName): The :code:`brand-evernote` icon. BRAND_FACEBOOK (IconName): The :code:`brand-facebook` icon. BRAND_FACEBOOK_FILLED (IconName): The :code:`brand-facebook-filled` icon. BRAND_FEEDLY (IconName): The :code:`brand-feedly` icon. BRAND_FIGMA (IconName): The :code:`brand-figma` icon. BRAND_FILEZILLA (IconName): The :code:`brand-filezilla` icon. BRAND_FINDER (IconName): The :code:`brand-finder` icon. BRAND_FIREBASE (IconName): The :code:`brand-firebase` icon. BRAND_FIREFOX (IconName): The :code:`brand-firefox` icon. BRAND_FIVERR (IconName): The :code:`brand-fiverr` icon. BRAND_FLICKR (IconName): The :code:`brand-flickr` icon. BRAND_FLIGHTRADAR24 (IconName): The :code:`brand-flightradar24` icon. BRAND_FLIPBOARD (IconName): The :code:`brand-flipboard` icon. BRAND_FLUTTER (IconName): The :code:`brand-flutter` icon. BRAND_FORTNITE (IconName): The :code:`brand-fortnite` icon. BRAND_FOURSQUARE (IconName): The :code:`brand-foursquare` icon. BRAND_FRAMER (IconName): The :code:`brand-framer` icon. BRAND_FRAMER_MOTION (IconName): The :code:`brand-framer-motion` icon. BRAND_FUNIMATION (IconName): The :code:`brand-funimation` icon. BRAND_GATSBY (IconName): The :code:`brand-gatsby` icon. BRAND_GIT (IconName): The :code:`brand-git` icon. BRAND_GITHUB (IconName): The :code:`brand-github` icon. BRAND_GITHUB_COPILOT (IconName): The :code:`brand-github-copilot` icon. BRAND_GITHUB_FILLED (IconName): The :code:`brand-github-filled` icon. BRAND_GITLAB (IconName): The :code:`brand-gitlab` icon. BRAND_GMAIL (IconName): The :code:`brand-gmail` icon. BRAND_GOLANG (IconName): The :code:`brand-golang` icon. BRAND_GOOGLE (IconName): The :code:`brand-google` icon. BRAND_GOOGLE_ANALYTICS (IconName): The :code:`brand-google-analytics` icon. BRAND_GOOGLE_BIG_QUERY (IconName): The :code:`brand-google-big-query` icon. BRAND_GOOGLE_DRIVE (IconName): The :code:`brand-google-drive` icon. BRAND_GOOGLE_FIT (IconName): The :code:`brand-google-fit` icon. BRAND_GOOGLE_HOME (IconName): The :code:`brand-google-home` icon. BRAND_GOOGLE_MAPS (IconName): The :code:`brand-google-maps` icon. BRAND_GOOGLE_ONE (IconName): The :code:`brand-google-one` icon. BRAND_GOOGLE_PHOTOS (IconName): The :code:`brand-google-photos` icon. BRAND_GOOGLE_PLAY (IconName): The :code:`brand-google-play` icon. BRAND_GOOGLE_PODCASTS (IconName): The :code:`brand-google-podcasts` icon. BRAND_GRAMMARLY (IconName): The :code:`brand-grammarly` icon. BRAND_GRAPHQL (IconName): The :code:`brand-graphql` icon. BRAND_GRAVATAR (IconName): The :code:`brand-gravatar` icon. BRAND_GRINDR (IconName): The :code:`brand-grindr` icon. BRAND_GUARDIAN (IconName): The :code:`brand-guardian` icon. BRAND_GUMROAD (IconName): The :code:`brand-gumroad` icon. BRAND_HBO (IconName): The :code:`brand-hbo` icon. BRAND_HEADLESSUI (IconName): The :code:`brand-headlessui` icon. BRAND_HEXO (IconName): The :code:`brand-hexo` icon. BRAND_HIPCHAT (IconName): The :code:`brand-hipchat` icon. BRAND_HTML5 (IconName): The :code:`brand-html5` icon. BRAND_INERTIA (IconName): The :code:`brand-inertia` icon. BRAND_INSTAGRAM (IconName): The :code:`brand-instagram` icon. BRAND_INTERCOM (IconName): The :code:`brand-intercom` icon. BRAND_ITCH (IconName): The :code:`brand-itch` icon. BRAND_JAVASCRIPT (IconName): The :code:`brand-javascript` icon. BRAND_JUEJIN (IconName): The :code:`brand-juejin` icon. BRAND_KBIN (IconName): The :code:`brand-kbin` icon. BRAND_KICK (IconName): The :code:`brand-kick` icon. BRAND_KICKSTARTER (IconName): The :code:`brand-kickstarter` icon. BRAND_KOTLIN (IconName): The :code:`brand-kotlin` icon. BRAND_LARAVEL (IconName): The :code:`brand-laravel` icon. BRAND_LASTFM (IconName): The :code:`brand-lastfm` icon. BRAND_LEETCODE (IconName): The :code:`brand-leetcode` icon. BRAND_LETTERBOXD (IconName): The :code:`brand-letterboxd` icon. BRAND_LINE (IconName): The :code:`brand-line` icon. BRAND_LINKEDIN (IconName): The :code:`brand-linkedin` icon. BRAND_LINKTREE (IconName): The :code:`brand-linktree` icon. BRAND_LINQPAD (IconName): The :code:`brand-linqpad` icon. BRAND_LOOM (IconName): The :code:`brand-loom` icon. BRAND_MAILGUN (IconName): The :code:`brand-mailgun` icon. BRAND_MANTINE (IconName): The :code:`brand-mantine` icon. BRAND_MASTERCARD (IconName): The :code:`brand-mastercard` icon. BRAND_MASTODON (IconName): The :code:`brand-mastodon` icon. BRAND_MATRIX (IconName): The :code:`brand-matrix` icon. BRAND_MCDONALDS (IconName): The :code:`brand-mcdonalds` icon. BRAND_MEDIUM (IconName): The :code:`brand-medium` icon. BRAND_MERCEDES (IconName): The :code:`brand-mercedes` icon. BRAND_MESSENGER (IconName): The :code:`brand-messenger` icon. BRAND_META (IconName): The :code:`brand-meta` icon. BRAND_MICROSOFT_TEAMS (IconName): The :code:`brand-microsoft-teams` icon. BRAND_MINECRAFT (IconName): The :code:`brand-minecraft` icon. BRAND_MINIPROGRAM (IconName): The :code:`brand-miniprogram` icon. BRAND_MIXPANEL (IconName): The :code:`brand-mixpanel` icon. BRAND_MONDAY (IconName): The :code:`brand-monday` icon. BRAND_MONGODB (IconName): The :code:`brand-mongodb` icon. BRAND_MY_OPPO (IconName): The :code:`brand-my-oppo` icon. BRAND_MYSQL (IconName): The :code:`brand-mysql` icon. BRAND_NATIONAL_GEOGRAPHIC (IconName): The :code:`brand-national-geographic` icon. BRAND_NEM (IconName): The :code:`brand-nem` icon. BRAND_NETBEANS (IconName): The :code:`brand-netbeans` icon. BRAND_NETEASE_MUSIC (IconName): The :code:`brand-netease-music` icon. BRAND_NETFLIX (IconName): The :code:`brand-netflix` icon. BRAND_NEXO (IconName): The :code:`brand-nexo` icon. BRAND_NEXTCLOUD (IconName): The :code:`brand-nextcloud` icon. BRAND_NEXTJS (IconName): The :code:`brand-nextjs` icon. BRAND_NODEJS (IconName): The :code:`brand-nodejs` icon. BRAND_NORD_VPN (IconName): The :code:`brand-nord-vpn` icon. BRAND_NOTION (IconName): The :code:`brand-notion` icon. BRAND_NPM (IconName): The :code:`brand-npm` icon. BRAND_NUXT (IconName): The :code:`brand-nuxt` icon. BRAND_NYTIMES (IconName): The :code:`brand-nytimes` icon. BRAND_OAUTH (IconName): The :code:`brand-oauth` icon. BRAND_OFFICE (IconName): The :code:`brand-office` icon. BRAND_OK_RU (IconName): The :code:`brand-ok-ru` icon. BRAND_ONEDRIVE (IconName): The :code:`brand-onedrive` icon. BRAND_ONLYFANS (IconName): The :code:`brand-onlyfans` icon. BRAND_OPEN_SOURCE (IconName): The :code:`brand-open-source` icon. BRAND_OPENAI (IconName): The :code:`brand-openai` icon. BRAND_OPENVPN (IconName): The :code:`brand-openvpn` icon. BRAND_OPERA (IconName): The :code:`brand-opera` icon. BRAND_PAGEKIT (IconName): The :code:`brand-pagekit` icon. BRAND_PATREON (IconName): The :code:`brand-patreon` icon. BRAND_PAYPAL (IconName): The :code:`brand-paypal` icon. BRAND_PAYPAL_FILLED (IconName): The :code:`brand-paypal-filled` icon. BRAND_PAYPAY (IconName): The :code:`brand-paypay` icon. BRAND_PEANUT (IconName): The :code:`brand-peanut` icon. BRAND_PEPSI (IconName): The :code:`brand-pepsi` icon. BRAND_PHP (IconName): The :code:`brand-php` icon. BRAND_PICSART (IconName): The :code:`brand-picsart` icon. BRAND_PINTEREST (IconName): The :code:`brand-pinterest` icon. BRAND_PLANETSCALE (IconName): The :code:`brand-planetscale` icon. BRAND_POCKET (IconName): The :code:`brand-pocket` icon. BRAND_POLYMER (IconName): The :code:`brand-polymer` icon. BRAND_POWERSHELL (IconName): The :code:`brand-powershell` icon. BRAND_PRISMA (IconName): The :code:`brand-prisma` icon. BRAND_PRODUCTHUNT (IconName): The :code:`brand-producthunt` icon. BRAND_PUSHBULLET (IconName): The :code:`brand-pushbullet` icon. BRAND_PUSHOVER (IconName): The :code:`brand-pushover` icon. BRAND_PYTHON (IconName): The :code:`brand-python` icon. BRAND_QQ (IconName): The :code:`brand-qq` icon. BRAND_RADIX_UI (IconName): The :code:`brand-radix-ui` icon. BRAND_REACT (IconName): The :code:`brand-react` icon. BRAND_REACT_NATIVE (IconName): The :code:`brand-react-native` icon. BRAND_REASON (IconName): The :code:`brand-reason` icon. BRAND_REDDIT (IconName): The :code:`brand-reddit` icon. BRAND_REDHAT (IconName): The :code:`brand-redhat` icon. BRAND_REDUX (IconName): The :code:`brand-redux` icon. BRAND_REVOLUT (IconName): The :code:`brand-revolut` icon. BRAND_RUMBLE (IconName): The :code:`brand-rumble` icon. BRAND_RUST (IconName): The :code:`brand-rust` icon. BRAND_SAFARI (IconName): The :code:`brand-safari` icon. BRAND_SAMSUNGPASS (IconName): The :code:`brand-samsungpass` icon. BRAND_SASS (IconName): The :code:`brand-sass` icon. BRAND_SENTRY (IconName): The :code:`brand-sentry` icon. BRAND_SHARIK (IconName): The :code:`brand-sharik` icon. BRAND_SHAZAM (IconName): The :code:`brand-shazam` icon. BRAND_SHOPEE (IconName): The :code:`brand-shopee` icon. BRAND_SKETCH (IconName): The :code:`brand-sketch` icon. BRAND_SKYPE (IconName): The :code:`brand-skype` icon. BRAND_SLACK (IconName): The :code:`brand-slack` icon. BRAND_SNAPCHAT (IconName): The :code:`brand-snapchat` icon. BRAND_SNAPSEED (IconName): The :code:`brand-snapseed` icon. BRAND_SNOWFLAKE (IconName): The :code:`brand-snowflake` icon. BRAND_SOCKET_IO (IconName): The :code:`brand-socket-io` icon. BRAND_SOLIDJS (IconName): The :code:`brand-solidjs` icon. BRAND_SOUNDCLOUD (IconName): The :code:`brand-soundcloud` icon. BRAND_SPACEHEY (IconName): The :code:`brand-spacehey` icon. BRAND_SPEEDTEST (IconName): The :code:`brand-speedtest` icon. BRAND_SPOTIFY (IconName): The :code:`brand-spotify` icon. BRAND_STACKOVERFLOW (IconName): The :code:`brand-stackoverflow` icon. BRAND_STACKSHARE (IconName): The :code:`brand-stackshare` icon. BRAND_STEAM (IconName): The :code:`brand-steam` icon. BRAND_STORJ (IconName): The :code:`brand-storj` icon. BRAND_STORYBOOK (IconName): The :code:`brand-storybook` icon. BRAND_STORYTEL (IconName): The :code:`brand-storytel` icon. BRAND_STRAVA (IconName): The :code:`brand-strava` icon. BRAND_STRIPE (IconName): The :code:`brand-stripe` icon. BRAND_SUBLIME_TEXT (IconName): The :code:`brand-sublime-text` icon. BRAND_SUGARIZER (IconName): The :code:`brand-sugarizer` icon. BRAND_SUPABASE (IconName): The :code:`brand-supabase` icon. BRAND_SUPERHUMAN (IconName): The :code:`brand-superhuman` icon. BRAND_SUPERNOVA (IconName): The :code:`brand-supernova` icon. BRAND_SURFSHARK (IconName): The :code:`brand-surfshark` icon. BRAND_SVELTE (IconName): The :code:`brand-svelte` icon. BRAND_SWIFT (IconName): The :code:`brand-swift` icon. BRAND_SYMFONY (IconName): The :code:`brand-symfony` icon. BRAND_TABLER (IconName): The :code:`brand-tabler` icon. BRAND_TAILWIND (IconName): The :code:`brand-tailwind` icon. BRAND_TAOBAO (IconName): The :code:`brand-taobao` icon. BRAND_TED (IconName): The :code:`brand-ted` icon. BRAND_TELEGRAM (IconName): The :code:`brand-telegram` icon. BRAND_TERRAFORM (IconName): The :code:`brand-terraform` icon. BRAND_TETHER (IconName): The :code:`brand-tether` icon. BRAND_THREEJS (IconName): The :code:`brand-threejs` icon. BRAND_TIDAL (IconName): The :code:`brand-tidal` icon. BRAND_TIKTO_FILLED (IconName): The :code:`brand-tikto-filled` icon. BRAND_TIKTOK (IconName): The :code:`brand-tiktok` icon. BRAND_TINDER (IconName): The :code:`brand-tinder` icon. BRAND_TOPBUZZ (IconName): The :code:`brand-topbuzz` icon. BRAND_TORCHAIN (IconName): The :code:`brand-torchain` icon. BRAND_TOYOTA (IconName): The :code:`brand-toyota` icon. BRAND_TRELLO (IconName): The :code:`brand-trello` icon. BRAND_TRIPADVISOR (IconName): The :code:`brand-tripadvisor` icon. BRAND_TUMBLR (IconName): The :code:`brand-tumblr` icon. BRAND_TWILIO (IconName): The :code:`brand-twilio` icon. BRAND_TWITCH (IconName): The :code:`brand-twitch` icon. BRAND_TWITTER (IconName): The :code:`brand-twitter` icon. BRAND_TWITTER_FILLED (IconName): The :code:`brand-twitter-filled` icon. BRAND_TYPESCRIPT (IconName): The :code:`brand-typescript` icon. BRAND_UBER (IconName): The :code:`brand-uber` icon. BRAND_UBUNTU (IconName): The :code:`brand-ubuntu` icon. BRAND_UNITY (IconName): The :code:`brand-unity` icon. BRAND_UNSPLASH (IconName): The :code:`brand-unsplash` icon. BRAND_UPWORK (IconName): The :code:`brand-upwork` icon. BRAND_VALORANT (IconName): The :code:`brand-valorant` icon. BRAND_VERCEL (IconName): The :code:`brand-vercel` icon. BRAND_VIMEO (IconName): The :code:`brand-vimeo` icon. BRAND_VINTED (IconName): The :code:`brand-vinted` icon. BRAND_VISA (IconName): The :code:`brand-visa` icon. BRAND_VISUAL_STUDIO (IconName): The :code:`brand-visual-studio` icon. BRAND_VITE (IconName): The :code:`brand-vite` icon. BRAND_VIVALDI (IconName): The :code:`brand-vivaldi` icon. BRAND_VK (IconName): The :code:`brand-vk` icon. BRAND_VLC (IconName): The :code:`brand-vlc` icon. BRAND_VOLKSWAGEN (IconName): The :code:`brand-volkswagen` icon. BRAND_VSCO (IconName): The :code:`brand-vsco` icon. BRAND_VSCODE (IconName): The :code:`brand-vscode` icon. BRAND_VUE (IconName): The :code:`brand-vue` icon. BRAND_WALMART (IconName): The :code:`brand-walmart` icon. BRAND_WAZE (IconName): The :code:`brand-waze` icon. BRAND_WEBFLOW (IconName): The :code:`brand-webflow` icon. BRAND_WECHAT (IconName): The :code:`brand-wechat` icon. BRAND_WEIBO (IconName): The :code:`brand-weibo` icon. BRAND_WHATSAPP (IconName): The :code:`brand-whatsapp` icon. BRAND_WIKIPEDIA (IconName): The :code:`brand-wikipedia` icon. BRAND_WINDOWS (IconName): The :code:`brand-windows` icon. BRAND_WINDY (IconName): The :code:`brand-windy` icon. BRAND_WISH (IconName): The :code:`brand-wish` icon. BRAND_WIX (IconName): The :code:`brand-wix` icon. BRAND_WORDPRESS (IconName): The :code:`brand-wordpress` icon. BRAND_XAMARIN (IconName): The :code:`brand-xamarin` icon. BRAND_XBOX (IconName): The :code:`brand-xbox` icon. BRAND_XING (IconName): The :code:`brand-xing` icon. BRAND_YAHOO (IconName): The :code:`brand-yahoo` icon. BRAND_YANDEX (IconName): The :code:`brand-yandex` icon. BRAND_YATSE (IconName): The :code:`brand-yatse` icon. BRAND_YCOMBINATOR (IconName): The :code:`brand-ycombinator` icon. BRAND_YOUTUBE (IconName): The :code:`brand-youtube` icon. BRAND_YOUTUBE_KIDS (IconName): The :code:`brand-youtube-kids` icon. BRAND_ZALANDO (IconName): The :code:`brand-zalando` icon. BRAND_ZAPIER (IconName): The :code:`brand-zapier` icon. BRAND_ZEIT (IconName): The :code:`brand-zeit` icon. BRAND_ZHIHU (IconName): The :code:`brand-zhihu` icon. BRAND_ZOOM (IconName): The :code:`brand-zoom` icon. BRAND_ZULIP (IconName): The :code:`brand-zulip` icon. BRAND_ZWIFT (IconName): The :code:`brand-zwift` icon. BREAD (IconName): The :code:`bread` icon. BREAD_OFF (IconName): The :code:`bread-off` icon. BRIEFCASE (IconName): The :code:`briefcase` icon. BRIEFCASE_OFF (IconName): The :code:`briefcase-off` icon. BRIGHTNESS (IconName): The :code:`brightness` icon. BRIGHTNESS_2 (IconName): The :code:`brightness-2` icon. BRIGHTNESS_DOWN (IconName): The :code:`brightness-down` icon. BRIGHTNESS_HALF (IconName): The :code:`brightness-half` icon. BRIGHTNESS_OFF (IconName): The :code:`brightness-off` icon. BRIGHTNESS_UP (IconName): The :code:`brightness-up` icon. BROADCAST (IconName): The :code:`broadcast` icon. BROADCAST_OFF (IconName): The :code:`broadcast-off` icon. BROWSER (IconName): The :code:`browser` icon. BROWSER_CHECK (IconName): The :code:`browser-check` icon. BROWSER_OFF (IconName): The :code:`browser-off` icon. BROWSER_PLUS (IconName): The :code:`browser-plus` icon. BROWSER_X (IconName): The :code:`browser-x` icon. BRUSH (IconName): The :code:`brush` icon. BRUSH_OFF (IconName): The :code:`brush-off` icon. BUCKET (IconName): The :code:`bucket` icon. BUCKET_DROPLET (IconName): The :code:`bucket-droplet` icon. BUCKET_OFF (IconName): The :code:`bucket-off` icon. BUG (IconName): The :code:`bug` icon. BUG_OFF (IconName): The :code:`bug-off` icon. BUILDING (IconName): The :code:`building` icon. BUILDING_ARCH (IconName): The :code:`building-arch` icon. BUILDING_BANK (IconName): The :code:`building-bank` icon. BUILDING_BRIDGE (IconName): The :code:`building-bridge` icon. BUILDING_BRIDGE_2 (IconName): The :code:`building-bridge-2` icon. BUILDING_BROADCAST_TOWER (IconName): The :code:`building-broadcast-tower` icon. BUILDING_CAROUSEL (IconName): The :code:`building-carousel` icon. BUILDING_CASTLE (IconName): The :code:`building-castle` icon. BUILDING_CHURCH (IconName): The :code:`building-church` icon. BUILDING_CIRCUS (IconName): The :code:`building-circus` icon. BUILDING_COMMUNITY (IconName): The :code:`building-community` icon. BUILDING_COTTAGE (IconName): The :code:`building-cottage` icon. BUILDING_ESTATE (IconName): The :code:`building-estate` icon. BUILDING_FACTORY (IconName): The :code:`building-factory` icon. BUILDING_FACTORY_2 (IconName): The :code:`building-factory-2` icon. BUILDING_FORTRESS (IconName): The :code:`building-fortress` icon. BUILDING_HOSPITAL (IconName): The :code:`building-hospital` icon. BUILDING_LIGHTHOUSE (IconName): The :code:`building-lighthouse` icon. BUILDING_MONUMENT (IconName): The :code:`building-monument` icon. BUILDING_MOSQUE (IconName): The :code:`building-mosque` icon. BUILDING_PAVILION (IconName): The :code:`building-pavilion` icon. BUILDING_SKYSCRAPER (IconName): The :code:`building-skyscraper` icon. BUILDING_STADIUM (IconName): The :code:`building-stadium` icon. BUILDING_STORE (IconName): The :code:`building-store` icon. BUILDING_TUNNEL (IconName): The :code:`building-tunnel` icon. BUILDING_WAREHOUSE (IconName): The :code:`building-warehouse` icon. BUILDING_WIND_TURBINE (IconName): The :code:`building-wind-turbine` icon. BULB (IconName): The :code:`bulb` icon. BULB_FILLED (IconName): The :code:`bulb-filled` icon. BULB_OFF (IconName): The :code:`bulb-off` icon. BULLDOZER (IconName): The :code:`bulldozer` icon. BUS (IconName): The :code:`bus` icon. BUS_OFF (IconName): The :code:`bus-off` icon. BUS_STOP (IconName): The :code:`bus-stop` icon. BUSINESSPLAN (IconName): The :code:`businessplan` icon. BUTTERFLY (IconName): The :code:`butterfly` icon. CACTUS (IconName): The :code:`cactus` icon. CACTUS_OFF (IconName): The :code:`cactus-off` icon. CAKE (IconName): The :code:`cake` icon. CAKE_OFF (IconName): The :code:`cake-off` icon. CALCULATOR (IconName): The :code:`calculator` icon. CALCULATOR_OFF (IconName): The :code:`calculator-off` icon. CALENDAR (IconName): The :code:`calendar` icon. CALENDAR_BOLT (IconName): The :code:`calendar-bolt` icon. CALENDAR_CANCEL (IconName): The :code:`calendar-cancel` icon. CALENDAR_CHECK (IconName): The :code:`calendar-check` icon. CALENDAR_CODE (IconName): The :code:`calendar-code` icon. CALENDAR_COG (IconName): The :code:`calendar-cog` icon. CALENDAR_DOLLAR (IconName): The :code:`calendar-dollar` icon. CALENDAR_DOWN (IconName): The :code:`calendar-down` icon. CALENDAR_DUE (IconName): The :code:`calendar-due` icon. CALENDAR_EVENT (IconName): The :code:`calendar-event` icon. CALENDAR_EXCLAMATION (IconName): The :code:`calendar-exclamation` icon. CALENDAR_HEART (IconName): The :code:`calendar-heart` icon. CALENDAR_MINUS (IconName): The :code:`calendar-minus` icon. CALENDAR_OFF (IconName): The :code:`calendar-off` icon. CALENDAR_PAUSE (IconName): The :code:`calendar-pause` icon. CALENDAR_PIN (IconName): The :code:`calendar-pin` icon. CALENDAR_PLUS (IconName): The :code:`calendar-plus` icon. CALENDAR_QUESTION (IconName): The :code:`calendar-question` icon. CALENDAR_REPEAT (IconName): The :code:`calendar-repeat` icon. CALENDAR_SEARCH (IconName): The :code:`calendar-search` icon. CALENDAR_SHARE (IconName): The :code:`calendar-share` icon. CALENDAR_STAR (IconName): The :code:`calendar-star` icon. CALENDAR_STATS (IconName): The :code:`calendar-stats` icon. CALENDAR_TIME (IconName): The :code:`calendar-time` icon. CALENDAR_UP (IconName): The :code:`calendar-up` icon. CALENDAR_X (IconName): The :code:`calendar-x` icon. CAMERA (IconName): The :code:`camera` icon. CAMERA_BOLT (IconName): The :code:`camera-bolt` icon. CAMERA_CANCEL (IconName): The :code:`camera-cancel` icon. CAMERA_CHECK (IconName): The :code:`camera-check` icon. CAMERA_CODE (IconName): The :code:`camera-code` icon. CAMERA_COG (IconName): The :code:`camera-cog` icon. CAMERA_DOLLAR (IconName): The :code:`camera-dollar` icon. CAMERA_DOWN (IconName): The :code:`camera-down` icon. CAMERA_EXCLAMATION (IconName): The :code:`camera-exclamation` icon. CAMERA_FILLED (IconName): The :code:`camera-filled` icon. CAMERA_HEART (IconName): The :code:`camera-heart` icon. CAMERA_MINUS (IconName): The :code:`camera-minus` icon. CAMERA_OFF (IconName): The :code:`camera-off` icon. CAMERA_PAUSE (IconName): The :code:`camera-pause` icon. CAMERA_PIN (IconName): The :code:`camera-pin` icon. CAMERA_PLUS (IconName): The :code:`camera-plus` icon. CAMERA_QUESTION (IconName): The :code:`camera-question` icon. CAMERA_ROTATE (IconName): The :code:`camera-rotate` icon. CAMERA_SEARCH (IconName): The :code:`camera-search` icon. CAMERA_SELFIE (IconName): The :code:`camera-selfie` icon. CAMERA_SHARE (IconName): The :code:`camera-share` icon. CAMERA_STAR (IconName): The :code:`camera-star` icon. CAMERA_UP (IconName): The :code:`camera-up` icon. CAMERA_X (IconName): The :code:`camera-x` icon. CAMPER (IconName): The :code:`camper` icon. CAMPFIRE (IconName): The :code:`campfire` icon. CANDLE (IconName): The :code:`candle` icon. CANDY (IconName): The :code:`candy` icon. CANDY_OFF (IconName): The :code:`candy-off` icon. CANE (IconName): The :code:`cane` icon. CANNABIS (IconName): The :code:`cannabis` icon. CAPSULE (IconName): The :code:`capsule` icon. CAPSULE_HORIZONTAL (IconName): The :code:`capsule-horizontal` icon. CAPTURE (IconName): The :code:`capture` icon. CAPTURE_OFF (IconName): The :code:`capture-off` icon. CAR (IconName): The :code:`car` icon. CAR_CRANE (IconName): The :code:`car-crane` icon. CAR_CRASH (IconName): The :code:`car-crash` icon. CAR_OFF (IconName): The :code:`car-off` icon. CAR_TURBINE (IconName): The :code:`car-turbine` icon. CARAVAN (IconName): The :code:`caravan` icon. CARDBOARDS (IconName): The :code:`cardboards` icon. CARDBOARDS_OFF (IconName): The :code:`cardboards-off` icon. CARDS (IconName): The :code:`cards` icon. CARET_DOWN (IconName): The :code:`caret-down` icon. CARET_LEFT (IconName): The :code:`caret-left` icon. CARET_RIGHT (IconName): The :code:`caret-right` icon. CARET_UP (IconName): The :code:`caret-up` icon. CAROUSEL_HORIZONTAL (IconName): The :code:`carousel-horizontal` icon. CAROUSEL_HORIZONTAL_FILLED (IconName): The :code:`carousel-horizontal-filled` icon. CAROUSEL_VERTICAL (IconName): The :code:`carousel-vertical` icon. CAROUSEL_VERTICAL_FILLED (IconName): The :code:`carousel-vertical-filled` icon. CARROT (IconName): The :code:`carrot` icon. CARROT_OFF (IconName): The :code:`carrot-off` icon. CASH (IconName): The :code:`cash` icon. CASH_BANKNOTE (IconName): The :code:`cash-banknote` icon. CASH_BANKNOTE_OFF (IconName): The :code:`cash-banknote-off` icon. CASH_OFF (IconName): The :code:`cash-off` icon. CAST (IconName): The :code:`cast` icon. CAST_OFF (IconName): The :code:`cast-off` icon. CAT (IconName): The :code:`cat` icon. CATEGORY (IconName): The :code:`category` icon. CATEGORY_2 (IconName): The :code:`category-2` icon. CE (IconName): The :code:`ce` icon. CE_OFF (IconName): The :code:`ce-off` icon. CELL (IconName): The :code:`cell` icon. CELL_SIGNAL_1 (IconName): The :code:`cell-signal-1` icon. CELL_SIGNAL_2 (IconName): The :code:`cell-signal-2` icon. CELL_SIGNAL_3 (IconName): The :code:`cell-signal-3` icon. CELL_SIGNAL_4 (IconName): The :code:`cell-signal-4` icon. CELL_SIGNAL_5 (IconName): The :code:`cell-signal-5` icon. CELL_SIGNAL_OFF (IconName): The :code:`cell-signal-off` icon. CERTIFICATE (IconName): The :code:`certificate` icon. CERTIFICATE_2 (IconName): The :code:`certificate-2` icon. CERTIFICATE_2_OFF (IconName): The :code:`certificate-2-off` icon. CERTIFICATE_OFF (IconName): The :code:`certificate-off` icon. CHAIR_DIRECTOR (IconName): The :code:`chair-director` icon. CHALKBOARD (IconName): The :code:`chalkboard` icon. CHALKBOARD_OFF (IconName): The :code:`chalkboard-off` icon. CHARGING_PILE (IconName): The :code:`charging-pile` icon. CHART_ARCS (IconName): The :code:`chart-arcs` icon. CHART_ARCS_3 (IconName): The :code:`chart-arcs-3` icon. CHART_AREA (IconName): The :code:`chart-area` icon. CHART_AREA_FILLED (IconName): The :code:`chart-area-filled` icon. CHART_AREA_LINE (IconName): The :code:`chart-area-line` icon. CHART_AREA_LINE_FILLED (IconName): The :code:`chart-area-line-filled` icon. CHART_ARROWS (IconName): The :code:`chart-arrows` icon. CHART_ARROWS_VERTICAL (IconName): The :code:`chart-arrows-vertical` icon. CHART_BAR (IconName): The :code:`chart-bar` icon. CHART_BAR_OFF (IconName): The :code:`chart-bar-off` icon. CHART_BUBBLE (IconName): The :code:`chart-bubble` icon. CHART_BUBBLE_FILLED (IconName): The :code:`chart-bubble-filled` icon. CHART_CANDLE (IconName): The :code:`chart-candle` icon. CHART_CANDLE_FILLED (IconName): The :code:`chart-candle-filled` icon. CHART_CIRCLES (IconName): The :code:`chart-circles` icon. CHART_DONUT (IconName): The :code:`chart-donut` icon. CHART_DONUT_2 (IconName): The :code:`chart-donut-2` icon. CHART_DONUT_3 (IconName): The :code:`chart-donut-3` icon. CHART_DONUT_4 (IconName): The :code:`chart-donut-4` icon. CHART_DONUT_FILLED (IconName): The :code:`chart-donut-filled` icon. CHART_DOTS (IconName): The :code:`chart-dots` icon. CHART_DOTS_2 (IconName): The :code:`chart-dots-2` icon. CHART_DOTS_3 (IconName): The :code:`chart-dots-3` icon. CHART_GRID_DOTS (IconName): The :code:`chart-grid-dots` icon. CHART_HISTOGRAM (IconName): The :code:`chart-histogram` icon. CHART_INFOGRAPHIC (IconName): The :code:`chart-infographic` icon. CHART_LINE (IconName): The :code:`chart-line` icon. CHART_PIE (IconName): The :code:`chart-pie` icon. CHART_PIE_2 (IconName): The :code:`chart-pie-2` icon. CHART_PIE_3 (IconName): The :code:`chart-pie-3` icon. CHART_PIE_4 (IconName): The :code:`chart-pie-4` icon. CHART_PIE_FILLED (IconName): The :code:`chart-pie-filled` icon. CHART_PIE_OFF (IconName): The :code:`chart-pie-off` icon. CHART_PPF (IconName): The :code:`chart-ppf` icon. CHART_RADAR (IconName): The :code:`chart-radar` icon. CHART_SANKEY (IconName): The :code:`chart-sankey` icon. CHART_TREEMAP (IconName): The :code:`chart-treemap` icon. CHECK (IconName): The :code:`check` icon. CHECKBOX (IconName): The :code:`checkbox` icon. CHECKLIST (IconName): The :code:`checklist` icon. CHECKS (IconName): The :code:`checks` icon. CHECKUP_LIST (IconName): The :code:`checkup-list` icon. CHEESE (IconName): The :code:`cheese` icon. CHEF_HAT (IconName): The :code:`chef-hat` icon. CHEF_HAT_OFF (IconName): The :code:`chef-hat-off` icon. CHERRY (IconName): The :code:`cherry` icon. CHERRY_FILLED (IconName): The :code:`cherry-filled` icon. CHESS (IconName): The :code:`chess` icon. CHESS_BISHOP (IconName): The :code:`chess-bishop` icon. CHESS_BISHOP_FILLED (IconName): The :code:`chess-bishop-filled` icon. CHESS_FILLED (IconName): The :code:`chess-filled` icon. CHESS_KING (IconName): The :code:`chess-king` icon. CHESS_KING_FILLED (IconName): The :code:`chess-king-filled` icon. CHESS_KNIGHT (IconName): The :code:`chess-knight` icon. CHESS_KNIGHT_FILLED (IconName): The :code:`chess-knight-filled` icon. CHESS_QUEEN (IconName): The :code:`chess-queen` icon. CHESS_QUEEN_FILLED (IconName): The :code:`chess-queen-filled` icon. CHESS_ROOK (IconName): The :code:`chess-rook` icon. CHESS_ROOK_FILLED (IconName): The :code:`chess-rook-filled` icon. CHEVRON_COMPACT_DOWN (IconName): The :code:`chevron-compact-down` icon. CHEVRON_COMPACT_LEFT (IconName): The :code:`chevron-compact-left` icon. CHEVRON_COMPACT_RIGHT (IconName): The :code:`chevron-compact-right` icon. CHEVRON_COMPACT_UP (IconName): The :code:`chevron-compact-up` icon. CHEVRON_DOWN (IconName): The :code:`chevron-down` icon. CHEVRON_DOWN_LEFT (IconName): The :code:`chevron-down-left` icon. CHEVRON_DOWN_RIGHT (IconName): The :code:`chevron-down-right` icon. CHEVRON_LEFT (IconName): The :code:`chevron-left` icon. CHEVRON_LEFT_PIPE (IconName): The :code:`chevron-left-pipe` icon. CHEVRON_RIGHT (IconName): The :code:`chevron-right` icon. CHEVRON_RIGHT_PIPE (IconName): The :code:`chevron-right-pipe` icon. CHEVRON_UP (IconName): The :code:`chevron-up` icon. CHEVRON_UP_LEFT (IconName): The :code:`chevron-up-left` icon. CHEVRON_UP_RIGHT (IconName): The :code:`chevron-up-right` icon. CHEVRONS_DOWN (IconName): The :code:`chevrons-down` icon. CHEVRONS_DOWN_LEFT (IconName): The :code:`chevrons-down-left` icon. CHEVRONS_DOWN_RIGHT (IconName): The :code:`chevrons-down-right` icon. CHEVRONS_LEFT (IconName): The :code:`chevrons-left` icon. CHEVRONS_RIGHT (IconName): The :code:`chevrons-right` icon. CHEVRONS_UP (IconName): The :code:`chevrons-up` icon. CHEVRONS_UP_LEFT (IconName): The :code:`chevrons-up-left` icon. CHEVRONS_UP_RIGHT (IconName): The :code:`chevrons-up-right` icon. CHISEL (IconName): The :code:`chisel` icon. CHRISTMAS_TREE (IconName): The :code:`christmas-tree` icon. CHRISTMAS_TREE_OFF (IconName): The :code:`christmas-tree-off` icon. CIRCLE (IconName): The :code:`circle` icon. CIRCLE_0_FILLED (IconName): The :code:`circle-0-filled` icon. CIRCLE_1_FILLED (IconName): The :code:`circle-1-filled` icon. CIRCLE_2_FILLED (IconName): The :code:`circle-2-filled` icon. CIRCLE_3_FILLED (IconName): The :code:`circle-3-filled` icon. CIRCLE_4_FILLED (IconName): The :code:`circle-4-filled` icon. CIRCLE_5_FILLED (IconName): The :code:`circle-5-filled` icon. CIRCLE_6_FILLED (IconName): The :code:`circle-6-filled` icon. CIRCLE_7_FILLED (IconName): The :code:`circle-7-filled` icon. CIRCLE_8_FILLED (IconName): The :code:`circle-8-filled` icon. CIRCLE_9_FILLED (IconName): The :code:`circle-9-filled` icon. CIRCLE_ARROW_DOWN (IconName): The :code:`circle-arrow-down` icon. CIRCLE_ARROW_DOWN_FILLED (IconName): The :code:`circle-arrow-down-filled` icon. CIRCLE_ARROW_DOWN_LEFT (IconName): The :code:`circle-arrow-down-left` icon. CIRCLE_ARROW_DOWN_LEFT_FILLED (IconName): The :code:`circle-arrow-down-left-filled` icon. CIRCLE_ARROW_DOWN_RIGHT (IconName): The :code:`circle-arrow-down-right` icon. CIRCLE_ARROW_DOWN_RIGHT_FILLED (IconName): The :code:`circle-arrow-down-right-filled` icon. CIRCLE_ARROW_LEFT (IconName): The :code:`circle-arrow-left` icon. CIRCLE_ARROW_LEFT_FILLED (IconName): The :code:`circle-arrow-left-filled` icon. CIRCLE_ARROW_RIGHT (IconName): The :code:`circle-arrow-right` icon. CIRCLE_ARROW_RIGHT_FILLED (IconName): The :code:`circle-arrow-right-filled` icon. CIRCLE_ARROW_UP (IconName): The :code:`circle-arrow-up` icon. CIRCLE_ARROW_UP_FILLED (IconName): The :code:`circle-arrow-up-filled` icon. CIRCLE_ARROW_UP_LEFT (IconName): The :code:`circle-arrow-up-left` icon. CIRCLE_ARROW_UP_LEFT_FILLED (IconName): The :code:`circle-arrow-up-left-filled` icon. CIRCLE_ARROW_UP_RIGHT (IconName): The :code:`circle-arrow-up-right` icon. CIRCLE_ARROW_UP_RIGHT_FILLED (IconName): The :code:`circle-arrow-up-right-filled` icon. CIRCLE_CARET_DOWN (IconName): The :code:`circle-caret-down` icon. CIRCLE_CARET_LEFT (IconName): The :code:`circle-caret-left` icon. CIRCLE_CARET_RIGHT (IconName): The :code:`circle-caret-right` icon. CIRCLE_CARET_UP (IconName): The :code:`circle-caret-up` icon. CIRCLE_CHECK (IconName): The :code:`circle-check` icon. CIRCLE_CHECK_FILLED (IconName): The :code:`circle-check-filled` icon. CIRCLE_CHEVRON_DOWN (IconName): The :code:`circle-chevron-down` icon. CIRCLE_CHEVRON_LEFT (IconName): The :code:`circle-chevron-left` icon. CIRCLE_CHEVRON_RIGHT (IconName): The :code:`circle-chevron-right` icon. CIRCLE_CHEVRON_UP (IconName): The :code:`circle-chevron-up` icon. CIRCLE_CHEVRONS_DOWN (IconName): The :code:`circle-chevrons-down` icon. CIRCLE_CHEVRONS_LEFT (IconName): The :code:`circle-chevrons-left` icon. CIRCLE_CHEVRONS_RIGHT (IconName): The :code:`circle-chevrons-right` icon. CIRCLE_CHEVRONS_UP (IconName): The :code:`circle-chevrons-up` icon. CIRCLE_DASHED (IconName): The :code:`circle-dashed` icon. CIRCLE_DOT (IconName): The :code:`circle-dot` icon. CIRCLE_DOT_FILLED (IconName): The :code:`circle-dot-filled` icon. CIRCLE_DOTTED (IconName): The :code:`circle-dotted` icon. CIRCLE_FILLED (IconName): The :code:`circle-filled` icon. CIRCLE_HALF (IconName): The :code:`circle-half` icon. CIRCLE_HALF_2 (IconName): The :code:`circle-half-2` icon. CIRCLE_HALF_VERTICAL (IconName): The :code:`circle-half-vertical` icon. CIRCLE_KEY (IconName): The :code:`circle-key` icon. CIRCLE_KEY_FILLED (IconName): The :code:`circle-key-filled` icon. CIRCLE_LETTER_A (IconName): The :code:`circle-letter-a` icon. CIRCLE_LETTER_B (IconName): The :code:`circle-letter-b` icon. CIRCLE_LETTER_C (IconName): The :code:`circle-letter-c` icon. CIRCLE_LETTER_D (IconName): The :code:`circle-letter-d` icon. CIRCLE_LETTER_E (IconName): The :code:`circle-letter-e` icon. CIRCLE_LETTER_F (IconName): The :code:`circle-letter-f` icon. CIRCLE_LETTER_G (IconName): The :code:`circle-letter-g` icon. CIRCLE_LETTER_H (IconName): The :code:`circle-letter-h` icon. CIRCLE_LETTER_I (IconName): The :code:`circle-letter-i` icon. CIRCLE_LETTER_J (IconName): The :code:`circle-letter-j` icon. CIRCLE_LETTER_K (IconName): The :code:`circle-letter-k` icon. CIRCLE_LETTER_L (IconName): The :code:`circle-letter-l` icon. CIRCLE_LETTER_M (IconName): The :code:`circle-letter-m` icon. CIRCLE_LETTER_N (IconName): The :code:`circle-letter-n` icon. CIRCLE_LETTER_O (IconName): The :code:`circle-letter-o` icon. CIRCLE_LETTER_P (IconName): The :code:`circle-letter-p` icon. CIRCLE_LETTER_Q (IconName): The :code:`circle-letter-q` icon. CIRCLE_LETTER_R (IconName): The :code:`circle-letter-r` icon. CIRCLE_LETTER_S (IconName): The :code:`circle-letter-s` icon. CIRCLE_LETTER_T (IconName): The :code:`circle-letter-t` icon. CIRCLE_LETTER_U (IconName): The :code:`circle-letter-u` icon. CIRCLE_LETTER_V (IconName): The :code:`circle-letter-v` icon. CIRCLE_LETTER_W (IconName): The :code:`circle-letter-w` icon. CIRCLE_LETTER_X (IconName): The :code:`circle-letter-x` icon. CIRCLE_LETTER_Y (IconName): The :code:`circle-letter-y` icon. CIRCLE_LETTER_Z (IconName): The :code:`circle-letter-z` icon. CIRCLE_MINUS (IconName): The :code:`circle-minus` icon. CIRCLE_NUMBER_0 (IconName): The :code:`circle-number-0` icon. CIRCLE_NUMBER_1 (IconName): The :code:`circle-number-1` icon. CIRCLE_NUMBER_2 (IconName): The :code:`circle-number-2` icon. CIRCLE_NUMBER_3 (IconName): The :code:`circle-number-3` icon. CIRCLE_NUMBER_4 (IconName): The :code:`circle-number-4` icon. CIRCLE_NUMBER_5 (IconName): The :code:`circle-number-5` icon. CIRCLE_NUMBER_6 (IconName): The :code:`circle-number-6` icon. CIRCLE_NUMBER_7 (IconName): The :code:`circle-number-7` icon. CIRCLE_NUMBER_8 (IconName): The :code:`circle-number-8` icon. CIRCLE_NUMBER_9 (IconName): The :code:`circle-number-9` icon. CIRCLE_OFF (IconName): The :code:`circle-off` icon. CIRCLE_PLUS (IconName): The :code:`circle-plus` icon. CIRCLE_RECTANGLE (IconName): The :code:`circle-rectangle` icon. CIRCLE_RECTANGLE_OFF (IconName): The :code:`circle-rectangle-off` icon. CIRCLE_SQUARE (IconName): The :code:`circle-square` icon. CIRCLE_TRIANGLE (IconName): The :code:`circle-triangle` icon. CIRCLE_X (IconName): The :code:`circle-x` icon. CIRCLE_X_FILLED (IconName): The :code:`circle-x-filled` icon. CIRCLES (IconName): The :code:`circles` icon. CIRCLES_FILLED (IconName): The :code:`circles-filled` icon. CIRCLES_RELATION (IconName): The :code:`circles-relation` icon. CIRCUIT_AMMETER (IconName): The :code:`circuit-ammeter` icon. CIRCUIT_BATTERY (IconName): The :code:`circuit-battery` icon. CIRCUIT_BULB (IconName): The :code:`circuit-bulb` icon. CIRCUIT_CAPACITOR (IconName): The :code:`circuit-capacitor` icon. CIRCUIT_CAPACITOR_POLARIZED (IconName): The :code:`circuit-capacitor-polarized` icon. CIRCUIT_CELL (IconName): The :code:`circuit-cell` icon. CIRCUIT_CELL_PLUS (IconName): The :code:`circuit-cell-plus` icon. CIRCUIT_CHANGEOVER (IconName): The :code:`circuit-changeover` icon. CIRCUIT_DIODE (IconName): The :code:`circuit-diode` icon. CIRCUIT_DIODE_ZENER (IconName): The :code:`circuit-diode-zener` icon. CIRCUIT_GROUND (IconName): The :code:`circuit-ground` icon. CIRCUIT_GROUND_DIGITAL (IconName): The :code:`circuit-ground-digital` icon. CIRCUIT_INDUCTOR (IconName): The :code:`circuit-inductor` icon. CIRCUIT_MOTOR (IconName): The :code:`circuit-motor` icon. CIRCUIT_PUSHBUTTON (IconName): The :code:`circuit-pushbutton` icon. CIRCUIT_RESISTOR (IconName): The :code:`circuit-resistor` icon. CIRCUIT_SWITCH_CLOSED (IconName): The :code:`circuit-switch-closed` icon. CIRCUIT_SWITCH_OPEN (IconName): The :code:`circuit-switch-open` icon. CIRCUIT_VOLTMETER (IconName): The :code:`circuit-voltmeter` icon. CLEAR_ALL (IconName): The :code:`clear-all` icon. CLEAR_FORMATTING (IconName): The :code:`clear-formatting` icon. CLICK (IconName): The :code:`click` icon. CLIPBOARD (IconName): The :code:`clipboard` icon. CLIPBOARD_CHECK (IconName): The :code:`clipboard-check` icon. CLIPBOARD_COPY (IconName): The :code:`clipboard-copy` icon. CLIPBOARD_DATA (IconName): The :code:`clipboard-data` icon. CLIPBOARD_HEART (IconName): The :code:`clipboard-heart` icon. CLIPBOARD_LIST (IconName): The :code:`clipboard-list` icon. CLIPBOARD_OFF (IconName): The :code:`clipboard-off` icon. CLIPBOARD_PLUS (IconName): The :code:`clipboard-plus` icon. CLIPBOARD_TEXT (IconName): The :code:`clipboard-text` icon. CLIPBOARD_TYPOGRAPHY (IconName): The :code:`clipboard-typography` icon. CLIPBOARD_X (IconName): The :code:`clipboard-x` icon. CLOCK (IconName): The :code:`clock` icon. CLOCK_2 (IconName): The :code:`clock-2` icon. CLOCK_BOLT (IconName): The :code:`clock-bolt` icon. CLOCK_CANCEL (IconName): The :code:`clock-cancel` icon. CLOCK_CHECK (IconName): The :code:`clock-check` icon. CLOCK_CODE (IconName): The :code:`clock-code` icon. CLOCK_COG (IconName): The :code:`clock-cog` icon. CLOCK_DOLLAR (IconName): The :code:`clock-dollar` icon. CLOCK_DOWN (IconName): The :code:`clock-down` icon. CLOCK_EDIT (IconName): The :code:`clock-edit` icon. CLOCK_EXCLAMATION (IconName): The :code:`clock-exclamation` icon. CLOCK_FILLED (IconName): The :code:`clock-filled` icon. CLOCK_HEART (IconName): The :code:`clock-heart` icon. CLOCK_HOUR_1 (IconName): The :code:`clock-hour-1` icon. CLOCK_HOUR_10 (IconName): The :code:`clock-hour-10` icon. CLOCK_HOUR_11 (IconName): The :code:`clock-hour-11` icon. CLOCK_HOUR_12 (IconName): The :code:`clock-hour-12` icon. CLOCK_HOUR_2 (IconName): The :code:`clock-hour-2` icon. CLOCK_HOUR_3 (IconName): The :code:`clock-hour-3` icon. CLOCK_HOUR_4 (IconName): The :code:`clock-hour-4` icon. CLOCK_HOUR_5 (IconName): The :code:`clock-hour-5` icon. CLOCK_HOUR_6 (IconName): The :code:`clock-hour-6` icon. CLOCK_HOUR_7 (IconName): The :code:`clock-hour-7` icon. CLOCK_HOUR_8 (IconName): The :code:`clock-hour-8` icon. CLOCK_HOUR_9 (IconName): The :code:`clock-hour-9` icon. CLOCK_MINUS (IconName): The :code:`clock-minus` icon. CLOCK_OFF (IconName): The :code:`clock-off` icon. CLOCK_PAUSE (IconName): The :code:`clock-pause` icon. CLOCK_PIN (IconName): The :code:`clock-pin` icon. CLOCK_PLAY (IconName): The :code:`clock-play` icon. CLOCK_PLUS (IconName): The :code:`clock-plus` icon. CLOCK_QUESTION (IconName): The :code:`clock-question` icon. CLOCK_RECORD (IconName): The :code:`clock-record` icon. CLOCK_SEARCH (IconName): The :code:`clock-search` icon. CLOCK_SHARE (IconName): The :code:`clock-share` icon. CLOCK_SHIELD (IconName): The :code:`clock-shield` icon. CLOCK_STAR (IconName): The :code:`clock-star` icon. CLOCK_STOP (IconName): The :code:`clock-stop` icon. CLOCK_UP (IconName): The :code:`clock-up` icon. CLOCK_X (IconName): The :code:`clock-x` icon. CLOTHES_RACK (IconName): The :code:`clothes-rack` icon. CLOTHES_RACK_OFF (IconName): The :code:`clothes-rack-off` icon. CLOUD (IconName): The :code:`cloud` icon. CLOUD_BOLT (IconName): The :code:`cloud-bolt` icon. CLOUD_CANCEL (IconName): The :code:`cloud-cancel` icon. CLOUD_CHECK (IconName): The :code:`cloud-check` icon. CLOUD_CODE (IconName): The :code:`cloud-code` icon. CLOUD_COG (IconName): The :code:`cloud-cog` icon. CLOUD_COMPUTING (IconName): The :code:`cloud-computing` icon. CLOUD_DATA_CONNECTION (IconName): The :code:`cloud-data-connection` icon. CLOUD_DOLLAR (IconName): The :code:`cloud-dollar` icon. CLOUD_DOWN (IconName): The :code:`cloud-down` icon. CLOUD_DOWNLOAD (IconName): The :code:`cloud-download` icon. CLOUD_EXCLAMATION (IconName): The :code:`cloud-exclamation` icon. CLOUD_FILLED (IconName): The :code:`cloud-filled` icon. CLOUD_FOG (IconName): The :code:`cloud-fog` icon. CLOUD_HEART (IconName): The :code:`cloud-heart` icon. CLOUD_LOCK (IconName): The :code:`cloud-lock` icon. CLOUD_LOCK_OPEN (IconName): The :code:`cloud-lock-open` icon. CLOUD_MINUS (IconName): The :code:`cloud-minus` icon. CLOUD_OFF (IconName): The :code:`cloud-off` icon. CLOUD_PAUSE (IconName): The :code:`cloud-pause` icon. CLOUD_PIN (IconName): The :code:`cloud-pin` icon. CLOUD_PLUS (IconName): The :code:`cloud-plus` icon. CLOUD_QUESTION (IconName): The :code:`cloud-question` icon. CLOUD_RAIN (IconName): The :code:`cloud-rain` icon. CLOUD_SEARCH (IconName): The :code:`cloud-search` icon. CLOUD_SHARE (IconName): The :code:`cloud-share` icon. CLOUD_SNOW (IconName): The :code:`cloud-snow` icon. CLOUD_STAR (IconName): The :code:`cloud-star` icon. CLOUD_STORM (IconName): The :code:`cloud-storm` icon. CLOUD_UP (IconName): The :code:`cloud-up` icon. CLOUD_UPLOAD (IconName): The :code:`cloud-upload` icon. CLOUD_X (IconName): The :code:`cloud-x` icon. CLOVER (IconName): The :code:`clover` icon. CLOVER_2 (IconName): The :code:`clover-2` icon. CLUBS (IconName): The :code:`clubs` icon. CLUBS_FILLED (IconName): The :code:`clubs-filled` icon. CODE (IconName): The :code:`code` icon. CODE_ASTERIX (IconName): The :code:`code-asterix` icon. CODE_CIRCLE (IconName): The :code:`code-circle` icon. CODE_CIRCLE_2 (IconName): The :code:`code-circle-2` icon. CODE_DOTS (IconName): The :code:`code-dots` icon. CODE_MINUS (IconName): The :code:`code-minus` icon. CODE_OFF (IconName): The :code:`code-off` icon. CODE_PLUS (IconName): The :code:`code-plus` icon. COFFEE (IconName): The :code:`coffee` icon. COFFEE_OFF (IconName): The :code:`coffee-off` icon. COFFIN (IconName): The :code:`coffin` icon. COIN (IconName): The :code:`coin` icon. COIN_BITCOIN (IconName): The :code:`coin-bitcoin` icon. COIN_EURO (IconName): The :code:`coin-euro` icon. COIN_MONERO (IconName): The :code:`coin-monero` icon. COIN_OFF (IconName): The :code:`coin-off` icon. COIN_POUND (IconName): The :code:`coin-pound` icon. COIN_RUPEE (IconName): The :code:`coin-rupee` icon. COIN_YEN (IconName): The :code:`coin-yen` icon. COIN_YUAN (IconName): The :code:`coin-yuan` icon. COINS (IconName): The :code:`coins` icon. COLOR_FILTER (IconName): The :code:`color-filter` icon. COLOR_PICKER (IconName): The :code:`color-picker` icon. COLOR_PICKER_OFF (IconName): The :code:`color-picker-off` icon. COLOR_SWATCH (IconName): The :code:`color-swatch` icon. COLOR_SWATCH_OFF (IconName): The :code:`color-swatch-off` icon. COLUMN_INSERT_LEFT (IconName): The :code:`column-insert-left` icon. COLUMN_INSERT_RIGHT (IconName): The :code:`column-insert-right` icon. COLUMN_REMOVE (IconName): The :code:`column-remove` icon. COLUMNS (IconName): The :code:`columns` icon. COLUMNS_1 (IconName): The :code:`columns-1` icon. COLUMNS_2 (IconName): The :code:`columns-2` icon. COLUMNS_3 (IconName): The :code:`columns-3` icon. COLUMNS_OFF (IconName): The :code:`columns-off` icon. COMET (IconName): The :code:`comet` icon. COMMAND (IconName): The :code:`command` icon. COMMAND_OFF (IconName): The :code:`command-off` icon. COMPASS (IconName): The :code:`compass` icon. COMPASS_OFF (IconName): The :code:`compass-off` icon. COMPONENTS (IconName): The :code:`components` icon. COMPONENTS_OFF (IconName): The :code:`components-off` icon. CONE (IconName): The :code:`cone` icon. CONE_2 (IconName): The :code:`cone-2` icon. CONE_OFF (IconName): The :code:`cone-off` icon. CONE_PLUS (IconName): The :code:`cone-plus` icon. CONFETTI (IconName): The :code:`confetti` icon. CONFETTI_OFF (IconName): The :code:`confetti-off` icon. CONFUCIUS (IconName): The :code:`confucius` icon. CONTAINER (IconName): The :code:`container` icon. CONTAINER_OFF (IconName): The :code:`container-off` icon. CONTRAST (IconName): The :code:`contrast` icon. CONTRAST_2 (IconName): The :code:`contrast-2` icon. CONTRAST_2_OFF (IconName): The :code:`contrast-2-off` icon. CONTRAST_OFF (IconName): The :code:`contrast-off` icon. COOKER (IconName): The :code:`cooker` icon. COOKIE (IconName): The :code:`cookie` icon. COOKIE_MAN (IconName): The :code:`cookie-man` icon. COOKIE_OFF (IconName): The :code:`cookie-off` icon. COPY (IconName): The :code:`copy` icon. COPY_OFF (IconName): The :code:`copy-off` icon. COPYLEFT (IconName): The :code:`copyleft` icon. COPYLEFT_FILLED (IconName): The :code:`copyleft-filled` icon. COPYLEFT_OFF (IconName): The :code:`copyleft-off` icon. COPYRIGHT (IconName): The :code:`copyright` icon. COPYRIGHT_FILLED (IconName): The :code:`copyright-filled` icon. COPYRIGHT_OFF (IconName): The :code:`copyright-off` icon. CORNER_DOWN_LEFT (IconName): The :code:`corner-down-left` icon. CORNER_DOWN_LEFT_DOUBLE (IconName): The :code:`corner-down-left-double` icon. CORNER_DOWN_RIGHT (IconName): The :code:`corner-down-right` icon. CORNER_DOWN_RIGHT_DOUBLE (IconName): The :code:`corner-down-right-double` icon. CORNER_LEFT_DOWN (IconName): The :code:`corner-left-down` icon. CORNER_LEFT_DOWN_DOUBLE (IconName): The :code:`corner-left-down-double` icon. CORNER_LEFT_UP (IconName): The :code:`corner-left-up` icon. CORNER_LEFT_UP_DOUBLE (IconName): The :code:`corner-left-up-double` icon. CORNER_RIGHT_DOWN (IconName): The :code:`corner-right-down` icon. CORNER_RIGHT_DOWN_DOUBLE (IconName): The :code:`corner-right-down-double` icon. CORNER_RIGHT_UP (IconName): The :code:`corner-right-up` icon. CORNER_RIGHT_UP_DOUBLE (IconName): The :code:`corner-right-up-double` icon. CORNER_UP_LEFT (IconName): The :code:`corner-up-left` icon. CORNER_UP_LEFT_DOUBLE (IconName): The :code:`corner-up-left-double` icon. CORNER_UP_RIGHT (IconName): The :code:`corner-up-right` icon. CORNER_UP_RIGHT_DOUBLE (IconName): The :code:`corner-up-right-double` icon. CPU (IconName): The :code:`cpu` icon. CPU_2 (IconName): The :code:`cpu-2` icon. CPU_OFF (IconName): The :code:`cpu-off` icon. CRANE (IconName): The :code:`crane` icon. CRANE_OFF (IconName): The :code:`crane-off` icon. CREATIVE_COMMONS (IconName): The :code:`creative-commons` icon. CREATIVE_COMMONS_BY (IconName): The :code:`creative-commons-by` icon. CREATIVE_COMMONS_NC (IconName): The :code:`creative-commons-nc` icon. CREATIVE_COMMONS_ND (IconName): The :code:`creative-commons-nd` icon. CREATIVE_COMMONS_OFF (IconName): The :code:`creative-commons-off` icon. CREATIVE_COMMONS_SA (IconName): The :code:`creative-commons-sa` icon. CREATIVE_COMMONS_ZERO (IconName): The :code:`creative-commons-zero` icon. CREDIT_CARD (IconName): The :code:`credit-card` icon. CREDIT_CARD_OFF (IconName): The :code:`credit-card-off` icon. CRICKET (IconName): The :code:`cricket` icon. CROP (IconName): The :code:`crop` icon. CROSS (IconName): The :code:`cross` icon. CROSS_FILLED (IconName): The :code:`cross-filled` icon. CROSS_OFF (IconName): The :code:`cross-off` icon. CROSSHAIR (IconName): The :code:`crosshair` icon. CROWN (IconName): The :code:`crown` icon. CROWN_OFF (IconName): The :code:`crown-off` icon. CRUTCHES (IconName): The :code:`crutches` icon. CRUTCHES_OFF (IconName): The :code:`crutches-off` icon. CRYSTAL_BALL (IconName): The :code:`crystal-ball` icon. CSV (IconName): The :code:`csv` icon. CUBE (IconName): The :code:`cube` icon. CUBE_OFF (IconName): The :code:`cube-off` icon. CUBE_PLUS (IconName): The :code:`cube-plus` icon. CUBE_SEND (IconName): The :code:`cube-send` icon. CUBE_UNFOLDED (IconName): The :code:`cube-unfolded` icon. CUP (IconName): The :code:`cup` icon. CUP_OFF (IconName): The :code:`cup-off` icon. CURLING (IconName): The :code:`curling` icon. CURLY_LOOP (IconName): The :code:`curly-loop` icon. CURRENCY (IconName): The :code:`currency` icon. CURRENCY_AFGHANI (IconName): The :code:`currency-afghani` icon. CURRENCY_BAHRAINI (IconName): The :code:`currency-bahraini` icon. CURRENCY_BAHT (IconName): The :code:`currency-baht` icon. CURRENCY_BITCOIN (IconName): The :code:`currency-bitcoin` icon. CURRENCY_CENT (IconName): The :code:`currency-cent` icon. CURRENCY_DINAR (IconName): The :code:`currency-dinar` icon. CURRENCY_DIRHAM (IconName): The :code:`currency-dirham` icon. CURRENCY_DOGECOIN (IconName): The :code:`currency-dogecoin` icon. CURRENCY_DOLLAR (IconName): The :code:`currency-dollar` icon. CURRENCY_DOLLAR_AUSTRALIAN (IconName): The :code:`currency-dollar-australian` icon. CURRENCY_DOLLAR_BRUNEI (IconName): The :code:`currency-dollar-brunei` icon. CURRENCY_DOLLAR_CANADIAN (IconName): The :code:`currency-dollar-canadian` icon. CURRENCY_DOLLAR_GUYANESE (IconName): The :code:`currency-dollar-guyanese` icon. CURRENCY_DOLLAR_OFF (IconName): The :code:`currency-dollar-off` icon. CURRENCY_DOLLAR_SINGAPORE (IconName): The :code:`currency-dollar-singapore` icon. CURRENCY_DOLLAR_ZIMBABWEAN (IconName): The :code:`currency-dollar-zimbabwean` icon. CURRENCY_DONG (IconName): The :code:`currency-dong` icon. CURRENCY_DRAM (IconName): The :code:`currency-dram` icon. CURRENCY_ETHEREUM (IconName): The :code:`currency-ethereum` icon. CURRENCY_EURO (IconName): The :code:`currency-euro` icon. CURRENCY_EURO_OFF (IconName): The :code:`currency-euro-off` icon. CURRENCY_FLORIN (IconName): The :code:`currency-florin` icon. CURRENCY_FORINT (IconName): The :code:`currency-forint` icon. CURRENCY_FRANK (IconName): The :code:`currency-frank` icon. CURRENCY_GUARANI (IconName): The :code:`currency-guarani` icon. CURRENCY_HRYVNIA (IconName): The :code:`currency-hryvnia` icon. CURRENCY_IRANIAN_RIAL (IconName): The :code:`currency-iranian-rial` icon. CURRENCY_KIP (IconName): The :code:`currency-kip` icon. CURRENCY_KRONE_CZECH (IconName): The :code:`currency-krone-czech` icon. CURRENCY_KRONE_DANISH (IconName): The :code:`currency-krone-danish` icon. CURRENCY_KRONE_SWEDISH (IconName): The :code:`currency-krone-swedish` icon. CURRENCY_LARI (IconName): The :code:`currency-lari` icon. CURRENCY_LEU (IconName): The :code:`currency-leu` icon. CURRENCY_LIRA (IconName): The :code:`currency-lira` icon. CURRENCY_LITECOIN (IconName): The :code:`currency-litecoin` icon. CURRENCY_LYD (IconName): The :code:`currency-lyd` icon. CURRENCY_MANAT (IconName): The :code:`currency-manat` icon. CURRENCY_MONERO (IconName): The :code:`currency-monero` icon. CURRENCY_NAIRA (IconName): The :code:`currency-naira` icon. CURRENCY_NANO (IconName): The :code:`currency-nano` icon. CURRENCY_OFF (IconName): The :code:`currency-off` icon. CURRENCY_PAANGA (IconName): The :code:`currency-paanga` icon. CURRENCY_PESO (IconName): The :code:`currency-peso` icon. CURRENCY_POUND (IconName): The :code:`currency-pound` icon. CURRENCY_POUND_OFF (IconName): The :code:`currency-pound-off` icon. CURRENCY_QUETZAL (IconName): The :code:`currency-quetzal` icon. CURRENCY_REAL (IconName): The :code:`currency-real` icon. CURRENCY_RENMINBI (IconName): The :code:`currency-renminbi` icon. CURRENCY_RIPPLE (IconName): The :code:`currency-ripple` icon. CURRENCY_RIYAL (IconName): The :code:`currency-riyal` icon. CURRENCY_RUBEL (IconName): The :code:`currency-rubel` icon. CURRENCY_RUFIYAA (IconName): The :code:`currency-rufiyaa` icon. CURRENCY_RUPEE (IconName): The :code:`currency-rupee` icon. CURRENCY_RUPEE_NEPALESE (IconName): The :code:`currency-rupee-nepalese` icon. CURRENCY_SHEKEL (IconName): The :code:`currency-shekel` icon. CURRENCY_SOLANA (IconName): The :code:`currency-solana` icon. CURRENCY_SOM (IconName): The :code:`currency-som` icon. CURRENCY_TAKA (IconName): The :code:`currency-taka` icon. CURRENCY_TENGE (IconName): The :code:`currency-tenge` icon. CURRENCY_TUGRIK (IconName): The :code:`currency-tugrik` icon. CURRENCY_WON (IconName): The :code:`currency-won` icon. CURRENCY_YEN (IconName): The :code:`currency-yen` icon. CURRENCY_YEN_OFF (IconName): The :code:`currency-yen-off` icon. CURRENCY_YUAN (IconName): The :code:`currency-yuan` icon. CURRENCY_ZLOTY (IconName): The :code:`currency-zloty` icon. CURRENT_LOCATION (IconName): The :code:`current-location` icon. CURRENT_LOCATION_OFF (IconName): The :code:`current-location-off` icon. CURSOR_OFF (IconName): The :code:`cursor-off` icon. CURSOR_TEXT (IconName): The :code:`cursor-text` icon. CUT (IconName): The :code:`cut` icon. CYLINDER (IconName): The :code:`cylinder` icon. CYLINDER_OFF (IconName): The :code:`cylinder-off` icon. CYLINDER_PLUS (IconName): The :code:`cylinder-plus` icon. DASHBOARD (IconName): The :code:`dashboard` icon. DASHBOARD_OFF (IconName): The :code:`dashboard-off` icon. DATABASE (IconName): The :code:`database` icon. DATABASE_COG (IconName): The :code:`database-cog` icon. DATABASE_DOLLAR (IconName): The :code:`database-dollar` icon. DATABASE_EDIT (IconName): The :code:`database-edit` icon. DATABASE_EXCLAMATION (IconName): The :code:`database-exclamation` icon. DATABASE_EXPORT (IconName): The :code:`database-export` icon. DATABASE_HEART (IconName): The :code:`database-heart` icon. DATABASE_IMPORT (IconName): The :code:`database-import` icon. DATABASE_LEAK (IconName): The :code:`database-leak` icon. DATABASE_MINUS (IconName): The :code:`database-minus` icon. DATABASE_OFF (IconName): The :code:`database-off` icon. DATABASE_PLUS (IconName): The :code:`database-plus` icon. DATABASE_SEARCH (IconName): The :code:`database-search` icon. DATABASE_SHARE (IconName): The :code:`database-share` icon. DATABASE_STAR (IconName): The :code:`database-star` icon. DATABASE_X (IconName): The :code:`database-x` icon. DECIMAL (IconName): The :code:`decimal` icon. DEER (IconName): The :code:`deer` icon. DELTA (IconName): The :code:`delta` icon. DENTAL (IconName): The :code:`dental` icon. DENTAL_BROKEN (IconName): The :code:`dental-broken` icon. DENTAL_OFF (IconName): The :code:`dental-off` icon. DESELECT (IconName): The :code:`deselect` icon. DETAILS (IconName): The :code:`details` icon. DETAILS_OFF (IconName): The :code:`details-off` icon. DEVICE_AIRPODS (IconName): The :code:`device-airpods` icon. DEVICE_AIRPODS_CASE (IconName): The :code:`device-airpods-case` icon. DEVICE_AIRTAG (IconName): The :code:`device-airtag` icon. DEVICE_ANALYTICS (IconName): The :code:`device-analytics` icon. DEVICE_AUDIO_TAPE (IconName): The :code:`device-audio-tape` icon. DEVICE_CAMERA_PHONE (IconName): The :code:`device-camera-phone` icon. DEVICE_CCTV (IconName): The :code:`device-cctv` icon. DEVICE_CCTV_OFF (IconName): The :code:`device-cctv-off` icon. DEVICE_COMPUTER_CAMERA (IconName): The :code:`device-computer-camera` icon. DEVICE_COMPUTER_CAMERA_OFF (IconName): The :code:`device-computer-camera-off` icon. DEVICE_DESKTOP (IconName): The :code:`device-desktop` icon. DEVICE_DESKTOP_ANALYTICS (IconName): The :code:`device-desktop-analytics` icon. DEVICE_DESKTOP_BOLT (IconName): The :code:`device-desktop-bolt` icon. DEVICE_DESKTOP_CANCEL (IconName): The :code:`device-desktop-cancel` icon. DEVICE_DESKTOP_CHECK (IconName): The :code:`device-desktop-check` icon. DEVICE_DESKTOP_CODE (IconName): The :code:`device-desktop-code` icon. DEVICE_DESKTOP_COG (IconName): The :code:`device-desktop-cog` icon. DEVICE_DESKTOP_DOLLAR (IconName): The :code:`device-desktop-dollar` icon. DEVICE_DESKTOP_DOWN (IconName): The :code:`device-desktop-down` icon. DEVICE_DESKTOP_EXCLAMATION (IconName): The :code:`device-desktop-exclamation` icon. DEVICE_DESKTOP_HEART (IconName): The :code:`device-desktop-heart` icon. DEVICE_DESKTOP_MINUS (IconName): The :code:`device-desktop-minus` icon. DEVICE_DESKTOP_OFF (IconName): The :code:`device-desktop-off` icon. DEVICE_DESKTOP_PAUSE (IconName): The :code:`device-desktop-pause` icon. DEVICE_DESKTOP_PIN (IconName): The :code:`device-desktop-pin` icon. DEVICE_DESKTOP_PLUS (IconName): The :code:`device-desktop-plus` icon. DEVICE_DESKTOP_QUESTION (IconName): The :code:`device-desktop-question` icon. DEVICE_DESKTOP_SEARCH (IconName): The :code:`device-desktop-search` icon. DEVICE_DESKTOP_SHARE (IconName): The :code:`device-desktop-share` icon. DEVICE_DESKTOP_STAR (IconName): The :code:`device-desktop-star` icon. DEVICE_DESKTOP_UP (IconName): The :code:`device-desktop-up` icon. DEVICE_DESKTOP_X (IconName): The :code:`device-desktop-x` icon. DEVICE_FLOPPY (IconName): The :code:`device-floppy` icon. DEVICE_GAMEPAD (IconName): The :code:`device-gamepad` icon. DEVICE_GAMEPAD_2 (IconName): The :code:`device-gamepad-2` icon. DEVICE_HEART_MONITOR (IconName): The :code:`device-heart-monitor` icon. DEVICE_HEART_MONITOR_FILLED (IconName): The :code:`device-heart-monitor-filled` icon. DEVICE_IMAC (IconName): The :code:`device-imac` icon. DEVICE_IMAC_BOLT (IconName): The :code:`device-imac-bolt` icon. DEVICE_IMAC_CANCEL (IconName): The :code:`device-imac-cancel` icon. DEVICE_IMAC_CHECK (IconName): The :code:`device-imac-check` icon. DEVICE_IMAC_CODE (IconName): The :code:`device-imac-code` icon. DEVICE_IMAC_COG (IconName): The :code:`device-imac-cog` icon. DEVICE_IMAC_DOLLAR (IconName): The :code:`device-imac-dollar` icon. DEVICE_IMAC_DOWN (IconName): The :code:`device-imac-down` icon. DEVICE_IMAC_EXCLAMATION (IconName): The :code:`device-imac-exclamation` icon. DEVICE_IMAC_HEART (IconName): The :code:`device-imac-heart` icon. DEVICE_IMAC_MINUS (IconName): The :code:`device-imac-minus` icon. DEVICE_IMAC_OFF (IconName): The :code:`device-imac-off` icon. DEVICE_IMAC_PAUSE (IconName): The :code:`device-imac-pause` icon. DEVICE_IMAC_PIN (IconName): The :code:`device-imac-pin` icon. DEVICE_IMAC_PLUS (IconName): The :code:`device-imac-plus` icon. DEVICE_IMAC_QUESTION (IconName): The :code:`device-imac-question` icon. DEVICE_IMAC_SEARCH (IconName): The :code:`device-imac-search` icon. DEVICE_IMAC_SHARE (IconName): The :code:`device-imac-share` icon. DEVICE_IMAC_STAR (IconName): The :code:`device-imac-star` icon. DEVICE_IMAC_UP (IconName): The :code:`device-imac-up` icon. DEVICE_IMAC_X (IconName): The :code:`device-imac-x` icon. DEVICE_IPAD (IconName): The :code:`device-ipad` icon. DEVICE_IPAD_BOLT (IconName): The :code:`device-ipad-bolt` icon. DEVICE_IPAD_CANCEL (IconName): The :code:`device-ipad-cancel` icon. DEVICE_IPAD_CHECK (IconName): The :code:`device-ipad-check` icon. DEVICE_IPAD_CODE (IconName): The :code:`device-ipad-code` icon. DEVICE_IPAD_COG (IconName): The :code:`device-ipad-cog` icon. DEVICE_IPAD_DOLLAR (IconName): The :code:`device-ipad-dollar` icon. DEVICE_IPAD_DOWN (IconName): The :code:`device-ipad-down` icon. DEVICE_IPAD_EXCLAMATION (IconName): The :code:`device-ipad-exclamation` icon. DEVICE_IPAD_HEART (IconName): The :code:`device-ipad-heart` icon. DEVICE_IPAD_HORIZONTAL (IconName): The :code:`device-ipad-horizontal` icon. DEVICE_IPAD_HORIZONTAL_BOLT (IconName): The :code:`device-ipad-horizontal-bolt` icon. DEVICE_IPAD_HORIZONTAL_CANCEL (IconName): The :code:`device-ipad-horizontal-cancel` icon. DEVICE_IPAD_HORIZONTAL_CHECK (IconName): The :code:`device-ipad-horizontal-check` icon. DEVICE_IPAD_HORIZONTAL_CODE (IconName): The :code:`device-ipad-horizontal-code` icon. DEVICE_IPAD_HORIZONTAL_COG (IconName): The :code:`device-ipad-horizontal-cog` icon. DEVICE_IPAD_HORIZONTAL_DOLLAR (IconName): The :code:`device-ipad-horizontal-dollar` icon. DEVICE_IPAD_HORIZONTAL_DOWN (IconName): The :code:`device-ipad-horizontal-down` icon. DEVICE_IPAD_HORIZONTAL_EXCLAMATION (IconName): The :code:`device-ipad-horizontal-exclamation` icon. DEVICE_IPAD_HORIZONTAL_HEART (IconName): The :code:`device-ipad-horizontal-heart` icon. DEVICE_IPAD_HORIZONTAL_MINUS (IconName): The :code:`device-ipad-horizontal-minus` icon. DEVICE_IPAD_HORIZONTAL_OFF (IconName): The :code:`device-ipad-horizontal-off` icon. DEVICE_IPAD_HORIZONTAL_PAUSE (IconName): The :code:`device-ipad-horizontal-pause` icon. DEVICE_IPAD_HORIZONTAL_PIN (IconName): The :code:`device-ipad-horizontal-pin` icon. DEVICE_IPAD_HORIZONTAL_PLUS (IconName): The :code:`device-ipad-horizontal-plus` icon. DEVICE_IPAD_HORIZONTAL_QUESTION (IconName): The :code:`device-ipad-horizontal-question` icon. DEVICE_IPAD_HORIZONTAL_SEARCH (IconName): The :code:`device-ipad-horizontal-search` icon. DEVICE_IPAD_HORIZONTAL_SHARE (IconName): The :code:`device-ipad-horizontal-share` icon. DEVICE_IPAD_HORIZONTAL_STAR (IconName): The :code:`device-ipad-horizontal-star` icon. DEVICE_IPAD_HORIZONTAL_UP (IconName): The :code:`device-ipad-horizontal-up` icon. DEVICE_IPAD_HORIZONTAL_X (IconName): The :code:`device-ipad-horizontal-x` icon. DEVICE_IPAD_MINUS (IconName): The :code:`device-ipad-minus` icon. DEVICE_IPAD_OFF (IconName): The :code:`device-ipad-off` icon. DEVICE_IPAD_PAUSE (IconName): The :code:`device-ipad-pause` icon. DEVICE_IPAD_PIN (IconName): The :code:`device-ipad-pin` icon. DEVICE_IPAD_PLUS (IconName): The :code:`device-ipad-plus` icon. DEVICE_IPAD_QUESTION (IconName): The :code:`device-ipad-question` icon. DEVICE_IPAD_SEARCH (IconName): The :code:`device-ipad-search` icon. DEVICE_IPAD_SHARE (IconName): The :code:`device-ipad-share` icon. DEVICE_IPAD_STAR (IconName): The :code:`device-ipad-star` icon. DEVICE_IPAD_UP (IconName): The :code:`device-ipad-up` icon. DEVICE_IPAD_X (IconName): The :code:`device-ipad-x` icon. DEVICE_LANDLINE_PHONE (IconName): The :code:`device-landline-phone` icon. DEVICE_LAPTOP (IconName): The :code:`device-laptop` icon. DEVICE_LAPTOP_OFF (IconName): The :code:`device-laptop-off` icon. DEVICE_MOBILE (IconName): The :code:`device-mobile` icon. DEVICE_MOBILE_BOLT (IconName): The :code:`device-mobile-bolt` icon. DEVICE_MOBILE_CANCEL (IconName): The :code:`device-mobile-cancel` icon. DEVICE_MOBILE_CHARGING (IconName): The :code:`device-mobile-charging` icon. DEVICE_MOBILE_CHECK (IconName): The :code:`device-mobile-check` icon. DEVICE_MOBILE_CODE (IconName): The :code:`device-mobile-code` icon. DEVICE_MOBILE_COG (IconName): The :code:`device-mobile-cog` icon. DEVICE_MOBILE_DOLLAR (IconName): The :code:`device-mobile-dollar` icon. DEVICE_MOBILE_DOWN (IconName): The :code:`device-mobile-down` icon. DEVICE_MOBILE_EXCLAMATION (IconName): The :code:`device-mobile-exclamation` icon. DEVICE_MOBILE_FILLED (IconName): The :code:`device-mobile-filled` icon. DEVICE_MOBILE_HEART (IconName): The :code:`device-mobile-heart` icon. DEVICE_MOBILE_MESSAGE (IconName): The :code:`device-mobile-message` icon. DEVICE_MOBILE_MINUS (IconName): The :code:`device-mobile-minus` icon. DEVICE_MOBILE_OFF (IconName): The :code:`device-mobile-off` icon. DEVICE_MOBILE_PAUSE (IconName): The :code:`device-mobile-pause` icon. DEVICE_MOBILE_PIN (IconName): The :code:`device-mobile-pin` icon. DEVICE_MOBILE_PLUS (IconName): The :code:`device-mobile-plus` icon. DEVICE_MOBILE_QUESTION (IconName): The :code:`device-mobile-question` icon. DEVICE_MOBILE_ROTATED (IconName): The :code:`device-mobile-rotated` icon. DEVICE_MOBILE_SEARCH (IconName): The :code:`device-mobile-search` icon. DEVICE_MOBILE_SHARE (IconName): The :code:`device-mobile-share` icon. DEVICE_MOBILE_STAR (IconName): The :code:`device-mobile-star` icon. DEVICE_MOBILE_UP (IconName): The :code:`device-mobile-up` icon. DEVICE_MOBILE_VIBRATION (IconName): The :code:`device-mobile-vibration` icon. DEVICE_MOBILE_X (IconName): The :code:`device-mobile-x` icon. DEVICE_NINTENDO (IconName): The :code:`device-nintendo` icon. DEVICE_NINTENDO_OFF (IconName): The :code:`device-nintendo-off` icon. DEVICE_REMOTE (IconName): The :code:`device-remote` icon. DEVICE_SD_CARD (IconName): The :code:`device-sd-card` icon. DEVICE_SIM (IconName): The :code:`device-sim` icon. DEVICE_SIM_1 (IconName): The :code:`device-sim-1` icon. DEVICE_SIM_2 (IconName): The :code:`device-sim-2` icon. DEVICE_SIM_3 (IconName): The :code:`device-sim-3` icon. DEVICE_SPEAKER (IconName): The :code:`device-speaker` icon. DEVICE_SPEAKER_OFF (IconName): The :code:`device-speaker-off` icon. DEVICE_TABLET (IconName): The :code:`device-tablet` icon. DEVICE_TABLET_BOLT (IconName): The :code:`device-tablet-bolt` icon. DEVICE_TABLET_CANCEL (IconName): The :code:`device-tablet-cancel` icon. DEVICE_TABLET_CHECK (IconName): The :code:`device-tablet-check` icon. DEVICE_TABLET_CODE (IconName): The :code:`device-tablet-code` icon. DEVICE_TABLET_COG (IconName): The :code:`device-tablet-cog` icon. DEVICE_TABLET_DOLLAR (IconName): The :code:`device-tablet-dollar` icon. DEVICE_TABLET_DOWN (IconName): The :code:`device-tablet-down` icon. DEVICE_TABLET_EXCLAMATION (IconName): The :code:`device-tablet-exclamation` icon. DEVICE_TABLET_FILLED (IconName): The :code:`device-tablet-filled` icon. DEVICE_TABLET_HEART (IconName): The :code:`device-tablet-heart` icon. DEVICE_TABLET_MINUS (IconName): The :code:`device-tablet-minus` icon. DEVICE_TABLET_OFF (IconName): The :code:`device-tablet-off` icon. DEVICE_TABLET_PAUSE (IconName): The :code:`device-tablet-pause` icon. DEVICE_TABLET_PIN (IconName): The :code:`device-tablet-pin` icon. DEVICE_TABLET_PLUS (IconName): The :code:`device-tablet-plus` icon. DEVICE_TABLET_QUESTION (IconName): The :code:`device-tablet-question` icon. DEVICE_TABLET_SEARCH (IconName): The :code:`device-tablet-search` icon. DEVICE_TABLET_SHARE (IconName): The :code:`device-tablet-share` icon. DEVICE_TABLET_STAR (IconName): The :code:`device-tablet-star` icon. DEVICE_TABLET_UP (IconName): The :code:`device-tablet-up` icon. DEVICE_TABLET_X (IconName): The :code:`device-tablet-x` icon. DEVICE_TV (IconName): The :code:`device-tv` icon. DEVICE_TV_OFF (IconName): The :code:`device-tv-off` icon. DEVICE_TV_OLD (IconName): The :code:`device-tv-old` icon. DEVICE_VISION_PRO (IconName): The :code:`device-vision-pro` icon. DEVICE_WATCH (IconName): The :code:`device-watch` icon. DEVICE_WATCH_BOLT (IconName): The :code:`device-watch-bolt` icon. DEVICE_WATCH_CANCEL (IconName): The :code:`device-watch-cancel` icon. DEVICE_WATCH_CHECK (IconName): The :code:`device-watch-check` icon. DEVICE_WATCH_CODE (IconName): The :code:`device-watch-code` icon. DEVICE_WATCH_COG (IconName): The :code:`device-watch-cog` icon. DEVICE_WATCH_DOLLAR (IconName): The :code:`device-watch-dollar` icon. DEVICE_WATCH_DOWN (IconName): The :code:`device-watch-down` icon. DEVICE_WATCH_EXCLAMATION (IconName): The :code:`device-watch-exclamation` icon. DEVICE_WATCH_HEART (IconName): The :code:`device-watch-heart` icon. DEVICE_WATCH_MINUS (IconName): The :code:`device-watch-minus` icon. DEVICE_WATCH_OFF (IconName): The :code:`device-watch-off` icon. DEVICE_WATCH_PAUSE (IconName): The :code:`device-watch-pause` icon. DEVICE_WATCH_PIN (IconName): The :code:`device-watch-pin` icon. DEVICE_WATCH_PLUS (IconName): The :code:`device-watch-plus` icon. DEVICE_WATCH_QUESTION (IconName): The :code:`device-watch-question` icon. DEVICE_WATCH_SEARCH (IconName): The :code:`device-watch-search` icon. DEVICE_WATCH_SHARE (IconName): The :code:`device-watch-share` icon. DEVICE_WATCH_STAR (IconName): The :code:`device-watch-star` icon. DEVICE_WATCH_STATS (IconName): The :code:`device-watch-stats` icon. DEVICE_WATCH_STATS_2 (IconName): The :code:`device-watch-stats-2` icon. DEVICE_WATCH_UP (IconName): The :code:`device-watch-up` icon. DEVICE_WATCH_X (IconName): The :code:`device-watch-x` icon. DEVICES (IconName): The :code:`devices` icon. DEVICES_2 (IconName): The :code:`devices-2` icon. DEVICES_BOLT (IconName): The :code:`devices-bolt` icon. DEVICES_CANCEL (IconName): The :code:`devices-cancel` icon. DEVICES_CHECK (IconName): The :code:`devices-check` icon. DEVICES_CODE (IconName): The :code:`devices-code` icon. DEVICES_COG (IconName): The :code:`devices-cog` icon. DEVICES_DOLLAR (IconName): The :code:`devices-dollar` icon. DEVICES_DOWN (IconName): The :code:`devices-down` icon. DEVICES_EXCLAMATION (IconName): The :code:`devices-exclamation` icon. DEVICES_HEART (IconName): The :code:`devices-heart` icon. DEVICES_MINUS (IconName): The :code:`devices-minus` icon. DEVICES_OFF (IconName): The :code:`devices-off` icon. DEVICES_PAUSE (IconName): The :code:`devices-pause` icon. DEVICES_PC (IconName): The :code:`devices-pc` icon. DEVICES_PC_OFF (IconName): The :code:`devices-pc-off` icon. DEVICES_PIN (IconName): The :code:`devices-pin` icon. DEVICES_PLUS (IconName): The :code:`devices-plus` icon. DEVICES_QUESTION (IconName): The :code:`devices-question` icon. DEVICES_SEARCH (IconName): The :code:`devices-search` icon. DEVICES_SHARE (IconName): The :code:`devices-share` icon. DEVICES_STAR (IconName): The :code:`devices-star` icon. DEVICES_UP (IconName): The :code:`devices-up` icon. DEVICES_X (IconName): The :code:`devices-x` icon. DIABOLO (IconName): The :code:`diabolo` icon. DIABOLO_OFF (IconName): The :code:`diabolo-off` icon. DIABOLO_PLUS (IconName): The :code:`diabolo-plus` icon. DIALPAD (IconName): The :code:`dialpad` icon. DIALPAD_FILLED (IconName): The :code:`dialpad-filled` icon. DIALPAD_OFF (IconName): The :code:`dialpad-off` icon. DIAMOND (IconName): The :code:`diamond` icon. DIAMOND_FILLED (IconName): The :code:`diamond-filled` icon. DIAMOND_OFF (IconName): The :code:`diamond-off` icon. DIAMONDS (IconName): The :code:`diamonds` icon. DIAMONDS_FILLED (IconName): The :code:`diamonds-filled` icon. DICE (IconName): The :code:`dice` icon. DICE_1 (IconName): The :code:`dice-1` icon. DICE_1_FILLED (IconName): The :code:`dice-1-filled` icon. DICE_2 (IconName): The :code:`dice-2` icon. DICE_2_FILLED (IconName): The :code:`dice-2-filled` icon. DICE_3 (IconName): The :code:`dice-3` icon. DICE_3_FILLED (IconName): The :code:`dice-3-filled` icon. DICE_4 (IconName): The :code:`dice-4` icon. DICE_4_FILLED (IconName): The :code:`dice-4-filled` icon. DICE_5 (IconName): The :code:`dice-5` icon. DICE_5_FILLED (IconName): The :code:`dice-5-filled` icon. DICE_6 (IconName): The :code:`dice-6` icon. DICE_6_FILLED (IconName): The :code:`dice-6-filled` icon. DICE_FILLED (IconName): The :code:`dice-filled` icon. DIMENSIONS (IconName): The :code:`dimensions` icon. DIRECTION (IconName): The :code:`direction` icon. DIRECTION_HORIZONTAL (IconName): The :code:`direction-horizontal` icon. DIRECTION_SIGN (IconName): The :code:`direction-sign` icon. DIRECTION_SIGN_FILLED (IconName): The :code:`direction-sign-filled` icon. DIRECTION_SIGN_OFF (IconName): The :code:`direction-sign-off` icon. DIRECTIONS (IconName): The :code:`directions` icon. DIRECTIONS_OFF (IconName): The :code:`directions-off` icon. DISABLED (IconName): The :code:`disabled` icon. DISABLED_2 (IconName): The :code:`disabled-2` icon. DISABLED_OFF (IconName): The :code:`disabled-off` icon. DISC (IconName): The :code:`disc` icon. DISC_GOLF (IconName): The :code:`disc-golf` icon. DISC_OFF (IconName): The :code:`disc-off` icon. DISCOUNT (IconName): The :code:`discount` icon. DISCOUNT_2 (IconName): The :code:`discount-2` icon. DISCOUNT_2_OFF (IconName): The :code:`discount-2-off` icon. DISCOUNT_CHECK (IconName): The :code:`discount-check` icon. DISCOUNT_CHECK_FILLED (IconName): The :code:`discount-check-filled` icon. DISCOUNT_OFF (IconName): The :code:`discount-off` icon. DIVIDE (IconName): The :code:`divide` icon. DNA (IconName): The :code:`dna` icon. DNA_2 (IconName): The :code:`dna-2` icon. DNA_2_OFF (IconName): The :code:`dna-2-off` icon. DNA_OFF (IconName): The :code:`dna-off` icon. DOG (IconName): The :code:`dog` icon. DOG_BOWL (IconName): The :code:`dog-bowl` icon. DOOR (IconName): The :code:`door` icon. DOOR_ENTER (IconName): The :code:`door-enter` icon. DOOR_EXIT (IconName): The :code:`door-exit` icon. DOOR_OFF (IconName): The :code:`door-off` icon. DOTS (IconName): The :code:`dots` icon. DOTS_CIRCLE_HORIZONTAL (IconName): The :code:`dots-circle-horizontal` icon. DOTS_DIAGONAL (IconName): The :code:`dots-diagonal` icon. DOTS_DIAGONAL_2 (IconName): The :code:`dots-diagonal-2` icon. DOTS_VERTICAL (IconName): The :code:`dots-vertical` icon. DOWNLOAD (IconName): The :code:`download` icon. DOWNLOAD_OFF (IconName): The :code:`download-off` icon. DRAG_DROP (IconName): The :code:`drag-drop` icon. DRAG_DROP_2 (IconName): The :code:`drag-drop-2` icon. DRONE (IconName): The :code:`drone` icon. DRONE_OFF (IconName): The :code:`drone-off` icon. DROP_CIRCLE (IconName): The :code:`drop-circle` icon. DROPLET (IconName): The :code:`droplet` icon. DROPLET_BOLT (IconName): The :code:`droplet-bolt` icon. DROPLET_CANCEL (IconName): The :code:`droplet-cancel` icon. DROPLET_CHECK (IconName): The :code:`droplet-check` icon. DROPLET_CODE (IconName): The :code:`droplet-code` icon. DROPLET_COG (IconName): The :code:`droplet-cog` icon. DROPLET_DOLLAR (IconName): The :code:`droplet-dollar` icon. DROPLET_DOWN (IconName): The :code:`droplet-down` icon. DROPLET_EXCLAMATION (IconName): The :code:`droplet-exclamation` icon. DROPLET_FILLED (IconName): The :code:`droplet-filled` icon. DROPLET_FILLED_2 (IconName): The :code:`droplet-filled-2` icon. DROPLET_HALF (IconName): The :code:`droplet-half` icon. DROPLET_HALF_2 (IconName): The :code:`droplet-half-2` icon. DROPLET_HALF_FILLED (IconName): The :code:`droplet-half-filled` icon. DROPLET_HEART (IconName): The :code:`droplet-heart` icon. DROPLET_MINUS (IconName): The :code:`droplet-minus` icon. DROPLET_OFF (IconName): The :code:`droplet-off` icon. DROPLET_PAUSE (IconName): The :code:`droplet-pause` icon. DROPLET_PIN (IconName): The :code:`droplet-pin` icon. DROPLET_PLUS (IconName): The :code:`droplet-plus` icon. DROPLET_QUESTION (IconName): The :code:`droplet-question` icon. DROPLET_SEARCH (IconName): The :code:`droplet-search` icon. DROPLET_SHARE (IconName): The :code:`droplet-share` icon. DROPLET_STAR (IconName): The :code:`droplet-star` icon. DROPLET_UP (IconName): The :code:`droplet-up` icon. DROPLET_X (IconName): The :code:`droplet-x` icon. DUAL_SCREEN (IconName): The :code:`dual-screen` icon. E_PASSPORT (IconName): The :code:`e-passport` icon. EAR (IconName): The :code:`ear` icon. EAR_OFF (IconName): The :code:`ear-off` icon. EASE_IN (IconName): The :code:`ease-in` icon. EASE_IN_CONTROL_POINT (IconName): The :code:`ease-in-control-point` icon. EASE_IN_OUT (IconName): The :code:`ease-in-out` icon. EASE_IN_OUT_CONTROL_POINTS (IconName): The :code:`ease-in-out-control-points` icon. EASE_OUT (IconName): The :code:`ease-out` icon. EASE_OUT_CONTROL_POINT (IconName): The :code:`ease-out-control-point` icon. EDIT (IconName): The :code:`edit` icon. EDIT_CIRCLE (IconName): The :code:`edit-circle` icon. EDIT_CIRCLE_OFF (IconName): The :code:`edit-circle-off` icon. EDIT_OFF (IconName): The :code:`edit-off` icon. EGG (IconName): The :code:`egg` icon. EGG_CRACKED (IconName): The :code:`egg-cracked` icon. EGG_FILLED (IconName): The :code:`egg-filled` icon. EGG_FRIED (IconName): The :code:`egg-fried` icon. EGG_OFF (IconName): The :code:`egg-off` icon. EGGS (IconName): The :code:`eggs` icon. ELEVATOR (IconName): The :code:`elevator` icon. ELEVATOR_OFF (IconName): The :code:`elevator-off` icon. EMERGENCY_BED (IconName): The :code:`emergency-bed` icon. EMPATHIZE (IconName): The :code:`empathize` icon. EMPATHIZE_OFF (IconName): The :code:`empathize-off` icon. EMPHASIS (IconName): The :code:`emphasis` icon. ENGINE (IconName): The :code:`engine` icon. ENGINE_OFF (IconName): The :code:`engine-off` icon. EQUAL (IconName): The :code:`equal` icon. EQUAL_DOUBLE (IconName): The :code:`equal-double` icon. EQUAL_NOT (IconName): The :code:`equal-not` icon. ERASER (IconName): The :code:`eraser` icon. ERASER_OFF (IconName): The :code:`eraser-off` icon. ERROR_404 (IconName): The :code:`error-404` icon. ERROR_404_OFF (IconName): The :code:`error-404-off` icon. EXCHANGE (IconName): The :code:`exchange` icon. EXCHANGE_OFF (IconName): The :code:`exchange-off` icon. EXCLAMATION_CIRCLE (IconName): The :code:`exclamation-circle` icon. EXCLAMATION_MARK (IconName): The :code:`exclamation-mark` icon. EXCLAMATION_MARK_OFF (IconName): The :code:`exclamation-mark-off` icon. EXPLICIT (IconName): The :code:`explicit` icon. EXPLICIT_OFF (IconName): The :code:`explicit-off` icon. EXPOSURE (IconName): The :code:`exposure` icon. EXPOSURE_0 (IconName): The :code:`exposure-0` icon. EXPOSURE_MINUS_1 (IconName): The :code:`exposure-minus-1` icon. EXPOSURE_MINUS_2 (IconName): The :code:`exposure-minus-2` icon. EXPOSURE_OFF (IconName): The :code:`exposure-off` icon. EXPOSURE_PLUS_1 (IconName): The :code:`exposure-plus-1` icon. EXPOSURE_PLUS_2 (IconName): The :code:`exposure-plus-2` icon. EXTERNAL_LINK (IconName): The :code:`external-link` icon. EXTERNAL_LINK_OFF (IconName): The :code:`external-link-off` icon. EYE (IconName): The :code:`eye` icon. EYE_CHECK (IconName): The :code:`eye-check` icon. EYE_CLOSED (IconName): The :code:`eye-closed` icon. EYE_COG (IconName): The :code:`eye-cog` icon. EYE_EDIT (IconName): The :code:`eye-edit` icon. EYE_EXCLAMATION (IconName): The :code:`eye-exclamation` icon. EYE_FILLED (IconName): The :code:`eye-filled` icon. EYE_HEART (IconName): The :code:`eye-heart` icon. EYE_OFF (IconName): The :code:`eye-off` icon. EYE_TABLE (IconName): The :code:`eye-table` icon. EYE_X (IconName): The :code:`eye-x` icon. EYEGLASS (IconName): The :code:`eyeglass` icon. EYEGLASS_2 (IconName): The :code:`eyeglass-2` icon. EYEGLASS_OFF (IconName): The :code:`eyeglass-off` icon. FACE_ID (IconName): The :code:`face-id` icon. FACE_ID_ERROR (IconName): The :code:`face-id-error` icon. FACE_MASK (IconName): The :code:`face-mask` icon. FACE_MASK_OFF (IconName): The :code:`face-mask-off` icon. FALL (IconName): The :code:`fall` icon. FEATHER (IconName): The :code:`feather` icon. FEATHER_OFF (IconName): The :code:`feather-off` icon. FENCE (IconName): The :code:`fence` icon. FENCE_OFF (IconName): The :code:`fence-off` icon. FIDGET_SPINNER (IconName): The :code:`fidget-spinner` icon. FILE (IconName): The :code:`file` icon. FILE_3D (IconName): The :code:`file-3d` icon. FILE_ALERT (IconName): The :code:`file-alert` icon. FILE_ANALYTICS (IconName): The :code:`file-analytics` icon. FILE_ARROW_LEFT (IconName): The :code:`file-arrow-left` icon. FILE_ARROW_RIGHT (IconName): The :code:`file-arrow-right` icon. FILE_BARCODE (IconName): The :code:`file-barcode` icon. FILE_BROKEN (IconName): The :code:`file-broken` icon. FILE_CERTIFICATE (IconName): The :code:`file-certificate` icon. FILE_CHART (IconName): The :code:`file-chart` icon. FILE_CHECK (IconName): The :code:`file-check` icon. FILE_CODE (IconName): The :code:`file-code` icon. FILE_CODE_2 (IconName): The :code:`file-code-2` icon. FILE_CV (IconName): The :code:`file-cv` icon. FILE_DATABASE (IconName): The :code:`file-database` icon. FILE_DELTA (IconName): The :code:`file-delta` icon. FILE_DESCRIPTION (IconName): The :code:`file-description` icon. FILE_DIFF (IconName): The :code:`file-diff` icon. FILE_DIGIT (IconName): The :code:`file-digit` icon. FILE_DISLIKE (IconName): The :code:`file-dislike` icon. FILE_DOLLAR (IconName): The :code:`file-dollar` icon. FILE_DOTS (IconName): The :code:`file-dots` icon. FILE_DOWNLOAD (IconName): The :code:`file-download` icon. FILE_EURO (IconName): The :code:`file-euro` icon. FILE_EXPORT (IconName): The :code:`file-export` icon. FILE_FILLED (IconName): The :code:`file-filled` icon. FILE_FUNCTION (IconName): The :code:`file-function` icon. FILE_HORIZONTAL (IconName): The :code:`file-horizontal` icon. FILE_IMPORT (IconName): The :code:`file-import` icon. FILE_INFINITY (IconName): The :code:`file-infinity` icon. FILE_INFO (IconName): The :code:`file-info` icon. FILE_INVOICE (IconName): The :code:`file-invoice` icon. FILE_LAMBDA (IconName): The :code:`file-lambda` icon. FILE_LIKE (IconName): The :code:`file-like` icon. FILE_MINUS (IconName): The :code:`file-minus` icon. FILE_MUSIC (IconName): The :code:`file-music` icon. FILE_OFF (IconName): The :code:`file-off` icon. FILE_ORIENTATION (IconName): The :code:`file-orientation` icon. FILE_PENCIL (IconName): The :code:`file-pencil` icon. FILE_PERCENT (IconName): The :code:`file-percent` icon. FILE_PHONE (IconName): The :code:`file-phone` icon. FILE_PLUS (IconName): The :code:`file-plus` icon. FILE_POWER (IconName): The :code:`file-power` icon. FILE_REPORT (IconName): The :code:`file-report` icon. FILE_RSS (IconName): The :code:`file-rss` icon. FILE_SCISSORS (IconName): The :code:`file-scissors` icon. FILE_SEARCH (IconName): The :code:`file-search` icon. FILE_SETTINGS (IconName): The :code:`file-settings` icon. FILE_SHREDDER (IconName): The :code:`file-shredder` icon. FILE_SIGNAL (IconName): The :code:`file-signal` icon. FILE_SPREADSHEET (IconName): The :code:`file-spreadsheet` icon. FILE_STACK (IconName): The :code:`file-stack` icon. FILE_STAR (IconName): The :code:`file-star` icon. FILE_SYMLINK (IconName): The :code:`file-symlink` icon. FILE_TEXT (IconName): The :code:`file-text` icon. FILE_TEXT_AI (IconName): The :code:`file-text-ai` icon. FILE_TIME (IconName): The :code:`file-time` icon. FILE_TYPOGRAPHY (IconName): The :code:`file-typography` icon. FILE_UNKNOWN (IconName): The :code:`file-unknown` icon. FILE_UPLOAD (IconName): The :code:`file-upload` icon. FILE_VECTOR (IconName): The :code:`file-vector` icon. FILE_X (IconName): The :code:`file-x` icon. FILE_X_FILLED (IconName): The :code:`file-x-filled` icon. FILE_ZIP (IconName): The :code:`file-zip` icon. FILES (IconName): The :code:`files` icon. FILES_OFF (IconName): The :code:`files-off` icon. FILTER (IconName): The :code:`filter` icon. FILTER_COG (IconName): The :code:`filter-cog` icon. FILTER_DOLLAR (IconName): The :code:`filter-dollar` icon. FILTER_EDIT (IconName): The :code:`filter-edit` icon. FILTER_MINUS (IconName): The :code:`filter-minus` icon. FILTER_OFF (IconName): The :code:`filter-off` icon. FILTER_PLUS (IconName): The :code:`filter-plus` icon. FILTER_STAR (IconName): The :code:`filter-star` icon. FILTER_X (IconName): The :code:`filter-x` icon. FILTERS (IconName): The :code:`filters` icon. FINGERPRINT (IconName): The :code:`fingerprint` icon. FINGERPRINT_OFF (IconName): The :code:`fingerprint-off` icon. FIRE_EXTINGUISHER (IconName): The :code:`fire-extinguisher` icon. FIRE_HYDRANT (IconName): The :code:`fire-hydrant` icon. FIRE_HYDRANT_OFF (IconName): The :code:`fire-hydrant-off` icon. FIRETRUCK (IconName): The :code:`firetruck` icon. FIRST_AID_KIT (IconName): The :code:`first-aid-kit` icon. FIRST_AID_KIT_OFF (IconName): The :code:`first-aid-kit-off` icon. FISH (IconName): The :code:`fish` icon. FISH_BONE (IconName): The :code:`fish-bone` icon. FISH_CHRISTIANITY (IconName): The :code:`fish-christianity` icon. FISH_HOOK (IconName): The :code:`fish-hook` icon. FISH_HOOK_OFF (IconName): The :code:`fish-hook-off` icon. FISH_OFF (IconName): The :code:`fish-off` icon. FLAG (IconName): The :code:`flag` icon. FLAG_2 (IconName): The :code:`flag-2` icon. FLAG_2_FILLED (IconName): The :code:`flag-2-filled` icon. FLAG_2_OFF (IconName): The :code:`flag-2-off` icon. FLAG_3 (IconName): The :code:`flag-3` icon. FLAG_3_FILLED (IconName): The :code:`flag-3-filled` icon. FLAG_FILLED (IconName): The :code:`flag-filled` icon. FLAG_OFF (IconName): The :code:`flag-off` icon. FLAME (IconName): The :code:`flame` icon. FLAME_OFF (IconName): The :code:`flame-off` icon. FLARE (IconName): The :code:`flare` icon. FLASK (IconName): The :code:`flask` icon. FLASK_2 (IconName): The :code:`flask-2` icon. FLASK_2_OFF (IconName): The :code:`flask-2-off` icon. FLASK_OFF (IconName): The :code:`flask-off` icon. FLIP_FLOPS (IconName): The :code:`flip-flops` icon. FLIP_HORIZONTAL (IconName): The :code:`flip-horizontal` icon. FLIP_VERTICAL (IconName): The :code:`flip-vertical` icon. FLOAT_CENTER (IconName): The :code:`float-center` icon. FLOAT_LEFT (IconName): The :code:`float-left` icon. FLOAT_NONE (IconName): The :code:`float-none` icon. FLOAT_RIGHT (IconName): The :code:`float-right` icon. FLOWER (IconName): The :code:`flower` icon. FLOWER_OFF (IconName): The :code:`flower-off` icon. FOCUS (IconName): The :code:`focus` icon. FOCUS_2 (IconName): The :code:`focus-2` icon. FOCUS_AUTO (IconName): The :code:`focus-auto` icon. FOCUS_CENTERED (IconName): The :code:`focus-centered` icon. FOLD (IconName): The :code:`fold` icon. FOLD_DOWN (IconName): The :code:`fold-down` icon. FOLD_UP (IconName): The :code:`fold-up` icon. FOLDER (IconName): The :code:`folder` icon. FOLDER_BOLT (IconName): The :code:`folder-bolt` icon. FOLDER_CANCEL (IconName): The :code:`folder-cancel` icon. FOLDER_CHECK (IconName): The :code:`folder-check` icon. FOLDER_CODE (IconName): The :code:`folder-code` icon. FOLDER_COG (IconName): The :code:`folder-cog` icon. FOLDER_DOLLAR (IconName): The :code:`folder-dollar` icon. FOLDER_DOWN (IconName): The :code:`folder-down` icon. FOLDER_EXCLAMATION (IconName): The :code:`folder-exclamation` icon. FOLDER_FILLED (IconName): The :code:`folder-filled` icon. FOLDER_HEART (IconName): The :code:`folder-heart` icon. FOLDER_MINUS (IconName): The :code:`folder-minus` icon. FOLDER_OFF (IconName): The :code:`folder-off` icon. FOLDER_OPEN (IconName): The :code:`folder-open` icon. FOLDER_PAUSE (IconName): The :code:`folder-pause` icon. FOLDER_PIN (IconName): The :code:`folder-pin` icon. FOLDER_PLUS (IconName): The :code:`folder-plus` icon. FOLDER_QUESTION (IconName): The :code:`folder-question` icon. FOLDER_SEARCH (IconName): The :code:`folder-search` icon. FOLDER_SHARE (IconName): The :code:`folder-share` icon. FOLDER_STAR (IconName): The :code:`folder-star` icon. FOLDER_SYMLINK (IconName): The :code:`folder-symlink` icon. FOLDER_UP (IconName): The :code:`folder-up` icon. FOLDER_X (IconName): The :code:`folder-x` icon. FOLDERS (IconName): The :code:`folders` icon. FOLDERS_OFF (IconName): The :code:`folders-off` icon. FORBID (IconName): The :code:`forbid` icon. FORBID_2 (IconName): The :code:`forbid-2` icon. FORKLIFT (IconName): The :code:`forklift` icon. FORMS (IconName): The :code:`forms` icon. FOUNTAIN (IconName): The :code:`fountain` icon. FOUNTAIN_OFF (IconName): The :code:`fountain-off` icon. FRAME (IconName): The :code:`frame` icon. FRAME_OFF (IconName): The :code:`frame-off` icon. FREE_RIGHTS (IconName): The :code:`free-rights` icon. FREEZE_COLUMN (IconName): The :code:`freeze-column` icon. FREEZE_ROW (IconName): The :code:`freeze-row` icon. FREEZE_ROW_COLUMN (IconName): The :code:`freeze-row-column` icon. FRIDGE (IconName): The :code:`fridge` icon. FRIDGE_OFF (IconName): The :code:`fridge-off` icon. FRIENDS (IconName): The :code:`friends` icon. FRIENDS_OFF (IconName): The :code:`friends-off` icon. FRUSTUM (IconName): The :code:`frustum` icon. FRUSTUM_OFF (IconName): The :code:`frustum-off` icon. FRUSTUM_PLUS (IconName): The :code:`frustum-plus` icon. FUNCTION (IconName): The :code:`function` icon. FUNCTION_OFF (IconName): The :code:`function-off` icon. GARDEN_CART (IconName): The :code:`garden-cart` icon. GARDEN_CART_OFF (IconName): The :code:`garden-cart-off` icon. GAS_STATION (IconName): The :code:`gas-station` icon. GAS_STATION_OFF (IconName): The :code:`gas-station-off` icon. GAUGE (IconName): The :code:`gauge` icon. GAUGE_OFF (IconName): The :code:`gauge-off` icon. GAVEL (IconName): The :code:`gavel` icon. GENDER_AGENDER (IconName): The :code:`gender-agender` icon. GENDER_ANDROGYNE (IconName): The :code:`gender-androgyne` icon. GENDER_BIGENDER (IconName): The :code:`gender-bigender` icon. GENDER_DEMIBOY (IconName): The :code:`gender-demiboy` icon. GENDER_DEMIGIRL (IconName): The :code:`gender-demigirl` icon. GENDER_EPICENE (IconName): The :code:`gender-epicene` icon. GENDER_FEMALE (IconName): The :code:`gender-female` icon. GENDER_FEMME (IconName): The :code:`gender-femme` icon. GENDER_GENDERFLUID (IconName): The :code:`gender-genderfluid` icon. GENDER_GENDERLESS (IconName): The :code:`gender-genderless` icon. GENDER_GENDERQUEER (IconName): The :code:`gender-genderqueer` icon. GENDER_HERMAPHRODITE (IconName): The :code:`gender-hermaphrodite` icon. GENDER_INTERGENDER (IconName): The :code:`gender-intergender` icon. GENDER_MALE (IconName): The :code:`gender-male` icon. GENDER_NEUTROIS (IconName): The :code:`gender-neutrois` icon. GENDER_THIRD (IconName): The :code:`gender-third` icon. GENDER_TRANSGENDER (IconName): The :code:`gender-transgender` icon. GENDER_TRASVESTI (IconName): The :code:`gender-trasvesti` icon. GEOMETRY (IconName): The :code:`geometry` icon. GHOST (IconName): The :code:`ghost` icon. GHOST_2 (IconName): The :code:`ghost-2` icon. GHOST_2_FILLED (IconName): The :code:`ghost-2-filled` icon. GHOST_FILLED (IconName): The :code:`ghost-filled` icon. GHOST_OFF (IconName): The :code:`ghost-off` icon. GIF (IconName): The :code:`gif` icon. GIFT (IconName): The :code:`gift` icon. GIFT_CARD (IconName): The :code:`gift-card` icon. GIFT_OFF (IconName): The :code:`gift-off` icon. GIT_BRANCH (IconName): The :code:`git-branch` icon. GIT_BRANCH_DELETED (IconName): The :code:`git-branch-deleted` icon. GIT_CHERRY_PICK (IconName): The :code:`git-cherry-pick` icon. GIT_COMMIT (IconName): The :code:`git-commit` icon. GIT_COMPARE (IconName): The :code:`git-compare` icon. GIT_FORK (IconName): The :code:`git-fork` icon. GIT_MERGE (IconName): The :code:`git-merge` icon. GIT_PULL_REQUEST (IconName): The :code:`git-pull-request` icon. GIT_PULL_REQUEST_CLOSED (IconName): The :code:`git-pull-request-closed` icon. GIT_PULL_REQUEST_DRAFT (IconName): The :code:`git-pull-request-draft` icon. GIZMO (IconName): The :code:`gizmo` icon. GLASS (IconName): The :code:`glass` icon. GLASS_FULL (IconName): The :code:`glass-full` icon. GLASS_OFF (IconName): The :code:`glass-off` icon. GLOBE (IconName): The :code:`globe` icon. GLOBE_OFF (IconName): The :code:`globe-off` icon. GO_GAME (IconName): The :code:`go-game` icon. GOLF (IconName): The :code:`golf` icon. GOLF_OFF (IconName): The :code:`golf-off` icon. GPS (IconName): The :code:`gps` icon. GRADIENTER (IconName): The :code:`gradienter` icon. GRAIN (IconName): The :code:`grain` icon. GRAPH (IconName): The :code:`graph` icon. GRAPH_OFF (IconName): The :code:`graph-off` icon. GRAVE (IconName): The :code:`grave` icon. GRAVE_2 (IconName): The :code:`grave-2` icon. GRID_DOTS (IconName): The :code:`grid-dots` icon. GRID_PATTERN (IconName): The :code:`grid-pattern` icon. GRILL (IconName): The :code:`grill` icon. GRILL_FORK (IconName): The :code:`grill-fork` icon. GRILL_OFF (IconName): The :code:`grill-off` icon. GRILL_SPATULA (IconName): The :code:`grill-spatula` icon. GRIP_HORIZONTAL (IconName): The :code:`grip-horizontal` icon. GRIP_VERTICAL (IconName): The :code:`grip-vertical` icon. GROWTH (IconName): The :code:`growth` icon. GUITAR_PICK (IconName): The :code:`guitar-pick` icon. GUITAR_PICK_FILLED (IconName): The :code:`guitar-pick-filled` icon. H_1 (IconName): The :code:`h-1` icon. H_2 (IconName): The :code:`h-2` icon. H_3 (IconName): The :code:`h-3` icon. H_4 (IconName): The :code:`h-4` icon. H_5 (IconName): The :code:`h-5` icon. H_6 (IconName): The :code:`h-6` icon. HAMMER (IconName): The :code:`hammer` icon. HAMMER_OFF (IconName): The :code:`hammer-off` icon. HAND_CLICK (IconName): The :code:`hand-click` icon. HAND_FINGER (IconName): The :code:`hand-finger` icon. HAND_FINGER_OFF (IconName): The :code:`hand-finger-off` icon. HAND_GRAB (IconName): The :code:`hand-grab` icon. HAND_LITTLE_FINGER (IconName): The :code:`hand-little-finger` icon. HAND_MIDDLE_FINGER (IconName): The :code:`hand-middle-finger` icon. HAND_MOVE (IconName): The :code:`hand-move` icon. HAND_OFF (IconName): The :code:`hand-off` icon. HAND_RING_FINGER (IconName): The :code:`hand-ring-finger` icon. HAND_ROCK (IconName): The :code:`hand-rock` icon. HAND_SANITIZER (IconName): The :code:`hand-sanitizer` icon. HAND_STOP (IconName): The :code:`hand-stop` icon. HAND_THREE_FINGERS (IconName): The :code:`hand-three-fingers` icon. HAND_TWO_FINGERS (IconName): The :code:`hand-two-fingers` icon. HANGER (IconName): The :code:`hanger` icon. HANGER_2 (IconName): The :code:`hanger-2` icon. HANGER_OFF (IconName): The :code:`hanger-off` icon. HASH (IconName): The :code:`hash` icon. HAZE (IconName): The :code:`haze` icon. HAZE_MOON (IconName): The :code:`haze-moon` icon. HDR (IconName): The :code:`hdr` icon. HEADING (IconName): The :code:`heading` icon. HEADING_OFF (IconName): The :code:`heading-off` icon. HEADPHONES (IconName): The :code:`headphones` icon. HEADPHONES_FILLED (IconName): The :code:`headphones-filled` icon. HEADPHONES_OFF (IconName): The :code:`headphones-off` icon. HEADSET (IconName): The :code:`headset` icon. HEADSET_OFF (IconName): The :code:`headset-off` icon. HEALTH_RECOGNITION (IconName): The :code:`health-recognition` icon. HEART (IconName): The :code:`heart` icon. HEART_BROKEN (IconName): The :code:`heart-broken` icon. HEART_FILLED (IconName): The :code:`heart-filled` icon. HEART_HANDSHAKE (IconName): The :code:`heart-handshake` icon. HEART_MINUS (IconName): The :code:`heart-minus` icon. HEART_OFF (IconName): The :code:`heart-off` icon. HEART_PLUS (IconName): The :code:`heart-plus` icon. HEART_RATE_MONITOR (IconName): The :code:`heart-rate-monitor` icon. HEARTBEAT (IconName): The :code:`heartbeat` icon. HEARTS (IconName): The :code:`hearts` icon. HEARTS_OFF (IconName): The :code:`hearts-off` icon. HELICOPTER (IconName): The :code:`helicopter` icon. HELICOPTER_LANDING (IconName): The :code:`helicopter-landing` icon. HELMET (IconName): The :code:`helmet` icon. HELMET_OFF (IconName): The :code:`helmet-off` icon. HELP (IconName): The :code:`help` icon. HELP_CIRCLE (IconName): The :code:`help-circle` icon. HELP_CIRCLE_FILLED (IconName): The :code:`help-circle-filled` icon. HELP_HEXAGON (IconName): The :code:`help-hexagon` icon. HELP_HEXAGON_FILLED (IconName): The :code:`help-hexagon-filled` icon. HELP_OCTAGON (IconName): The :code:`help-octagon` icon. HELP_OCTAGON_FILLED (IconName): The :code:`help-octagon-filled` icon. HELP_OFF (IconName): The :code:`help-off` icon. HELP_SMALL (IconName): The :code:`help-small` icon. HELP_SQUARE (IconName): The :code:`help-square` icon. HELP_SQUARE_FILLED (IconName): The :code:`help-square-filled` icon. HELP_SQUARE_ROUNDED (IconName): The :code:`help-square-rounded` icon. HELP_SQUARE_ROUNDED_FILLED (IconName): The :code:`help-square-rounded-filled` icon. HELP_TRIANGLE (IconName): The :code:`help-triangle` icon. HELP_TRIANGLE_FILLED (IconName): The :code:`help-triangle-filled` icon. HEMISPHERE (IconName): The :code:`hemisphere` icon. HEMISPHERE_OFF (IconName): The :code:`hemisphere-off` icon. HEMISPHERE_PLUS (IconName): The :code:`hemisphere-plus` icon. HEXAGON (IconName): The :code:`hexagon` icon. HEXAGON_0_FILLED (IconName): The :code:`hexagon-0-filled` icon. HEXAGON_1_FILLED (IconName): The :code:`hexagon-1-filled` icon. HEXAGON_2_FILLED (IconName): The :code:`hexagon-2-filled` icon. HEXAGON_3_FILLED (IconName): The :code:`hexagon-3-filled` icon. HEXAGON_3D (IconName): The :code:`hexagon-3d` icon. HEXAGON_4_FILLED (IconName): The :code:`hexagon-4-filled` icon. HEXAGON_5_FILLED (IconName): The :code:`hexagon-5-filled` icon. HEXAGON_6_FILLED (IconName): The :code:`hexagon-6-filled` icon. HEXAGON_7_FILLED (IconName): The :code:`hexagon-7-filled` icon. HEXAGON_8_FILLED (IconName): The :code:`hexagon-8-filled` icon. HEXAGON_9_FILLED (IconName): The :code:`hexagon-9-filled` icon. HEXAGON_FILLED (IconName): The :code:`hexagon-filled` icon. HEXAGON_LETTER_A (IconName): The :code:`hexagon-letter-a` icon. HEXAGON_LETTER_B (IconName): The :code:`hexagon-letter-b` icon. HEXAGON_LETTER_C (IconName): The :code:`hexagon-letter-c` icon. HEXAGON_LETTER_D (IconName): The :code:`hexagon-letter-d` icon. HEXAGON_LETTER_E (IconName): The :code:`hexagon-letter-e` icon. HEXAGON_LETTER_F (IconName): The :code:`hexagon-letter-f` icon. HEXAGON_LETTER_G (IconName): The :code:`hexagon-letter-g` icon. HEXAGON_LETTER_H (IconName): The :code:`hexagon-letter-h` icon. HEXAGON_LETTER_I (IconName): The :code:`hexagon-letter-i` icon. HEXAGON_LETTER_J (IconName): The :code:`hexagon-letter-j` icon. HEXAGON_LETTER_K (IconName): The :code:`hexagon-letter-k` icon. HEXAGON_LETTER_L (IconName): The :code:`hexagon-letter-l` icon. HEXAGON_LETTER_M (IconName): The :code:`hexagon-letter-m` icon. HEXAGON_LETTER_N (IconName): The :code:`hexagon-letter-n` icon. HEXAGON_LETTER_O (IconName): The :code:`hexagon-letter-o` icon. HEXAGON_LETTER_P (IconName): The :code:`hexagon-letter-p` icon. HEXAGON_LETTER_Q (IconName): The :code:`hexagon-letter-q` icon. HEXAGON_LETTER_R (IconName): The :code:`hexagon-letter-r` icon. HEXAGON_LETTER_S (IconName): The :code:`hexagon-letter-s` icon. HEXAGON_LETTER_T (IconName): The :code:`hexagon-letter-t` icon. HEXAGON_LETTER_U (IconName): The :code:`hexagon-letter-u` icon. HEXAGON_LETTER_V (IconName): The :code:`hexagon-letter-v` icon. HEXAGON_LETTER_W (IconName): The :code:`hexagon-letter-w` icon. HEXAGON_LETTER_X (IconName): The :code:`hexagon-letter-x` icon. HEXAGON_LETTER_Y (IconName): The :code:`hexagon-letter-y` icon. HEXAGON_LETTER_Z (IconName): The :code:`hexagon-letter-z` icon. HEXAGON_NUMBER_0 (IconName): The :code:`hexagon-number-0` icon. HEXAGON_NUMBER_1 (IconName): The :code:`hexagon-number-1` icon. HEXAGON_NUMBER_2 (IconName): The :code:`hexagon-number-2` icon. HEXAGON_NUMBER_3 (IconName): The :code:`hexagon-number-3` icon. HEXAGON_NUMBER_4 (IconName): The :code:`hexagon-number-4` icon. HEXAGON_NUMBER_5 (IconName): The :code:`hexagon-number-5` icon. HEXAGON_NUMBER_6 (IconName): The :code:`hexagon-number-6` icon. HEXAGON_NUMBER_7 (IconName): The :code:`hexagon-number-7` icon. HEXAGON_NUMBER_8 (IconName): The :code:`hexagon-number-8` icon. HEXAGON_NUMBER_9 (IconName): The :code:`hexagon-number-9` icon. HEXAGON_OFF (IconName): The :code:`hexagon-off` icon. HEXAGONAL_PRISM (IconName): The :code:`hexagonal-prism` icon. HEXAGONAL_PRISM_OFF (IconName): The :code:`hexagonal-prism-off` icon. HEXAGONAL_PRISM_PLUS (IconName): The :code:`hexagonal-prism-plus` icon. HEXAGONAL_PYRAMID (IconName): The :code:`hexagonal-pyramid` icon. HEXAGONAL_PYRAMID_OFF (IconName): The :code:`hexagonal-pyramid-off` icon. HEXAGONAL_PYRAMID_PLUS (IconName): The :code:`hexagonal-pyramid-plus` icon. HEXAGONS (IconName): The :code:`hexagons` icon. HEXAGONS_OFF (IconName): The :code:`hexagons-off` icon. HIERARCHY (IconName): The :code:`hierarchy` icon. HIERARCHY_2 (IconName): The :code:`hierarchy-2` icon. HIERARCHY_3 (IconName): The :code:`hierarchy-3` icon. HIERARCHY_OFF (IconName): The :code:`hierarchy-off` icon. HIGHLIGHT (IconName): The :code:`highlight` icon. HIGHLIGHT_OFF (IconName): The :code:`highlight-off` icon. HISTORY (IconName): The :code:`history` icon. HISTORY_OFF (IconName): The :code:`history-off` icon. HISTORY_TOGGLE (IconName): The :code:`history-toggle` icon. HOME (IconName): The :code:`home` icon. HOME_2 (IconName): The :code:`home-2` icon. HOME_BOLT (IconName): The :code:`home-bolt` icon. HOME_CANCEL (IconName): The :code:`home-cancel` icon. HOME_CHECK (IconName): The :code:`home-check` icon. HOME_COG (IconName): The :code:`home-cog` icon. HOME_DOLLAR (IconName): The :code:`home-dollar` icon. HOME_DOT (IconName): The :code:`home-dot` icon. HOME_DOWN (IconName): The :code:`home-down` icon. HOME_ECO (IconName): The :code:`home-eco` icon. HOME_EDIT (IconName): The :code:`home-edit` icon. HOME_EXCLAMATION (IconName): The :code:`home-exclamation` icon. HOME_HAND (IconName): The :code:`home-hand` icon. HOME_HEART (IconName): The :code:`home-heart` icon. HOME_INFINITY (IconName): The :code:`home-infinity` icon. HOME_LINK (IconName): The :code:`home-link` icon. HOME_MINUS (IconName): The :code:`home-minus` icon. HOME_MOVE (IconName): The :code:`home-move` icon. HOME_OFF (IconName): The :code:`home-off` icon. HOME_PLUS (IconName): The :code:`home-plus` icon. HOME_QUESTION (IconName): The :code:`home-question` icon. HOME_RIBBON (IconName): The :code:`home-ribbon` icon. HOME_SEARCH (IconName): The :code:`home-search` icon. HOME_SHARE (IconName): The :code:`home-share` icon. HOME_SHIELD (IconName): The :code:`home-shield` icon. HOME_SIGNAL (IconName): The :code:`home-signal` icon. HOME_STAR (IconName): The :code:`home-star` icon. HOME_STATS (IconName): The :code:`home-stats` icon. HOME_UP (IconName): The :code:`home-up` icon. HOME_X (IconName): The :code:`home-x` icon. HORSE_TOY (IconName): The :code:`horse-toy` icon. HOTEL_SERVICE (IconName): The :code:`hotel-service` icon. HOURGLASS (IconName): The :code:`hourglass` icon. HOURGLASS_EMPTY (IconName): The :code:`hourglass-empty` icon. HOURGLASS_FILLED (IconName): The :code:`hourglass-filled` icon. HOURGLASS_HIGH (IconName): The :code:`hourglass-high` icon. HOURGLASS_LOW (IconName): The :code:`hourglass-low` icon. HOURGLASS_OFF (IconName): The :code:`hourglass-off` icon. HTML (IconName): The :code:`html` icon. HTTP_CONNECT (IconName): The :code:`http-connect` icon. HTTP_DELETE (IconName): The :code:`http-delete` icon. HTTP_GET (IconName): The :code:`http-get` icon. HTTP_HEAD (IconName): The :code:`http-head` icon. HTTP_OPTIONS (IconName): The :code:`http-options` icon. HTTP_PATCH (IconName): The :code:`http-patch` icon. HTTP_POST (IconName): The :code:`http-post` icon. HTTP_PUT (IconName): The :code:`http-put` icon. HTTP_QUE (IconName): The :code:`http-que` icon. HTTP_TRACE (IconName): The :code:`http-trace` icon. ICE_CREAM (IconName): The :code:`ice-cream` icon. ICE_CREAM_2 (IconName): The :code:`ice-cream-2` icon. ICE_CREAM_OFF (IconName): The :code:`ice-cream-off` icon. ICE_SKATING (IconName): The :code:`ice-skating` icon. ICONS (IconName): The :code:`icons` icon. ICONS_OFF (IconName): The :code:`icons-off` icon. ID (IconName): The :code:`id` icon. ID_BADGE (IconName): The :code:`id-badge` icon. ID_BADGE_2 (IconName): The :code:`id-badge-2` icon. ID_BADGE_OFF (IconName): The :code:`id-badge-off` icon. ID_OFF (IconName): The :code:`id-off` icon. INBOX (IconName): The :code:`inbox` icon. INBOX_OFF (IconName): The :code:`inbox-off` icon. INDENT_DECREASE (IconName): The :code:`indent-decrease` icon. INDENT_INCREASE (IconName): The :code:`indent-increase` icon. INFINITY (IconName): The :code:`infinity` icon. INFINITY_OFF (IconName): The :code:`infinity-off` icon. INFO_CIRCLE (IconName): The :code:`info-circle` icon. INFO_CIRCLE_FILLED (IconName): The :code:`info-circle-filled` icon. INFO_HEXAGON (IconName): The :code:`info-hexagon` icon. INFO_HEXAGON_FILLED (IconName): The :code:`info-hexagon-filled` icon. INFO_OCTAGON (IconName): The :code:`info-octagon` icon. INFO_OCTAGON_FILLED (IconName): The :code:`info-octagon-filled` icon. INFO_SMALL (IconName): The :code:`info-small` icon. INFO_SQUARE (IconName): The :code:`info-square` icon. INFO_SQUARE_FILLED (IconName): The :code:`info-square-filled` icon. INFO_SQUARE_ROUNDED (IconName): The :code:`info-square-rounded` icon. INFO_SQUARE_ROUNDED_FILLED (IconName): The :code:`info-square-rounded-filled` icon. INFO_TRIANGLE (IconName): The :code:`info-triangle` icon. INFO_TRIANGLE_FILLED (IconName): The :code:`info-triangle-filled` icon. INNER_SHADOW_BOTTOM (IconName): The :code:`inner-shadow-bottom` icon. INNER_SHADOW_BOTTOM_FILLED (IconName): The :code:`inner-shadow-bottom-filled` icon. INNER_SHADOW_BOTTOM_LEFT (IconName): The :code:`inner-shadow-bottom-left` icon. INNER_SHADOW_BOTTOM_LEFT_FILLED (IconName): The :code:`inner-shadow-bottom-left-filled` icon. INNER_SHADOW_BOTTOM_RIGHT (IconName): The :code:`inner-shadow-bottom-right` icon. INNER_SHADOW_BOTTOM_RIGHT_FILLED (IconName): The :code:`inner-shadow-bottom-right-filled` icon. INNER_SHADOW_LEFT (IconName): The :code:`inner-shadow-left` icon. INNER_SHADOW_LEFT_FILLED (IconName): The :code:`inner-shadow-left-filled` icon. INNER_SHADOW_RIGHT (IconName): The :code:`inner-shadow-right` icon. INNER_SHADOW_RIGHT_FILLED (IconName): The :code:`inner-shadow-right-filled` icon. INNER_SHADOW_TOP (IconName): The :code:`inner-shadow-top` icon. INNER_SHADOW_TOP_FILLED (IconName): The :code:`inner-shadow-top-filled` icon. INNER_SHADOW_TOP_LEFT (IconName): The :code:`inner-shadow-top-left` icon. INNER_SHADOW_TOP_LEFT_FILLED (IconName): The :code:`inner-shadow-top-left-filled` icon. INNER_SHADOW_TOP_RIGHT (IconName): The :code:`inner-shadow-top-right` icon. INNER_SHADOW_TOP_RIGHT_FILLED (IconName): The :code:`inner-shadow-top-right-filled` icon. INPUT_SEARCH (IconName): The :code:`input-search` icon. IRONING (IconName): The :code:`ironing` icon. IRONING_1 (IconName): The :code:`ironing-1` icon. IRONING_2 (IconName): The :code:`ironing-2` icon. IRONING_3 (IconName): The :code:`ironing-3` icon. IRONING_OFF (IconName): The :code:`ironing-off` icon. IRONING_STEAM (IconName): The :code:`ironing-steam` icon. IRONING_STEAM_OFF (IconName): The :code:`ironing-steam-off` icon. IRREGULAR_POLYHEDRON (IconName): The :code:`irregular-polyhedron` icon. IRREGULAR_POLYHEDRON_OFF (IconName): The :code:`irregular-polyhedron-off` icon. IRREGULAR_POLYHEDRON_PLUS (IconName): The :code:`irregular-polyhedron-plus` icon. ITALIC (IconName): The :code:`italic` icon. JACKET (IconName): The :code:`jacket` icon. JETPACK (IconName): The :code:`jetpack` icon. JEWISH_STAR (IconName): The :code:`jewish-star` icon. JEWISH_STAR_FILLED (IconName): The :code:`jewish-star-filled` icon. JPG (IconName): The :code:`jpg` icon. JSON (IconName): The :code:`json` icon. JUMP_ROPE (IconName): The :code:`jump-rope` icon. KARATE (IconName): The :code:`karate` icon. KAYAK (IconName): The :code:`kayak` icon. KERING (IconName): The :code:`kering` icon. KEY (IconName): The :code:`key` icon. KEY_OFF (IconName): The :code:`key-off` icon. KEYBOARD (IconName): The :code:`keyboard` icon. KEYBOARD_HIDE (IconName): The :code:`keyboard-hide` icon. KEYBOARD_OFF (IconName): The :code:`keyboard-off` icon. KEYBOARD_SHOW (IconName): The :code:`keyboard-show` icon. KEYFRAME (IconName): The :code:`keyframe` icon. KEYFRAME_ALIGN_CENTER (IconName): The :code:`keyframe-align-center` icon. KEYFRAME_ALIGN_HORIZONTAL (IconName): The :code:`keyframe-align-horizontal` icon. KEYFRAME_ALIGN_VERTICAL (IconName): The :code:`keyframe-align-vertical` icon. KEYFRAMES (IconName): The :code:`keyframes` icon. LADDER (IconName): The :code:`ladder` icon. LADDER_OFF (IconName): The :code:`ladder-off` icon. LAMBDA (IconName): The :code:`lambda` icon. LAMP (IconName): The :code:`lamp` icon. LAMP_2 (IconName): The :code:`lamp-2` icon. LAMP_OFF (IconName): The :code:`lamp-off` icon. LANE (IconName): The :code:`lane` icon. LANGUAGE (IconName): The :code:`language` icon. LANGUAGE_HIRAGANA (IconName): The :code:`language-hiragana` icon. LANGUAGE_KATAKANA (IconName): The :code:`language-katakana` icon. LANGUAGE_OFF (IconName): The :code:`language-off` icon. LASSO (IconName): The :code:`lasso` icon. LASSO_OFF (IconName): The :code:`lasso-off` icon. LASSO_POLYGON (IconName): The :code:`lasso-polygon` icon. LAYERS_DIFFERENCE (IconName): The :code:`layers-difference` icon. LAYERS_INTERSECT (IconName): The :code:`layers-intersect` icon. LAYERS_INTERSECT_2 (IconName): The :code:`layers-intersect-2` icon. LAYERS_LINKED (IconName): The :code:`layers-linked` icon. LAYERS_OFF (IconName): The :code:`layers-off` icon. LAYERS_SUBTRACT (IconName): The :code:`layers-subtract` icon. LAYERS_UNION (IconName): The :code:`layers-union` icon. LAYOUT (IconName): The :code:`layout` icon. LAYOUT_2 (IconName): The :code:`layout-2` icon. LAYOUT_ALIGN_BOTTOM (IconName): The :code:`layout-align-bottom` icon. LAYOUT_ALIGN_CENTER (IconName): The :code:`layout-align-center` icon. LAYOUT_ALIGN_LEFT (IconName): The :code:`layout-align-left` icon. LAYOUT_ALIGN_MIDDLE (IconName): The :code:`layout-align-middle` icon. LAYOUT_ALIGN_RIGHT (IconName): The :code:`layout-align-right` icon. LAYOUT_ALIGN_TOP (IconName): The :code:`layout-align-top` icon. LAYOUT_BOARD (IconName): The :code:`layout-board` icon. LAYOUT_BOARD_SPLIT (IconName): The :code:`layout-board-split` icon. LAYOUT_BOTTOMBAR (IconName): The :code:`layout-bottombar` icon. LAYOUT_BOTTOMBAR_COLLAPSE (IconName): The :code:`layout-bottombar-collapse` icon. LAYOUT_BOTTOMBAR_EXPAND (IconName): The :code:`layout-bottombar-expand` icon. LAYOUT_CARDS (IconName): The :code:`layout-cards` icon. LAYOUT_COLLAGE (IconName): The :code:`layout-collage` icon. LAYOUT_COLUMNS (IconName): The :code:`layout-columns` icon. LAYOUT_DASHBOARD (IconName): The :code:`layout-dashboard` icon. LAYOUT_DISTRIBUTE_HORIZONTAL (IconName): The :code:`layout-distribute-horizontal` icon. LAYOUT_DISTRIBUTE_VERTICAL (IconName): The :code:`layout-distribute-vertical` icon. LAYOUT_GRID (IconName): The :code:`layout-grid` icon. LAYOUT_GRID_ADD (IconName): The :code:`layout-grid-add` icon. LAYOUT_GRID_REMOVE (IconName): The :code:`layout-grid-remove` icon. LAYOUT_KANBAN (IconName): The :code:`layout-kanban` icon. LAYOUT_LIST (IconName): The :code:`layout-list` icon. LAYOUT_NAVBAR (IconName): The :code:`layout-navbar` icon. LAYOUT_NAVBAR_COLLAPSE (IconName): The :code:`layout-navbar-collapse` icon. LAYOUT_NAVBAR_EXPAND (IconName): The :code:`layout-navbar-expand` icon. LAYOUT_OFF (IconName): The :code:`layout-off` icon. LAYOUT_ROWS (IconName): The :code:`layout-rows` icon. LAYOUT_SIDEBAR (IconName): The :code:`layout-sidebar` icon. LAYOUT_SIDEBAR_LEFT_COLLAPSE (IconName): The :code:`layout-sidebar-left-collapse` icon. LAYOUT_SIDEBAR_LEFT_EXPAND (IconName): The :code:`layout-sidebar-left-expand` icon. LAYOUT_SIDEBAR_RIGHT (IconName): The :code:`layout-sidebar-right` icon. LAYOUT_SIDEBAR_RIGHT_COLLAPSE (IconName): The :code:`layout-sidebar-right-collapse` icon. LAYOUT_SIDEBAR_RIGHT_EXPAND (IconName): The :code:`layout-sidebar-right-expand` icon. LEAF (IconName): The :code:`leaf` icon. LEAF_OFF (IconName): The :code:`leaf-off` icon. LEGO (IconName): The :code:`lego` icon. LEGO_OFF (IconName): The :code:`lego-off` icon. LEMON (IconName): The :code:`lemon` icon. LEMON_2 (IconName): The :code:`lemon-2` icon. LETTER_A (IconName): The :code:`letter-a` icon. LETTER_B (IconName): The :code:`letter-b` icon. LETTER_C (IconName): The :code:`letter-c` icon. LETTER_CASE (IconName): The :code:`letter-case` icon. LETTER_CASE_LOWER (IconName): The :code:`letter-case-lower` icon. LETTER_CASE_TOGGLE (IconName): The :code:`letter-case-toggle` icon. LETTER_CASE_UPPER (IconName): The :code:`letter-case-upper` icon. LETTER_D (IconName): The :code:`letter-d` icon. LETTER_E (IconName): The :code:`letter-e` icon. LETTER_F (IconName): The :code:`letter-f` icon. LETTER_G (IconName): The :code:`letter-g` icon. LETTER_H (IconName): The :code:`letter-h` icon. LETTER_I (IconName): The :code:`letter-i` icon. LETTER_J (IconName): The :code:`letter-j` icon. LETTER_K (IconName): The :code:`letter-k` icon. LETTER_L (IconName): The :code:`letter-l` icon. LETTER_M (IconName): The :code:`letter-m` icon. LETTER_N (IconName): The :code:`letter-n` icon. LETTER_O (IconName): The :code:`letter-o` icon. LETTER_P (IconName): The :code:`letter-p` icon. LETTER_Q (IconName): The :code:`letter-q` icon. LETTER_R (IconName): The :code:`letter-r` icon. LETTER_S (IconName): The :code:`letter-s` icon. LETTER_SPACING (IconName): The :code:`letter-spacing` icon. LETTER_T (IconName): The :code:`letter-t` icon. LETTER_U (IconName): The :code:`letter-u` icon. LETTER_V (IconName): The :code:`letter-v` icon. LETTER_W (IconName): The :code:`letter-w` icon. LETTER_X (IconName): The :code:`letter-x` icon. LETTER_Y (IconName): The :code:`letter-y` icon. LETTER_Z (IconName): The :code:`letter-z` icon. LICENSE (IconName): The :code:`license` icon. LICENSE_OFF (IconName): The :code:`license-off` icon. LIFEBUOY (IconName): The :code:`lifebuoy` icon. LIFEBUOY_OFF (IconName): The :code:`lifebuoy-off` icon. LIGHTER (IconName): The :code:`lighter` icon. LINE (IconName): The :code:`line` icon. LINE_DASHED (IconName): The :code:`line-dashed` icon. LINE_DOTTED (IconName): The :code:`line-dotted` icon. LINE_HEIGHT (IconName): The :code:`line-height` icon. LINK (IconName): The :code:`link` icon. LINK_OFF (IconName): The :code:`link-off` icon. LIST (IconName): The :code:`list` icon. LIST_CHECK (IconName): The :code:`list-check` icon. LIST_DETAILS (IconName): The :code:`list-details` icon. LIST_NUMBERS (IconName): The :code:`list-numbers` icon. LIST_SEARCH (IconName): The :code:`list-search` icon. LIST_TREE (IconName): The :code:`list-tree` icon. LIVE_PHOTO (IconName): The :code:`live-photo` icon. LIVE_PHOTO_OFF (IconName): The :code:`live-photo-off` icon. LIVE_VIEW (IconName): The :code:`live-view` icon. LOAD_BALANCER (IconName): The :code:`load-balancer` icon. LOADER (IconName): The :code:`loader` icon. LOADER_2 (IconName): The :code:`loader-2` icon. LOADER_3 (IconName): The :code:`loader-3` icon. LOADER_QUARTER (IconName): The :code:`loader-quarter` icon. LOCATION (IconName): The :code:`location` icon. LOCATION_BROKEN (IconName): The :code:`location-broken` icon. LOCATION_FILLED (IconName): The :code:`location-filled` icon. LOCATION_OFF (IconName): The :code:`location-off` icon. LOCK (IconName): The :code:`lock` icon. LOCK_ACCESS (IconName): The :code:`lock-access` icon. LOCK_ACCESS_OFF (IconName): The :code:`lock-access-off` icon. LOCK_BOLT (IconName): The :code:`lock-bolt` icon. LOCK_CANCEL (IconName): The :code:`lock-cancel` icon. LOCK_CHECK (IconName): The :code:`lock-check` icon. LOCK_CODE (IconName): The :code:`lock-code` icon. LOCK_COG (IconName): The :code:`lock-cog` icon. LOCK_DOLLAR (IconName): The :code:`lock-dollar` icon. LOCK_DOWN (IconName): The :code:`lock-down` icon. LOCK_EXCLAMATION (IconName): The :code:`lock-exclamation` icon. LOCK_HEART (IconName): The :code:`lock-heart` icon. LOCK_MINUS (IconName): The :code:`lock-minus` icon. LOCK_OFF (IconName): The :code:`lock-off` icon. LOCK_OPEN (IconName): The :code:`lock-open` icon. LOCK_OPEN_OFF (IconName): The :code:`lock-open-off` icon. LOCK_PAUSE (IconName): The :code:`lock-pause` icon. LOCK_PIN (IconName): The :code:`lock-pin` icon. LOCK_PLUS (IconName): The :code:`lock-plus` icon. LOCK_QUESTION (IconName): The :code:`lock-question` icon. LOCK_SEARCH (IconName): The :code:`lock-search` icon. LOCK_SHARE (IconName): The :code:`lock-share` icon. LOCK_SQUARE (IconName): The :code:`lock-square` icon. LOCK_SQUARE_ROUNDED (IconName): The :code:`lock-square-rounded` icon. LOCK_SQUARE_ROUNDED_FILLED (IconName): The :code:`lock-square-rounded-filled` icon. LOCK_STAR (IconName): The :code:`lock-star` icon. LOCK_UP (IconName): The :code:`lock-up` icon. LOCK_X (IconName): The :code:`lock-x` icon. LOGIC_AND (IconName): The :code:`logic-and` icon. LOGIC_BUFFER (IconName): The :code:`logic-buffer` icon. LOGIC_NAND (IconName): The :code:`logic-nand` icon. LOGIC_NOR (IconName): The :code:`logic-nor` icon. LOGIC_NOT (IconName): The :code:`logic-not` icon. LOGIC_OR (IconName): The :code:`logic-or` icon. LOGIC_XNOR (IconName): The :code:`logic-xnor` icon. LOGIC_XOR (IconName): The :code:`logic-xor` icon. LOGIN (IconName): The :code:`login` icon. LOGOUT (IconName): The :code:`logout` icon. LOGOUT_2 (IconName): The :code:`logout-2` icon. LOLLIPOP (IconName): The :code:`lollipop` icon. LOLLIPOP_OFF (IconName): The :code:`lollipop-off` icon. LUGGAGE (IconName): The :code:`luggage` icon. LUGGAGE_OFF (IconName): The :code:`luggage-off` icon. LUNGS (IconName): The :code:`lungs` icon. LUNGS_OFF (IconName): The :code:`lungs-off` icon. MACRO (IconName): The :code:`macro` icon. MACRO_OFF (IconName): The :code:`macro-off` icon. MAGNET (IconName): The :code:`magnet` icon. MAGNET_OFF (IconName): The :code:`magnet-off` icon. MAIL (IconName): The :code:`mail` icon. MAIL_AI (IconName): The :code:`mail-ai` icon. MAIL_BOLT (IconName): The :code:`mail-bolt` icon. MAIL_CANCEL (IconName): The :code:`mail-cancel` icon. MAIL_CHECK (IconName): The :code:`mail-check` icon. MAIL_CODE (IconName): The :code:`mail-code` icon. MAIL_COG (IconName): The :code:`mail-cog` icon. MAIL_DOLLAR (IconName): The :code:`mail-dollar` icon. MAIL_DOWN (IconName): The :code:`mail-down` icon. MAIL_EXCLAMATION (IconName): The :code:`mail-exclamation` icon. MAIL_FAST (IconName): The :code:`mail-fast` icon. MAIL_FILLED (IconName): The :code:`mail-filled` icon. MAIL_FORWARD (IconName): The :code:`mail-forward` icon. MAIL_HEART (IconName): The :code:`mail-heart` icon. MAIL_MINUS (IconName): The :code:`mail-minus` icon. MAIL_OFF (IconName): The :code:`mail-off` icon. MAIL_OPENED (IconName): The :code:`mail-opened` icon. MAIL_OPENED_FILLED (IconName): The :code:`mail-opened-filled` icon. MAIL_PAUSE (IconName): The :code:`mail-pause` icon. MAIL_PIN (IconName): The :code:`mail-pin` icon. MAIL_PLUS (IconName): The :code:`mail-plus` icon. MAIL_QUESTION (IconName): The :code:`mail-question` icon. MAIL_SEARCH (IconName): The :code:`mail-search` icon. MAIL_SHARE (IconName): The :code:`mail-share` icon. MAIL_STAR (IconName): The :code:`mail-star` icon. MAIL_UP (IconName): The :code:`mail-up` icon. MAIL_X (IconName): The :code:`mail-x` icon. MAILBOX (IconName): The :code:`mailbox` icon. MAILBOX_OFF (IconName): The :code:`mailbox-off` icon. MAN (IconName): The :code:`man` icon. MANUAL_GEARBOX (IconName): The :code:`manual-gearbox` icon. MAP (IconName): The :code:`map` icon. MAP_2 (IconName): The :code:`map-2` icon. MAP_OFF (IconName): The :code:`map-off` icon. MAP_PIN (IconName): The :code:`map-pin` icon. MAP_PIN_BOLT (IconName): The :code:`map-pin-bolt` icon. MAP_PIN_CANCEL (IconName): The :code:`map-pin-cancel` icon. MAP_PIN_CHECK (IconName): The :code:`map-pin-check` icon. MAP_PIN_CODE (IconName): The :code:`map-pin-code` icon. MAP_PIN_COG (IconName): The :code:`map-pin-cog` icon. MAP_PIN_DOLLAR (IconName): The :code:`map-pin-dollar` icon. MAP_PIN_DOWN (IconName): The :code:`map-pin-down` icon. MAP_PIN_EXCLAMATION (IconName): The :code:`map-pin-exclamation` icon. MAP_PIN_FILLED (IconName): The :code:`map-pin-filled` icon. MAP_PIN_HEART (IconName): The :code:`map-pin-heart` icon. MAP_PIN_MINUS (IconName): The :code:`map-pin-minus` icon. MAP_PIN_OFF (IconName): The :code:`map-pin-off` icon. MAP_PIN_PAUSE (IconName): The :code:`map-pin-pause` icon. MAP_PIN_PIN (IconName): The :code:`map-pin-pin` icon. MAP_PIN_PLUS (IconName): The :code:`map-pin-plus` icon. MAP_PIN_QUESTION (IconName): The :code:`map-pin-question` icon. MAP_PIN_SEARCH (IconName): The :code:`map-pin-search` icon. MAP_PIN_SHARE (IconName): The :code:`map-pin-share` icon. MAP_PIN_STAR (IconName): The :code:`map-pin-star` icon. MAP_PIN_UP (IconName): The :code:`map-pin-up` icon. MAP_PIN_X (IconName): The :code:`map-pin-x` icon. MAP_PINS (IconName): The :code:`map-pins` icon. MAP_SEARCH (IconName): The :code:`map-search` icon. MARKDOWN (IconName): The :code:`markdown` icon. MARKDOWN_OFF (IconName): The :code:`markdown-off` icon. MARQUEE (IconName): The :code:`marquee` icon. MARQUEE_2 (IconName): The :code:`marquee-2` icon. MARQUEE_OFF (IconName): The :code:`marquee-off` icon. MARS (IconName): The :code:`mars` icon. MASK (IconName): The :code:`mask` icon. MASK_OFF (IconName): The :code:`mask-off` icon. MASKS_THEATER (IconName): The :code:`masks-theater` icon. MASKS_THEATER_OFF (IconName): The :code:`masks-theater-off` icon. MASSAGE (IconName): The :code:`massage` icon. MATCHSTICK (IconName): The :code:`matchstick` icon. MATH (IconName): The :code:`math` icon. MATH_1_DIVIDE_2 (IconName): The :code:`math-1-divide-2` icon. MATH_1_DIVIDE_3 (IconName): The :code:`math-1-divide-3` icon. MATH_AVG (IconName): The :code:`math-avg` icon. MATH_EQUAL_GREATER (IconName): The :code:`math-equal-greater` icon. MATH_EQUAL_LOWER (IconName): The :code:`math-equal-lower` icon. MATH_FUNCTION (IconName): The :code:`math-function` icon. MATH_FUNCTION_OFF (IconName): The :code:`math-function-off` icon. MATH_FUNCTION_Y (IconName): The :code:`math-function-y` icon. MATH_GREATER (IconName): The :code:`math-greater` icon. MATH_INTEGRAL (IconName): The :code:`math-integral` icon. MATH_INTEGRAL_X (IconName): The :code:`math-integral-x` icon. MATH_INTEGRALS (IconName): The :code:`math-integrals` icon. MATH_LOWER (IconName): The :code:`math-lower` icon. MATH_MAX (IconName): The :code:`math-max` icon. MATH_MIN (IconName): The :code:`math-min` icon. MATH_NOT (IconName): The :code:`math-not` icon. MATH_OFF (IconName): The :code:`math-off` icon. MATH_PI (IconName): The :code:`math-pi` icon. MATH_PI_DIVIDE_2 (IconName): The :code:`math-pi-divide-2` icon. MATH_SYMBOLS (IconName): The :code:`math-symbols` icon. MATH_X_DIVIDE_2 (IconName): The :code:`math-x-divide-2` icon. MATH_X_DIVIDE_Y (IconName): The :code:`math-x-divide-y` icon. MATH_X_DIVIDE_Y_2 (IconName): The :code:`math-x-divide-y-2` icon. MATH_X_MINUS_X (IconName): The :code:`math-x-minus-x` icon. MATH_X_MINUS_Y (IconName): The :code:`math-x-minus-y` icon. MATH_X_PLUS_X (IconName): The :code:`math-x-plus-x` icon. MATH_X_PLUS_Y (IconName): The :code:`math-x-plus-y` icon. MATH_XY (IconName): The :code:`math-xy` icon. MATH_Y_MINUS_Y (IconName): The :code:`math-y-minus-y` icon. MATH_Y_PLUS_Y (IconName): The :code:`math-y-plus-y` icon. MAXIMIZE (IconName): The :code:`maximize` icon. MAXIMIZE_OFF (IconName): The :code:`maximize-off` icon. MEAT (IconName): The :code:`meat` icon. MEAT_OFF (IconName): The :code:`meat-off` icon. MEDAL (IconName): The :code:`medal` icon. MEDAL_2 (IconName): The :code:`medal-2` icon. MEDICAL_CROSS (IconName): The :code:`medical-cross` icon. MEDICAL_CROSS_CIRCLE (IconName): The :code:`medical-cross-circle` icon. MEDICAL_CROSS_FILLED (IconName): The :code:`medical-cross-filled` icon. MEDICAL_CROSS_OFF (IconName): The :code:`medical-cross-off` icon. MEDICINE_SYRUP (IconName): The :code:`medicine-syrup` icon. MEEPLE (IconName): The :code:`meeple` icon. MENORAH (IconName): The :code:`menorah` icon. MENU (IconName): The :code:`menu` icon. MENU_2 (IconName): The :code:`menu-2` icon. MENU_DEEP (IconName): The :code:`menu-deep` icon. MENU_ORDER (IconName): The :code:`menu-order` icon. MESSAGE (IconName): The :code:`message` icon. MESSAGE_2 (IconName): The :code:`message-2` icon. MESSAGE_2_BOLT (IconName): The :code:`message-2-bolt` icon. MESSAGE_2_CANCEL (IconName): The :code:`message-2-cancel` icon. MESSAGE_2_CHECK (IconName): The :code:`message-2-check` icon. MESSAGE_2_CODE (IconName): The :code:`message-2-code` icon. MESSAGE_2_COG (IconName): The :code:`message-2-cog` icon. MESSAGE_2_DOLLAR (IconName): The :code:`message-2-dollar` icon. MESSAGE_2_DOWN (IconName): The :code:`message-2-down` icon. MESSAGE_2_EXCLAMATION (IconName): The :code:`message-2-exclamation` icon. MESSAGE_2_HEART (IconName): The :code:`message-2-heart` icon. MESSAGE_2_MINUS (IconName): The :code:`message-2-minus` icon. MESSAGE_2_OFF (IconName): The :code:`message-2-off` icon. MESSAGE_2_PAUSE (IconName): The :code:`message-2-pause` icon. MESSAGE_2_PIN (IconName): The :code:`message-2-pin` icon. MESSAGE_2_PLUS (IconName): The :code:`message-2-plus` icon. MESSAGE_2_QUESTION (IconName): The :code:`message-2-question` icon. MESSAGE_2_SEARCH (IconName): The :code:`message-2-search` icon. MESSAGE_2_SHARE (IconName): The :code:`message-2-share` icon. MESSAGE_2_STAR (IconName): The :code:`message-2-star` icon. MESSAGE_2_UP (IconName): The :code:`message-2-up` icon. MESSAGE_2_X (IconName): The :code:`message-2-x` icon. MESSAGE_BOLT (IconName): The :code:`message-bolt` icon. MESSAGE_CANCEL (IconName): The :code:`message-cancel` icon. MESSAGE_CHATBOT (IconName): The :code:`message-chatbot` icon. MESSAGE_CHECK (IconName): The :code:`message-check` icon. MESSAGE_CIRCLE (IconName): The :code:`message-circle` icon. MESSAGE_CIRCLE_2 (IconName): The :code:`message-circle-2` icon. MESSAGE_CIRCLE_2_FILLED (IconName): The :code:`message-circle-2-filled` icon. MESSAGE_CIRCLE_BOLT (IconName): The :code:`message-circle-bolt` icon. MESSAGE_CIRCLE_CANCEL (IconName): The :code:`message-circle-cancel` icon. MESSAGE_CIRCLE_CHECK (IconName): The :code:`message-circle-check` icon. MESSAGE_CIRCLE_CODE (IconName): The :code:`message-circle-code` icon. MESSAGE_CIRCLE_COG (IconName): The :code:`message-circle-cog` icon. MESSAGE_CIRCLE_DOLLAR (IconName): The :code:`message-circle-dollar` icon. MESSAGE_CIRCLE_DOWN (IconName): The :code:`message-circle-down` icon. MESSAGE_CIRCLE_EXCLAMATION (IconName): The :code:`message-circle-exclamation` icon. MESSAGE_CIRCLE_HEART (IconName): The :code:`message-circle-heart` icon. MESSAGE_CIRCLE_MINUS (IconName): The :code:`message-circle-minus` icon. MESSAGE_CIRCLE_OFF (IconName): The :code:`message-circle-off` icon. MESSAGE_CIRCLE_PAUSE (IconName): The :code:`message-circle-pause` icon. MESSAGE_CIRCLE_PIN (IconName): The :code:`message-circle-pin` icon. MESSAGE_CIRCLE_PLUS (IconName): The :code:`message-circle-plus` icon. MESSAGE_CIRCLE_QUESTION (IconName): The :code:`message-circle-question` icon. MESSAGE_CIRCLE_SEARCH (IconName): The :code:`message-circle-search` icon. MESSAGE_CIRCLE_SHARE (IconName): The :code:`message-circle-share` icon. MESSAGE_CIRCLE_STAR (IconName): The :code:`message-circle-star` icon. MESSAGE_CIRCLE_UP (IconName): The :code:`message-circle-up` icon. MESSAGE_CIRCLE_X (IconName): The :code:`message-circle-x` icon. MESSAGE_CODE (IconName): The :code:`message-code` icon. MESSAGE_COG (IconName): The :code:`message-cog` icon. MESSAGE_DOLLAR (IconName): The :code:`message-dollar` icon. MESSAGE_DOTS (IconName): The :code:`message-dots` icon. MESSAGE_DOWN (IconName): The :code:`message-down` icon. MESSAGE_EXCLAMATION (IconName): The :code:`message-exclamation` icon. MESSAGE_FORWARD (IconName): The :code:`message-forward` icon. MESSAGE_HEART (IconName): The :code:`message-heart` icon. MESSAGE_LANGUAGE (IconName): The :code:`message-language` icon. MESSAGE_MINUS (IconName): The :code:`message-minus` icon. MESSAGE_OFF (IconName): The :code:`message-off` icon. MESSAGE_PAUSE (IconName): The :code:`message-pause` icon. MESSAGE_PIN (IconName): The :code:`message-pin` icon. MESSAGE_PLUS (IconName): The :code:`message-plus` icon. MESSAGE_QUESTION (IconName): The :code:`message-question` icon. MESSAGE_REPORT (IconName): The :code:`message-report` icon. MESSAGE_SEARCH (IconName): The :code:`message-search` icon. MESSAGE_SHARE (IconName): The :code:`message-share` icon. MESSAGE_STAR (IconName): The :code:`message-star` icon. MESSAGE_UP (IconName): The :code:`message-up` icon. MESSAGE_X (IconName): The :code:`message-x` icon. MESSAGES (IconName): The :code:`messages` icon. MESSAGES_OFF (IconName): The :code:`messages-off` icon. METEOR (IconName): The :code:`meteor` icon. METEOR_OFF (IconName): The :code:`meteor-off` icon. MICHELIN_BIB_GOURMAND (IconName): The :code:`michelin-bib-gourmand` icon. MICHELIN_STAR (IconName): The :code:`michelin-star` icon. MICHELIN_STAR_GREEN (IconName): The :code:`michelin-star-green` icon. MICKEY (IconName): The :code:`mickey` icon. MICKEY_FILLED (IconName): The :code:`mickey-filled` icon. MICROPHONE (IconName): The :code:`microphone` icon. MICROPHONE_2 (IconName): The :code:`microphone-2` icon. MICROPHONE_2_OFF (IconName): The :code:`microphone-2-off` icon. MICROPHONE_OFF (IconName): The :code:`microphone-off` icon. MICROSCOPE (IconName): The :code:`microscope` icon. MICROSCOPE_OFF (IconName): The :code:`microscope-off` icon. MICROWAVE (IconName): The :code:`microwave` icon. MICROWAVE_OFF (IconName): The :code:`microwave-off` icon. MILITARY_AWARD (IconName): The :code:`military-award` icon. MILITARY_RANK (IconName): The :code:`military-rank` icon. MILK (IconName): The :code:`milk` icon. MILK_OFF (IconName): The :code:`milk-off` icon. MILKSHAKE (IconName): The :code:`milkshake` icon. MINIMIZE (IconName): The :code:`minimize` icon. MINUS (IconName): The :code:`minus` icon. MINUS_VERTICAL (IconName): The :code:`minus-vertical` icon. MIST (IconName): The :code:`mist` icon. MIST_OFF (IconName): The :code:`mist-off` icon. MOBILEDATA (IconName): The :code:`mobiledata` icon. MOBILEDATA_OFF (IconName): The :code:`mobiledata-off` icon. MONEYBAG (IconName): The :code:`moneybag` icon. MOOD_ANGRY (IconName): The :code:`mood-angry` icon. MOOD_ANNOYED (IconName): The :code:`mood-annoyed` icon. MOOD_ANNOYED_2 (IconName): The :code:`mood-annoyed-2` icon. MOOD_BOY (IconName): The :code:`mood-boy` icon. MOOD_CHECK (IconName): The :code:`mood-check` icon. MOOD_COG (IconName): The :code:`mood-cog` icon. MOOD_CONFUZED (IconName): The :code:`mood-confuzed` icon. MOOD_CONFUZED_FILLED (IconName): The :code:`mood-confuzed-filled` icon. MOOD_CRAZY_HAPPY (IconName): The :code:`mood-crazy-happy` icon. MOOD_CRY (IconName): The :code:`mood-cry` icon. MOOD_DOLLAR (IconName): The :code:`mood-dollar` icon. MOOD_EDIT (IconName): The :code:`mood-edit` icon. MOOD_EMPTY (IconName): The :code:`mood-empty` icon. MOOD_EMPTY_FILLED (IconName): The :code:`mood-empty-filled` icon. MOOD_HAPPY (IconName): The :code:`mood-happy` icon. MOOD_HAPPY_FILLED (IconName): The :code:`mood-happy-filled` icon. MOOD_HEART (IconName): The :code:`mood-heart` icon. MOOD_KID (IconName): The :code:`mood-kid` icon. MOOD_KID_FILLED (IconName): The :code:`mood-kid-filled` icon. MOOD_LOOK_LEFT (IconName): The :code:`mood-look-left` icon. MOOD_LOOK_RIGHT (IconName): The :code:`mood-look-right` icon. MOOD_MINUS (IconName): The :code:`mood-minus` icon. MOOD_NERD (IconName): The :code:`mood-nerd` icon. MOOD_NERVOUS (IconName): The :code:`mood-nervous` icon. MOOD_NEUTRAL (IconName): The :code:`mood-neutral` icon. MOOD_NEUTRAL_FILLED (IconName): The :code:`mood-neutral-filled` icon. MOOD_OFF (IconName): The :code:`mood-off` icon. MOOD_PIN (IconName): The :code:`mood-pin` icon. MOOD_PLUS (IconName): The :code:`mood-plus` icon. MOOD_SAD (IconName): The :code:`mood-sad` icon. MOOD_SAD_2 (IconName): The :code:`mood-sad-2` icon. MOOD_SAD_DIZZY (IconName): The :code:`mood-sad-dizzy` icon. MOOD_SAD_FILLED (IconName): The :code:`mood-sad-filled` icon. MOOD_SAD_SQUINT (IconName): The :code:`mood-sad-squint` icon. MOOD_SEARCH (IconName): The :code:`mood-search` icon. MOOD_SHARE (IconName): The :code:`mood-share` icon. MOOD_SICK (IconName): The :code:`mood-sick` icon. MOOD_SILENCE (IconName): The :code:`mood-silence` icon. MOOD_SING (IconName): The :code:`mood-sing` icon. MOOD_SMILE (IconName): The :code:`mood-smile` icon. MOOD_SMILE_BEAM (IconName): The :code:`mood-smile-beam` icon. MOOD_SMILE_DIZZY (IconName): The :code:`mood-smile-dizzy` icon. MOOD_SMILE_FILLED (IconName): The :code:`mood-smile-filled` icon. MOOD_SUPRISED (IconName): The :code:`mood-suprised` icon. MOOD_TONGUE (IconName): The :code:`mood-tongue` icon. MOOD_TONGUE_WINK (IconName): The :code:`mood-tongue-wink` icon. MOOD_TONGUE_WINK_2 (IconName): The :code:`mood-tongue-wink-2` icon. MOOD_UNAMUSED (IconName): The :code:`mood-unamused` icon. MOOD_UP (IconName): The :code:`mood-up` icon. MOOD_WINK (IconName): The :code:`mood-wink` icon. MOOD_WINK_2 (IconName): The :code:`mood-wink-2` icon. MOOD_WRRR (IconName): The :code:`mood-wrrr` icon. MOOD_X (IconName): The :code:`mood-x` icon. MOOD_XD (IconName): The :code:`mood-xd` icon. MOON (IconName): The :code:`moon` icon. MOON_2 (IconName): The :code:`moon-2` icon. MOON_FILLED (IconName): The :code:`moon-filled` icon. MOON_OFF (IconName): The :code:`moon-off` icon. MOON_STARS (IconName): The :code:`moon-stars` icon. MOPED (IconName): The :code:`moped` icon. MOTORBIKE (IconName): The :code:`motorbike` icon. MOUNTAIN (IconName): The :code:`mountain` icon. MOUNTAIN_OFF (IconName): The :code:`mountain-off` icon. MOUSE (IconName): The :code:`mouse` icon. MOUSE_2 (IconName): The :code:`mouse-2` icon. MOUSE_OFF (IconName): The :code:`mouse-off` icon. MOUSTACHE (IconName): The :code:`moustache` icon. MOVIE (IconName): The :code:`movie` icon. MOVIE_OFF (IconName): The :code:`movie-off` icon. MUG (IconName): The :code:`mug` icon. MUG_OFF (IconName): The :code:`mug-off` icon. MULTIPLIER_0_5X (IconName): The :code:`multiplier-0-5x` icon. MULTIPLIER_1_5X (IconName): The :code:`multiplier-1-5x` icon. MULTIPLIER_1X (IconName): The :code:`multiplier-1x` icon. MULTIPLIER_2X (IconName): The :code:`multiplier-2x` icon. MUSHROOM (IconName): The :code:`mushroom` icon. MUSHROOM_FILLED (IconName): The :code:`mushroom-filled` icon. MUSHROOM_OFF (IconName): The :code:`mushroom-off` icon. MUSIC (IconName): The :code:`music` icon. MUSIC_OFF (IconName): The :code:`music-off` icon. NAVIGATION (IconName): The :code:`navigation` icon. NAVIGATION_FILLED (IconName): The :code:`navigation-filled` icon. NAVIGATION_NORTH (IconName): The :code:`navigation-north` icon. NAVIGATION_OFF (IconName): The :code:`navigation-off` icon. NEEDLE (IconName): The :code:`needle` icon. NEEDLE_THREAD (IconName): The :code:`needle-thread` icon. NETWORK (IconName): The :code:`network` icon. NETWORK_OFF (IconName): The :code:`network-off` icon. NEW_SECTION (IconName): The :code:`new-section` icon. NEWS (IconName): The :code:`news` icon. NEWS_OFF (IconName): The :code:`news-off` icon. NFC (IconName): The :code:`nfc` icon. NFC_OFF (IconName): The :code:`nfc-off` icon. NO_COPYRIGHT (IconName): The :code:`no-copyright` icon. NO_CREATIVE_COMMONS (IconName): The :code:`no-creative-commons` icon. NO_DERIVATIVES (IconName): The :code:`no-derivatives` icon. NORTH_STAR (IconName): The :code:`north-star` icon. NOTE (IconName): The :code:`note` icon. NOTE_OFF (IconName): The :code:`note-off` icon. NOTEBOOK (IconName): The :code:`notebook` icon. NOTEBOOK_OFF (IconName): The :code:`notebook-off` icon. NOTES (IconName): The :code:`notes` icon. NOTES_OFF (IconName): The :code:`notes-off` icon. NOTIFICATION (IconName): The :code:`notification` icon. NOTIFICATION_OFF (IconName): The :code:`notification-off` icon. NUMBER (IconName): The :code:`number` icon. NUMBER_0 (IconName): The :code:`number-0` icon. NUMBER_1 (IconName): The :code:`number-1` icon. NUMBER_2 (IconName): The :code:`number-2` icon. NUMBER_3 (IconName): The :code:`number-3` icon. NUMBER_4 (IconName): The :code:`number-4` icon. NUMBER_5 (IconName): The :code:`number-5` icon. NUMBER_6 (IconName): The :code:`number-6` icon. NUMBER_7 (IconName): The :code:`number-7` icon. NUMBER_8 (IconName): The :code:`number-8` icon. NUMBER_9 (IconName): The :code:`number-9` icon. NUMBERS (IconName): The :code:`numbers` icon. NURSE (IconName): The :code:`nurse` icon. OCTAGON (IconName): The :code:`octagon` icon. OCTAGON_FILLED (IconName): The :code:`octagon-filled` icon. OCTAGON_OFF (IconName): The :code:`octagon-off` icon. OCTAHEDRON (IconName): The :code:`octahedron` icon. OCTAHEDRON_OFF (IconName): The :code:`octahedron-off` icon. OCTAHEDRON_PLUS (IconName): The :code:`octahedron-plus` icon. OLD (IconName): The :code:`old` icon. OLYMPICS (IconName): The :code:`olympics` icon. OLYMPICS_OFF (IconName): The :code:`olympics-off` icon. OM (IconName): The :code:`om` icon. OMEGA (IconName): The :code:`omega` icon. OUTBOUND (IconName): The :code:`outbound` icon. OUTLET (IconName): The :code:`outlet` icon. OVAL (IconName): The :code:`oval` icon. OVAL_FILLED (IconName): The :code:`oval-filled` icon. OVAL_VERTICAL (IconName): The :code:`oval-vertical` icon. OVAL_VERTICAL_FILLED (IconName): The :code:`oval-vertical-filled` icon. OVERLINE (IconName): The :code:`overline` icon. PACKAGE (IconName): The :code:`package` icon. PACKAGE_EXPORT (IconName): The :code:`package-export` icon. PACKAGE_IMPORT (IconName): The :code:`package-import` icon. PACKAGE_OFF (IconName): The :code:`package-off` icon. PACKAGES (IconName): The :code:`packages` icon. PACMAN (IconName): The :code:`pacman` icon. PAGE_BREAK (IconName): The :code:`page-break` icon. PAINT (IconName): The :code:`paint` icon. PAINT_FILLED (IconName): The :code:`paint-filled` icon. PAINT_OFF (IconName): The :code:`paint-off` icon. PALETTE (IconName): The :code:`palette` icon. PALETTE_OFF (IconName): The :code:`palette-off` icon. PANORAMA_HORIZONTAL (IconName): The :code:`panorama-horizontal` icon. PANORAMA_HORIZONTAL_OFF (IconName): The :code:`panorama-horizontal-off` icon. PANORAMA_VERTICAL (IconName): The :code:`panorama-vertical` icon. PANORAMA_VERTICAL_OFF (IconName): The :code:`panorama-vertical-off` icon. PAPER_BAG (IconName): The :code:`paper-bag` icon. PAPER_BAG_OFF (IconName): The :code:`paper-bag-off` icon. PAPERCLIP (IconName): The :code:`paperclip` icon. PARACHUTE (IconName): The :code:`parachute` icon. PARACHUTE_OFF (IconName): The :code:`parachute-off` icon. PARENTHESES (IconName): The :code:`parentheses` icon. PARENTHESES_OFF (IconName): The :code:`parentheses-off` icon. PARKING (IconName): The :code:`parking` icon. PARKING_OFF (IconName): The :code:`parking-off` icon. PASSWORD (IconName): The :code:`password` icon. PAW (IconName): The :code:`paw` icon. PAW_FILLED (IconName): The :code:`paw-filled` icon. PAW_OFF (IconName): The :code:`paw-off` icon. PDF (IconName): The :code:`pdf` icon. PEACE (IconName): The :code:`peace` icon. PENCIL (IconName): The :code:`pencil` icon. PENCIL_MINUS (IconName): The :code:`pencil-minus` icon. PENCIL_OFF (IconName): The :code:`pencil-off` icon. PENCIL_PLUS (IconName): The :code:`pencil-plus` icon. PENNANT (IconName): The :code:`pennant` icon. PENNANT_2 (IconName): The :code:`pennant-2` icon. PENNANT_2_FILLED (IconName): The :code:`pennant-2-filled` icon. PENNANT_FILLED (IconName): The :code:`pennant-filled` icon. PENNANT_OFF (IconName): The :code:`pennant-off` icon. PENTAGON (IconName): The :code:`pentagon` icon. PENTAGON_FILLED (IconName): The :code:`pentagon-filled` icon. PENTAGON_OFF (IconName): The :code:`pentagon-off` icon. PENTAGRAM (IconName): The :code:`pentagram` icon. PEPPER (IconName): The :code:`pepper` icon. PEPPER_OFF (IconName): The :code:`pepper-off` icon. PERCENTAGE (IconName): The :code:`percentage` icon. PERFUME (IconName): The :code:`perfume` icon. PERSPECTIVE (IconName): The :code:`perspective` icon. PERSPECTIVE_OFF (IconName): The :code:`perspective-off` icon. PHONE (IconName): The :code:`phone` icon. PHONE_CALL (IconName): The :code:`phone-call` icon. PHONE_CALLING (IconName): The :code:`phone-calling` icon. PHONE_CHECK (IconName): The :code:`phone-check` icon. PHONE_FILLED (IconName): The :code:`phone-filled` icon. PHONE_INCOMING (IconName): The :code:`phone-incoming` icon. PHONE_OFF (IconName): The :code:`phone-off` icon. PHONE_OUTGOING (IconName): The :code:`phone-outgoing` icon. PHONE_PAUSE (IconName): The :code:`phone-pause` icon. PHONE_PLUS (IconName): The :code:`phone-plus` icon. PHONE_X (IconName): The :code:`phone-x` icon. PHOTO (IconName): The :code:`photo` icon. PHOTO_AI (IconName): The :code:`photo-ai` icon. PHOTO_BOLT (IconName): The :code:`photo-bolt` icon. PHOTO_CANCEL (IconName): The :code:`photo-cancel` icon. PHOTO_CHECK (IconName): The :code:`photo-check` icon. PHOTO_CODE (IconName): The :code:`photo-code` icon. PHOTO_COG (IconName): The :code:`photo-cog` icon. PHOTO_DOLLAR (IconName): The :code:`photo-dollar` icon. PHOTO_DOWN (IconName): The :code:`photo-down` icon. PHOTO_EDIT (IconName): The :code:`photo-edit` icon. PHOTO_EXCLAMATION (IconName): The :code:`photo-exclamation` icon. PHOTO_FILLED (IconName): The :code:`photo-filled` icon. PHOTO_HEART (IconName): The :code:`photo-heart` icon. PHOTO_MINUS (IconName): The :code:`photo-minus` icon. PHOTO_OFF (IconName): The :code:`photo-off` icon. PHOTO_PAUSE (IconName): The :code:`photo-pause` icon. PHOTO_PIN (IconName): The :code:`photo-pin` icon. PHOTO_PLUS (IconName): The :code:`photo-plus` icon. PHOTO_QUESTION (IconName): The :code:`photo-question` icon. PHOTO_SEARCH (IconName): The :code:`photo-search` icon. PHOTO_SENSOR (IconName): The :code:`photo-sensor` icon. PHOTO_SENSOR_2 (IconName): The :code:`photo-sensor-2` icon. PHOTO_SENSOR_3 (IconName): The :code:`photo-sensor-3` icon. PHOTO_SHARE (IconName): The :code:`photo-share` icon. PHOTO_SHIELD (IconName): The :code:`photo-shield` icon. PHOTO_STAR (IconName): The :code:`photo-star` icon. PHOTO_UP (IconName): The :code:`photo-up` icon. PHOTO_X (IconName): The :code:`photo-x` icon. PHYSOTHERAPIST (IconName): The :code:`physotherapist` icon. PIANO (IconName): The :code:`piano` icon. PICK (IconName): The :code:`pick` icon. PICTURE_IN_PICTURE (IconName): The :code:`picture-in-picture` icon. PICTURE_IN_PICTURE_OFF (IconName): The :code:`picture-in-picture-off` icon. PICTURE_IN_PICTURE_ON (IconName): The :code:`picture-in-picture-on` icon. PICTURE_IN_PICTURE_TOP (IconName): The :code:`picture-in-picture-top` icon. PIG (IconName): The :code:`pig` icon. PIG_MONEY (IconName): The :code:`pig-money` icon. PIG_OFF (IconName): The :code:`pig-off` icon. PILCROW (IconName): The :code:`pilcrow` icon. PILL (IconName): The :code:`pill` icon. PILL_OFF (IconName): The :code:`pill-off` icon. PILLS (IconName): The :code:`pills` icon. PIN (IconName): The :code:`pin` icon. PIN_FILLED (IconName): The :code:`pin-filled` icon. PING_PONG (IconName): The :code:`ping-pong` icon. PINNED (IconName): The :code:`pinned` icon. PINNED_FILLED (IconName): The :code:`pinned-filled` icon. PINNED_OFF (IconName): The :code:`pinned-off` icon. PIZZA (IconName): The :code:`pizza` icon. PIZZA_OFF (IconName): The :code:`pizza-off` icon. PLACEHOLDER (IconName): The :code:`placeholder` icon. PLANE (IconName): The :code:`plane` icon. PLANE_ARRIVAL (IconName): The :code:`plane-arrival` icon. PLANE_DEPARTURE (IconName): The :code:`plane-departure` icon. PLANE_INFLIGHT (IconName): The :code:`plane-inflight` icon. PLANE_OFF (IconName): The :code:`plane-off` icon. PLANE_TILT (IconName): The :code:`plane-tilt` icon. PLANET (IconName): The :code:`planet` icon. PLANET_OFF (IconName): The :code:`planet-off` icon. PLANT (IconName): The :code:`plant` icon. PLANT_2 (IconName): The :code:`plant-2` icon. PLANT_2_OFF (IconName): The :code:`plant-2-off` icon. PLANT_OFF (IconName): The :code:`plant-off` icon. PLAY_BASKETBALL (IconName): The :code:`play-basketball` icon. PLAY_CARD (IconName): The :code:`play-card` icon. PLAY_CARD_OFF (IconName): The :code:`play-card-off` icon. PLAY_FOOTBALL (IconName): The :code:`play-football` icon. PLAY_HANDBALL (IconName): The :code:`play-handball` icon. PLAY_VOLLEYBALL (IconName): The :code:`play-volleyball` icon. PLAYER_EJECT (IconName): The :code:`player-eject` icon. PLAYER_EJECT_FILLED (IconName): The :code:`player-eject-filled` icon. PLAYER_PAUSE (IconName): The :code:`player-pause` icon. PLAYER_PAUSE_FILLED (IconName): The :code:`player-pause-filled` icon. PLAYER_PLAY (IconName): The :code:`player-play` icon. PLAYER_PLAY_FILLED (IconName): The :code:`player-play-filled` icon. PLAYER_RECORD (IconName): The :code:`player-record` icon. PLAYER_RECORD_FILLED (IconName): The :code:`player-record-filled` icon. PLAYER_SKIP_BACK (IconName): The :code:`player-skip-back` icon. PLAYER_SKIP_BACK_FILLED (IconName): The :code:`player-skip-back-filled` icon. PLAYER_SKIP_FORWARD (IconName): The :code:`player-skip-forward` icon. PLAYER_SKIP_FORWARD_FILLED (IconName): The :code:`player-skip-forward-filled` icon. PLAYER_STOP (IconName): The :code:`player-stop` icon. PLAYER_STOP_FILLED (IconName): The :code:`player-stop-filled` icon. PLAYER_TRACK_NEXT (IconName): The :code:`player-track-next` icon. PLAYER_TRACK_NEXT_FILLED (IconName): The :code:`player-track-next-filled` icon. PLAYER_TRACK_PREV (IconName): The :code:`player-track-prev` icon. PLAYER_TRACK_PREV_FILLED (IconName): The :code:`player-track-prev-filled` icon. PLAYLIST (IconName): The :code:`playlist` icon. PLAYLIST_ADD (IconName): The :code:`playlist-add` icon. PLAYLIST_OFF (IconName): The :code:`playlist-off` icon. PLAYLIST_X (IconName): The :code:`playlist-x` icon. PLAYSTATION_CIRCLE (IconName): The :code:`playstation-circle` icon. PLAYSTATION_SQUARE (IconName): The :code:`playstation-square` icon. PLAYSTATION_TRIANGLE (IconName): The :code:`playstation-triangle` icon. PLAYSTATION_X (IconName): The :code:`playstation-x` icon. PLUG (IconName): The :code:`plug` icon. PLUG_CONNECTED (IconName): The :code:`plug-connected` icon. PLUG_CONNECTED_X (IconName): The :code:`plug-connected-x` icon. PLUG_OFF (IconName): The :code:`plug-off` icon. PLUG_X (IconName): The :code:`plug-x` icon. PLUS (IconName): The :code:`plus` icon. PLUS_EQUAL (IconName): The :code:`plus-equal` icon. PLUS_MINUS (IconName): The :code:`plus-minus` icon. PNG (IconName): The :code:`png` icon. PODIUM (IconName): The :code:`podium` icon. PODIUM_OFF (IconName): The :code:`podium-off` icon. POINT (IconName): The :code:`point` icon. POINT_FILLED (IconName): The :code:`point-filled` icon. POINT_OFF (IconName): The :code:`point-off` icon. POINTER (IconName): The :code:`pointer` icon. POINTER_BOLT (IconName): The :code:`pointer-bolt` icon. POINTER_CANCEL (IconName): The :code:`pointer-cancel` icon. POINTER_CHECK (IconName): The :code:`pointer-check` icon. POINTER_CODE (IconName): The :code:`pointer-code` icon. POINTER_COG (IconName): The :code:`pointer-cog` icon. POINTER_DOLLAR (IconName): The :code:`pointer-dollar` icon. POINTER_DOWN (IconName): The :code:`pointer-down` icon. POINTER_EXCLAMATION (IconName): The :code:`pointer-exclamation` icon. POINTER_HEART (IconName): The :code:`pointer-heart` icon. POINTER_MINUS (IconName): The :code:`pointer-minus` icon. POINTER_OFF (IconName): The :code:`pointer-off` icon. POINTER_PAUSE (IconName): The :code:`pointer-pause` icon. POINTER_PIN (IconName): The :code:`pointer-pin` icon. POINTER_PLUS (IconName): The :code:`pointer-plus` icon. POINTER_QUESTION (IconName): The :code:`pointer-question` icon. POINTER_SEARCH (IconName): The :code:`pointer-search` icon. POINTER_SHARE (IconName): The :code:`pointer-share` icon. POINTER_STAR (IconName): The :code:`pointer-star` icon. POINTER_UP (IconName): The :code:`pointer-up` icon. POINTER_X (IconName): The :code:`pointer-x` icon. POKEBALL (IconName): The :code:`pokeball` icon. POKEBALL_OFF (IconName): The :code:`pokeball-off` icon. POKER_CHIP (IconName): The :code:`poker-chip` icon. POLAROID (IconName): The :code:`polaroid` icon. POLAROID_FILLED (IconName): The :code:`polaroid-filled` icon. POLYGON (IconName): The :code:`polygon` icon. POLYGON_OFF (IconName): The :code:`polygon-off` icon. POO (IconName): The :code:`poo` icon. POOL (IconName): The :code:`pool` icon. POOL_OFF (IconName): The :code:`pool-off` icon. POWER (IconName): The :code:`power` icon. PRAY (IconName): The :code:`pray` icon. PREMIUM_RIGHTS (IconName): The :code:`premium-rights` icon. PRESCRIPTION (IconName): The :code:`prescription` icon. PRESENTATION (IconName): The :code:`presentation` icon. PRESENTATION_ANALYTICS (IconName): The :code:`presentation-analytics` icon. PRESENTATION_OFF (IconName): The :code:`presentation-off` icon. PRINTER (IconName): The :code:`printer` icon. PRINTER_OFF (IconName): The :code:`printer-off` icon. PRISM (IconName): The :code:`prism` icon. PRISM_OFF (IconName): The :code:`prism-off` icon. PRISM_PLUS (IconName): The :code:`prism-plus` icon. PRISON (IconName): The :code:`prison` icon. PROGRESS (IconName): The :code:`progress` icon. PROGRESS_ALERT (IconName): The :code:`progress-alert` icon. PROGRESS_BOLT (IconName): The :code:`progress-bolt` icon. PROGRESS_CHECK (IconName): The :code:`progress-check` icon. PROGRESS_DOWN (IconName): The :code:`progress-down` icon. PROGRESS_HELP (IconName): The :code:`progress-help` icon. PROGRESS_X (IconName): The :code:`progress-x` icon. PROMPT (IconName): The :code:`prompt` icon. PROPELLER (IconName): The :code:`propeller` icon. PROPELLER_OFF (IconName): The :code:`propeller-off` icon. PUMPKIN_SCARY (IconName): The :code:`pumpkin-scary` icon. PUZZLE (IconName): The :code:`puzzle` icon. PUZZLE_2 (IconName): The :code:`puzzle-2` icon. PUZZLE_FILLED (IconName): The :code:`puzzle-filled` icon. PUZZLE_OFF (IconName): The :code:`puzzle-off` icon. PYRAMID (IconName): The :code:`pyramid` icon. PYRAMID_OFF (IconName): The :code:`pyramid-off` icon. PYRAMID_PLUS (IconName): The :code:`pyramid-plus` icon. QRCODE (IconName): The :code:`qrcode` icon. QRCODE_OFF (IconName): The :code:`qrcode-off` icon. QUESTION_MARK (IconName): The :code:`question-mark` icon. QUOTE (IconName): The :code:`quote` icon. QUOTE_OFF (IconName): The :code:`quote-off` icon. RADAR (IconName): The :code:`radar` icon. RADAR_2 (IconName): The :code:`radar-2` icon. RADAR_OFF (IconName): The :code:`radar-off` icon. RADIO (IconName): The :code:`radio` icon. RADIO_OFF (IconName): The :code:`radio-off` icon. RADIOACTIVE (IconName): The :code:`radioactive` icon. RADIOACTIVE_FILLED (IconName): The :code:`radioactive-filled` icon. RADIOACTIVE_OFF (IconName): The :code:`radioactive-off` icon. RADIUS_BOTTOM_LEFT (IconName): The :code:`radius-bottom-left` icon. RADIUS_BOTTOM_RIGHT (IconName): The :code:`radius-bottom-right` icon. RADIUS_TOP_LEFT (IconName): The :code:`radius-top-left` icon. RADIUS_TOP_RIGHT (IconName): The :code:`radius-top-right` icon. RAINBOW (IconName): The :code:`rainbow` icon. RAINBOW_OFF (IconName): The :code:`rainbow-off` icon. RATING_12_PLUS (IconName): The :code:`rating-12-plus` icon. RATING_14_PLUS (IconName): The :code:`rating-14-plus` icon. RATING_16_PLUS (IconName): The :code:`rating-16-plus` icon. RATING_18_PLUS (IconName): The :code:`rating-18-plus` icon. RATING_21_PLUS (IconName): The :code:`rating-21-plus` icon. RAZOR (IconName): The :code:`razor` icon. RAZOR_ELECTRIC (IconName): The :code:`razor-electric` icon. RECEIPT (IconName): The :code:`receipt` icon. RECEIPT_2 (IconName): The :code:`receipt-2` icon. RECEIPT_OFF (IconName): The :code:`receipt-off` icon. RECEIPT_REFUND (IconName): The :code:`receipt-refund` icon. RECEIPT_TAX (IconName): The :code:`receipt-tax` icon. RECHARGING (IconName): The :code:`recharging` icon. RECORD_MAIL (IconName): The :code:`record-mail` icon. RECORD_MAIL_OFF (IconName): The :code:`record-mail-off` icon. RECTANGLE (IconName): The :code:`rectangle` icon. RECTANGLE_FILLED (IconName): The :code:`rectangle-filled` icon. RECTANGLE_ROUNDED_BOTTOM (IconName): The :code:`rectangle-rounded-bottom` icon. RECTANGLE_ROUNDED_TOP (IconName): The :code:`rectangle-rounded-top` icon. RECTANGLE_VERTICAL (IconName): The :code:`rectangle-vertical` icon. RECTANGLE_VERTICAL_FILLED (IconName): The :code:`rectangle-vertical-filled` icon. RECTANGULAR_PRISM (IconName): The :code:`rectangular-prism` icon. RECTANGULAR_PRISM_OFF (IconName): The :code:`rectangular-prism-off` icon. RECTANGULAR_PRISM_PLUS (IconName): The :code:`rectangular-prism-plus` icon. RECYCLE (IconName): The :code:`recycle` icon. RECYCLE_OFF (IconName): The :code:`recycle-off` icon. REFRESH (IconName): The :code:`refresh` icon. REFRESH_ALERT (IconName): The :code:`refresh-alert` icon. REFRESH_DOT (IconName): The :code:`refresh-dot` icon. REFRESH_OFF (IconName): The :code:`refresh-off` icon. REGEX (IconName): The :code:`regex` icon. REGEX_OFF (IconName): The :code:`regex-off` icon. REGISTERED (IconName): The :code:`registered` icon. RELATION_MANY_TO_MANY (IconName): The :code:`relation-many-to-many` icon. RELATION_ONE_TO_MANY (IconName): The :code:`relation-one-to-many` icon. RELATION_ONE_TO_ONE (IconName): The :code:`relation-one-to-one` icon. RELOAD (IconName): The :code:`reload` icon. REPEAT (IconName): The :code:`repeat` icon. REPEAT_OFF (IconName): The :code:`repeat-off` icon. REPEAT_ONCE (IconName): The :code:`repeat-once` icon. REPLACE (IconName): The :code:`replace` icon. REPLACE_FILLED (IconName): The :code:`replace-filled` icon. REPLACE_OFF (IconName): The :code:`replace-off` icon. REPORT (IconName): The :code:`report` icon. REPORT_ANALYTICS (IconName): The :code:`report-analytics` icon. REPORT_MEDICAL (IconName): The :code:`report-medical` icon. REPORT_MONEY (IconName): The :code:`report-money` icon. REPORT_OFF (IconName): The :code:`report-off` icon. REPORT_SEARCH (IconName): The :code:`report-search` icon. RESERVED_LINE (IconName): The :code:`reserved-line` icon. RESIZE (IconName): The :code:`resize` icon. RESTORE (IconName): The :code:`restore` icon. REWIND_BACKWARD_10 (IconName): The :code:`rewind-backward-10` icon. REWIND_BACKWARD_15 (IconName): The :code:`rewind-backward-15` icon. REWIND_BACKWARD_20 (IconName): The :code:`rewind-backward-20` icon. REWIND_BACKWARD_30 (IconName): The :code:`rewind-backward-30` icon. REWIND_BACKWARD_40 (IconName): The :code:`rewind-backward-40` icon. REWIND_BACKWARD_5 (IconName): The :code:`rewind-backward-5` icon. REWIND_BACKWARD_50 (IconName): The :code:`rewind-backward-50` icon. REWIND_BACKWARD_60 (IconName): The :code:`rewind-backward-60` icon. REWIND_FORWARD_10 (IconName): The :code:`rewind-forward-10` icon. REWIND_FORWARD_15 (IconName): The :code:`rewind-forward-15` icon. REWIND_FORWARD_20 (IconName): The :code:`rewind-forward-20` icon. REWIND_FORWARD_30 (IconName): The :code:`rewind-forward-30` icon. REWIND_FORWARD_40 (IconName): The :code:`rewind-forward-40` icon. REWIND_FORWARD_5 (IconName): The :code:`rewind-forward-5` icon. REWIND_FORWARD_50 (IconName): The :code:`rewind-forward-50` icon. REWIND_FORWARD_60 (IconName): The :code:`rewind-forward-60` icon. RIBBON_HEALTH (IconName): The :code:`ribbon-health` icon. RINGS (IconName): The :code:`rings` icon. RIPPLE (IconName): The :code:`ripple` icon. RIPPLE_OFF (IconName): The :code:`ripple-off` icon. ROAD (IconName): The :code:`road` icon. ROAD_OFF (IconName): The :code:`road-off` icon. ROAD_SIGN (IconName): The :code:`road-sign` icon. ROBOT (IconName): The :code:`robot` icon. ROBOT_OFF (IconName): The :code:`robot-off` icon. ROCKET (IconName): The :code:`rocket` icon. ROCKET_OFF (IconName): The :code:`rocket-off` icon. ROLLER_SKATING (IconName): The :code:`roller-skating` icon. ROLLERCOASTER (IconName): The :code:`rollercoaster` icon. ROLLERCOASTER_OFF (IconName): The :code:`rollercoaster-off` icon. ROSETTE (IconName): The :code:`rosette` icon. ROSETTE_FILLED (IconName): The :code:`rosette-filled` icon. ROSETTE_NUMBER_0 (IconName): The :code:`rosette-number-0` icon. ROSETTE_NUMBER_1 (IconName): The :code:`rosette-number-1` icon. ROSETTE_NUMBER_2 (IconName): The :code:`rosette-number-2` icon. ROSETTE_NUMBER_3 (IconName): The :code:`rosette-number-3` icon. ROSETTE_NUMBER_4 (IconName): The :code:`rosette-number-4` icon. ROSETTE_NUMBER_5 (IconName): The :code:`rosette-number-5` icon. ROSETTE_NUMBER_6 (IconName): The :code:`rosette-number-6` icon. ROSETTE_NUMBER_7 (IconName): The :code:`rosette-number-7` icon. ROSETTE_NUMBER_8 (IconName): The :code:`rosette-number-8` icon. ROSETTE_NUMBER_9 (IconName): The :code:`rosette-number-9` icon. ROTATE (IconName): The :code:`rotate` icon. ROTATE_2 (IconName): The :code:`rotate-2` icon. ROTATE_360 (IconName): The :code:`rotate-360` icon. ROTATE_CLOCKWISE (IconName): The :code:`rotate-clockwise` icon. ROTATE_CLOCKWISE_2 (IconName): The :code:`rotate-clockwise-2` icon. ROTATE_DOT (IconName): The :code:`rotate-dot` icon. ROTATE_RECTANGLE (IconName): The :code:`rotate-rectangle` icon. ROUTE (IconName): The :code:`route` icon. ROUTE_2 (IconName): The :code:`route-2` icon. ROUTE_OFF (IconName): The :code:`route-off` icon. ROUTER (IconName): The :code:`router` icon. ROUTER_OFF (IconName): The :code:`router-off` icon. ROW_INSERT_BOTTOM (IconName): The :code:`row-insert-bottom` icon. ROW_INSERT_TOP (IconName): The :code:`row-insert-top` icon. ROW_REMOVE (IconName): The :code:`row-remove` icon. RSS (IconName): The :code:`rss` icon. RUBBER_STAMP (IconName): The :code:`rubber-stamp` icon. RUBBER_STAMP_OFF (IconName): The :code:`rubber-stamp-off` icon. RULER (IconName): The :code:`ruler` icon. RULER_2 (IconName): The :code:`ruler-2` icon. RULER_2_OFF (IconName): The :code:`ruler-2-off` icon. RULER_3 (IconName): The :code:`ruler-3` icon. RULER_MEASURE (IconName): The :code:`ruler-measure` icon. RULER_OFF (IconName): The :code:`ruler-off` icon. RUN (IconName): The :code:`run` icon. S_TURN_DOWN (IconName): The :code:`s-turn-down` icon. S_TURN_LEFT (IconName): The :code:`s-turn-left` icon. S_TURN_RIGHT (IconName): The :code:`s-turn-right` icon. S_TURN_UP (IconName): The :code:`s-turn-up` icon. SAILBOAT (IconName): The :code:`sailboat` icon. SAILBOAT_2 (IconName): The :code:`sailboat-2` icon. SAILBOAT_OFF (IconName): The :code:`sailboat-off` icon. SALAD (IconName): The :code:`salad` icon. SALT (IconName): The :code:`salt` icon. SATELLITE (IconName): The :code:`satellite` icon. SATELLITE_OFF (IconName): The :code:`satellite-off` icon. SAUSAGE (IconName): The :code:`sausage` icon. SCALE (IconName): The :code:`scale` icon. SCALE_OFF (IconName): The :code:`scale-off` icon. SCALE_OUTLINE (IconName): The :code:`scale-outline` icon. SCALE_OUTLINE_OFF (IconName): The :code:`scale-outline-off` icon. SCAN (IconName): The :code:`scan` icon. SCAN_EYE (IconName): The :code:`scan-eye` icon. SCHEMA (IconName): The :code:`schema` icon. SCHEMA_OFF (IconName): The :code:`schema-off` icon. SCHOOL (IconName): The :code:`school` icon. SCHOOL_BELL (IconName): The :code:`school-bell` icon. SCHOOL_OFF (IconName): The :code:`school-off` icon. SCISSORS (IconName): The :code:`scissors` icon. SCISSORS_OFF (IconName): The :code:`scissors-off` icon. SCOOTER (IconName): The :code:`scooter` icon. SCOOTER_ELECTRIC (IconName): The :code:`scooter-electric` icon. SCOREBOARD (IconName): The :code:`scoreboard` icon. SCREEN_SHARE (IconName): The :code:`screen-share` icon. SCREEN_SHARE_OFF (IconName): The :code:`screen-share-off` icon. SCREENSHOT (IconName): The :code:`screenshot` icon. SCRIBBLE (IconName): The :code:`scribble` icon. SCRIBBLE_OFF (IconName): The :code:`scribble-off` icon. SCRIPT (IconName): The :code:`script` icon. SCRIPT_MINUS (IconName): The :code:`script-minus` icon. SCRIPT_PLUS (IconName): The :code:`script-plus` icon. SCRIPT_X (IconName): The :code:`script-x` icon. SCUBA_MASK (IconName): The :code:`scuba-mask` icon. SCUBA_MASK_OFF (IconName): The :code:`scuba-mask-off` icon. SDK (IconName): The :code:`sdk` icon. SEARCH (IconName): The :code:`search` icon. SEARCH_OFF (IconName): The :code:`search-off` icon. SECTION (IconName): The :code:`section` icon. SECTION_SIGN (IconName): The :code:`section-sign` icon. SEEDING (IconName): The :code:`seeding` icon. SEEDING_OFF (IconName): The :code:`seeding-off` icon. SELECT (IconName): The :code:`select` icon. SELECT_ALL (IconName): The :code:`select-all` icon. SELECTOR (IconName): The :code:`selector` icon. SEND (IconName): The :code:`send` icon. SEND_OFF (IconName): The :code:`send-off` icon. SEO (IconName): The :code:`seo` icon. SEPARATOR (IconName): The :code:`separator` icon. SEPARATOR_HORIZONTAL (IconName): The :code:`separator-horizontal` icon. SEPARATOR_VERTICAL (IconName): The :code:`separator-vertical` icon. SERVER (IconName): The :code:`server` icon. SERVER_2 (IconName): The :code:`server-2` icon. SERVER_BOLT (IconName): The :code:`server-bolt` icon. SERVER_COG (IconName): The :code:`server-cog` icon. SERVER_OFF (IconName): The :code:`server-off` icon. SERVICEMARK (IconName): The :code:`servicemark` icon. SETTINGS (IconName): The :code:`settings` icon. SETTINGS_2 (IconName): The :code:`settings-2` icon. SETTINGS_AUTOMATION (IconName): The :code:`settings-automation` icon. SETTINGS_BOLT (IconName): The :code:`settings-bolt` icon. SETTINGS_CANCEL (IconName): The :code:`settings-cancel` icon. SETTINGS_CHECK (IconName): The :code:`settings-check` icon. SETTINGS_CODE (IconName): The :code:`settings-code` icon. SETTINGS_COG (IconName): The :code:`settings-cog` icon. SETTINGS_DOLLAR (IconName): The :code:`settings-dollar` icon. SETTINGS_DOWN (IconName): The :code:`settings-down` icon. SETTINGS_EXCLAMATION (IconName): The :code:`settings-exclamation` icon. SETTINGS_FILLED (IconName): The :code:`settings-filled` icon. SETTINGS_HEART (IconName): The :code:`settings-heart` icon. SETTINGS_MINUS (IconName): The :code:`settings-minus` icon. SETTINGS_OFF (IconName): The :code:`settings-off` icon. SETTINGS_PAUSE (IconName): The :code:`settings-pause` icon. SETTINGS_PIN (IconName): The :code:`settings-pin` icon. SETTINGS_PLUS (IconName): The :code:`settings-plus` icon. SETTINGS_QUESTION (IconName): The :code:`settings-question` icon. SETTINGS_SEARCH (IconName): The :code:`settings-search` icon. SETTINGS_SHARE (IconName): The :code:`settings-share` icon. SETTINGS_STAR (IconName): The :code:`settings-star` icon. SETTINGS_UP (IconName): The :code:`settings-up` icon. SETTINGS_X (IconName): The :code:`settings-x` icon. SHADOW (IconName): The :code:`shadow` icon. SHADOW_OFF (IconName): The :code:`shadow-off` icon. SHAPE (IconName): The :code:`shape` icon. SHAPE_2 (IconName): The :code:`shape-2` icon. SHAPE_3 (IconName): The :code:`shape-3` icon. SHAPE_OFF (IconName): The :code:`shape-off` icon. SHARE (IconName): The :code:`share` icon. SHARE_2 (IconName): The :code:`share-2` icon. SHARE_3 (IconName): The :code:`share-3` icon. SHARE_OFF (IconName): The :code:`share-off` icon. SHI_JUMPING (IconName): The :code:`shi-jumping` icon. SHIELD (IconName): The :code:`shield` icon. SHIELD_BOLT (IconName): The :code:`shield-bolt` icon. SHIELD_CANCEL (IconName): The :code:`shield-cancel` icon. SHIELD_CHECK (IconName): The :code:`shield-check` icon. SHIELD_CHECK_FILLED (IconName): The :code:`shield-check-filled` icon. SHIELD_CHECKERED (IconName): The :code:`shield-checkered` icon. SHIELD_CHECKERED_FILLED (IconName): The :code:`shield-checkered-filled` icon. SHIELD_CHEVRON (IconName): The :code:`shield-chevron` icon. SHIELD_CODE (IconName): The :code:`shield-code` icon. SHIELD_COG (IconName): The :code:`shield-cog` icon. SHIELD_DOLLAR (IconName): The :code:`shield-dollar` icon. SHIELD_DOWN (IconName): The :code:`shield-down` icon. SHIELD_EXCLAMATION (IconName): The :code:`shield-exclamation` icon. SHIELD_FILLED (IconName): The :code:`shield-filled` icon. SHIELD_HALF (IconName): The :code:`shield-half` icon. SHIELD_HALF_FILLED (IconName): The :code:`shield-half-filled` icon. SHIELD_HEART (IconName): The :code:`shield-heart` icon. SHIELD_LOCK (IconName): The :code:`shield-lock` icon. SHIELD_LOCK_FILLED (IconName): The :code:`shield-lock-filled` icon. SHIELD_MINUS (IconName): The :code:`shield-minus` icon. SHIELD_OFF (IconName): The :code:`shield-off` icon. SHIELD_PAUSE (IconName): The :code:`shield-pause` icon. SHIELD_PIN (IconName): The :code:`shield-pin` icon. SHIELD_PLUS (IconName): The :code:`shield-plus` icon. SHIELD_QUESTION (IconName): The :code:`shield-question` icon. SHIELD_SEARCH (IconName): The :code:`shield-search` icon. SHIELD_SHARE (IconName): The :code:`shield-share` icon. SHIELD_STAR (IconName): The :code:`shield-star` icon. SHIELD_UP (IconName): The :code:`shield-up` icon. SHIELD_X (IconName): The :code:`shield-x` icon. SHIP (IconName): The :code:`ship` icon. SHIP_OFF (IconName): The :code:`ship-off` icon. SHIRT (IconName): The :code:`shirt` icon. SHIRT_FILLED (IconName): The :code:`shirt-filled` icon. SHIRT_OFF (IconName): The :code:`shirt-off` icon. SHIRT_SPORT (IconName): The :code:`shirt-sport` icon. SHOE (IconName): The :code:`shoe` icon. SHOE_OFF (IconName): The :code:`shoe-off` icon. SHOPPING_BAG (IconName): The :code:`shopping-bag` icon. SHOPPING_CART (IconName): The :code:`shopping-cart` icon. SHOPPING_CART_DISCOUNT (IconName): The :code:`shopping-cart-discount` icon. SHOPPING_CART_OFF (IconName): The :code:`shopping-cart-off` icon. SHOPPING_CART_PLUS (IconName): The :code:`shopping-cart-plus` icon. SHOPPING_CART_X (IconName): The :code:`shopping-cart-x` icon. SHOVEL (IconName): The :code:`shovel` icon. SHREDDER (IconName): The :code:`shredder` icon. SIGN_LEFT (IconName): The :code:`sign-left` icon. SIGN_LEFT_FILLED (IconName): The :code:`sign-left-filled` icon. SIGN_RIGHT (IconName): The :code:`sign-right` icon. SIGN_RIGHT_FILLED (IconName): The :code:`sign-right-filled` icon. SIGNAL_2G (IconName): The :code:`signal-2g` icon. SIGNAL_3G (IconName): The :code:`signal-3g` icon. SIGNAL_4G (IconName): The :code:`signal-4g` icon. SIGNAL_4G_PLUS (IconName): The :code:`signal-4g-plus` icon. SIGNAL_5G (IconName): The :code:`signal-5g` icon. SIGNAL_6G (IconName): The :code:`signal-6g` icon. SIGNAL_E (IconName): The :code:`signal-e` icon. SIGNAL_G (IconName): The :code:`signal-g` icon. SIGNAL_H (IconName): The :code:`signal-h` icon. SIGNAL_H_PLUS (IconName): The :code:`signal-h-plus` icon. SIGNAL_LTE (IconName): The :code:`signal-lte` icon. SIGNATURE (IconName): The :code:`signature` icon. SIGNATURE_OFF (IconName): The :code:`signature-off` icon. SITEMAP (IconName): The :code:`sitemap` icon. SITEMAP_OFF (IconName): The :code:`sitemap-off` icon. SKATEBOARD (IconName): The :code:`skateboard` icon. SKATEBOARD_OFF (IconName): The :code:`skateboard-off` icon. SKATEBOARDING (IconName): The :code:`skateboarding` icon. SKULL (IconName): The :code:`skull` icon. SLASH (IconName): The :code:`slash` icon. SLASHES (IconName): The :code:`slashes` icon. SLEIGH (IconName): The :code:`sleigh` icon. SLICE (IconName): The :code:`slice` icon. SLIDESHOW (IconName): The :code:`slideshow` icon. SMART_HOME (IconName): The :code:`smart-home` icon. SMART_HOME_OFF (IconName): The :code:`smart-home-off` icon. SMOKING (IconName): The :code:`smoking` icon. SMOKING_NO (IconName): The :code:`smoking-no` icon. SNOWFLAKE (IconName): The :code:`snowflake` icon. SNOWFLAKE_OFF (IconName): The :code:`snowflake-off` icon. SNOWMAN (IconName): The :code:`snowman` icon. SOCCER_FIELD (IconName): The :code:`soccer-field` icon. SOCIAL (IconName): The :code:`social` icon. SOCIAL_OFF (IconName): The :code:`social-off` icon. SOCK (IconName): The :code:`sock` icon. SOFA (IconName): The :code:`sofa` icon. SOFA_OFF (IconName): The :code:`sofa-off` icon. SOLAR_PANEL (IconName): The :code:`solar-panel` icon. SOLAR_PANEL_2 (IconName): The :code:`solar-panel-2` icon. SORT_0_9 (IconName): The :code:`sort-0-9` icon. SORT_9_0 (IconName): The :code:`sort-9-0` icon. SORT_A_Z (IconName): The :code:`sort-a-z` icon. SORT_ASCENDING (IconName): The :code:`sort-ascending` icon. SORT_ASCENDING_2 (IconName): The :code:`sort-ascending-2` icon. SORT_ASCENDING_LETTERS (IconName): The :code:`sort-ascending-letters` icon. SORT_ASCENDING_NUMBERS (IconName): The :code:`sort-ascending-numbers` icon. SORT_DESCENDING (IconName): The :code:`sort-descending` icon. SORT_DESCENDING_2 (IconName): The :code:`sort-descending-2` icon. SORT_DESCENDING_LETTERS (IconName): The :code:`sort-descending-letters` icon. SORT_DESCENDING_NUMBERS (IconName): The :code:`sort-descending-numbers` icon. SORT_Z_A (IconName): The :code:`sort-z-a` icon. SOS (IconName): The :code:`sos` icon. SOUP (IconName): The :code:`soup` icon. SOUP_OFF (IconName): The :code:`soup-off` icon. SOURCE_CODE (IconName): The :code:`source-code` icon. SPACE (IconName): The :code:`space` icon. SPACE_OFF (IconName): The :code:`space-off` icon. SPACING_HORIZONTAL (IconName): The :code:`spacing-horizontal` icon. SPACING_VERTICAL (IconName): The :code:`spacing-vertical` icon. SPADE (IconName): The :code:`spade` icon. SPADE_FILLED (IconName): The :code:`spade-filled` icon. SPARKLES (IconName): The :code:`sparkles` icon. SPEAKERPHONE (IconName): The :code:`speakerphone` icon. SPEEDBOAT (IconName): The :code:`speedboat` icon. SPHERE (IconName): The :code:`sphere` icon. SPHERE_OFF (IconName): The :code:`sphere-off` icon. SPHERE_PLUS (IconName): The :code:`sphere-plus` icon. SPIDER (IconName): The :code:`spider` icon. SPIRAL (IconName): The :code:`spiral` icon. SPIRAL_OFF (IconName): The :code:`spiral-off` icon. SPORT_BILLARD (IconName): The :code:`sport-billard` icon. SPRAY (IconName): The :code:`spray` icon. SPY (IconName): The :code:`spy` icon. SPY_OFF (IconName): The :code:`spy-off` icon. SQL (IconName): The :code:`sql` icon. SQUARE (IconName): The :code:`square` icon. SQUARE_0_FILLED (IconName): The :code:`square-0-filled` icon. SQUARE_1_FILLED (IconName): The :code:`square-1-filled` icon. SQUARE_2_FILLED (IconName): The :code:`square-2-filled` icon. SQUARE_3_FILLED (IconName): The :code:`square-3-filled` icon. SQUARE_4_FILLED (IconName): The :code:`square-4-filled` icon. SQUARE_5_FILLED (IconName): The :code:`square-5-filled` icon. SQUARE_6_FILLED (IconName): The :code:`square-6-filled` icon. SQUARE_7_FILLED (IconName): The :code:`square-7-filled` icon. SQUARE_8_FILLED (IconName): The :code:`square-8-filled` icon. SQUARE_9_FILLED (IconName): The :code:`square-9-filled` icon. SQUARE_ARROW_DOWN (IconName): The :code:`square-arrow-down` icon. SQUARE_ARROW_LEFT (IconName): The :code:`square-arrow-left` icon. SQUARE_ARROW_RIGHT (IconName): The :code:`square-arrow-right` icon. SQUARE_ARROW_UP (IconName): The :code:`square-arrow-up` icon. SQUARE_ASTERISK (IconName): The :code:`square-asterisk` icon. SQUARE_CHECK (IconName): The :code:`square-check` icon. SQUARE_CHECK_FILLED (IconName): The :code:`square-check-filled` icon. SQUARE_CHEVRON_DOWN (IconName): The :code:`square-chevron-down` icon. SQUARE_CHEVRON_LEFT (IconName): The :code:`square-chevron-left` icon. SQUARE_CHEVRON_RIGHT (IconName): The :code:`square-chevron-right` icon. SQUARE_CHEVRON_UP (IconName): The :code:`square-chevron-up` icon. SQUARE_CHEVRONS_DOWN (IconName): The :code:`square-chevrons-down` icon. SQUARE_CHEVRONS_LEFT (IconName): The :code:`square-chevrons-left` icon. SQUARE_CHEVRONS_RIGHT (IconName): The :code:`square-chevrons-right` icon. SQUARE_CHEVRONS_UP (IconName): The :code:`square-chevrons-up` icon. SQUARE_DOT (IconName): The :code:`square-dot` icon. SQUARE_F0 (IconName): The :code:`square-f0` icon. SQUARE_F0_FILLED (IconName): The :code:`square-f0-filled` icon. SQUARE_F1 (IconName): The :code:`square-f1` icon. SQUARE_F1_FILLED (IconName): The :code:`square-f1-filled` icon. SQUARE_F2 (IconName): The :code:`square-f2` icon. SQUARE_F2_FILLED (IconName): The :code:`square-f2-filled` icon. SQUARE_F3 (IconName): The :code:`square-f3` icon. SQUARE_F3_FILLED (IconName): The :code:`square-f3-filled` icon. SQUARE_F4 (IconName): The :code:`square-f4` icon. SQUARE_F4_FILLED (IconName): The :code:`square-f4-filled` icon. SQUARE_F5 (IconName): The :code:`square-f5` icon. SQUARE_F5_FILLED (IconName): The :code:`square-f5-filled` icon. SQUARE_F6 (IconName): The :code:`square-f6` icon. SQUARE_F6_FILLED (IconName): The :code:`square-f6-filled` icon. SQUARE_F7 (IconName): The :code:`square-f7` icon. SQUARE_F7_FILLED (IconName): The :code:`square-f7-filled` icon. SQUARE_F8 (IconName): The :code:`square-f8` icon. SQUARE_F8_FILLED (IconName): The :code:`square-f8-filled` icon. SQUARE_F9 (IconName): The :code:`square-f9` icon. SQUARE_F9_FILLED (IconName): The :code:`square-f9-filled` icon. SQUARE_FORBID (IconName): The :code:`square-forbid` icon. SQUARE_FORBID_2 (IconName): The :code:`square-forbid-2` icon. SQUARE_HALF (IconName): The :code:`square-half` icon. SQUARE_KEY (IconName): The :code:`square-key` icon. SQUARE_LETTER_A (IconName): The :code:`square-letter-a` icon. SQUARE_LETTER_B (IconName): The :code:`square-letter-b` icon. SQUARE_LETTER_C (IconName): The :code:`square-letter-c` icon. SQUARE_LETTER_D (IconName): The :code:`square-letter-d` icon. SQUARE_LETTER_E (IconName): The :code:`square-letter-e` icon. SQUARE_LETTER_F (IconName): The :code:`square-letter-f` icon. SQUARE_LETTER_G (IconName): The :code:`square-letter-g` icon. SQUARE_LETTER_H (IconName): The :code:`square-letter-h` icon. SQUARE_LETTER_I (IconName): The :code:`square-letter-i` icon. SQUARE_LETTER_J (IconName): The :code:`square-letter-j` icon. SQUARE_LETTER_K (IconName): The :code:`square-letter-k` icon. SQUARE_LETTER_L (IconName): The :code:`square-letter-l` icon. SQUARE_LETTER_M (IconName): The :code:`square-letter-m` icon. SQUARE_LETTER_N (IconName): The :code:`square-letter-n` icon. SQUARE_LETTER_O (IconName): The :code:`square-letter-o` icon. SQUARE_LETTER_P (IconName): The :code:`square-letter-p` icon. SQUARE_LETTER_Q (IconName): The :code:`square-letter-q` icon. SQUARE_LETTER_R (IconName): The :code:`square-letter-r` icon. SQUARE_LETTER_S (IconName): The :code:`square-letter-s` icon. SQUARE_LETTER_T (IconName): The :code:`square-letter-t` icon. SQUARE_LETTER_U (IconName): The :code:`square-letter-u` icon. SQUARE_LETTER_V (IconName): The :code:`square-letter-v` icon. SQUARE_LETTER_W (IconName): The :code:`square-letter-w` icon. SQUARE_LETTER_X (IconName): The :code:`square-letter-x` icon. SQUARE_LETTER_Y (IconName): The :code:`square-letter-y` icon. SQUARE_LETTER_Z (IconName): The :code:`square-letter-z` icon. SQUARE_MINUS (IconName): The :code:`square-minus` icon. SQUARE_NUMBER_0 (IconName): The :code:`square-number-0` icon. SQUARE_NUMBER_1 (IconName): The :code:`square-number-1` icon. SQUARE_NUMBER_2 (IconName): The :code:`square-number-2` icon. SQUARE_NUMBER_3 (IconName): The :code:`square-number-3` icon. SQUARE_NUMBER_4 (IconName): The :code:`square-number-4` icon. SQUARE_NUMBER_5 (IconName): The :code:`square-number-5` icon. SQUARE_NUMBER_6 (IconName): The :code:`square-number-6` icon. SQUARE_NUMBER_7 (IconName): The :code:`square-number-7` icon. SQUARE_NUMBER_8 (IconName): The :code:`square-number-8` icon. SQUARE_NUMBER_9 (IconName): The :code:`square-number-9` icon. SQUARE_OFF (IconName): The :code:`square-off` icon. SQUARE_PLUS (IconName): The :code:`square-plus` icon. SQUARE_ROOT (IconName): The :code:`square-root` icon. SQUARE_ROOT_2 (IconName): The :code:`square-root-2` icon. SQUARE_ROTATED (IconName): The :code:`square-rotated` icon. SQUARE_ROTATED_FILLED (IconName): The :code:`square-rotated-filled` icon. SQUARE_ROTATED_FORBID (IconName): The :code:`square-rotated-forbid` icon. SQUARE_ROTATED_FORBID_2 (IconName): The :code:`square-rotated-forbid-2` icon. SQUARE_ROTATED_OFF (IconName): The :code:`square-rotated-off` icon. SQUARE_ROUNDED (IconName): The :code:`square-rounded` icon. SQUARE_ROUNDED_ARROW_DOWN (IconName): The :code:`square-rounded-arrow-down` icon. SQUARE_ROUNDED_ARROW_DOWN_FILLED (IconName): The :code:`square-rounded-arrow-down-filled` icon. SQUARE_ROUNDED_ARROW_LEFT (IconName): The :code:`square-rounded-arrow-left` icon. SQUARE_ROUNDED_ARROW_LEFT_FILLED (IconName): The :code:`square-rounded-arrow-left-filled` icon. SQUARE_ROUNDED_ARROW_RIGHT (IconName): The :code:`square-rounded-arrow-right` icon. SQUARE_ROUNDED_ARROW_RIGHT_FILLED (IconName): The :code:`square-rounded-arrow-right-filled` icon. SQUARE_ROUNDED_ARROW_UP (IconName): The :code:`square-rounded-arrow-up` icon. SQUARE_ROUNDED_ARROW_UP_FILLED (IconName): The :code:`square-rounded-arrow-up-filled` icon. SQUARE_ROUNDED_CHECK (IconName): The :code:`square-rounded-check` icon. SQUARE_ROUNDED_CHECK_FILLED (IconName): The :code:`square-rounded-check-filled` icon. SQUARE_ROUNDED_CHEVRON_DOWN (IconName): The :code:`square-rounded-chevron-down` icon. SQUARE_ROUNDED_CHEVRON_DOWN_FILLED (IconName): The :code:`square-rounded-chevron-down-filled` icon. SQUARE_ROUNDED_CHEVRON_LEFT (IconName): The :code:`square-rounded-chevron-left` icon. SQUARE_ROUNDED_CHEVRON_LEFT_FILLED (IconName): The :code:`square-rounded-chevron-left-filled` icon. SQUARE_ROUNDED_CHEVRON_RIGHT (IconName): The :code:`square-rounded-chevron-right` icon. SQUARE_ROUNDED_CHEVRON_RIGHT_FILLED (IconName): The :code:`square-rounded-chevron-right-filled` icon. SQUARE_ROUNDED_CHEVRON_UP (IconName): The :code:`square-rounded-chevron-up` icon. SQUARE_ROUNDED_CHEVRON_UP_FILLED (IconName): The :code:`square-rounded-chevron-up-filled` icon. SQUARE_ROUNDED_CHEVRONS_DOWN (IconName): The :code:`square-rounded-chevrons-down` icon. SQUARE_ROUNDED_CHEVRONS_DOWN_FILLED (IconName): The :code:`square-rounded-chevrons-down-filled` icon. SQUARE_ROUNDED_CHEVRONS_LEFT (IconName): The :code:`square-rounded-chevrons-left` icon. SQUARE_ROUNDED_CHEVRONS_LEFT_FILLED (IconName): The :code:`square-rounded-chevrons-left-filled` icon. SQUARE_ROUNDED_CHEVRONS_RIGHT (IconName): The :code:`square-rounded-chevrons-right` icon. SQUARE_ROUNDED_CHEVRONS_RIGHT_FILLED (IconName): The :code:`square-rounded-chevrons-right-filled` icon. SQUARE_ROUNDED_CHEVRONS_UP (IconName): The :code:`square-rounded-chevrons-up` icon. SQUARE_ROUNDED_CHEVRONS_UP_FILLED (IconName): The :code:`square-rounded-chevrons-up-filled` icon. SQUARE_ROUNDED_FILLED (IconName): The :code:`square-rounded-filled` icon. SQUARE_ROUNDED_LETTER_A (IconName): The :code:`square-rounded-letter-a` icon. SQUARE_ROUNDED_LETTER_B (IconName): The :code:`square-rounded-letter-b` icon. SQUARE_ROUNDED_LETTER_C (IconName): The :code:`square-rounded-letter-c` icon. SQUARE_ROUNDED_LETTER_D (IconName): The :code:`square-rounded-letter-d` icon. SQUARE_ROUNDED_LETTER_E (IconName): The :code:`square-rounded-letter-e` icon. SQUARE_ROUNDED_LETTER_F (IconName): The :code:`square-rounded-letter-f` icon. SQUARE_ROUNDED_LETTER_G (IconName): The :code:`square-rounded-letter-g` icon. SQUARE_ROUNDED_LETTER_H (IconName): The :code:`square-rounded-letter-h` icon. SQUARE_ROUNDED_LETTER_I (IconName): The :code:`square-rounded-letter-i` icon. SQUARE_ROUNDED_LETTER_J (IconName): The :code:`square-rounded-letter-j` icon. SQUARE_ROUNDED_LETTER_K (IconName): The :code:`square-rounded-letter-k` icon. SQUARE_ROUNDED_LETTER_L (IconName): The :code:`square-rounded-letter-l` icon. SQUARE_ROUNDED_LETTER_M (IconName): The :code:`square-rounded-letter-m` icon. SQUARE_ROUNDED_LETTER_N (IconName): The :code:`square-rounded-letter-n` icon. SQUARE_ROUNDED_LETTER_O (IconName): The :code:`square-rounded-letter-o` icon. SQUARE_ROUNDED_LETTER_P (IconName): The :code:`square-rounded-letter-p` icon. SQUARE_ROUNDED_LETTER_Q (IconName): The :code:`square-rounded-letter-q` icon. SQUARE_ROUNDED_LETTER_R (IconName): The :code:`square-rounded-letter-r` icon. SQUARE_ROUNDED_LETTER_S (IconName): The :code:`square-rounded-letter-s` icon. SQUARE_ROUNDED_LETTER_T (IconName): The :code:`square-rounded-letter-t` icon. SQUARE_ROUNDED_LETTER_U (IconName): The :code:`square-rounded-letter-u` icon. SQUARE_ROUNDED_LETTER_V (IconName): The :code:`square-rounded-letter-v` icon. SQUARE_ROUNDED_LETTER_W (IconName): The :code:`square-rounded-letter-w` icon. SQUARE_ROUNDED_LETTER_X (IconName): The :code:`square-rounded-letter-x` icon. SQUARE_ROUNDED_LETTER_Y (IconName): The :code:`square-rounded-letter-y` icon. SQUARE_ROUNDED_LETTER_Z (IconName): The :code:`square-rounded-letter-z` icon. SQUARE_ROUNDED_MINUS (IconName): The :code:`square-rounded-minus` icon. SQUARE_ROUNDED_NUMBER_0 (IconName): The :code:`square-rounded-number-0` icon. SQUARE_ROUNDED_NUMBER_0_FILLED (IconName): The :code:`square-rounded-number-0-filled` icon. SQUARE_ROUNDED_NUMBER_1 (IconName): The :code:`square-rounded-number-1` icon. SQUARE_ROUNDED_NUMBER_1_FILLED (IconName): The :code:`square-rounded-number-1-filled` icon. SQUARE_ROUNDED_NUMBER_2 (IconName): The :code:`square-rounded-number-2` icon. SQUARE_ROUNDED_NUMBER_2_FILLED (IconName): The :code:`square-rounded-number-2-filled` icon. SQUARE_ROUNDED_NUMBER_3 (IconName): The :code:`square-rounded-number-3` icon. SQUARE_ROUNDED_NUMBER_3_FILLED (IconName): The :code:`square-rounded-number-3-filled` icon. SQUARE_ROUNDED_NUMBER_4 (IconName): The :code:`square-rounded-number-4` icon. SQUARE_ROUNDED_NUMBER_4_FILLED (IconName): The :code:`square-rounded-number-4-filled` icon. SQUARE_ROUNDED_NUMBER_5 (IconName): The :code:`square-rounded-number-5` icon. SQUARE_ROUNDED_NUMBER_5_FILLED (IconName): The :code:`square-rounded-number-5-filled` icon. SQUARE_ROUNDED_NUMBER_6 (IconName): The :code:`square-rounded-number-6` icon. SQUARE_ROUNDED_NUMBER_6_FILLED (IconName): The :code:`square-rounded-number-6-filled` icon. SQUARE_ROUNDED_NUMBER_7 (IconName): The :code:`square-rounded-number-7` icon. SQUARE_ROUNDED_NUMBER_7_FILLED (IconName): The :code:`square-rounded-number-7-filled` icon. SQUARE_ROUNDED_NUMBER_8 (IconName): The :code:`square-rounded-number-8` icon. SQUARE_ROUNDED_NUMBER_8_FILLED (IconName): The :code:`square-rounded-number-8-filled` icon. SQUARE_ROUNDED_NUMBER_9 (IconName): The :code:`square-rounded-number-9` icon. SQUARE_ROUNDED_NUMBER_9_FILLED (IconName): The :code:`square-rounded-number-9-filled` icon. SQUARE_ROUNDED_PLUS (IconName): The :code:`square-rounded-plus` icon. SQUARE_ROUNDED_PLUS_FILLED (IconName): The :code:`square-rounded-plus-filled` icon. SQUARE_ROUNDED_X (IconName): The :code:`square-rounded-x` icon. SQUARE_ROUNDED_X_FILLED (IconName): The :code:`square-rounded-x-filled` icon. SQUARE_TOGGLE (IconName): The :code:`square-toggle` icon. SQUARE_TOGGLE_HORIZONTAL (IconName): The :code:`square-toggle-horizontal` icon. SQUARE_X (IconName): The :code:`square-x` icon. SQUARES_DIAGONAL (IconName): The :code:`squares-diagonal` icon. SQUARES_FILLED (IconName): The :code:`squares-filled` icon. STACK (IconName): The :code:`stack` icon. STACK_2 (IconName): The :code:`stack-2` icon. STACK_3 (IconName): The :code:`stack-3` icon. STACK_POP (IconName): The :code:`stack-pop` icon. STACK_PUSH (IconName): The :code:`stack-push` icon. STAIRS (IconName): The :code:`stairs` icon. STAIRS_DOWN (IconName): The :code:`stairs-down` icon. STAIRS_UP (IconName): The :code:`stairs-up` icon. STAR (IconName): The :code:`star` icon. STAR_FILLED (IconName): The :code:`star-filled` icon. STAR_HALF (IconName): The :code:`star-half` icon. STAR_HALF_FILLED (IconName): The :code:`star-half-filled` icon. STAR_OFF (IconName): The :code:`star-off` icon. STARS (IconName): The :code:`stars` icon. STARS_FILLED (IconName): The :code:`stars-filled` icon. STARS_OFF (IconName): The :code:`stars-off` icon. STATUS_CHANGE (IconName): The :code:`status-change` icon. STEAM (IconName): The :code:`steam` icon. STEERING_WHEEL (IconName): The :code:`steering-wheel` icon. STEERING_WHEEL_OFF (IconName): The :code:`steering-wheel-off` icon. STEP_INTO (IconName): The :code:`step-into` icon. STEP_OUT (IconName): The :code:`step-out` icon. STEREO_GLASSES (IconName): The :code:`stereo-glasses` icon. STETHOSCOPE (IconName): The :code:`stethoscope` icon. STETHOSCOPE_OFF (IconName): The :code:`stethoscope-off` icon. STICKER (IconName): The :code:`sticker` icon. STORM (IconName): The :code:`storm` icon. STORM_OFF (IconName): The :code:`storm-off` icon. STRETCHING (IconName): The :code:`stretching` icon. STRETCHING_2 (IconName): The :code:`stretching-2` icon. STRIKETHROUGH (IconName): The :code:`strikethrough` icon. SUBMARINE (IconName): The :code:`submarine` icon. SUBSCRIPT (IconName): The :code:`subscript` icon. SUBTASK (IconName): The :code:`subtask` icon. SUM (IconName): The :code:`sum` icon. SUM_OFF (IconName): The :code:`sum-off` icon. SUN (IconName): The :code:`sun` icon. SUN_FILLED (IconName): The :code:`sun-filled` icon. SUN_HIGH (IconName): The :code:`sun-high` icon. SUN_LOW (IconName): The :code:`sun-low` icon. SUN_MOON (IconName): The :code:`sun-moon` icon. SUN_OFF (IconName): The :code:`sun-off` icon. SUN_WIND (IconName): The :code:`sun-wind` icon. SUNGLASSES (IconName): The :code:`sunglasses` icon. SUNRISE (IconName): The :code:`sunrise` icon. SUNSET (IconName): The :code:`sunset` icon. SUNSET_2 (IconName): The :code:`sunset-2` icon. SUPERSCRIPT (IconName): The :code:`superscript` icon. SVG (IconName): The :code:`svg` icon. SWIMMING (IconName): The :code:`swimming` icon. SWIPE (IconName): The :code:`swipe` icon. SWITCH (IconName): The :code:`switch` icon. SWITCH_2 (IconName): The :code:`switch-2` icon. SWITCH_3 (IconName): The :code:`switch-3` icon. SWITCH_HORIZONTAL (IconName): The :code:`switch-horizontal` icon. SWITCH_VERTICAL (IconName): The :code:`switch-vertical` icon. SWORD (IconName): The :code:`sword` icon. SWORD_OFF (IconName): The :code:`sword-off` icon. SWORDS (IconName): The :code:`swords` icon. TABLE (IconName): The :code:`table` icon. TABLE_ALIAS (IconName): The :code:`table-alias` icon. TABLE_COLUMN (IconName): The :code:`table-column` icon. TABLE_DOWN (IconName): The :code:`table-down` icon. TABLE_EXPORT (IconName): The :code:`table-export` icon. TABLE_FILLED (IconName): The :code:`table-filled` icon. TABLE_HEART (IconName): The :code:`table-heart` icon. TABLE_IMPORT (IconName): The :code:`table-import` icon. TABLE_MINUS (IconName): The :code:`table-minus` icon. TABLE_OFF (IconName): The :code:`table-off` icon. TABLE_OPTIONS (IconName): The :code:`table-options` icon. TABLE_PLUS (IconName): The :code:`table-plus` icon. TABLE_ROW (IconName): The :code:`table-row` icon. TABLE_SHARE (IconName): The :code:`table-share` icon. TABLE_SHORTCUT (IconName): The :code:`table-shortcut` icon. TAG (IconName): The :code:`tag` icon. TAG_OFF (IconName): The :code:`tag-off` icon. TAGS (IconName): The :code:`tags` icon. TAGS_OFF (IconName): The :code:`tags-off` icon. TALLYMARK_1 (IconName): The :code:`tallymark-1` icon. TALLYMARK_2 (IconName): The :code:`tallymark-2` icon. TALLYMARK_3 (IconName): The :code:`tallymark-3` icon. TALLYMARK_4 (IconName): The :code:`tallymark-4` icon. TALLYMARKS (IconName): The :code:`tallymarks` icon. TANK (IconName): The :code:`tank` icon. TARGET (IconName): The :code:`target` icon. TARGET_ARROW (IconName): The :code:`target-arrow` icon. TARGET_OFF (IconName): The :code:`target-off` icon. TEAPOT (IconName): The :code:`teapot` icon. TELESCOPE (IconName): The :code:`telescope` icon. TELESCOPE_OFF (IconName): The :code:`telescope-off` icon. TEMPERATURE (IconName): The :code:`temperature` icon. TEMPERATURE_CELSIUS (IconName): The :code:`temperature-celsius` icon. TEMPERATURE_FAHRENHEIT (IconName): The :code:`temperature-fahrenheit` icon. TEMPERATURE_MINUS (IconName): The :code:`temperature-minus` icon. TEMPERATURE_OFF (IconName): The :code:`temperature-off` icon. TEMPERATURE_PLUS (IconName): The :code:`temperature-plus` icon. TEMPLATE (IconName): The :code:`template` icon. TEMPLATE_OFF (IconName): The :code:`template-off` icon. TENT (IconName): The :code:`tent` icon. TENT_OFF (IconName): The :code:`tent-off` icon. TERMINAL (IconName): The :code:`terminal` icon. TERMINAL_2 (IconName): The :code:`terminal-2` icon. TEST_PIPE (IconName): The :code:`test-pipe` icon. TEST_PIPE_2 (IconName): The :code:`test-pipe-2` icon. TEST_PIPE_OFF (IconName): The :code:`test-pipe-off` icon. TEX (IconName): The :code:`tex` icon. TEXT_CAPTION (IconName): The :code:`text-caption` icon. TEXT_COLOR (IconName): The :code:`text-color` icon. TEXT_DECREASE (IconName): The :code:`text-decrease` icon. TEXT_DIRECTION_LTR (IconName): The :code:`text-direction-ltr` icon. TEXT_DIRECTION_RTL (IconName): The :code:`text-direction-rtl` icon. TEXT_INCREASE (IconName): The :code:`text-increase` icon. TEXT_ORIENTATION (IconName): The :code:`text-orientation` icon. TEXT_PLUS (IconName): The :code:`text-plus` icon. TEXT_RECOGNITION (IconName): The :code:`text-recognition` icon. TEXT_RESIZE (IconName): The :code:`text-resize` icon. TEXT_SIZE (IconName): The :code:`text-size` icon. TEXT_SPELLCHECK (IconName): The :code:`text-spellcheck` icon. TEXT_WRAP (IconName): The :code:`text-wrap` icon. TEXT_WRAP_DISABLED (IconName): The :code:`text-wrap-disabled` icon. TEXTURE (IconName): The :code:`texture` icon. THEATER (IconName): The :code:`theater` icon. THERMOMETER (IconName): The :code:`thermometer` icon. THUMB_DOWN (IconName): The :code:`thumb-down` icon. THUMB_DOWN_FILLED (IconName): The :code:`thumb-down-filled` icon. THUMB_DOWN_OFF (IconName): The :code:`thumb-down-off` icon. THUMB_UP (IconName): The :code:`thumb-up` icon. THUMB_UP_FILLED (IconName): The :code:`thumb-up-filled` icon. THUMB_UP_OFF (IconName): The :code:`thumb-up-off` icon. TIC_TAC (IconName): The :code:`tic-tac` icon. TICKET (IconName): The :code:`ticket` icon. TICKET_OFF (IconName): The :code:`ticket-off` icon. TIE (IconName): The :code:`tie` icon. TILDE (IconName): The :code:`tilde` icon. TILT_SHIFT (IconName): The :code:`tilt-shift` icon. TILT_SHIFT_OFF (IconName): The :code:`tilt-shift-off` icon. TIME_DURATION_0 (IconName): The :code:`time-duration-0` icon. TIME_DURATION_10 (IconName): The :code:`time-duration-10` icon. TIME_DURATION_15 (IconName): The :code:`time-duration-15` icon. TIME_DURATION_30 (IconName): The :code:`time-duration-30` icon. TIME_DURATION_45 (IconName): The :code:`time-duration-45` icon. TIME_DURATION_5 (IconName): The :code:`time-duration-5` icon. TIME_DURATION_60 (IconName): The :code:`time-duration-60` icon. TIME_DURATION_90 (IconName): The :code:`time-duration-90` icon. TIME_DURATION_OFF (IconName): The :code:`time-duration-off` icon. TIMELINE (IconName): The :code:`timeline` icon. TIMELINE_EVENT (IconName): The :code:`timeline-event` icon. TIMELINE_EVENT_EXCLAMATION (IconName): The :code:`timeline-event-exclamation` icon. TIMELINE_EVENT_MINUS (IconName): The :code:`timeline-event-minus` icon. TIMELINE_EVENT_PLUS (IconName): The :code:`timeline-event-plus` icon. TIMELINE_EVENT_TEXT (IconName): The :code:`timeline-event-text` icon. TIMELINE_EVENT_X (IconName): The :code:`timeline-event-x` icon. TIR (IconName): The :code:`tir` icon. TOGGLE_LEFT (IconName): The :code:`toggle-left` icon. TOGGLE_RIGHT (IconName): The :code:`toggle-right` icon. TOILET_PAPER (IconName): The :code:`toilet-paper` icon. TOILET_PAPER_OFF (IconName): The :code:`toilet-paper-off` icon. TOML (IconName): The :code:`toml` icon. TOOL (IconName): The :code:`tool` icon. TOOLS (IconName): The :code:`tools` icon. TOOLS_KITCHEN (IconName): The :code:`tools-kitchen` icon. TOOLS_KITCHEN_2 (IconName): The :code:`tools-kitchen-2` icon. TOOLS_KITCHEN_2_OFF (IconName): The :code:`tools-kitchen-2-off` icon. TOOLS_KITCHEN_OFF (IconName): The :code:`tools-kitchen-off` icon. TOOLS_OFF (IconName): The :code:`tools-off` icon. TOOLTIP (IconName): The :code:`tooltip` icon. TOPOLOGY_BUS (IconName): The :code:`topology-bus` icon. TOPOLOGY_COMPLEX (IconName): The :code:`topology-complex` icon. TOPOLOGY_FULL (IconName): The :code:`topology-full` icon. TOPOLOGY_FULL_HIERARCHY (IconName): The :code:`topology-full-hierarchy` icon. TOPOLOGY_RING (IconName): The :code:`topology-ring` icon. TOPOLOGY_RING_2 (IconName): The :code:`topology-ring-2` icon. TOPOLOGY_RING_3 (IconName): The :code:`topology-ring-3` icon. TOPOLOGY_STAR (IconName): The :code:`topology-star` icon. TOPOLOGY_STAR_2 (IconName): The :code:`topology-star-2` icon. TOPOLOGY_STAR_3 (IconName): The :code:`topology-star-3` icon. TOPOLOGY_STAR_RING (IconName): The :code:`topology-star-ring` icon. TOPOLOGY_STAR_RING_2 (IconName): The :code:`topology-star-ring-2` icon. TOPOLOGY_STAR_RING_3 (IconName): The :code:`topology-star-ring-3` icon. TORII (IconName): The :code:`torii` icon. TORNADO (IconName): The :code:`tornado` icon. TOURNAMENT (IconName): The :code:`tournament` icon. TOWER (IconName): The :code:`tower` icon. TOWER_OFF (IconName): The :code:`tower-off` icon. TRACK (IconName): The :code:`track` icon. TRACTOR (IconName): The :code:`tractor` icon. TRADEMARK (IconName): The :code:`trademark` icon. TRAFFIC_CONE (IconName): The :code:`traffic-cone` icon. TRAFFIC_CONE_OFF (IconName): The :code:`traffic-cone-off` icon. TRAFFIC_LIGHTS (IconName): The :code:`traffic-lights` icon. TRAFFIC_LIGHTS_OFF (IconName): The :code:`traffic-lights-off` icon. TRAIN (IconName): The :code:`train` icon. TRANSFER_IN (IconName): The :code:`transfer-in` icon. TRANSFER_OUT (IconName): The :code:`transfer-out` icon. TRANSFORM (IconName): The :code:`transform` icon. TRANSFORM_FILLED (IconName): The :code:`transform-filled` icon. TRANSITION_BOTTOM (IconName): The :code:`transition-bottom` icon. TRANSITION_LEFT (IconName): The :code:`transition-left` icon. TRANSITION_RIGHT (IconName): The :code:`transition-right` icon. TRANSITION_TOP (IconName): The :code:`transition-top` icon. TRASH (IconName): The :code:`trash` icon. TRASH_FILLED (IconName): The :code:`trash-filled` icon. TRASH_OFF (IconName): The :code:`trash-off` icon. TRASH_X (IconName): The :code:`trash-x` icon. TRASH_X_FILLED (IconName): The :code:`trash-x-filled` icon. TREADMILL (IconName): The :code:`treadmill` icon. TREE (IconName): The :code:`tree` icon. TREES (IconName): The :code:`trees` icon. TREKKING (IconName): The :code:`trekking` icon. TRENDING_DOWN (IconName): The :code:`trending-down` icon. TRENDING_DOWN_2 (IconName): The :code:`trending-down-2` icon. TRENDING_DOWN_3 (IconName): The :code:`trending-down-3` icon. TRENDING_UP (IconName): The :code:`trending-up` icon. TRENDING_UP_2 (IconName): The :code:`trending-up-2` icon. TRENDING_UP_3 (IconName): The :code:`trending-up-3` icon. TRIANGLE (IconName): The :code:`triangle` icon. TRIANGLE_FILLED (IconName): The :code:`triangle-filled` icon. TRIANGLE_INVERTED (IconName): The :code:`triangle-inverted` icon. TRIANGLE_INVERTED_FILLED (IconName): The :code:`triangle-inverted-filled` icon. TRIANGLE_OFF (IconName): The :code:`triangle-off` icon. TRIANGLE_SQUARE_CIRCLE (IconName): The :code:`triangle-square-circle` icon. TRIANGLES (IconName): The :code:`triangles` icon. TRIDENT (IconName): The :code:`trident` icon. TROLLEY (IconName): The :code:`trolley` icon. TROPHY (IconName): The :code:`trophy` icon. TROPHY_FILLED (IconName): The :code:`trophy-filled` icon. TROPHY_OFF (IconName): The :code:`trophy-off` icon. TROWEL (IconName): The :code:`trowel` icon. TRUCK (IconName): The :code:`truck` icon. TRUCK_DELIVERY (IconName): The :code:`truck-delivery` icon. TRUCK_LOADING (IconName): The :code:`truck-loading` icon. TRUCK_OFF (IconName): The :code:`truck-off` icon. TRUCK_RETURN (IconName): The :code:`truck-return` icon. TXT (IconName): The :code:`txt` icon. TYPOGRAPHY (IconName): The :code:`typography` icon. TYPOGRAPHY_OFF (IconName): The :code:`typography-off` icon. UFO (IconName): The :code:`ufo` icon. UFO_OFF (IconName): The :code:`ufo-off` icon. UMBRELLA (IconName): The :code:`umbrella` icon. UMBRELLA_FILLED (IconName): The :code:`umbrella-filled` icon. UMBRELLA_OFF (IconName): The :code:`umbrella-off` icon. UNDERLINE (IconName): The :code:`underline` icon. UNLINK (IconName): The :code:`unlink` icon. UPLOAD (IconName): The :code:`upload` icon. URGENT (IconName): The :code:`urgent` icon. USB (IconName): The :code:`usb` icon. USER (IconName): The :code:`user` icon. USER_BOLT (IconName): The :code:`user-bolt` icon. USER_CANCEL (IconName): The :code:`user-cancel` icon. USER_CHECK (IconName): The :code:`user-check` icon. USER_CIRCLE (IconName): The :code:`user-circle` icon. USER_CODE (IconName): The :code:`user-code` icon. USER_COG (IconName): The :code:`user-cog` icon. USER_DOLLAR (IconName): The :code:`user-dollar` icon. USER_DOWN (IconName): The :code:`user-down` icon. USER_EDIT (IconName): The :code:`user-edit` icon. USER_EXCLAMATION (IconName): The :code:`user-exclamation` icon. USER_HEART (IconName): The :code:`user-heart` icon. USER_MINUS (IconName): The :code:`user-minus` icon. USER_OFF (IconName): The :code:`user-off` icon. USER_PAUSE (IconName): The :code:`user-pause` icon. USER_PIN (IconName): The :code:`user-pin` icon. USER_PLUS (IconName): The :code:`user-plus` icon. USER_QUESTION (IconName): The :code:`user-question` icon. USER_SEARCH (IconName): The :code:`user-search` icon. USER_SHARE (IconName): The :code:`user-share` icon. USER_SHIELD (IconName): The :code:`user-shield` icon. USER_STAR (IconName): The :code:`user-star` icon. USER_UP (IconName): The :code:`user-up` icon. USER_X (IconName): The :code:`user-x` icon. USERS (IconName): The :code:`users` icon. USERS_GROUP (IconName): The :code:`users-group` icon. USERS_MINUS (IconName): The :code:`users-minus` icon. USERS_PLUS (IconName): The :code:`users-plus` icon. UV_INDEX (IconName): The :code:`uv-index` icon. UX_CIRCLE (IconName): The :code:`ux-circle` icon. VACCINE (IconName): The :code:`vaccine` icon. VACCINE_BOTTLE (IconName): The :code:`vaccine-bottle` icon. VACCINE_BOTTLE_OFF (IconName): The :code:`vaccine-bottle-off` icon. VACCINE_OFF (IconName): The :code:`vaccine-off` icon. VACUUM_CLEANER (IconName): The :code:`vacuum-cleaner` icon. VARIABLE (IconName): The :code:`variable` icon. VARIABLE_MINUS (IconName): The :code:`variable-minus` icon. VARIABLE_OFF (IconName): The :code:`variable-off` icon. VARIABLE_PLUS (IconName): The :code:`variable-plus` icon. VECTOR (IconName): The :code:`vector` icon. VECTOR_BEZIER (IconName): The :code:`vector-bezier` icon. VECTOR_BEZIER_2 (IconName): The :code:`vector-bezier-2` icon. VECTOR_BEZIER_ARC (IconName): The :code:`vector-bezier-arc` icon. VECTOR_BEZIER_CIRCLE (IconName): The :code:`vector-bezier-circle` icon. VECTOR_OFF (IconName): The :code:`vector-off` icon. VECTOR_SPLINE (IconName): The :code:`vector-spline` icon. VECTOR_TRIANGLE (IconName): The :code:`vector-triangle` icon. VECTOR_TRIANGLE_OFF (IconName): The :code:`vector-triangle-off` icon. VENUS (IconName): The :code:`venus` icon. VERSIONS (IconName): The :code:`versions` icon. VERSIONS_FILLED (IconName): The :code:`versions-filled` icon. VERSIONS_OFF (IconName): The :code:`versions-off` icon. VIDEO (IconName): The :code:`video` icon. VIDEO_MINUS (IconName): The :code:`video-minus` icon. VIDEO_OFF (IconName): The :code:`video-off` icon. VIDEO_PLUS (IconName): The :code:`video-plus` icon. VIEW_360 (IconName): The :code:`view-360` icon. VIEW_360_OFF (IconName): The :code:`view-360-off` icon. VIEWFINDER (IconName): The :code:`viewfinder` icon. VIEWFINDER_OFF (IconName): The :code:`viewfinder-off` icon. VIEWPORT_NARROW (IconName): The :code:`viewport-narrow` icon. VIEWPORT_WIDE (IconName): The :code:`viewport-wide` icon. VINYL (IconName): The :code:`vinyl` icon. VIP (IconName): The :code:`vip` icon. VIP_OFF (IconName): The :code:`vip-off` icon. VIRUS (IconName): The :code:`virus` icon. VIRUS_OFF (IconName): The :code:`virus-off` icon. VIRUS_SEARCH (IconName): The :code:`virus-search` icon. VOCABULARY (IconName): The :code:`vocabulary` icon. VOCABULARY_OFF (IconName): The :code:`vocabulary-off` icon. VOLCANO (IconName): The :code:`volcano` icon. VOLUME (IconName): The :code:`volume` icon. VOLUME_2 (IconName): The :code:`volume-2` icon. VOLUME_3 (IconName): The :code:`volume-3` icon. VOLUME_OFF (IconName): The :code:`volume-off` icon. WALK (IconName): The :code:`walk` icon. WALL (IconName): The :code:`wall` icon. WALL_OFF (IconName): The :code:`wall-off` icon. WALLET (IconName): The :code:`wallet` icon. WALLET_OFF (IconName): The :code:`wallet-off` icon. WALLPAPER (IconName): The :code:`wallpaper` icon. WALLPAPER_OFF (IconName): The :code:`wallpaper-off` icon. WAND (IconName): The :code:`wand` icon. WAND_OFF (IconName): The :code:`wand-off` icon. WASH (IconName): The :code:`wash` icon. WASH_DRY (IconName): The :code:`wash-dry` icon. WASH_DRY_1 (IconName): The :code:`wash-dry-1` icon. WASH_DRY_2 (IconName): The :code:`wash-dry-2` icon. WASH_DRY_3 (IconName): The :code:`wash-dry-3` icon. WASH_DRY_A (IconName): The :code:`wash-dry-a` icon. WASH_DRY_DIP (IconName): The :code:`wash-dry-dip` icon. WASH_DRY_F (IconName): The :code:`wash-dry-f` icon. WASH_DRY_FLAT (IconName): The :code:`wash-dry-flat` icon. WASH_DRY_HANG (IconName): The :code:`wash-dry-hang` icon. WASH_DRY_OFF (IconName): The :code:`wash-dry-off` icon. WASH_DRY_P (IconName): The :code:`wash-dry-p` icon. WASH_DRY_SHADE (IconName): The :code:`wash-dry-shade` icon. WASH_DRY_W (IconName): The :code:`wash-dry-w` icon. WASH_DRYCLEAN (IconName): The :code:`wash-dryclean` icon. WASH_DRYCLEAN_OFF (IconName): The :code:`wash-dryclean-off` icon. WASH_ECO (IconName): The :code:`wash-eco` icon. WASH_GENTLE (IconName): The :code:`wash-gentle` icon. WASH_HAND (IconName): The :code:`wash-hand` icon. WASH_MACHINE (IconName): The :code:`wash-machine` icon. WASH_OFF (IconName): The :code:`wash-off` icon. WASH_PRESS (IconName): The :code:`wash-press` icon. WASH_TEMPERATURE_1 (IconName): The :code:`wash-temperature-1` icon. WASH_TEMPERATURE_2 (IconName): The :code:`wash-temperature-2` icon. WASH_TEMPERATURE_3 (IconName): The :code:`wash-temperature-3` icon. WASH_TEMPERATURE_4 (IconName): The :code:`wash-temperature-4` icon. WASH_TEMPERATURE_5 (IconName): The :code:`wash-temperature-5` icon. WASH_TEMPERATURE_6 (IconName): The :code:`wash-temperature-6` icon. WASH_TUMBLE_DRY (IconName): The :code:`wash-tumble-dry` icon. WASH_TUMBLE_OFF (IconName): The :code:`wash-tumble-off` icon. WATERPOLO (IconName): The :code:`waterpolo` icon. WAVE_SAW_TOOL (IconName): The :code:`wave-saw-tool` icon. WAVE_SINE (IconName): The :code:`wave-sine` icon. WAVE_SQUARE (IconName): The :code:`wave-square` icon. WEBHOOK (IconName): The :code:`webhook` icon. WEBHOOK_OFF (IconName): The :code:`webhook-off` icon. WEIGHT (IconName): The :code:`weight` icon. WHEELCHAIR (IconName): The :code:`wheelchair` icon. WHEELCHAIR_OFF (IconName): The :code:`wheelchair-off` icon. WHIRL (IconName): The :code:`whirl` icon. WIFI (IconName): The :code:`wifi` icon. WIFI_0 (IconName): The :code:`wifi-0` icon. WIFI_1 (IconName): The :code:`wifi-1` icon. WIFI_2 (IconName): The :code:`wifi-2` icon. WIFI_OFF (IconName): The :code:`wifi-off` icon. WIND (IconName): The :code:`wind` icon. WIND_OFF (IconName): The :code:`wind-off` icon. WINDMILL (IconName): The :code:`windmill` icon. WINDMILL_FILLED (IconName): The :code:`windmill-filled` icon. WINDMILL_OFF (IconName): The :code:`windmill-off` icon. WINDOW (IconName): The :code:`window` icon. WINDOW_MAXIMIZE (IconName): The :code:`window-maximize` icon. WINDOW_MINIMIZE (IconName): The :code:`window-minimize` icon. WINDOW_OFF (IconName): The :code:`window-off` icon. WINDSOCK (IconName): The :code:`windsock` icon. WIPER (IconName): The :code:`wiper` icon. WIPER_WASH (IconName): The :code:`wiper-wash` icon. WOMAN (IconName): The :code:`woman` icon. WOOD (IconName): The :code:`wood` icon. WORLD (IconName): The :code:`world` icon. WORLD_BOLT (IconName): The :code:`world-bolt` icon. WORLD_CANCEL (IconName): The :code:`world-cancel` icon. WORLD_CHECK (IconName): The :code:`world-check` icon. WORLD_CODE (IconName): The :code:`world-code` icon. WORLD_COG (IconName): The :code:`world-cog` icon. WORLD_DOLLAR (IconName): The :code:`world-dollar` icon. WORLD_DOWN (IconName): The :code:`world-down` icon. WORLD_DOWNLOAD (IconName): The :code:`world-download` icon. WORLD_EXCLAMATION (IconName): The :code:`world-exclamation` icon. WORLD_HEART (IconName): The :code:`world-heart` icon. WORLD_LATITUDE (IconName): The :code:`world-latitude` icon. WORLD_LONGITUDE (IconName): The :code:`world-longitude` icon. WORLD_MINUS (IconName): The :code:`world-minus` icon. WORLD_OFF (IconName): The :code:`world-off` icon. WORLD_PAUSE (IconName): The :code:`world-pause` icon. WORLD_PIN (IconName): The :code:`world-pin` icon. WORLD_PLUS (IconName): The :code:`world-plus` icon. WORLD_QUESTION (IconName): The :code:`world-question` icon. WORLD_SEARCH (IconName): The :code:`world-search` icon. WORLD_SHARE (IconName): The :code:`world-share` icon. WORLD_STAR (IconName): The :code:`world-star` icon. WORLD_UP (IconName): The :code:`world-up` icon. WORLD_UPLOAD (IconName): The :code:`world-upload` icon. WORLD_WWW (IconName): The :code:`world-www` icon. WORLD_X (IconName): The :code:`world-x` icon. WRECKING_BALL (IconName): The :code:`wrecking-ball` icon. WRITING (IconName): The :code:`writing` icon. WRITING_OFF (IconName): The :code:`writing-off` icon. WRITING_SIGN (IconName): The :code:`writing-sign` icon. WRITING_SIGN_OFF (IconName): The :code:`writing-sign-off` icon. X (IconName): The :code:`x` icon. XBOX_A (IconName): The :code:`xbox-a` icon. XBOX_B (IconName): The :code:`xbox-b` icon. XBOX_X (IconName): The :code:`xbox-x` icon. XBOX_Y (IconName): The :code:`xbox-y` icon. XD (IconName): The :code:`xd` icon. YIN_YANG (IconName): The :code:`yin-yang` icon. YIN_YANG_FILLED (IconName): The :code:`yin-yang-filled` icon. YOGA (IconName): The :code:`yoga` icon. ZEPPELIN (IconName): The :code:`zeppelin` icon. ZEPPELIN_OFF (IconName): The :code:`zeppelin-off` icon. ZIP (IconName): The :code:`zip` icon. ZODIAC_AQUARIUS (IconName): The :code:`zodiac-aquarius` icon. ZODIAC_ARIES (IconName): The :code:`zodiac-aries` icon. ZODIAC_CANCER (IconName): The :code:`zodiac-cancer` icon. ZODIAC_CAPRICORN (IconName): The :code:`zodiac-capricorn` icon. ZODIAC_GEMINI (IconName): The :code:`zodiac-gemini` icon. ZODIAC_LEO (IconName): The :code:`zodiac-leo` icon. ZODIAC_LIBRA (IconName): The :code:`zodiac-libra` icon. ZODIAC_PISCES (IconName): The :code:`zodiac-pisces` icon. ZODIAC_SAGITTARIUS (IconName): The :code:`zodiac-sagittarius` icon. ZODIAC_SCORPIO (IconName): The :code:`zodiac-scorpio` icon. ZODIAC_TAURUS (IconName): The :code:`zodiac-taurus` icon. ZODIAC_VIRGO (IconName): The :code:`zodiac-virgo` icon. ZOOM_CANCEL (IconName): The :code:`zoom-cancel` icon. ZOOM_CHECK (IconName): The :code:`zoom-check` icon. ZOOM_CHECK_FILLED (IconName): The :code:`zoom-check-filled` icon. ZOOM_CODE (IconName): The :code:`zoom-code` icon. ZOOM_EXCLAMATION (IconName): The :code:`zoom-exclamation` icon. ZOOM_FILLED (IconName): The :code:`zoom-filled` icon. ZOOM_IN (IconName): The :code:`zoom-in` icon. ZOOM_IN_AREA (IconName): The :code:`zoom-in-area` icon. ZOOM_IN_AREA_FILLED (IconName): The :code:`zoom-in-area-filled` icon. ZOOM_IN_FILLED (IconName): The :code:`zoom-in-filled` icon. ZOOM_MONEY (IconName): The :code:`zoom-money` icon. ZOOM_OUT (IconName): The :code:`zoom-out` icon. ZOOM_OUT_AREA (IconName): The :code:`zoom-out-area` icon. ZOOM_OUT_FILLED (IconName): The :code:`zoom-out-filled` icon. ZOOM_PAN (IconName): The :code:`zoom-pan` icon. ZOOM_QUESTION (IconName): The :code:`zoom-question` icon. ZOOM_REPLACE (IconName): The :code:`zoom-replace` icon. ZOOM_RESET (IconName): The :code:`zoom-reset` icon. ZZZ (IconName): The :code:`zzz` icon. ZZZ_OFF (IconName): The :code:`zzz-off` icon. """ ================================================ FILE: viser/src/viser/_icons_enum.pyi ================================================ # Automatically generated by `_icons_generate_enum.py` # See https://tabler-icons.io/ from typing import NewType IconName = NewType("IconName", str) """Name of an icon. Should be generated via `viser.Icon.*`.""" class Icon: """'Enum' class for referencing Tabler icons. We don't subclass enum.Enum for performance reasons -- importing an enum with thousands of names can result in import times in the hundreds of milliseconds. """ ICON_123: IconName = IconName("123") ICON_24_HOURS: IconName = IconName("24-hours") ICON_2FA: IconName = IconName("2fa") ICON_360: IconName = IconName("360") ICON_360_VIEW: IconName = IconName("360-view") ICON_3D_CUBE_SPHERE: IconName = IconName("3d-cube-sphere") ICON_3D_CUBE_SPHERE_OFF: IconName = IconName("3d-cube-sphere-off") ICON_3D_ROTATE: IconName = IconName("3d-rotate") A_B: IconName = IconName("a-b") A_B_2: IconName = IconName("a-b-2") A_B_OFF: IconName = IconName("a-b-off") ABACUS: IconName = IconName("abacus") ABACUS_OFF: IconName = IconName("abacus-off") ABC: IconName = IconName("abc") ACCESS_POINT: IconName = IconName("access-point") ACCESS_POINT_OFF: IconName = IconName("access-point-off") ACCESSIBLE: IconName = IconName("accessible") ACCESSIBLE_OFF: IconName = IconName("accessible-off") ACCESSIBLE_OFF_FILLED: IconName = IconName("accessible-off-filled") ACTIVITY: IconName = IconName("activity") ACTIVITY_HEARTBEAT: IconName = IconName("activity-heartbeat") AD: IconName = IconName("ad") AD_2: IconName = IconName("ad-2") AD_CIRCLE: IconName = IconName("ad-circle") AD_CIRCLE_FILLED: IconName = IconName("ad-circle-filled") AD_CIRCLE_OFF: IconName = IconName("ad-circle-off") AD_FILLED: IconName = IconName("ad-filled") AD_OFF: IconName = IconName("ad-off") ADDRESS_BOOK: IconName = IconName("address-book") ADDRESS_BOOK_OFF: IconName = IconName("address-book-off") ADJUSTMENTS: IconName = IconName("adjustments") ADJUSTMENTS_ALT: IconName = IconName("adjustments-alt") ADJUSTMENTS_BOLT: IconName = IconName("adjustments-bolt") ADJUSTMENTS_CANCEL: IconName = IconName("adjustments-cancel") ADJUSTMENTS_CHECK: IconName = IconName("adjustments-check") ADJUSTMENTS_CODE: IconName = IconName("adjustments-code") ADJUSTMENTS_COG: IconName = IconName("adjustments-cog") ADJUSTMENTS_DOLLAR: IconName = IconName("adjustments-dollar") ADJUSTMENTS_DOWN: IconName = IconName("adjustments-down") ADJUSTMENTS_EXCLAMATION: IconName = IconName("adjustments-exclamation") ADJUSTMENTS_FILLED: IconName = IconName("adjustments-filled") ADJUSTMENTS_HEART: IconName = IconName("adjustments-heart") ADJUSTMENTS_HORIZONTAL: IconName = IconName("adjustments-horizontal") ADJUSTMENTS_MINUS: IconName = IconName("adjustments-minus") ADJUSTMENTS_OFF: IconName = IconName("adjustments-off") ADJUSTMENTS_PAUSE: IconName = IconName("adjustments-pause") ADJUSTMENTS_PIN: IconName = IconName("adjustments-pin") ADJUSTMENTS_PLUS: IconName = IconName("adjustments-plus") ADJUSTMENTS_QUESTION: IconName = IconName("adjustments-question") ADJUSTMENTS_SEARCH: IconName = IconName("adjustments-search") ADJUSTMENTS_SHARE: IconName = IconName("adjustments-share") ADJUSTMENTS_STAR: IconName = IconName("adjustments-star") ADJUSTMENTS_UP: IconName = IconName("adjustments-up") ADJUSTMENTS_X: IconName = IconName("adjustments-x") AERIAL_LIFT: IconName = IconName("aerial-lift") AFFILIATE: IconName = IconName("affiliate") AFFILIATE_FILLED: IconName = IconName("affiliate-filled") AIR_BALLOON: IconName = IconName("air-balloon") AIR_CONDITIONING: IconName = IconName("air-conditioning") AIR_CONDITIONING_DISABLED: IconName = IconName("air-conditioning-disabled") ALARM: IconName = IconName("alarm") ALARM_FILLED: IconName = IconName("alarm-filled") ALARM_MINUS: IconName = IconName("alarm-minus") ALARM_MINUS_FILLED: IconName = IconName("alarm-minus-filled") ALARM_OFF: IconName = IconName("alarm-off") ALARM_PLUS: IconName = IconName("alarm-plus") ALARM_PLUS_FILLED: IconName = IconName("alarm-plus-filled") ALARM_SNOOZE: IconName = IconName("alarm-snooze") ALARM_SNOOZE_FILLED: IconName = IconName("alarm-snooze-filled") ALBUM: IconName = IconName("album") ALBUM_OFF: IconName = IconName("album-off") ALERT_CIRCLE: IconName = IconName("alert-circle") ALERT_CIRCLE_FILLED: IconName = IconName("alert-circle-filled") ALERT_HEXAGON: IconName = IconName("alert-hexagon") ALERT_HEXAGON_FILLED: IconName = IconName("alert-hexagon-filled") ALERT_OCTAGON: IconName = IconName("alert-octagon") ALERT_OCTAGON_FILLED: IconName = IconName("alert-octagon-filled") ALERT_SMALL: IconName = IconName("alert-small") ALERT_SQUARE: IconName = IconName("alert-square") ALERT_SQUARE_FILLED: IconName = IconName("alert-square-filled") ALERT_SQUARE_ROUNDED: IconName = IconName("alert-square-rounded") ALERT_SQUARE_ROUNDED_FILLED: IconName = IconName("alert-square-rounded-filled") ALERT_TRIANGLE: IconName = IconName("alert-triangle") ALERT_TRIANGLE_FILLED: IconName = IconName("alert-triangle-filled") ALIEN: IconName = IconName("alien") ALIEN_FILLED: IconName = IconName("alien-filled") ALIGN_BOX_BOTTOM_CENTER: IconName = IconName("align-box-bottom-center") ALIGN_BOX_BOTTOM_CENTER_FILLED: IconName = IconName( "align-box-bottom-center-filled" ) ALIGN_BOX_BOTTOM_LEFT: IconName = IconName("align-box-bottom-left") ALIGN_BOX_BOTTOM_LEFT_FILLED: IconName = IconName("align-box-bottom-left-filled") ALIGN_BOX_BOTTOM_RIGHT: IconName = IconName("align-box-bottom-right") ALIGN_BOX_BOTTOM_RIGHT_FILLED: IconName = IconName("align-box-bottom-right-filled") ALIGN_BOX_CENTER_BOTTOM: IconName = IconName("align-box-center-bottom") ALIGN_BOX_CENTER_MIDDLE: IconName = IconName("align-box-center-middle") ALIGN_BOX_CENTER_MIDDLE_FILLED: IconName = IconName( "align-box-center-middle-filled" ) ALIGN_BOX_CENTER_STRETCH: IconName = IconName("align-box-center-stretch") ALIGN_BOX_CENTER_TOP: IconName = IconName("align-box-center-top") ALIGN_BOX_LEFT_BOTTOM: IconName = IconName("align-box-left-bottom") ALIGN_BOX_LEFT_BOTTOM_FILLED: IconName = IconName("align-box-left-bottom-filled") ALIGN_BOX_LEFT_MIDDLE: IconName = IconName("align-box-left-middle") ALIGN_BOX_LEFT_MIDDLE_FILLED: IconName = IconName("align-box-left-middle-filled") ALIGN_BOX_LEFT_STRETCH: IconName = IconName("align-box-left-stretch") ALIGN_BOX_LEFT_TOP: IconName = IconName("align-box-left-top") ALIGN_BOX_LEFT_TOP_FILLED: IconName = IconName("align-box-left-top-filled") ALIGN_BOX_RIGHT_BOTTOM: IconName = IconName("align-box-right-bottom") ALIGN_BOX_RIGHT_BOTTOM_FILLED: IconName = IconName("align-box-right-bottom-filled") ALIGN_BOX_RIGHT_MIDDLE: IconName = IconName("align-box-right-middle") ALIGN_BOX_RIGHT_MIDDLE_FILLED: IconName = IconName("align-box-right-middle-filled") ALIGN_BOX_RIGHT_STRETCH: IconName = IconName("align-box-right-stretch") ALIGN_BOX_RIGHT_TOP: IconName = IconName("align-box-right-top") ALIGN_BOX_RIGHT_TOP_FILLED: IconName = IconName("align-box-right-top-filled") ALIGN_BOX_TOP_CENTER: IconName = IconName("align-box-top-center") ALIGN_BOX_TOP_CENTER_FILLED: IconName = IconName("align-box-top-center-filled") ALIGN_BOX_TOP_LEFT: IconName = IconName("align-box-top-left") ALIGN_BOX_TOP_LEFT_FILLED: IconName = IconName("align-box-top-left-filled") ALIGN_BOX_TOP_RIGHT: IconName = IconName("align-box-top-right") ALIGN_BOX_TOP_RIGHT_FILLED: IconName = IconName("align-box-top-right-filled") ALIGN_CENTER: IconName = IconName("align-center") ALIGN_JUSTIFIED: IconName = IconName("align-justified") ALIGN_LEFT: IconName = IconName("align-left") ALIGN_RIGHT: IconName = IconName("align-right") ALPHA: IconName = IconName("alpha") ALPHABET_CYRILLIC: IconName = IconName("alphabet-cyrillic") ALPHABET_GREEK: IconName = IconName("alphabet-greek") ALPHABET_LATIN: IconName = IconName("alphabet-latin") AMBULANCE: IconName = IconName("ambulance") AMPERSAND: IconName = IconName("ampersand") ANALYZE: IconName = IconName("analyze") ANALYZE_FILLED: IconName = IconName("analyze-filled") ANALYZE_OFF: IconName = IconName("analyze-off") ANCHOR: IconName = IconName("anchor") ANCHOR_OFF: IconName = IconName("anchor-off") ANGLE: IconName = IconName("angle") ANKH: IconName = IconName("ankh") ANTENNA: IconName = IconName("antenna") ANTENNA_BARS_1: IconName = IconName("antenna-bars-1") ANTENNA_BARS_2: IconName = IconName("antenna-bars-2") ANTENNA_BARS_3: IconName = IconName("antenna-bars-3") ANTENNA_BARS_4: IconName = IconName("antenna-bars-4") ANTENNA_BARS_5: IconName = IconName("antenna-bars-5") ANTENNA_BARS_OFF: IconName = IconName("antenna-bars-off") ANTENNA_OFF: IconName = IconName("antenna-off") APERTURE: IconName = IconName("aperture") APERTURE_OFF: IconName = IconName("aperture-off") API: IconName = IconName("api") API_APP: IconName = IconName("api-app") API_APP_OFF: IconName = IconName("api-app-off") API_OFF: IconName = IconName("api-off") APP_WINDOW: IconName = IconName("app-window") APP_WINDOW_FILLED: IconName = IconName("app-window-filled") APPLE: IconName = IconName("apple") APPS: IconName = IconName("apps") APPS_FILLED: IconName = IconName("apps-filled") APPS_OFF: IconName = IconName("apps-off") ARCHIVE: IconName = IconName("archive") ARCHIVE_FILLED: IconName = IconName("archive-filled") ARCHIVE_OFF: IconName = IconName("archive-off") ARMCHAIR: IconName = IconName("armchair") ARMCHAIR_2: IconName = IconName("armchair-2") ARMCHAIR_2_OFF: IconName = IconName("armchair-2-off") ARMCHAIR_OFF: IconName = IconName("armchair-off") ARROW_AUTOFIT_CONTENT: IconName = IconName("arrow-autofit-content") ARROW_AUTOFIT_CONTENT_FILLED: IconName = IconName("arrow-autofit-content-filled") ARROW_AUTOFIT_DOWN: IconName = IconName("arrow-autofit-down") ARROW_AUTOFIT_HEIGHT: IconName = IconName("arrow-autofit-height") ARROW_AUTOFIT_LEFT: IconName = IconName("arrow-autofit-left") ARROW_AUTOFIT_RIGHT: IconName = IconName("arrow-autofit-right") ARROW_AUTOFIT_UP: IconName = IconName("arrow-autofit-up") ARROW_AUTOFIT_WIDTH: IconName = IconName("arrow-autofit-width") ARROW_BACK: IconName = IconName("arrow-back") ARROW_BACK_UP: IconName = IconName("arrow-back-up") ARROW_BACK_UP_DOUBLE: IconName = IconName("arrow-back-up-double") ARROW_BADGE_DOWN: IconName = IconName("arrow-badge-down") ARROW_BADGE_DOWN_FILLED: IconName = IconName("arrow-badge-down-filled") ARROW_BADGE_LEFT: IconName = IconName("arrow-badge-left") ARROW_BADGE_LEFT_FILLED: IconName = IconName("arrow-badge-left-filled") ARROW_BADGE_RIGHT: IconName = IconName("arrow-badge-right") ARROW_BADGE_RIGHT_FILLED: IconName = IconName("arrow-badge-right-filled") ARROW_BADGE_UP: IconName = IconName("arrow-badge-up") ARROW_BADGE_UP_FILLED: IconName = IconName("arrow-badge-up-filled") ARROW_BAR_BOTH: IconName = IconName("arrow-bar-both") ARROW_BAR_DOWN: IconName = IconName("arrow-bar-down") ARROW_BAR_LEFT: IconName = IconName("arrow-bar-left") ARROW_BAR_RIGHT: IconName = IconName("arrow-bar-right") ARROW_BAR_TO_DOWN: IconName = IconName("arrow-bar-to-down") ARROW_BAR_TO_LEFT: IconName = IconName("arrow-bar-to-left") ARROW_BAR_TO_RIGHT: IconName = IconName("arrow-bar-to-right") ARROW_BAR_TO_UP: IconName = IconName("arrow-bar-to-up") ARROW_BAR_UP: IconName = IconName("arrow-bar-up") ARROW_BEAR_LEFT: IconName = IconName("arrow-bear-left") ARROW_BEAR_LEFT_2: IconName = IconName("arrow-bear-left-2") ARROW_BEAR_RIGHT: IconName = IconName("arrow-bear-right") ARROW_BEAR_RIGHT_2: IconName = IconName("arrow-bear-right-2") ARROW_BIG_DOWN: IconName = IconName("arrow-big-down") ARROW_BIG_DOWN_FILLED: IconName = IconName("arrow-big-down-filled") ARROW_BIG_DOWN_LINE: IconName = IconName("arrow-big-down-line") ARROW_BIG_DOWN_LINE_FILLED: IconName = IconName("arrow-big-down-line-filled") ARROW_BIG_DOWN_LINES: IconName = IconName("arrow-big-down-lines") ARROW_BIG_DOWN_LINES_FILLED: IconName = IconName("arrow-big-down-lines-filled") ARROW_BIG_LEFT: IconName = IconName("arrow-big-left") ARROW_BIG_LEFT_FILLED: IconName = IconName("arrow-big-left-filled") ARROW_BIG_LEFT_LINE: IconName = IconName("arrow-big-left-line") ARROW_BIG_LEFT_LINE_FILLED: IconName = IconName("arrow-big-left-line-filled") ARROW_BIG_LEFT_LINES: IconName = IconName("arrow-big-left-lines") ARROW_BIG_LEFT_LINES_FILLED: IconName = IconName("arrow-big-left-lines-filled") ARROW_BIG_RIGHT: IconName = IconName("arrow-big-right") ARROW_BIG_RIGHT_FILLED: IconName = IconName("arrow-big-right-filled") ARROW_BIG_RIGHT_LINE: IconName = IconName("arrow-big-right-line") ARROW_BIG_RIGHT_LINE_FILLED: IconName = IconName("arrow-big-right-line-filled") ARROW_BIG_RIGHT_LINES: IconName = IconName("arrow-big-right-lines") ARROW_BIG_RIGHT_LINES_FILLED: IconName = IconName("arrow-big-right-lines-filled") ARROW_BIG_UP: IconName = IconName("arrow-big-up") ARROW_BIG_UP_FILLED: IconName = IconName("arrow-big-up-filled") ARROW_BIG_UP_LINE: IconName = IconName("arrow-big-up-line") ARROW_BIG_UP_LINE_FILLED: IconName = IconName("arrow-big-up-line-filled") ARROW_BIG_UP_LINES: IconName = IconName("arrow-big-up-lines") ARROW_BIG_UP_LINES_FILLED: IconName = IconName("arrow-big-up-lines-filled") ARROW_BOUNCE: IconName = IconName("arrow-bounce") ARROW_CAPSULE: IconName = IconName("arrow-capsule") ARROW_CURVE_LEFT: IconName = IconName("arrow-curve-left") ARROW_CURVE_RIGHT: IconName = IconName("arrow-curve-right") ARROW_DOWN: IconName = IconName("arrow-down") ARROW_DOWN_BAR: IconName = IconName("arrow-down-bar") ARROW_DOWN_CIRCLE: IconName = IconName("arrow-down-circle") ARROW_DOWN_LEFT: IconName = IconName("arrow-down-left") ARROW_DOWN_LEFT_CIRCLE: IconName = IconName("arrow-down-left-circle") ARROW_DOWN_RHOMBUS: IconName = IconName("arrow-down-rhombus") ARROW_DOWN_RIGHT: IconName = IconName("arrow-down-right") ARROW_DOWN_RIGHT_CIRCLE: IconName = IconName("arrow-down-right-circle") ARROW_DOWN_SQUARE: IconName = IconName("arrow-down-square") ARROW_DOWN_TAIL: IconName = IconName("arrow-down-tail") ARROW_ELBOW_LEFT: IconName = IconName("arrow-elbow-left") ARROW_ELBOW_RIGHT: IconName = IconName("arrow-elbow-right") ARROW_FORK: IconName = IconName("arrow-fork") ARROW_FORWARD: IconName = IconName("arrow-forward") ARROW_FORWARD_UP: IconName = IconName("arrow-forward-up") ARROW_FORWARD_UP_DOUBLE: IconName = IconName("arrow-forward-up-double") ARROW_GUIDE: IconName = IconName("arrow-guide") ARROW_ITERATION: IconName = IconName("arrow-iteration") ARROW_LEFT: IconName = IconName("arrow-left") ARROW_LEFT_BAR: IconName = IconName("arrow-left-bar") ARROW_LEFT_CIRCLE: IconName = IconName("arrow-left-circle") ARROW_LEFT_RHOMBUS: IconName = IconName("arrow-left-rhombus") ARROW_LEFT_RIGHT: IconName = IconName("arrow-left-right") ARROW_LEFT_SQUARE: IconName = IconName("arrow-left-square") ARROW_LEFT_TAIL: IconName = IconName("arrow-left-tail") ARROW_LOOP_LEFT: IconName = IconName("arrow-loop-left") ARROW_LOOP_LEFT_2: IconName = IconName("arrow-loop-left-2") ARROW_LOOP_RIGHT: IconName = IconName("arrow-loop-right") ARROW_LOOP_RIGHT_2: IconName = IconName("arrow-loop-right-2") ARROW_MERGE: IconName = IconName("arrow-merge") ARROW_MERGE_BOTH: IconName = IconName("arrow-merge-both") ARROW_MERGE_LEFT: IconName = IconName("arrow-merge-left") ARROW_MERGE_RIGHT: IconName = IconName("arrow-merge-right") ARROW_MOVE_DOWN: IconName = IconName("arrow-move-down") ARROW_MOVE_LEFT: IconName = IconName("arrow-move-left") ARROW_MOVE_RIGHT: IconName = IconName("arrow-move-right") ARROW_MOVE_UP: IconName = IconName("arrow-move-up") ARROW_NARROW_DOWN: IconName = IconName("arrow-narrow-down") ARROW_NARROW_LEFT: IconName = IconName("arrow-narrow-left") ARROW_NARROW_RIGHT: IconName = IconName("arrow-narrow-right") ARROW_NARROW_UP: IconName = IconName("arrow-narrow-up") ARROW_RAMP_LEFT: IconName = IconName("arrow-ramp-left") ARROW_RAMP_LEFT_2: IconName = IconName("arrow-ramp-left-2") ARROW_RAMP_LEFT_3: IconName = IconName("arrow-ramp-left-3") ARROW_RAMP_RIGHT: IconName = IconName("arrow-ramp-right") ARROW_RAMP_RIGHT_2: IconName = IconName("arrow-ramp-right-2") ARROW_RAMP_RIGHT_3: IconName = IconName("arrow-ramp-right-3") ARROW_RIGHT: IconName = IconName("arrow-right") ARROW_RIGHT_BAR: IconName = IconName("arrow-right-bar") ARROW_RIGHT_CIRCLE: IconName = IconName("arrow-right-circle") ARROW_RIGHT_RHOMBUS: IconName = IconName("arrow-right-rhombus") ARROW_RIGHT_SQUARE: IconName = IconName("arrow-right-square") ARROW_RIGHT_TAIL: IconName = IconName("arrow-right-tail") ARROW_ROTARY_FIRST_LEFT: IconName = IconName("arrow-rotary-first-left") ARROW_ROTARY_FIRST_RIGHT: IconName = IconName("arrow-rotary-first-right") ARROW_ROTARY_LAST_LEFT: IconName = IconName("arrow-rotary-last-left") ARROW_ROTARY_LAST_RIGHT: IconName = IconName("arrow-rotary-last-right") ARROW_ROTARY_LEFT: IconName = IconName("arrow-rotary-left") ARROW_ROTARY_RIGHT: IconName = IconName("arrow-rotary-right") ARROW_ROTARY_STRAIGHT: IconName = IconName("arrow-rotary-straight") ARROW_ROUNDABOUT_LEFT: IconName = IconName("arrow-roundabout-left") ARROW_ROUNDABOUT_RIGHT: IconName = IconName("arrow-roundabout-right") ARROW_SHARP_TURN_LEFT: IconName = IconName("arrow-sharp-turn-left") ARROW_SHARP_TURN_RIGHT: IconName = IconName("arrow-sharp-turn-right") ARROW_UP: IconName = IconName("arrow-up") ARROW_UP_BAR: IconName = IconName("arrow-up-bar") ARROW_UP_CIRCLE: IconName = IconName("arrow-up-circle") ARROW_UP_LEFT: IconName = IconName("arrow-up-left") ARROW_UP_LEFT_CIRCLE: IconName = IconName("arrow-up-left-circle") ARROW_UP_RHOMBUS: IconName = IconName("arrow-up-rhombus") ARROW_UP_RIGHT: IconName = IconName("arrow-up-right") ARROW_UP_RIGHT_CIRCLE: IconName = IconName("arrow-up-right-circle") ARROW_UP_SQUARE: IconName = IconName("arrow-up-square") ARROW_UP_TAIL: IconName = IconName("arrow-up-tail") ARROW_WAVE_LEFT_DOWN: IconName = IconName("arrow-wave-left-down") ARROW_WAVE_LEFT_UP: IconName = IconName("arrow-wave-left-up") ARROW_WAVE_RIGHT_DOWN: IconName = IconName("arrow-wave-right-down") ARROW_WAVE_RIGHT_UP: IconName = IconName("arrow-wave-right-up") ARROW_ZIG_ZAG: IconName = IconName("arrow-zig-zag") ARROWS_CROSS: IconName = IconName("arrows-cross") ARROWS_DIAGONAL: IconName = IconName("arrows-diagonal") ARROWS_DIAGONAL_2: IconName = IconName("arrows-diagonal-2") ARROWS_DIAGONAL_MINIMIZE: IconName = IconName("arrows-diagonal-minimize") ARROWS_DIAGONAL_MINIMIZE_2: IconName = IconName("arrows-diagonal-minimize-2") ARROWS_DIFF: IconName = IconName("arrows-diff") ARROWS_DOUBLE_NE_SW: IconName = IconName("arrows-double-ne-sw") ARROWS_DOUBLE_NW_SE: IconName = IconName("arrows-double-nw-se") ARROWS_DOUBLE_SE_NW: IconName = IconName("arrows-double-se-nw") ARROWS_DOUBLE_SW_NE: IconName = IconName("arrows-double-sw-ne") ARROWS_DOWN: IconName = IconName("arrows-down") ARROWS_DOWN_UP: IconName = IconName("arrows-down-up") ARROWS_EXCHANGE: IconName = IconName("arrows-exchange") ARROWS_EXCHANGE_2: IconName = IconName("arrows-exchange-2") ARROWS_HORIZONTAL: IconName = IconName("arrows-horizontal") ARROWS_JOIN: IconName = IconName("arrows-join") ARROWS_JOIN_2: IconName = IconName("arrows-join-2") ARROWS_LEFT: IconName = IconName("arrows-left") ARROWS_LEFT_DOWN: IconName = IconName("arrows-left-down") ARROWS_LEFT_RIGHT: IconName = IconName("arrows-left-right") ARROWS_MAXIMIZE: IconName = IconName("arrows-maximize") ARROWS_MINIMIZE: IconName = IconName("arrows-minimize") ARROWS_MOVE: IconName = IconName("arrows-move") ARROWS_MOVE_HORIZONTAL: IconName = IconName("arrows-move-horizontal") ARROWS_MOVE_VERTICAL: IconName = IconName("arrows-move-vertical") ARROWS_RANDOM: IconName = IconName("arrows-random") ARROWS_RIGHT: IconName = IconName("arrows-right") ARROWS_RIGHT_DOWN: IconName = IconName("arrows-right-down") ARROWS_RIGHT_LEFT: IconName = IconName("arrows-right-left") ARROWS_SHUFFLE: IconName = IconName("arrows-shuffle") ARROWS_SHUFFLE_2: IconName = IconName("arrows-shuffle-2") ARROWS_SORT: IconName = IconName("arrows-sort") ARROWS_SPLIT: IconName = IconName("arrows-split") ARROWS_SPLIT_2: IconName = IconName("arrows-split-2") ARROWS_TRANSFER_DOWN: IconName = IconName("arrows-transfer-down") ARROWS_TRANSFER_UP: IconName = IconName("arrows-transfer-up") ARROWS_UP: IconName = IconName("arrows-up") ARROWS_UP_DOWN: IconName = IconName("arrows-up-down") ARROWS_UP_LEFT: IconName = IconName("arrows-up-left") ARROWS_UP_RIGHT: IconName = IconName("arrows-up-right") ARROWS_VERTICAL: IconName = IconName("arrows-vertical") ARTBOARD: IconName = IconName("artboard") ARTBOARD_FILLED: IconName = IconName("artboard-filled") ARTBOARD_OFF: IconName = IconName("artboard-off") ARTICLE: IconName = IconName("article") ARTICLE_FILLED_FILLED: IconName = IconName("article-filled-filled") ARTICLE_OFF: IconName = IconName("article-off") ASPECT_RATIO: IconName = IconName("aspect-ratio") ASPECT_RATIO_FILLED: IconName = IconName("aspect-ratio-filled") ASPECT_RATIO_OFF: IconName = IconName("aspect-ratio-off") ASSEMBLY: IconName = IconName("assembly") ASSEMBLY_OFF: IconName = IconName("assembly-off") ASSET: IconName = IconName("asset") ASTERISK: IconName = IconName("asterisk") ASTERISK_SIMPLE: IconName = IconName("asterisk-simple") AT: IconName = IconName("at") AT_OFF: IconName = IconName("at-off") ATOM: IconName = IconName("atom") ATOM_2: IconName = IconName("atom-2") ATOM_2_FILLED: IconName = IconName("atom-2-filled") ATOM_OFF: IconName = IconName("atom-off") AUGMENTED_REALITY: IconName = IconName("augmented-reality") AUGMENTED_REALITY_2: IconName = IconName("augmented-reality-2") AUGMENTED_REALITY_OFF: IconName = IconName("augmented-reality-off") AWARD: IconName = IconName("award") AWARD_FILLED: IconName = IconName("award-filled") AWARD_OFF: IconName = IconName("award-off") AXE: IconName = IconName("axe") AXIS_X: IconName = IconName("axis-x") AXIS_Y: IconName = IconName("axis-y") BABY_BOTTLE: IconName = IconName("baby-bottle") BABY_CARRIAGE: IconName = IconName("baby-carriage") BACKHOE: IconName = IconName("backhoe") BACKPACK: IconName = IconName("backpack") BACKPACK_OFF: IconName = IconName("backpack-off") BACKSLASH: IconName = IconName("backslash") BACKSPACE: IconName = IconName("backspace") BACKSPACE_FILLED: IconName = IconName("backspace-filled") BADGE: IconName = IconName("badge") BADGE_3D: IconName = IconName("badge-3d") BADGE_4K: IconName = IconName("badge-4k") BADGE_8K: IconName = IconName("badge-8k") BADGE_AD: IconName = IconName("badge-ad") BADGE_AR: IconName = IconName("badge-ar") BADGE_CC: IconName = IconName("badge-cc") BADGE_FILLED: IconName = IconName("badge-filled") BADGE_HD: IconName = IconName("badge-hd") BADGE_OFF: IconName = IconName("badge-off") BADGE_SD: IconName = IconName("badge-sd") BADGE_TM: IconName = IconName("badge-tm") BADGE_VO: IconName = IconName("badge-vo") BADGE_VR: IconName = IconName("badge-vr") BADGE_WC: IconName = IconName("badge-wc") BADGES: IconName = IconName("badges") BADGES_FILLED: IconName = IconName("badges-filled") BADGES_OFF: IconName = IconName("badges-off") BAGUETTE: IconName = IconName("baguette") BALL_AMERICAN_FOOTBALL: IconName = IconName("ball-american-football") BALL_AMERICAN_FOOTBALL_OFF: IconName = IconName("ball-american-football-off") BALL_BASEBALL: IconName = IconName("ball-baseball") BALL_BASKETBALL: IconName = IconName("ball-basketball") BALL_BOWLING: IconName = IconName("ball-bowling") BALL_FOOTBALL: IconName = IconName("ball-football") BALL_FOOTBALL_OFF: IconName = IconName("ball-football-off") BALL_TENNIS: IconName = IconName("ball-tennis") BALL_VOLLEYBALL: IconName = IconName("ball-volleyball") BALLOON: IconName = IconName("balloon") BALLOON_FILLED: IconName = IconName("balloon-filled") BALLOON_OFF: IconName = IconName("balloon-off") BALLPEN: IconName = IconName("ballpen") BALLPEN_FILLED: IconName = IconName("ballpen-filled") BALLPEN_OFF: IconName = IconName("ballpen-off") BAN: IconName = IconName("ban") BANDAGE: IconName = IconName("bandage") BANDAGE_FILLED: IconName = IconName("bandage-filled") BANDAGE_OFF: IconName = IconName("bandage-off") BARBELL: IconName = IconName("barbell") BARBELL_OFF: IconName = IconName("barbell-off") BARCODE: IconName = IconName("barcode") BARCODE_OFF: IconName = IconName("barcode-off") BARREL: IconName = IconName("barrel") BARREL_OFF: IconName = IconName("barrel-off") BARRIER_BLOCK: IconName = IconName("barrier-block") BARRIER_BLOCK_OFF: IconName = IconName("barrier-block-off") BASELINE: IconName = IconName("baseline") BASELINE_DENSITY_LARGE: IconName = IconName("baseline-density-large") BASELINE_DENSITY_MEDIUM: IconName = IconName("baseline-density-medium") BASELINE_DENSITY_SMALL: IconName = IconName("baseline-density-small") BASKET: IconName = IconName("basket") BASKET_FILLED: IconName = IconName("basket-filled") BASKET_OFF: IconName = IconName("basket-off") BAT: IconName = IconName("bat") BATH: IconName = IconName("bath") BATH_FILLED: IconName = IconName("bath-filled") BATH_OFF: IconName = IconName("bath-off") BATTERY: IconName = IconName("battery") BATTERY_1: IconName = IconName("battery-1") BATTERY_1_FILLED: IconName = IconName("battery-1-filled") BATTERY_2: IconName = IconName("battery-2") BATTERY_2_FILLED: IconName = IconName("battery-2-filled") BATTERY_3: IconName = IconName("battery-3") BATTERY_3_FILLED: IconName = IconName("battery-3-filled") BATTERY_4: IconName = IconName("battery-4") BATTERY_4_FILLED: IconName = IconName("battery-4-filled") BATTERY_AUTOMOTIVE: IconName = IconName("battery-automotive") BATTERY_CHARGING: IconName = IconName("battery-charging") BATTERY_CHARGING_2: IconName = IconName("battery-charging-2") BATTERY_ECO: IconName = IconName("battery-eco") BATTERY_FILLED: IconName = IconName("battery-filled") BATTERY_OFF: IconName = IconName("battery-off") BEACH: IconName = IconName("beach") BEACH_OFF: IconName = IconName("beach-off") BED: IconName = IconName("bed") BED_FILLED: IconName = IconName("bed-filled") BED_OFF: IconName = IconName("bed-off") BEER: IconName = IconName("beer") BEER_FILLED: IconName = IconName("beer-filled") BEER_OFF: IconName = IconName("beer-off") BELL: IconName = IconName("bell") BELL_BOLT: IconName = IconName("bell-bolt") BELL_CANCEL: IconName = IconName("bell-cancel") BELL_CHECK: IconName = IconName("bell-check") BELL_CODE: IconName = IconName("bell-code") BELL_COG: IconName = IconName("bell-cog") BELL_DOLLAR: IconName = IconName("bell-dollar") BELL_DOWN: IconName = IconName("bell-down") BELL_EXCLAMATION: IconName = IconName("bell-exclamation") BELL_FILLED: IconName = IconName("bell-filled") BELL_HEART: IconName = IconName("bell-heart") BELL_MINUS: IconName = IconName("bell-minus") BELL_MINUS_FILLED: IconName = IconName("bell-minus-filled") BELL_OFF: IconName = IconName("bell-off") BELL_PAUSE: IconName = IconName("bell-pause") BELL_PIN: IconName = IconName("bell-pin") BELL_PLUS: IconName = IconName("bell-plus") BELL_PLUS_FILLED: IconName = IconName("bell-plus-filled") BELL_QUESTION: IconName = IconName("bell-question") BELL_RINGING: IconName = IconName("bell-ringing") BELL_RINGING_2: IconName = IconName("bell-ringing-2") BELL_RINGING_2_FILLED: IconName = IconName("bell-ringing-2-filled") BELL_RINGING_FILLED: IconName = IconName("bell-ringing-filled") BELL_SCHOOL: IconName = IconName("bell-school") BELL_SEARCH: IconName = IconName("bell-search") BELL_SHARE: IconName = IconName("bell-share") BELL_STAR: IconName = IconName("bell-star") BELL_UP: IconName = IconName("bell-up") BELL_X: IconName = IconName("bell-x") BELL_X_FILLED: IconName = IconName("bell-x-filled") BELL_Z: IconName = IconName("bell-z") BELL_Z_FILLED: IconName = IconName("bell-z-filled") BETA: IconName = IconName("beta") BIBLE: IconName = IconName("bible") BIKE: IconName = IconName("bike") BIKE_OFF: IconName = IconName("bike-off") BINARY: IconName = IconName("binary") BINARY_OFF: IconName = IconName("binary-off") BINARY_TREE: IconName = IconName("binary-tree") BINARY_TREE_2: IconName = IconName("binary-tree-2") BIOHAZARD: IconName = IconName("biohazard") BIOHAZARD_OFF: IconName = IconName("biohazard-off") BLADE: IconName = IconName("blade") BLADE_FILLED: IconName = IconName("blade-filled") BLEACH: IconName = IconName("bleach") BLEACH_CHLORINE: IconName = IconName("bleach-chlorine") BLEACH_NO_CHLORINE: IconName = IconName("bleach-no-chlorine") BLEACH_OFF: IconName = IconName("bleach-off") BLOCKQUOTE: IconName = IconName("blockquote") BLUETOOTH: IconName = IconName("bluetooth") BLUETOOTH_CONNECTED: IconName = IconName("bluetooth-connected") BLUETOOTH_OFF: IconName = IconName("bluetooth-off") BLUETOOTH_X: IconName = IconName("bluetooth-x") BLUR: IconName = IconName("blur") BLUR_OFF: IconName = IconName("blur-off") BMP: IconName = IconName("bmp") BOLD: IconName = IconName("bold") BOLD_OFF: IconName = IconName("bold-off") BOLT: IconName = IconName("bolt") BOLT_OFF: IconName = IconName("bolt-off") BOMB: IconName = IconName("bomb") BOMB_FILLED: IconName = IconName("bomb-filled") BONE: IconName = IconName("bone") BONE_OFF: IconName = IconName("bone-off") BONG: IconName = IconName("bong") BONG_OFF: IconName = IconName("bong-off") BOOK: IconName = IconName("book") BOOK_2: IconName = IconName("book-2") BOOK_DOWNLOAD: IconName = IconName("book-download") BOOK_FILLED: IconName = IconName("book-filled") BOOK_OFF: IconName = IconName("book-off") BOOK_UPLOAD: IconName = IconName("book-upload") BOOKMARK: IconName = IconName("bookmark") BOOKMARK_EDIT: IconName = IconName("bookmark-edit") BOOKMARK_FILLED: IconName = IconName("bookmark-filled") BOOKMARK_MINUS: IconName = IconName("bookmark-minus") BOOKMARK_OFF: IconName = IconName("bookmark-off") BOOKMARK_PLUS: IconName = IconName("bookmark-plus") BOOKMARK_QUESTION: IconName = IconName("bookmark-question") BOOKMARKS: IconName = IconName("bookmarks") BOOKMARKS_OFF: IconName = IconName("bookmarks-off") BOOKS: IconName = IconName("books") BOOKS_OFF: IconName = IconName("books-off") BORDER_ALL: IconName = IconName("border-all") BORDER_BOTTOM: IconName = IconName("border-bottom") BORDER_CORNERS: IconName = IconName("border-corners") BORDER_HORIZONTAL: IconName = IconName("border-horizontal") BORDER_INNER: IconName = IconName("border-inner") BORDER_LEFT: IconName = IconName("border-left") BORDER_NONE: IconName = IconName("border-none") BORDER_OUTER: IconName = IconName("border-outer") BORDER_RADIUS: IconName = IconName("border-radius") BORDER_RIGHT: IconName = IconName("border-right") BORDER_SIDES: IconName = IconName("border-sides") BORDER_STYLE: IconName = IconName("border-style") BORDER_STYLE_2: IconName = IconName("border-style-2") BORDER_TOP: IconName = IconName("border-top") BORDER_VERTICAL: IconName = IconName("border-vertical") BOTTLE: IconName = IconName("bottle") BOTTLE_FILLED: IconName = IconName("bottle-filled") BOTTLE_OFF: IconName = IconName("bottle-off") BOUNCE_LEFT: IconName = IconName("bounce-left") BOUNCE_RIGHT: IconName = IconName("bounce-right") BOW: IconName = IconName("bow") BOWL: IconName = IconName("bowl") BOX: IconName = IconName("box") BOX_ALIGN_BOTTOM: IconName = IconName("box-align-bottom") BOX_ALIGN_BOTTOM_FILLED: IconName = IconName("box-align-bottom-filled") BOX_ALIGN_BOTTOM_LEFT: IconName = IconName("box-align-bottom-left") BOX_ALIGN_BOTTOM_LEFT_FILLED: IconName = IconName("box-align-bottom-left-filled") BOX_ALIGN_BOTTOM_RIGHT: IconName = IconName("box-align-bottom-right") BOX_ALIGN_BOTTOM_RIGHT_FILLED: IconName = IconName("box-align-bottom-right-filled") BOX_ALIGN_LEFT: IconName = IconName("box-align-left") BOX_ALIGN_LEFT_FILLED: IconName = IconName("box-align-left-filled") BOX_ALIGN_RIGHT: IconName = IconName("box-align-right") BOX_ALIGN_RIGHT_FILLED: IconName = IconName("box-align-right-filled") BOX_ALIGN_TOP: IconName = IconName("box-align-top") BOX_ALIGN_TOP_FILLED: IconName = IconName("box-align-top-filled") BOX_ALIGN_TOP_LEFT: IconName = IconName("box-align-top-left") BOX_ALIGN_TOP_LEFT_FILLED: IconName = IconName("box-align-top-left-filled") BOX_ALIGN_TOP_RIGHT: IconName = IconName("box-align-top-right") BOX_ALIGN_TOP_RIGHT_FILLED: IconName = IconName("box-align-top-right-filled") BOX_MARGIN: IconName = IconName("box-margin") BOX_MODEL: IconName = IconName("box-model") BOX_MODEL_2: IconName = IconName("box-model-2") BOX_MODEL_2_OFF: IconName = IconName("box-model-2-off") BOX_MODEL_OFF: IconName = IconName("box-model-off") BOX_MULTIPLE: IconName = IconName("box-multiple") BOX_MULTIPLE_0: IconName = IconName("box-multiple-0") BOX_MULTIPLE_1: IconName = IconName("box-multiple-1") BOX_MULTIPLE_2: IconName = IconName("box-multiple-2") BOX_MULTIPLE_3: IconName = IconName("box-multiple-3") BOX_MULTIPLE_4: IconName = IconName("box-multiple-4") BOX_MULTIPLE_5: IconName = IconName("box-multiple-5") BOX_MULTIPLE_6: IconName = IconName("box-multiple-6") BOX_MULTIPLE_7: IconName = IconName("box-multiple-7") BOX_MULTIPLE_8: IconName = IconName("box-multiple-8") BOX_MULTIPLE_9: IconName = IconName("box-multiple-9") BOX_OFF: IconName = IconName("box-off") BOX_PADDING: IconName = IconName("box-padding") BOX_SEAM: IconName = IconName("box-seam") BRACES: IconName = IconName("braces") BRACES_OFF: IconName = IconName("braces-off") BRACKETS: IconName = IconName("brackets") BRACKETS_CONTAIN: IconName = IconName("brackets-contain") BRACKETS_CONTAIN_END: IconName = IconName("brackets-contain-end") BRACKETS_CONTAIN_START: IconName = IconName("brackets-contain-start") BRACKETS_OFF: IconName = IconName("brackets-off") BRAILLE: IconName = IconName("braille") BRAIN: IconName = IconName("brain") BRAND_4CHAN: IconName = IconName("brand-4chan") BRAND_ABSTRACT: IconName = IconName("brand-abstract") BRAND_ADOBE: IconName = IconName("brand-adobe") BRAND_ADONIS_JS: IconName = IconName("brand-adonis-js") BRAND_AIRBNB: IconName = IconName("brand-airbnb") BRAND_AIRTABLE: IconName = IconName("brand-airtable") BRAND_ALGOLIA: IconName = IconName("brand-algolia") BRAND_ALIPAY: IconName = IconName("brand-alipay") BRAND_ALPINE_JS: IconName = IconName("brand-alpine-js") BRAND_AMAZON: IconName = IconName("brand-amazon") BRAND_AMD: IconName = IconName("brand-amd") BRAND_AMIGO: IconName = IconName("brand-amigo") BRAND_AMONG_US: IconName = IconName("brand-among-us") BRAND_ANDROID: IconName = IconName("brand-android") BRAND_ANGULAR: IconName = IconName("brand-angular") BRAND_ANSIBLE: IconName = IconName("brand-ansible") BRAND_AO3: IconName = IconName("brand-ao3") BRAND_APPGALLERY: IconName = IconName("brand-appgallery") BRAND_APPLE: IconName = IconName("brand-apple") BRAND_APPLE_ARCADE: IconName = IconName("brand-apple-arcade") BRAND_APPLE_PODCAST: IconName = IconName("brand-apple-podcast") BRAND_APPSTORE: IconName = IconName("brand-appstore") BRAND_ASANA: IconName = IconName("brand-asana") BRAND_AWS: IconName = IconName("brand-aws") BRAND_AZURE: IconName = IconName("brand-azure") BRAND_BACKBONE: IconName = IconName("brand-backbone") BRAND_BADOO: IconName = IconName("brand-badoo") BRAND_BAIDU: IconName = IconName("brand-baidu") BRAND_BANDCAMP: IconName = IconName("brand-bandcamp") BRAND_BANDLAB: IconName = IconName("brand-bandlab") BRAND_BEATS: IconName = IconName("brand-beats") BRAND_BEHANCE: IconName = IconName("brand-behance") BRAND_BILIBILI: IconName = IconName("brand-bilibili") BRAND_BINANCE: IconName = IconName("brand-binance") BRAND_BING: IconName = IconName("brand-bing") BRAND_BITBUCKET: IconName = IconName("brand-bitbucket") BRAND_BLACKBERRY: IconName = IconName("brand-blackberry") BRAND_BLENDER: IconName = IconName("brand-blender") BRAND_BLOGGER: IconName = IconName("brand-blogger") BRAND_BOOKING: IconName = IconName("brand-booking") BRAND_BOOTSTRAP: IconName = IconName("brand-bootstrap") BRAND_BULMA: IconName = IconName("brand-bulma") BRAND_BUMBLE: IconName = IconName("brand-bumble") BRAND_BUNPO: IconName = IconName("brand-bunpo") BRAND_C_SHARP: IconName = IconName("brand-c-sharp") BRAND_CAKE: IconName = IconName("brand-cake") BRAND_CAKEPHP: IconName = IconName("brand-cakephp") BRAND_CAMPAIGNMONITOR: IconName = IconName("brand-campaignmonitor") BRAND_CARBON: IconName = IconName("brand-carbon") BRAND_CASHAPP: IconName = IconName("brand-cashapp") BRAND_CHROME: IconName = IconName("brand-chrome") BRAND_CINEMA_4D: IconName = IconName("brand-cinema-4d") BRAND_CITYMAPPER: IconName = IconName("brand-citymapper") BRAND_CLOUDFLARE: IconName = IconName("brand-cloudflare") BRAND_CODECOV: IconName = IconName("brand-codecov") BRAND_CODEPEN: IconName = IconName("brand-codepen") BRAND_CODESANDBOX: IconName = IconName("brand-codesandbox") BRAND_COHOST: IconName = IconName("brand-cohost") BRAND_COINBASE: IconName = IconName("brand-coinbase") BRAND_COMEDY_CENTRAL: IconName = IconName("brand-comedy-central") BRAND_COREOS: IconName = IconName("brand-coreos") BRAND_COUCHDB: IconName = IconName("brand-couchdb") BRAND_COUCHSURFING: IconName = IconName("brand-couchsurfing") BRAND_CPP: IconName = IconName("brand-cpp") BRAND_CRAFT: IconName = IconName("brand-craft") BRAND_CRUNCHBASE: IconName = IconName("brand-crunchbase") BRAND_CSS3: IconName = IconName("brand-css3") BRAND_CTEMPLAR: IconName = IconName("brand-ctemplar") BRAND_CUCUMBER: IconName = IconName("brand-cucumber") BRAND_CUPRA: IconName = IconName("brand-cupra") BRAND_CYPRESS: IconName = IconName("brand-cypress") BRAND_D3: IconName = IconName("brand-d3") BRAND_DAYS_COUNTER: IconName = IconName("brand-days-counter") BRAND_DCOS: IconName = IconName("brand-dcos") BRAND_DEBIAN: IconName = IconName("brand-debian") BRAND_DEEZER: IconName = IconName("brand-deezer") BRAND_DELIVEROO: IconName = IconName("brand-deliveroo") BRAND_DENO: IconName = IconName("brand-deno") BRAND_DENODO: IconName = IconName("brand-denodo") BRAND_DEVIANTART: IconName = IconName("brand-deviantart") BRAND_DIGG: IconName = IconName("brand-digg") BRAND_DINGTALK: IconName = IconName("brand-dingtalk") BRAND_DISCORD: IconName = IconName("brand-discord") BRAND_DISCORD_FILLED: IconName = IconName("brand-discord-filled") BRAND_DISNEY: IconName = IconName("brand-disney") BRAND_DISQUS: IconName = IconName("brand-disqus") BRAND_DJANGO: IconName = IconName("brand-django") BRAND_DOCKER: IconName = IconName("brand-docker") BRAND_DOCTRINE: IconName = IconName("brand-doctrine") BRAND_DOLBY_DIGITAL: IconName = IconName("brand-dolby-digital") BRAND_DOUBAN: IconName = IconName("brand-douban") BRAND_DRIBBBLE: IconName = IconName("brand-dribbble") BRAND_DRIBBBLE_FILLED: IconName = IconName("brand-dribbble-filled") BRAND_DROPS: IconName = IconName("brand-drops") BRAND_DRUPAL: IconName = IconName("brand-drupal") BRAND_EDGE: IconName = IconName("brand-edge") BRAND_ELASTIC: IconName = IconName("brand-elastic") BRAND_ELECTRONIC_ARTS: IconName = IconName("brand-electronic-arts") BRAND_EMBER: IconName = IconName("brand-ember") BRAND_ENVATO: IconName = IconName("brand-envato") BRAND_ETSY: IconName = IconName("brand-etsy") BRAND_EVERNOTE: IconName = IconName("brand-evernote") BRAND_FACEBOOK: IconName = IconName("brand-facebook") BRAND_FACEBOOK_FILLED: IconName = IconName("brand-facebook-filled") BRAND_FEEDLY: IconName = IconName("brand-feedly") BRAND_FIGMA: IconName = IconName("brand-figma") BRAND_FILEZILLA: IconName = IconName("brand-filezilla") BRAND_FINDER: IconName = IconName("brand-finder") BRAND_FIREBASE: IconName = IconName("brand-firebase") BRAND_FIREFOX: IconName = IconName("brand-firefox") BRAND_FIVERR: IconName = IconName("brand-fiverr") BRAND_FLICKR: IconName = IconName("brand-flickr") BRAND_FLIGHTRADAR24: IconName = IconName("brand-flightradar24") BRAND_FLIPBOARD: IconName = IconName("brand-flipboard") BRAND_FLUTTER: IconName = IconName("brand-flutter") BRAND_FORTNITE: IconName = IconName("brand-fortnite") BRAND_FOURSQUARE: IconName = IconName("brand-foursquare") BRAND_FRAMER: IconName = IconName("brand-framer") BRAND_FRAMER_MOTION: IconName = IconName("brand-framer-motion") BRAND_FUNIMATION: IconName = IconName("brand-funimation") BRAND_GATSBY: IconName = IconName("brand-gatsby") BRAND_GIT: IconName = IconName("brand-git") BRAND_GITHUB: IconName = IconName("brand-github") BRAND_GITHUB_COPILOT: IconName = IconName("brand-github-copilot") BRAND_GITHUB_FILLED: IconName = IconName("brand-github-filled") BRAND_GITLAB: IconName = IconName("brand-gitlab") BRAND_GMAIL: IconName = IconName("brand-gmail") BRAND_GOLANG: IconName = IconName("brand-golang") BRAND_GOOGLE: IconName = IconName("brand-google") BRAND_GOOGLE_ANALYTICS: IconName = IconName("brand-google-analytics") BRAND_GOOGLE_BIG_QUERY: IconName = IconName("brand-google-big-query") BRAND_GOOGLE_DRIVE: IconName = IconName("brand-google-drive") BRAND_GOOGLE_FIT: IconName = IconName("brand-google-fit") BRAND_GOOGLE_HOME: IconName = IconName("brand-google-home") BRAND_GOOGLE_MAPS: IconName = IconName("brand-google-maps") BRAND_GOOGLE_ONE: IconName = IconName("brand-google-one") BRAND_GOOGLE_PHOTOS: IconName = IconName("brand-google-photos") BRAND_GOOGLE_PLAY: IconName = IconName("brand-google-play") BRAND_GOOGLE_PODCASTS: IconName = IconName("brand-google-podcasts") BRAND_GRAMMARLY: IconName = IconName("brand-grammarly") BRAND_GRAPHQL: IconName = IconName("brand-graphql") BRAND_GRAVATAR: IconName = IconName("brand-gravatar") BRAND_GRINDR: IconName = IconName("brand-grindr") BRAND_GUARDIAN: IconName = IconName("brand-guardian") BRAND_GUMROAD: IconName = IconName("brand-gumroad") BRAND_HBO: IconName = IconName("brand-hbo") BRAND_HEADLESSUI: IconName = IconName("brand-headlessui") BRAND_HEXO: IconName = IconName("brand-hexo") BRAND_HIPCHAT: IconName = IconName("brand-hipchat") BRAND_HTML5: IconName = IconName("brand-html5") BRAND_INERTIA: IconName = IconName("brand-inertia") BRAND_INSTAGRAM: IconName = IconName("brand-instagram") BRAND_INTERCOM: IconName = IconName("brand-intercom") BRAND_ITCH: IconName = IconName("brand-itch") BRAND_JAVASCRIPT: IconName = IconName("brand-javascript") BRAND_JUEJIN: IconName = IconName("brand-juejin") BRAND_KBIN: IconName = IconName("brand-kbin") BRAND_KICK: IconName = IconName("brand-kick") BRAND_KICKSTARTER: IconName = IconName("brand-kickstarter") BRAND_KOTLIN: IconName = IconName("brand-kotlin") BRAND_LARAVEL: IconName = IconName("brand-laravel") BRAND_LASTFM: IconName = IconName("brand-lastfm") BRAND_LEETCODE: IconName = IconName("brand-leetcode") BRAND_LETTERBOXD: IconName = IconName("brand-letterboxd") BRAND_LINE: IconName = IconName("brand-line") BRAND_LINKEDIN: IconName = IconName("brand-linkedin") BRAND_LINKTREE: IconName = IconName("brand-linktree") BRAND_LINQPAD: IconName = IconName("brand-linqpad") BRAND_LOOM: IconName = IconName("brand-loom") BRAND_MAILGUN: IconName = IconName("brand-mailgun") BRAND_MANTINE: IconName = IconName("brand-mantine") BRAND_MASTERCARD: IconName = IconName("brand-mastercard") BRAND_MASTODON: IconName = IconName("brand-mastodon") BRAND_MATRIX: IconName = IconName("brand-matrix") BRAND_MCDONALDS: IconName = IconName("brand-mcdonalds") BRAND_MEDIUM: IconName = IconName("brand-medium") BRAND_MERCEDES: IconName = IconName("brand-mercedes") BRAND_MESSENGER: IconName = IconName("brand-messenger") BRAND_META: IconName = IconName("brand-meta") BRAND_MICROSOFT_TEAMS: IconName = IconName("brand-microsoft-teams") BRAND_MINECRAFT: IconName = IconName("brand-minecraft") BRAND_MINIPROGRAM: IconName = IconName("brand-miniprogram") BRAND_MIXPANEL: IconName = IconName("brand-mixpanel") BRAND_MONDAY: IconName = IconName("brand-monday") BRAND_MONGODB: IconName = IconName("brand-mongodb") BRAND_MY_OPPO: IconName = IconName("brand-my-oppo") BRAND_MYSQL: IconName = IconName("brand-mysql") BRAND_NATIONAL_GEOGRAPHIC: IconName = IconName("brand-national-geographic") BRAND_NEM: IconName = IconName("brand-nem") BRAND_NETBEANS: IconName = IconName("brand-netbeans") BRAND_NETEASE_MUSIC: IconName = IconName("brand-netease-music") BRAND_NETFLIX: IconName = IconName("brand-netflix") BRAND_NEXO: IconName = IconName("brand-nexo") BRAND_NEXTCLOUD: IconName = IconName("brand-nextcloud") BRAND_NEXTJS: IconName = IconName("brand-nextjs") BRAND_NODEJS: IconName = IconName("brand-nodejs") BRAND_NORD_VPN: IconName = IconName("brand-nord-vpn") BRAND_NOTION: IconName = IconName("brand-notion") BRAND_NPM: IconName = IconName("brand-npm") BRAND_NUXT: IconName = IconName("brand-nuxt") BRAND_NYTIMES: IconName = IconName("brand-nytimes") BRAND_OAUTH: IconName = IconName("brand-oauth") BRAND_OFFICE: IconName = IconName("brand-office") BRAND_OK_RU: IconName = IconName("brand-ok-ru") BRAND_ONEDRIVE: IconName = IconName("brand-onedrive") BRAND_ONLYFANS: IconName = IconName("brand-onlyfans") BRAND_OPEN_SOURCE: IconName = IconName("brand-open-source") BRAND_OPENAI: IconName = IconName("brand-openai") BRAND_OPENVPN: IconName = IconName("brand-openvpn") BRAND_OPERA: IconName = IconName("brand-opera") BRAND_PAGEKIT: IconName = IconName("brand-pagekit") BRAND_PATREON: IconName = IconName("brand-patreon") BRAND_PAYPAL: IconName = IconName("brand-paypal") BRAND_PAYPAL_FILLED: IconName = IconName("brand-paypal-filled") BRAND_PAYPAY: IconName = IconName("brand-paypay") BRAND_PEANUT: IconName = IconName("brand-peanut") BRAND_PEPSI: IconName = IconName("brand-pepsi") BRAND_PHP: IconName = IconName("brand-php") BRAND_PICSART: IconName = IconName("brand-picsart") BRAND_PINTEREST: IconName = IconName("brand-pinterest") BRAND_PLANETSCALE: IconName = IconName("brand-planetscale") BRAND_POCKET: IconName = IconName("brand-pocket") BRAND_POLYMER: IconName = IconName("brand-polymer") BRAND_POWERSHELL: IconName = IconName("brand-powershell") BRAND_PRISMA: IconName = IconName("brand-prisma") BRAND_PRODUCTHUNT: IconName = IconName("brand-producthunt") BRAND_PUSHBULLET: IconName = IconName("brand-pushbullet") BRAND_PUSHOVER: IconName = IconName("brand-pushover") BRAND_PYTHON: IconName = IconName("brand-python") BRAND_QQ: IconName = IconName("brand-qq") BRAND_RADIX_UI: IconName = IconName("brand-radix-ui") BRAND_REACT: IconName = IconName("brand-react") BRAND_REACT_NATIVE: IconName = IconName("brand-react-native") BRAND_REASON: IconName = IconName("brand-reason") BRAND_REDDIT: IconName = IconName("brand-reddit") BRAND_REDHAT: IconName = IconName("brand-redhat") BRAND_REDUX: IconName = IconName("brand-redux") BRAND_REVOLUT: IconName = IconName("brand-revolut") BRAND_RUMBLE: IconName = IconName("brand-rumble") BRAND_RUST: IconName = IconName("brand-rust") BRAND_SAFARI: IconName = IconName("brand-safari") BRAND_SAMSUNGPASS: IconName = IconName("brand-samsungpass") BRAND_SASS: IconName = IconName("brand-sass") BRAND_SENTRY: IconName = IconName("brand-sentry") BRAND_SHARIK: IconName = IconName("brand-sharik") BRAND_SHAZAM: IconName = IconName("brand-shazam") BRAND_SHOPEE: IconName = IconName("brand-shopee") BRAND_SKETCH: IconName = IconName("brand-sketch") BRAND_SKYPE: IconName = IconName("brand-skype") BRAND_SLACK: IconName = IconName("brand-slack") BRAND_SNAPCHAT: IconName = IconName("brand-snapchat") BRAND_SNAPSEED: IconName = IconName("brand-snapseed") BRAND_SNOWFLAKE: IconName = IconName("brand-snowflake") BRAND_SOCKET_IO: IconName = IconName("brand-socket-io") BRAND_SOLIDJS: IconName = IconName("brand-solidjs") BRAND_SOUNDCLOUD: IconName = IconName("brand-soundcloud") BRAND_SPACEHEY: IconName = IconName("brand-spacehey") BRAND_SPEEDTEST: IconName = IconName("brand-speedtest") BRAND_SPOTIFY: IconName = IconName("brand-spotify") BRAND_STACKOVERFLOW: IconName = IconName("brand-stackoverflow") BRAND_STACKSHARE: IconName = IconName("brand-stackshare") BRAND_STEAM: IconName = IconName("brand-steam") BRAND_STORJ: IconName = IconName("brand-storj") BRAND_STORYBOOK: IconName = IconName("brand-storybook") BRAND_STORYTEL: IconName = IconName("brand-storytel") BRAND_STRAVA: IconName = IconName("brand-strava") BRAND_STRIPE: IconName = IconName("brand-stripe") BRAND_SUBLIME_TEXT: IconName = IconName("brand-sublime-text") BRAND_SUGARIZER: IconName = IconName("brand-sugarizer") BRAND_SUPABASE: IconName = IconName("brand-supabase") BRAND_SUPERHUMAN: IconName = IconName("brand-superhuman") BRAND_SUPERNOVA: IconName = IconName("brand-supernova") BRAND_SURFSHARK: IconName = IconName("brand-surfshark") BRAND_SVELTE: IconName = IconName("brand-svelte") BRAND_SWIFT: IconName = IconName("brand-swift") BRAND_SYMFONY: IconName = IconName("brand-symfony") BRAND_TABLER: IconName = IconName("brand-tabler") BRAND_TAILWIND: IconName = IconName("brand-tailwind") BRAND_TAOBAO: IconName = IconName("brand-taobao") BRAND_TED: IconName = IconName("brand-ted") BRAND_TELEGRAM: IconName = IconName("brand-telegram") BRAND_TERRAFORM: IconName = IconName("brand-terraform") BRAND_TETHER: IconName = IconName("brand-tether") BRAND_THREEJS: IconName = IconName("brand-threejs") BRAND_TIDAL: IconName = IconName("brand-tidal") BRAND_TIKTO_FILLED: IconName = IconName("brand-tikto-filled") BRAND_TIKTOK: IconName = IconName("brand-tiktok") BRAND_TINDER: IconName = IconName("brand-tinder") BRAND_TOPBUZZ: IconName = IconName("brand-topbuzz") BRAND_TORCHAIN: IconName = IconName("brand-torchain") BRAND_TOYOTA: IconName = IconName("brand-toyota") BRAND_TRELLO: IconName = IconName("brand-trello") BRAND_TRIPADVISOR: IconName = IconName("brand-tripadvisor") BRAND_TUMBLR: IconName = IconName("brand-tumblr") BRAND_TWILIO: IconName = IconName("brand-twilio") BRAND_TWITCH: IconName = IconName("brand-twitch") BRAND_TWITTER: IconName = IconName("brand-twitter") BRAND_TWITTER_FILLED: IconName = IconName("brand-twitter-filled") BRAND_TYPESCRIPT: IconName = IconName("brand-typescript") BRAND_UBER: IconName = IconName("brand-uber") BRAND_UBUNTU: IconName = IconName("brand-ubuntu") BRAND_UNITY: IconName = IconName("brand-unity") BRAND_UNSPLASH: IconName = IconName("brand-unsplash") BRAND_UPWORK: IconName = IconName("brand-upwork") BRAND_VALORANT: IconName = IconName("brand-valorant") BRAND_VERCEL: IconName = IconName("brand-vercel") BRAND_VIMEO: IconName = IconName("brand-vimeo") BRAND_VINTED: IconName = IconName("brand-vinted") BRAND_VISA: IconName = IconName("brand-visa") BRAND_VISUAL_STUDIO: IconName = IconName("brand-visual-studio") BRAND_VITE: IconName = IconName("brand-vite") BRAND_VIVALDI: IconName = IconName("brand-vivaldi") BRAND_VK: IconName = IconName("brand-vk") BRAND_VLC: IconName = IconName("brand-vlc") BRAND_VOLKSWAGEN: IconName = IconName("brand-volkswagen") BRAND_VSCO: IconName = IconName("brand-vsco") BRAND_VSCODE: IconName = IconName("brand-vscode") BRAND_VUE: IconName = IconName("brand-vue") BRAND_WALMART: IconName = IconName("brand-walmart") BRAND_WAZE: IconName = IconName("brand-waze") BRAND_WEBFLOW: IconName = IconName("brand-webflow") BRAND_WECHAT: IconName = IconName("brand-wechat") BRAND_WEIBO: IconName = IconName("brand-weibo") BRAND_WHATSAPP: IconName = IconName("brand-whatsapp") BRAND_WIKIPEDIA: IconName = IconName("brand-wikipedia") BRAND_WINDOWS: IconName = IconName("brand-windows") BRAND_WINDY: IconName = IconName("brand-windy") BRAND_WISH: IconName = IconName("brand-wish") BRAND_WIX: IconName = IconName("brand-wix") BRAND_WORDPRESS: IconName = IconName("brand-wordpress") BRAND_XAMARIN: IconName = IconName("brand-xamarin") BRAND_XBOX: IconName = IconName("brand-xbox") BRAND_XING: IconName = IconName("brand-xing") BRAND_YAHOO: IconName = IconName("brand-yahoo") BRAND_YANDEX: IconName = IconName("brand-yandex") BRAND_YATSE: IconName = IconName("brand-yatse") BRAND_YCOMBINATOR: IconName = IconName("brand-ycombinator") BRAND_YOUTUBE: IconName = IconName("brand-youtube") BRAND_YOUTUBE_KIDS: IconName = IconName("brand-youtube-kids") BRAND_ZALANDO: IconName = IconName("brand-zalando") BRAND_ZAPIER: IconName = IconName("brand-zapier") BRAND_ZEIT: IconName = IconName("brand-zeit") BRAND_ZHIHU: IconName = IconName("brand-zhihu") BRAND_ZOOM: IconName = IconName("brand-zoom") BRAND_ZULIP: IconName = IconName("brand-zulip") BRAND_ZWIFT: IconName = IconName("brand-zwift") BREAD: IconName = IconName("bread") BREAD_OFF: IconName = IconName("bread-off") BRIEFCASE: IconName = IconName("briefcase") BRIEFCASE_OFF: IconName = IconName("briefcase-off") BRIGHTNESS: IconName = IconName("brightness") BRIGHTNESS_2: IconName = IconName("brightness-2") BRIGHTNESS_DOWN: IconName = IconName("brightness-down") BRIGHTNESS_HALF: IconName = IconName("brightness-half") BRIGHTNESS_OFF: IconName = IconName("brightness-off") BRIGHTNESS_UP: IconName = IconName("brightness-up") BROADCAST: IconName = IconName("broadcast") BROADCAST_OFF: IconName = IconName("broadcast-off") BROWSER: IconName = IconName("browser") BROWSER_CHECK: IconName = IconName("browser-check") BROWSER_OFF: IconName = IconName("browser-off") BROWSER_PLUS: IconName = IconName("browser-plus") BROWSER_X: IconName = IconName("browser-x") BRUSH: IconName = IconName("brush") BRUSH_OFF: IconName = IconName("brush-off") BUCKET: IconName = IconName("bucket") BUCKET_DROPLET: IconName = IconName("bucket-droplet") BUCKET_OFF: IconName = IconName("bucket-off") BUG: IconName = IconName("bug") BUG_OFF: IconName = IconName("bug-off") BUILDING: IconName = IconName("building") BUILDING_ARCH: IconName = IconName("building-arch") BUILDING_BANK: IconName = IconName("building-bank") BUILDING_BRIDGE: IconName = IconName("building-bridge") BUILDING_BRIDGE_2: IconName = IconName("building-bridge-2") BUILDING_BROADCAST_TOWER: IconName = IconName("building-broadcast-tower") BUILDING_CAROUSEL: IconName = IconName("building-carousel") BUILDING_CASTLE: IconName = IconName("building-castle") BUILDING_CHURCH: IconName = IconName("building-church") BUILDING_CIRCUS: IconName = IconName("building-circus") BUILDING_COMMUNITY: IconName = IconName("building-community") BUILDING_COTTAGE: IconName = IconName("building-cottage") BUILDING_ESTATE: IconName = IconName("building-estate") BUILDING_FACTORY: IconName = IconName("building-factory") BUILDING_FACTORY_2: IconName = IconName("building-factory-2") BUILDING_FORTRESS: IconName = IconName("building-fortress") BUILDING_HOSPITAL: IconName = IconName("building-hospital") BUILDING_LIGHTHOUSE: IconName = IconName("building-lighthouse") BUILDING_MONUMENT: IconName = IconName("building-monument") BUILDING_MOSQUE: IconName = IconName("building-mosque") BUILDING_PAVILION: IconName = IconName("building-pavilion") BUILDING_SKYSCRAPER: IconName = IconName("building-skyscraper") BUILDING_STADIUM: IconName = IconName("building-stadium") BUILDING_STORE: IconName = IconName("building-store") BUILDING_TUNNEL: IconName = IconName("building-tunnel") BUILDING_WAREHOUSE: IconName = IconName("building-warehouse") BUILDING_WIND_TURBINE: IconName = IconName("building-wind-turbine") BULB: IconName = IconName("bulb") BULB_FILLED: IconName = IconName("bulb-filled") BULB_OFF: IconName = IconName("bulb-off") BULLDOZER: IconName = IconName("bulldozer") BUS: IconName = IconName("bus") BUS_OFF: IconName = IconName("bus-off") BUS_STOP: IconName = IconName("bus-stop") BUSINESSPLAN: IconName = IconName("businessplan") BUTTERFLY: IconName = IconName("butterfly") CACTUS: IconName = IconName("cactus") CACTUS_OFF: IconName = IconName("cactus-off") CAKE: IconName = IconName("cake") CAKE_OFF: IconName = IconName("cake-off") CALCULATOR: IconName = IconName("calculator") CALCULATOR_OFF: IconName = IconName("calculator-off") CALENDAR: IconName = IconName("calendar") CALENDAR_BOLT: IconName = IconName("calendar-bolt") CALENDAR_CANCEL: IconName = IconName("calendar-cancel") CALENDAR_CHECK: IconName = IconName("calendar-check") CALENDAR_CODE: IconName = IconName("calendar-code") CALENDAR_COG: IconName = IconName("calendar-cog") CALENDAR_DOLLAR: IconName = IconName("calendar-dollar") CALENDAR_DOWN: IconName = IconName("calendar-down") CALENDAR_DUE: IconName = IconName("calendar-due") CALENDAR_EVENT: IconName = IconName("calendar-event") CALENDAR_EXCLAMATION: IconName = IconName("calendar-exclamation") CALENDAR_HEART: IconName = IconName("calendar-heart") CALENDAR_MINUS: IconName = IconName("calendar-minus") CALENDAR_OFF: IconName = IconName("calendar-off") CALENDAR_PAUSE: IconName = IconName("calendar-pause") CALENDAR_PIN: IconName = IconName("calendar-pin") CALENDAR_PLUS: IconName = IconName("calendar-plus") CALENDAR_QUESTION: IconName = IconName("calendar-question") CALENDAR_REPEAT: IconName = IconName("calendar-repeat") CALENDAR_SEARCH: IconName = IconName("calendar-search") CALENDAR_SHARE: IconName = IconName("calendar-share") CALENDAR_STAR: IconName = IconName("calendar-star") CALENDAR_STATS: IconName = IconName("calendar-stats") CALENDAR_TIME: IconName = IconName("calendar-time") CALENDAR_UP: IconName = IconName("calendar-up") CALENDAR_X: IconName = IconName("calendar-x") CAMERA: IconName = IconName("camera") CAMERA_BOLT: IconName = IconName("camera-bolt") CAMERA_CANCEL: IconName = IconName("camera-cancel") CAMERA_CHECK: IconName = IconName("camera-check") CAMERA_CODE: IconName = IconName("camera-code") CAMERA_COG: IconName = IconName("camera-cog") CAMERA_DOLLAR: IconName = IconName("camera-dollar") CAMERA_DOWN: IconName = IconName("camera-down") CAMERA_EXCLAMATION: IconName = IconName("camera-exclamation") CAMERA_FILLED: IconName = IconName("camera-filled") CAMERA_HEART: IconName = IconName("camera-heart") CAMERA_MINUS: IconName = IconName("camera-minus") CAMERA_OFF: IconName = IconName("camera-off") CAMERA_PAUSE: IconName = IconName("camera-pause") CAMERA_PIN: IconName = IconName("camera-pin") CAMERA_PLUS: IconName = IconName("camera-plus") CAMERA_QUESTION: IconName = IconName("camera-question") CAMERA_ROTATE: IconName = IconName("camera-rotate") CAMERA_SEARCH: IconName = IconName("camera-search") CAMERA_SELFIE: IconName = IconName("camera-selfie") CAMERA_SHARE: IconName = IconName("camera-share") CAMERA_STAR: IconName = IconName("camera-star") CAMERA_UP: IconName = IconName("camera-up") CAMERA_X: IconName = IconName("camera-x") CAMPER: IconName = IconName("camper") CAMPFIRE: IconName = IconName("campfire") CANDLE: IconName = IconName("candle") CANDY: IconName = IconName("candy") CANDY_OFF: IconName = IconName("candy-off") CANE: IconName = IconName("cane") CANNABIS: IconName = IconName("cannabis") CAPSULE: IconName = IconName("capsule") CAPSULE_HORIZONTAL: IconName = IconName("capsule-horizontal") CAPTURE: IconName = IconName("capture") CAPTURE_OFF: IconName = IconName("capture-off") CAR: IconName = IconName("car") CAR_CRANE: IconName = IconName("car-crane") CAR_CRASH: IconName = IconName("car-crash") CAR_OFF: IconName = IconName("car-off") CAR_TURBINE: IconName = IconName("car-turbine") CARAVAN: IconName = IconName("caravan") CARDBOARDS: IconName = IconName("cardboards") CARDBOARDS_OFF: IconName = IconName("cardboards-off") CARDS: IconName = IconName("cards") CARET_DOWN: IconName = IconName("caret-down") CARET_LEFT: IconName = IconName("caret-left") CARET_RIGHT: IconName = IconName("caret-right") CARET_UP: IconName = IconName("caret-up") CAROUSEL_HORIZONTAL: IconName = IconName("carousel-horizontal") CAROUSEL_HORIZONTAL_FILLED: IconName = IconName("carousel-horizontal-filled") CAROUSEL_VERTICAL: IconName = IconName("carousel-vertical") CAROUSEL_VERTICAL_FILLED: IconName = IconName("carousel-vertical-filled") CARROT: IconName = IconName("carrot") CARROT_OFF: IconName = IconName("carrot-off") CASH: IconName = IconName("cash") CASH_BANKNOTE: IconName = IconName("cash-banknote") CASH_BANKNOTE_OFF: IconName = IconName("cash-banknote-off") CASH_OFF: IconName = IconName("cash-off") CAST: IconName = IconName("cast") CAST_OFF: IconName = IconName("cast-off") CAT: IconName = IconName("cat") CATEGORY: IconName = IconName("category") CATEGORY_2: IconName = IconName("category-2") CE: IconName = IconName("ce") CE_OFF: IconName = IconName("ce-off") CELL: IconName = IconName("cell") CELL_SIGNAL_1: IconName = IconName("cell-signal-1") CELL_SIGNAL_2: IconName = IconName("cell-signal-2") CELL_SIGNAL_3: IconName = IconName("cell-signal-3") CELL_SIGNAL_4: IconName = IconName("cell-signal-4") CELL_SIGNAL_5: IconName = IconName("cell-signal-5") CELL_SIGNAL_OFF: IconName = IconName("cell-signal-off") CERTIFICATE: IconName = IconName("certificate") CERTIFICATE_2: IconName = IconName("certificate-2") CERTIFICATE_2_OFF: IconName = IconName("certificate-2-off") CERTIFICATE_OFF: IconName = IconName("certificate-off") CHAIR_DIRECTOR: IconName = IconName("chair-director") CHALKBOARD: IconName = IconName("chalkboard") CHALKBOARD_OFF: IconName = IconName("chalkboard-off") CHARGING_PILE: IconName = IconName("charging-pile") CHART_ARCS: IconName = IconName("chart-arcs") CHART_ARCS_3: IconName = IconName("chart-arcs-3") CHART_AREA: IconName = IconName("chart-area") CHART_AREA_FILLED: IconName = IconName("chart-area-filled") CHART_AREA_LINE: IconName = IconName("chart-area-line") CHART_AREA_LINE_FILLED: IconName = IconName("chart-area-line-filled") CHART_ARROWS: IconName = IconName("chart-arrows") CHART_ARROWS_VERTICAL: IconName = IconName("chart-arrows-vertical") CHART_BAR: IconName = IconName("chart-bar") CHART_BAR_OFF: IconName = IconName("chart-bar-off") CHART_BUBBLE: IconName = IconName("chart-bubble") CHART_BUBBLE_FILLED: IconName = IconName("chart-bubble-filled") CHART_CANDLE: IconName = IconName("chart-candle") CHART_CANDLE_FILLED: IconName = IconName("chart-candle-filled") CHART_CIRCLES: IconName = IconName("chart-circles") CHART_DONUT: IconName = IconName("chart-donut") CHART_DONUT_2: IconName = IconName("chart-donut-2") CHART_DONUT_3: IconName = IconName("chart-donut-3") CHART_DONUT_4: IconName = IconName("chart-donut-4") CHART_DONUT_FILLED: IconName = IconName("chart-donut-filled") CHART_DOTS: IconName = IconName("chart-dots") CHART_DOTS_2: IconName = IconName("chart-dots-2") CHART_DOTS_3: IconName = IconName("chart-dots-3") CHART_GRID_DOTS: IconName = IconName("chart-grid-dots") CHART_HISTOGRAM: IconName = IconName("chart-histogram") CHART_INFOGRAPHIC: IconName = IconName("chart-infographic") CHART_LINE: IconName = IconName("chart-line") CHART_PIE: IconName = IconName("chart-pie") CHART_PIE_2: IconName = IconName("chart-pie-2") CHART_PIE_3: IconName = IconName("chart-pie-3") CHART_PIE_4: IconName = IconName("chart-pie-4") CHART_PIE_FILLED: IconName = IconName("chart-pie-filled") CHART_PIE_OFF: IconName = IconName("chart-pie-off") CHART_PPF: IconName = IconName("chart-ppf") CHART_RADAR: IconName = IconName("chart-radar") CHART_SANKEY: IconName = IconName("chart-sankey") CHART_TREEMAP: IconName = IconName("chart-treemap") CHECK: IconName = IconName("check") CHECKBOX: IconName = IconName("checkbox") CHECKLIST: IconName = IconName("checklist") CHECKS: IconName = IconName("checks") CHECKUP_LIST: IconName = IconName("checkup-list") CHEESE: IconName = IconName("cheese") CHEF_HAT: IconName = IconName("chef-hat") CHEF_HAT_OFF: IconName = IconName("chef-hat-off") CHERRY: IconName = IconName("cherry") CHERRY_FILLED: IconName = IconName("cherry-filled") CHESS: IconName = IconName("chess") CHESS_BISHOP: IconName = IconName("chess-bishop") CHESS_BISHOP_FILLED: IconName = IconName("chess-bishop-filled") CHESS_FILLED: IconName = IconName("chess-filled") CHESS_KING: IconName = IconName("chess-king") CHESS_KING_FILLED: IconName = IconName("chess-king-filled") CHESS_KNIGHT: IconName = IconName("chess-knight") CHESS_KNIGHT_FILLED: IconName = IconName("chess-knight-filled") CHESS_QUEEN: IconName = IconName("chess-queen") CHESS_QUEEN_FILLED: IconName = IconName("chess-queen-filled") CHESS_ROOK: IconName = IconName("chess-rook") CHESS_ROOK_FILLED: IconName = IconName("chess-rook-filled") CHEVRON_COMPACT_DOWN: IconName = IconName("chevron-compact-down") CHEVRON_COMPACT_LEFT: IconName = IconName("chevron-compact-left") CHEVRON_COMPACT_RIGHT: IconName = IconName("chevron-compact-right") CHEVRON_COMPACT_UP: IconName = IconName("chevron-compact-up") CHEVRON_DOWN: IconName = IconName("chevron-down") CHEVRON_DOWN_LEFT: IconName = IconName("chevron-down-left") CHEVRON_DOWN_RIGHT: IconName = IconName("chevron-down-right") CHEVRON_LEFT: IconName = IconName("chevron-left") CHEVRON_LEFT_PIPE: IconName = IconName("chevron-left-pipe") CHEVRON_RIGHT: IconName = IconName("chevron-right") CHEVRON_RIGHT_PIPE: IconName = IconName("chevron-right-pipe") CHEVRON_UP: IconName = IconName("chevron-up") CHEVRON_UP_LEFT: IconName = IconName("chevron-up-left") CHEVRON_UP_RIGHT: IconName = IconName("chevron-up-right") CHEVRONS_DOWN: IconName = IconName("chevrons-down") CHEVRONS_DOWN_LEFT: IconName = IconName("chevrons-down-left") CHEVRONS_DOWN_RIGHT: IconName = IconName("chevrons-down-right") CHEVRONS_LEFT: IconName = IconName("chevrons-left") CHEVRONS_RIGHT: IconName = IconName("chevrons-right") CHEVRONS_UP: IconName = IconName("chevrons-up") CHEVRONS_UP_LEFT: IconName = IconName("chevrons-up-left") CHEVRONS_UP_RIGHT: IconName = IconName("chevrons-up-right") CHISEL: IconName = IconName("chisel") CHRISTMAS_TREE: IconName = IconName("christmas-tree") CHRISTMAS_TREE_OFF: IconName = IconName("christmas-tree-off") CIRCLE: IconName = IconName("circle") CIRCLE_0_FILLED: IconName = IconName("circle-0-filled") CIRCLE_1_FILLED: IconName = IconName("circle-1-filled") CIRCLE_2_FILLED: IconName = IconName("circle-2-filled") CIRCLE_3_FILLED: IconName = IconName("circle-3-filled") CIRCLE_4_FILLED: IconName = IconName("circle-4-filled") CIRCLE_5_FILLED: IconName = IconName("circle-5-filled") CIRCLE_6_FILLED: IconName = IconName("circle-6-filled") CIRCLE_7_FILLED: IconName = IconName("circle-7-filled") CIRCLE_8_FILLED: IconName = IconName("circle-8-filled") CIRCLE_9_FILLED: IconName = IconName("circle-9-filled") CIRCLE_ARROW_DOWN: IconName = IconName("circle-arrow-down") CIRCLE_ARROW_DOWN_FILLED: IconName = IconName("circle-arrow-down-filled") CIRCLE_ARROW_DOWN_LEFT: IconName = IconName("circle-arrow-down-left") CIRCLE_ARROW_DOWN_LEFT_FILLED: IconName = IconName("circle-arrow-down-left-filled") CIRCLE_ARROW_DOWN_RIGHT: IconName = IconName("circle-arrow-down-right") CIRCLE_ARROW_DOWN_RIGHT_FILLED: IconName = IconName( "circle-arrow-down-right-filled" ) CIRCLE_ARROW_LEFT: IconName = IconName("circle-arrow-left") CIRCLE_ARROW_LEFT_FILLED: IconName = IconName("circle-arrow-left-filled") CIRCLE_ARROW_RIGHT: IconName = IconName("circle-arrow-right") CIRCLE_ARROW_RIGHT_FILLED: IconName = IconName("circle-arrow-right-filled") CIRCLE_ARROW_UP: IconName = IconName("circle-arrow-up") CIRCLE_ARROW_UP_FILLED: IconName = IconName("circle-arrow-up-filled") CIRCLE_ARROW_UP_LEFT: IconName = IconName("circle-arrow-up-left") CIRCLE_ARROW_UP_LEFT_FILLED: IconName = IconName("circle-arrow-up-left-filled") CIRCLE_ARROW_UP_RIGHT: IconName = IconName("circle-arrow-up-right") CIRCLE_ARROW_UP_RIGHT_FILLED: IconName = IconName("circle-arrow-up-right-filled") CIRCLE_CARET_DOWN: IconName = IconName("circle-caret-down") CIRCLE_CARET_LEFT: IconName = IconName("circle-caret-left") CIRCLE_CARET_RIGHT: IconName = IconName("circle-caret-right") CIRCLE_CARET_UP: IconName = IconName("circle-caret-up") CIRCLE_CHECK: IconName = IconName("circle-check") CIRCLE_CHECK_FILLED: IconName = IconName("circle-check-filled") CIRCLE_CHEVRON_DOWN: IconName = IconName("circle-chevron-down") CIRCLE_CHEVRON_LEFT: IconName = IconName("circle-chevron-left") CIRCLE_CHEVRON_RIGHT: IconName = IconName("circle-chevron-right") CIRCLE_CHEVRON_UP: IconName = IconName("circle-chevron-up") CIRCLE_CHEVRONS_DOWN: IconName = IconName("circle-chevrons-down") CIRCLE_CHEVRONS_LEFT: IconName = IconName("circle-chevrons-left") CIRCLE_CHEVRONS_RIGHT: IconName = IconName("circle-chevrons-right") CIRCLE_CHEVRONS_UP: IconName = IconName("circle-chevrons-up") CIRCLE_DASHED: IconName = IconName("circle-dashed") CIRCLE_DOT: IconName = IconName("circle-dot") CIRCLE_DOT_FILLED: IconName = IconName("circle-dot-filled") CIRCLE_DOTTED: IconName = IconName("circle-dotted") CIRCLE_FILLED: IconName = IconName("circle-filled") CIRCLE_HALF: IconName = IconName("circle-half") CIRCLE_HALF_2: IconName = IconName("circle-half-2") CIRCLE_HALF_VERTICAL: IconName = IconName("circle-half-vertical") CIRCLE_KEY: IconName = IconName("circle-key") CIRCLE_KEY_FILLED: IconName = IconName("circle-key-filled") CIRCLE_LETTER_A: IconName = IconName("circle-letter-a") CIRCLE_LETTER_B: IconName = IconName("circle-letter-b") CIRCLE_LETTER_C: IconName = IconName("circle-letter-c") CIRCLE_LETTER_D: IconName = IconName("circle-letter-d") CIRCLE_LETTER_E: IconName = IconName("circle-letter-e") CIRCLE_LETTER_F: IconName = IconName("circle-letter-f") CIRCLE_LETTER_G: IconName = IconName("circle-letter-g") CIRCLE_LETTER_H: IconName = IconName("circle-letter-h") CIRCLE_LETTER_I: IconName = IconName("circle-letter-i") CIRCLE_LETTER_J: IconName = IconName("circle-letter-j") CIRCLE_LETTER_K: IconName = IconName("circle-letter-k") CIRCLE_LETTER_L: IconName = IconName("circle-letter-l") CIRCLE_LETTER_M: IconName = IconName("circle-letter-m") CIRCLE_LETTER_N: IconName = IconName("circle-letter-n") CIRCLE_LETTER_O: IconName = IconName("circle-letter-o") CIRCLE_LETTER_P: IconName = IconName("circle-letter-p") CIRCLE_LETTER_Q: IconName = IconName("circle-letter-q") CIRCLE_LETTER_R: IconName = IconName("circle-letter-r") CIRCLE_LETTER_S: IconName = IconName("circle-letter-s") CIRCLE_LETTER_T: IconName = IconName("circle-letter-t") CIRCLE_LETTER_U: IconName = IconName("circle-letter-u") CIRCLE_LETTER_V: IconName = IconName("circle-letter-v") CIRCLE_LETTER_W: IconName = IconName("circle-letter-w") CIRCLE_LETTER_X: IconName = IconName("circle-letter-x") CIRCLE_LETTER_Y: IconName = IconName("circle-letter-y") CIRCLE_LETTER_Z: IconName = IconName("circle-letter-z") CIRCLE_MINUS: IconName = IconName("circle-minus") CIRCLE_NUMBER_0: IconName = IconName("circle-number-0") CIRCLE_NUMBER_1: IconName = IconName("circle-number-1") CIRCLE_NUMBER_2: IconName = IconName("circle-number-2") CIRCLE_NUMBER_3: IconName = IconName("circle-number-3") CIRCLE_NUMBER_4: IconName = IconName("circle-number-4") CIRCLE_NUMBER_5: IconName = IconName("circle-number-5") CIRCLE_NUMBER_6: IconName = IconName("circle-number-6") CIRCLE_NUMBER_7: IconName = IconName("circle-number-7") CIRCLE_NUMBER_8: IconName = IconName("circle-number-8") CIRCLE_NUMBER_9: IconName = IconName("circle-number-9") CIRCLE_OFF: IconName = IconName("circle-off") CIRCLE_PLUS: IconName = IconName("circle-plus") CIRCLE_RECTANGLE: IconName = IconName("circle-rectangle") CIRCLE_RECTANGLE_OFF: IconName = IconName("circle-rectangle-off") CIRCLE_SQUARE: IconName = IconName("circle-square") CIRCLE_TRIANGLE: IconName = IconName("circle-triangle") CIRCLE_X: IconName = IconName("circle-x") CIRCLE_X_FILLED: IconName = IconName("circle-x-filled") CIRCLES: IconName = IconName("circles") CIRCLES_FILLED: IconName = IconName("circles-filled") CIRCLES_RELATION: IconName = IconName("circles-relation") CIRCUIT_AMMETER: IconName = IconName("circuit-ammeter") CIRCUIT_BATTERY: IconName = IconName("circuit-battery") CIRCUIT_BULB: IconName = IconName("circuit-bulb") CIRCUIT_CAPACITOR: IconName = IconName("circuit-capacitor") CIRCUIT_CAPACITOR_POLARIZED: IconName = IconName("circuit-capacitor-polarized") CIRCUIT_CELL: IconName = IconName("circuit-cell") CIRCUIT_CELL_PLUS: IconName = IconName("circuit-cell-plus") CIRCUIT_CHANGEOVER: IconName = IconName("circuit-changeover") CIRCUIT_DIODE: IconName = IconName("circuit-diode") CIRCUIT_DIODE_ZENER: IconName = IconName("circuit-diode-zener") CIRCUIT_GROUND: IconName = IconName("circuit-ground") CIRCUIT_GROUND_DIGITAL: IconName = IconName("circuit-ground-digital") CIRCUIT_INDUCTOR: IconName = IconName("circuit-inductor") CIRCUIT_MOTOR: IconName = IconName("circuit-motor") CIRCUIT_PUSHBUTTON: IconName = IconName("circuit-pushbutton") CIRCUIT_RESISTOR: IconName = IconName("circuit-resistor") CIRCUIT_SWITCH_CLOSED: IconName = IconName("circuit-switch-closed") CIRCUIT_SWITCH_OPEN: IconName = IconName("circuit-switch-open") CIRCUIT_VOLTMETER: IconName = IconName("circuit-voltmeter") CLEAR_ALL: IconName = IconName("clear-all") CLEAR_FORMATTING: IconName = IconName("clear-formatting") CLICK: IconName = IconName("click") CLIPBOARD: IconName = IconName("clipboard") CLIPBOARD_CHECK: IconName = IconName("clipboard-check") CLIPBOARD_COPY: IconName = IconName("clipboard-copy") CLIPBOARD_DATA: IconName = IconName("clipboard-data") CLIPBOARD_HEART: IconName = IconName("clipboard-heart") CLIPBOARD_LIST: IconName = IconName("clipboard-list") CLIPBOARD_OFF: IconName = IconName("clipboard-off") CLIPBOARD_PLUS: IconName = IconName("clipboard-plus") CLIPBOARD_TEXT: IconName = IconName("clipboard-text") CLIPBOARD_TYPOGRAPHY: IconName = IconName("clipboard-typography") CLIPBOARD_X: IconName = IconName("clipboard-x") CLOCK: IconName = IconName("clock") CLOCK_2: IconName = IconName("clock-2") CLOCK_BOLT: IconName = IconName("clock-bolt") CLOCK_CANCEL: IconName = IconName("clock-cancel") CLOCK_CHECK: IconName = IconName("clock-check") CLOCK_CODE: IconName = IconName("clock-code") CLOCK_COG: IconName = IconName("clock-cog") CLOCK_DOLLAR: IconName = IconName("clock-dollar") CLOCK_DOWN: IconName = IconName("clock-down") CLOCK_EDIT: IconName = IconName("clock-edit") CLOCK_EXCLAMATION: IconName = IconName("clock-exclamation") CLOCK_FILLED: IconName = IconName("clock-filled") CLOCK_HEART: IconName = IconName("clock-heart") CLOCK_HOUR_1: IconName = IconName("clock-hour-1") CLOCK_HOUR_10: IconName = IconName("clock-hour-10") CLOCK_HOUR_11: IconName = IconName("clock-hour-11") CLOCK_HOUR_12: IconName = IconName("clock-hour-12") CLOCK_HOUR_2: IconName = IconName("clock-hour-2") CLOCK_HOUR_3: IconName = IconName("clock-hour-3") CLOCK_HOUR_4: IconName = IconName("clock-hour-4") CLOCK_HOUR_5: IconName = IconName("clock-hour-5") CLOCK_HOUR_6: IconName = IconName("clock-hour-6") CLOCK_HOUR_7: IconName = IconName("clock-hour-7") CLOCK_HOUR_8: IconName = IconName("clock-hour-8") CLOCK_HOUR_9: IconName = IconName("clock-hour-9") CLOCK_MINUS: IconName = IconName("clock-minus") CLOCK_OFF: IconName = IconName("clock-off") CLOCK_PAUSE: IconName = IconName("clock-pause") CLOCK_PIN: IconName = IconName("clock-pin") CLOCK_PLAY: IconName = IconName("clock-play") CLOCK_PLUS: IconName = IconName("clock-plus") CLOCK_QUESTION: IconName = IconName("clock-question") CLOCK_RECORD: IconName = IconName("clock-record") CLOCK_SEARCH: IconName = IconName("clock-search") CLOCK_SHARE: IconName = IconName("clock-share") CLOCK_SHIELD: IconName = IconName("clock-shield") CLOCK_STAR: IconName = IconName("clock-star") CLOCK_STOP: IconName = IconName("clock-stop") CLOCK_UP: IconName = IconName("clock-up") CLOCK_X: IconName = IconName("clock-x") CLOTHES_RACK: IconName = IconName("clothes-rack") CLOTHES_RACK_OFF: IconName = IconName("clothes-rack-off") CLOUD: IconName = IconName("cloud") CLOUD_BOLT: IconName = IconName("cloud-bolt") CLOUD_CANCEL: IconName = IconName("cloud-cancel") CLOUD_CHECK: IconName = IconName("cloud-check") CLOUD_CODE: IconName = IconName("cloud-code") CLOUD_COG: IconName = IconName("cloud-cog") CLOUD_COMPUTING: IconName = IconName("cloud-computing") CLOUD_DATA_CONNECTION: IconName = IconName("cloud-data-connection") CLOUD_DOLLAR: IconName = IconName("cloud-dollar") CLOUD_DOWN: IconName = IconName("cloud-down") CLOUD_DOWNLOAD: IconName = IconName("cloud-download") CLOUD_EXCLAMATION: IconName = IconName("cloud-exclamation") CLOUD_FILLED: IconName = IconName("cloud-filled") CLOUD_FOG: IconName = IconName("cloud-fog") CLOUD_HEART: IconName = IconName("cloud-heart") CLOUD_LOCK: IconName = IconName("cloud-lock") CLOUD_LOCK_OPEN: IconName = IconName("cloud-lock-open") CLOUD_MINUS: IconName = IconName("cloud-minus") CLOUD_OFF: IconName = IconName("cloud-off") CLOUD_PAUSE: IconName = IconName("cloud-pause") CLOUD_PIN: IconName = IconName("cloud-pin") CLOUD_PLUS: IconName = IconName("cloud-plus") CLOUD_QUESTION: IconName = IconName("cloud-question") CLOUD_RAIN: IconName = IconName("cloud-rain") CLOUD_SEARCH: IconName = IconName("cloud-search") CLOUD_SHARE: IconName = IconName("cloud-share") CLOUD_SNOW: IconName = IconName("cloud-snow") CLOUD_STAR: IconName = IconName("cloud-star") CLOUD_STORM: IconName = IconName("cloud-storm") CLOUD_UP: IconName = IconName("cloud-up") CLOUD_UPLOAD: IconName = IconName("cloud-upload") CLOUD_X: IconName = IconName("cloud-x") CLOVER: IconName = IconName("clover") CLOVER_2: IconName = IconName("clover-2") CLUBS: IconName = IconName("clubs") CLUBS_FILLED: IconName = IconName("clubs-filled") CODE: IconName = IconName("code") CODE_ASTERIX: IconName = IconName("code-asterix") CODE_CIRCLE: IconName = IconName("code-circle") CODE_CIRCLE_2: IconName = IconName("code-circle-2") CODE_DOTS: IconName = IconName("code-dots") CODE_MINUS: IconName = IconName("code-minus") CODE_OFF: IconName = IconName("code-off") CODE_PLUS: IconName = IconName("code-plus") COFFEE: IconName = IconName("coffee") COFFEE_OFF: IconName = IconName("coffee-off") COFFIN: IconName = IconName("coffin") COIN: IconName = IconName("coin") COIN_BITCOIN: IconName = IconName("coin-bitcoin") COIN_EURO: IconName = IconName("coin-euro") COIN_MONERO: IconName = IconName("coin-monero") COIN_OFF: IconName = IconName("coin-off") COIN_POUND: IconName = IconName("coin-pound") COIN_RUPEE: IconName = IconName("coin-rupee") COIN_YEN: IconName = IconName("coin-yen") COIN_YUAN: IconName = IconName("coin-yuan") COINS: IconName = IconName("coins") COLOR_FILTER: IconName = IconName("color-filter") COLOR_PICKER: IconName = IconName("color-picker") COLOR_PICKER_OFF: IconName = IconName("color-picker-off") COLOR_SWATCH: IconName = IconName("color-swatch") COLOR_SWATCH_OFF: IconName = IconName("color-swatch-off") COLUMN_INSERT_LEFT: IconName = IconName("column-insert-left") COLUMN_INSERT_RIGHT: IconName = IconName("column-insert-right") COLUMN_REMOVE: IconName = IconName("column-remove") COLUMNS: IconName = IconName("columns") COLUMNS_1: IconName = IconName("columns-1") COLUMNS_2: IconName = IconName("columns-2") COLUMNS_3: IconName = IconName("columns-3") COLUMNS_OFF: IconName = IconName("columns-off") COMET: IconName = IconName("comet") COMMAND: IconName = IconName("command") COMMAND_OFF: IconName = IconName("command-off") COMPASS: IconName = IconName("compass") COMPASS_OFF: IconName = IconName("compass-off") COMPONENTS: IconName = IconName("components") COMPONENTS_OFF: IconName = IconName("components-off") CONE: IconName = IconName("cone") CONE_2: IconName = IconName("cone-2") CONE_OFF: IconName = IconName("cone-off") CONE_PLUS: IconName = IconName("cone-plus") CONFETTI: IconName = IconName("confetti") CONFETTI_OFF: IconName = IconName("confetti-off") CONFUCIUS: IconName = IconName("confucius") CONTAINER: IconName = IconName("container") CONTAINER_OFF: IconName = IconName("container-off") CONTRAST: IconName = IconName("contrast") CONTRAST_2: IconName = IconName("contrast-2") CONTRAST_2_OFF: IconName = IconName("contrast-2-off") CONTRAST_OFF: IconName = IconName("contrast-off") COOKER: IconName = IconName("cooker") COOKIE: IconName = IconName("cookie") COOKIE_MAN: IconName = IconName("cookie-man") COOKIE_OFF: IconName = IconName("cookie-off") COPY: IconName = IconName("copy") COPY_OFF: IconName = IconName("copy-off") COPYLEFT: IconName = IconName("copyleft") COPYLEFT_FILLED: IconName = IconName("copyleft-filled") COPYLEFT_OFF: IconName = IconName("copyleft-off") COPYRIGHT: IconName = IconName("copyright") COPYRIGHT_FILLED: IconName = IconName("copyright-filled") COPYRIGHT_OFF: IconName = IconName("copyright-off") CORNER_DOWN_LEFT: IconName = IconName("corner-down-left") CORNER_DOWN_LEFT_DOUBLE: IconName = IconName("corner-down-left-double") CORNER_DOWN_RIGHT: IconName = IconName("corner-down-right") CORNER_DOWN_RIGHT_DOUBLE: IconName = IconName("corner-down-right-double") CORNER_LEFT_DOWN: IconName = IconName("corner-left-down") CORNER_LEFT_DOWN_DOUBLE: IconName = IconName("corner-left-down-double") CORNER_LEFT_UP: IconName = IconName("corner-left-up") CORNER_LEFT_UP_DOUBLE: IconName = IconName("corner-left-up-double") CORNER_RIGHT_DOWN: IconName = IconName("corner-right-down") CORNER_RIGHT_DOWN_DOUBLE: IconName = IconName("corner-right-down-double") CORNER_RIGHT_UP: IconName = IconName("corner-right-up") CORNER_RIGHT_UP_DOUBLE: IconName = IconName("corner-right-up-double") CORNER_UP_LEFT: IconName = IconName("corner-up-left") CORNER_UP_LEFT_DOUBLE: IconName = IconName("corner-up-left-double") CORNER_UP_RIGHT: IconName = IconName("corner-up-right") CORNER_UP_RIGHT_DOUBLE: IconName = IconName("corner-up-right-double") CPU: IconName = IconName("cpu") CPU_2: IconName = IconName("cpu-2") CPU_OFF: IconName = IconName("cpu-off") CRANE: IconName = IconName("crane") CRANE_OFF: IconName = IconName("crane-off") CREATIVE_COMMONS: IconName = IconName("creative-commons") CREATIVE_COMMONS_BY: IconName = IconName("creative-commons-by") CREATIVE_COMMONS_NC: IconName = IconName("creative-commons-nc") CREATIVE_COMMONS_ND: IconName = IconName("creative-commons-nd") CREATIVE_COMMONS_OFF: IconName = IconName("creative-commons-off") CREATIVE_COMMONS_SA: IconName = IconName("creative-commons-sa") CREATIVE_COMMONS_ZERO: IconName = IconName("creative-commons-zero") CREDIT_CARD: IconName = IconName("credit-card") CREDIT_CARD_OFF: IconName = IconName("credit-card-off") CRICKET: IconName = IconName("cricket") CROP: IconName = IconName("crop") CROSS: IconName = IconName("cross") CROSS_FILLED: IconName = IconName("cross-filled") CROSS_OFF: IconName = IconName("cross-off") CROSSHAIR: IconName = IconName("crosshair") CROWN: IconName = IconName("crown") CROWN_OFF: IconName = IconName("crown-off") CRUTCHES: IconName = IconName("crutches") CRUTCHES_OFF: IconName = IconName("crutches-off") CRYSTAL_BALL: IconName = IconName("crystal-ball") CSV: IconName = IconName("csv") CUBE: IconName = IconName("cube") CUBE_OFF: IconName = IconName("cube-off") CUBE_PLUS: IconName = IconName("cube-plus") CUBE_SEND: IconName = IconName("cube-send") CUBE_UNFOLDED: IconName = IconName("cube-unfolded") CUP: IconName = IconName("cup") CUP_OFF: IconName = IconName("cup-off") CURLING: IconName = IconName("curling") CURLY_LOOP: IconName = IconName("curly-loop") CURRENCY: IconName = IconName("currency") CURRENCY_AFGHANI: IconName = IconName("currency-afghani") CURRENCY_BAHRAINI: IconName = IconName("currency-bahraini") CURRENCY_BAHT: IconName = IconName("currency-baht") CURRENCY_BITCOIN: IconName = IconName("currency-bitcoin") CURRENCY_CENT: IconName = IconName("currency-cent") CURRENCY_DINAR: IconName = IconName("currency-dinar") CURRENCY_DIRHAM: IconName = IconName("currency-dirham") CURRENCY_DOGECOIN: IconName = IconName("currency-dogecoin") CURRENCY_DOLLAR: IconName = IconName("currency-dollar") CURRENCY_DOLLAR_AUSTRALIAN: IconName = IconName("currency-dollar-australian") CURRENCY_DOLLAR_BRUNEI: IconName = IconName("currency-dollar-brunei") CURRENCY_DOLLAR_CANADIAN: IconName = IconName("currency-dollar-canadian") CURRENCY_DOLLAR_GUYANESE: IconName = IconName("currency-dollar-guyanese") CURRENCY_DOLLAR_OFF: IconName = IconName("currency-dollar-off") CURRENCY_DOLLAR_SINGAPORE: IconName = IconName("currency-dollar-singapore") CURRENCY_DOLLAR_ZIMBABWEAN: IconName = IconName("currency-dollar-zimbabwean") CURRENCY_DONG: IconName = IconName("currency-dong") CURRENCY_DRAM: IconName = IconName("currency-dram") CURRENCY_ETHEREUM: IconName = IconName("currency-ethereum") CURRENCY_EURO: IconName = IconName("currency-euro") CURRENCY_EURO_OFF: IconName = IconName("currency-euro-off") CURRENCY_FLORIN: IconName = IconName("currency-florin") CURRENCY_FORINT: IconName = IconName("currency-forint") CURRENCY_FRANK: IconName = IconName("currency-frank") CURRENCY_GUARANI: IconName = IconName("currency-guarani") CURRENCY_HRYVNIA: IconName = IconName("currency-hryvnia") CURRENCY_IRANIAN_RIAL: IconName = IconName("currency-iranian-rial") CURRENCY_KIP: IconName = IconName("currency-kip") CURRENCY_KRONE_CZECH: IconName = IconName("currency-krone-czech") CURRENCY_KRONE_DANISH: IconName = IconName("currency-krone-danish") CURRENCY_KRONE_SWEDISH: IconName = IconName("currency-krone-swedish") CURRENCY_LARI: IconName = IconName("currency-lari") CURRENCY_LEU: IconName = IconName("currency-leu") CURRENCY_LIRA: IconName = IconName("currency-lira") CURRENCY_LITECOIN: IconName = IconName("currency-litecoin") CURRENCY_LYD: IconName = IconName("currency-lyd") CURRENCY_MANAT: IconName = IconName("currency-manat") CURRENCY_MONERO: IconName = IconName("currency-monero") CURRENCY_NAIRA: IconName = IconName("currency-naira") CURRENCY_NANO: IconName = IconName("currency-nano") CURRENCY_OFF: IconName = IconName("currency-off") CURRENCY_PAANGA: IconName = IconName("currency-paanga") CURRENCY_PESO: IconName = IconName("currency-peso") CURRENCY_POUND: IconName = IconName("currency-pound") CURRENCY_POUND_OFF: IconName = IconName("currency-pound-off") CURRENCY_QUETZAL: IconName = IconName("currency-quetzal") CURRENCY_REAL: IconName = IconName("currency-real") CURRENCY_RENMINBI: IconName = IconName("currency-renminbi") CURRENCY_RIPPLE: IconName = IconName("currency-ripple") CURRENCY_RIYAL: IconName = IconName("currency-riyal") CURRENCY_RUBEL: IconName = IconName("currency-rubel") CURRENCY_RUFIYAA: IconName = IconName("currency-rufiyaa") CURRENCY_RUPEE: IconName = IconName("currency-rupee") CURRENCY_RUPEE_NEPALESE: IconName = IconName("currency-rupee-nepalese") CURRENCY_SHEKEL: IconName = IconName("currency-shekel") CURRENCY_SOLANA: IconName = IconName("currency-solana") CURRENCY_SOM: IconName = IconName("currency-som") CURRENCY_TAKA: IconName = IconName("currency-taka") CURRENCY_TENGE: IconName = IconName("currency-tenge") CURRENCY_TUGRIK: IconName = IconName("currency-tugrik") CURRENCY_WON: IconName = IconName("currency-won") CURRENCY_YEN: IconName = IconName("currency-yen") CURRENCY_YEN_OFF: IconName = IconName("currency-yen-off") CURRENCY_YUAN: IconName = IconName("currency-yuan") CURRENCY_ZLOTY: IconName = IconName("currency-zloty") CURRENT_LOCATION: IconName = IconName("current-location") CURRENT_LOCATION_OFF: IconName = IconName("current-location-off") CURSOR_OFF: IconName = IconName("cursor-off") CURSOR_TEXT: IconName = IconName("cursor-text") CUT: IconName = IconName("cut") CYLINDER: IconName = IconName("cylinder") CYLINDER_OFF: IconName = IconName("cylinder-off") CYLINDER_PLUS: IconName = IconName("cylinder-plus") DASHBOARD: IconName = IconName("dashboard") DASHBOARD_OFF: IconName = IconName("dashboard-off") DATABASE: IconName = IconName("database") DATABASE_COG: IconName = IconName("database-cog") DATABASE_DOLLAR: IconName = IconName("database-dollar") DATABASE_EDIT: IconName = IconName("database-edit") DATABASE_EXCLAMATION: IconName = IconName("database-exclamation") DATABASE_EXPORT: IconName = IconName("database-export") DATABASE_HEART: IconName = IconName("database-heart") DATABASE_IMPORT: IconName = IconName("database-import") DATABASE_LEAK: IconName = IconName("database-leak") DATABASE_MINUS: IconName = IconName("database-minus") DATABASE_OFF: IconName = IconName("database-off") DATABASE_PLUS: IconName = IconName("database-plus") DATABASE_SEARCH: IconName = IconName("database-search") DATABASE_SHARE: IconName = IconName("database-share") DATABASE_STAR: IconName = IconName("database-star") DATABASE_X: IconName = IconName("database-x") DECIMAL: IconName = IconName("decimal") DEER: IconName = IconName("deer") DELTA: IconName = IconName("delta") DENTAL: IconName = IconName("dental") DENTAL_BROKEN: IconName = IconName("dental-broken") DENTAL_OFF: IconName = IconName("dental-off") DESELECT: IconName = IconName("deselect") DETAILS: IconName = IconName("details") DETAILS_OFF: IconName = IconName("details-off") DEVICE_AIRPODS: IconName = IconName("device-airpods") DEVICE_AIRPODS_CASE: IconName = IconName("device-airpods-case") DEVICE_AIRTAG: IconName = IconName("device-airtag") DEVICE_ANALYTICS: IconName = IconName("device-analytics") DEVICE_AUDIO_TAPE: IconName = IconName("device-audio-tape") DEVICE_CAMERA_PHONE: IconName = IconName("device-camera-phone") DEVICE_CCTV: IconName = IconName("device-cctv") DEVICE_CCTV_OFF: IconName = IconName("device-cctv-off") DEVICE_COMPUTER_CAMERA: IconName = IconName("device-computer-camera") DEVICE_COMPUTER_CAMERA_OFF: IconName = IconName("device-computer-camera-off") DEVICE_DESKTOP: IconName = IconName("device-desktop") DEVICE_DESKTOP_ANALYTICS: IconName = IconName("device-desktop-analytics") DEVICE_DESKTOP_BOLT: IconName = IconName("device-desktop-bolt") DEVICE_DESKTOP_CANCEL: IconName = IconName("device-desktop-cancel") DEVICE_DESKTOP_CHECK: IconName = IconName("device-desktop-check") DEVICE_DESKTOP_CODE: IconName = IconName("device-desktop-code") DEVICE_DESKTOP_COG: IconName = IconName("device-desktop-cog") DEVICE_DESKTOP_DOLLAR: IconName = IconName("device-desktop-dollar") DEVICE_DESKTOP_DOWN: IconName = IconName("device-desktop-down") DEVICE_DESKTOP_EXCLAMATION: IconName = IconName("device-desktop-exclamation") DEVICE_DESKTOP_HEART: IconName = IconName("device-desktop-heart") DEVICE_DESKTOP_MINUS: IconName = IconName("device-desktop-minus") DEVICE_DESKTOP_OFF: IconName = IconName("device-desktop-off") DEVICE_DESKTOP_PAUSE: IconName = IconName("device-desktop-pause") DEVICE_DESKTOP_PIN: IconName = IconName("device-desktop-pin") DEVICE_DESKTOP_PLUS: IconName = IconName("device-desktop-plus") DEVICE_DESKTOP_QUESTION: IconName = IconName("device-desktop-question") DEVICE_DESKTOP_SEARCH: IconName = IconName("device-desktop-search") DEVICE_DESKTOP_SHARE: IconName = IconName("device-desktop-share") DEVICE_DESKTOP_STAR: IconName = IconName("device-desktop-star") DEVICE_DESKTOP_UP: IconName = IconName("device-desktop-up") DEVICE_DESKTOP_X: IconName = IconName("device-desktop-x") DEVICE_FLOPPY: IconName = IconName("device-floppy") DEVICE_GAMEPAD: IconName = IconName("device-gamepad") DEVICE_GAMEPAD_2: IconName = IconName("device-gamepad-2") DEVICE_HEART_MONITOR: IconName = IconName("device-heart-monitor") DEVICE_HEART_MONITOR_FILLED: IconName = IconName("device-heart-monitor-filled") DEVICE_IMAC: IconName = IconName("device-imac") DEVICE_IMAC_BOLT: IconName = IconName("device-imac-bolt") DEVICE_IMAC_CANCEL: IconName = IconName("device-imac-cancel") DEVICE_IMAC_CHECK: IconName = IconName("device-imac-check") DEVICE_IMAC_CODE: IconName = IconName("device-imac-code") DEVICE_IMAC_COG: IconName = IconName("device-imac-cog") DEVICE_IMAC_DOLLAR: IconName = IconName("device-imac-dollar") DEVICE_IMAC_DOWN: IconName = IconName("device-imac-down") DEVICE_IMAC_EXCLAMATION: IconName = IconName("device-imac-exclamation") DEVICE_IMAC_HEART: IconName = IconName("device-imac-heart") DEVICE_IMAC_MINUS: IconName = IconName("device-imac-minus") DEVICE_IMAC_OFF: IconName = IconName("device-imac-off") DEVICE_IMAC_PAUSE: IconName = IconName("device-imac-pause") DEVICE_IMAC_PIN: IconName = IconName("device-imac-pin") DEVICE_IMAC_PLUS: IconName = IconName("device-imac-plus") DEVICE_IMAC_QUESTION: IconName = IconName("device-imac-question") DEVICE_IMAC_SEARCH: IconName = IconName("device-imac-search") DEVICE_IMAC_SHARE: IconName = IconName("device-imac-share") DEVICE_IMAC_STAR: IconName = IconName("device-imac-star") DEVICE_IMAC_UP: IconName = IconName("device-imac-up") DEVICE_IMAC_X: IconName = IconName("device-imac-x") DEVICE_IPAD: IconName = IconName("device-ipad") DEVICE_IPAD_BOLT: IconName = IconName("device-ipad-bolt") DEVICE_IPAD_CANCEL: IconName = IconName("device-ipad-cancel") DEVICE_IPAD_CHECK: IconName = IconName("device-ipad-check") DEVICE_IPAD_CODE: IconName = IconName("device-ipad-code") DEVICE_IPAD_COG: IconName = IconName("device-ipad-cog") DEVICE_IPAD_DOLLAR: IconName = IconName("device-ipad-dollar") DEVICE_IPAD_DOWN: IconName = IconName("device-ipad-down") DEVICE_IPAD_EXCLAMATION: IconName = IconName("device-ipad-exclamation") DEVICE_IPAD_HEART: IconName = IconName("device-ipad-heart") DEVICE_IPAD_HORIZONTAL: IconName = IconName("device-ipad-horizontal") DEVICE_IPAD_HORIZONTAL_BOLT: IconName = IconName("device-ipad-horizontal-bolt") DEVICE_IPAD_HORIZONTAL_CANCEL: IconName = IconName("device-ipad-horizontal-cancel") DEVICE_IPAD_HORIZONTAL_CHECK: IconName = IconName("device-ipad-horizontal-check") DEVICE_IPAD_HORIZONTAL_CODE: IconName = IconName("device-ipad-horizontal-code") DEVICE_IPAD_HORIZONTAL_COG: IconName = IconName("device-ipad-horizontal-cog") DEVICE_IPAD_HORIZONTAL_DOLLAR: IconName = IconName("device-ipad-horizontal-dollar") DEVICE_IPAD_HORIZONTAL_DOWN: IconName = IconName("device-ipad-horizontal-down") DEVICE_IPAD_HORIZONTAL_EXCLAMATION: IconName = IconName( "device-ipad-horizontal-exclamation" ) DEVICE_IPAD_HORIZONTAL_HEART: IconName = IconName("device-ipad-horizontal-heart") DEVICE_IPAD_HORIZONTAL_MINUS: IconName = IconName("device-ipad-horizontal-minus") DEVICE_IPAD_HORIZONTAL_OFF: IconName = IconName("device-ipad-horizontal-off") DEVICE_IPAD_HORIZONTAL_PAUSE: IconName = IconName("device-ipad-horizontal-pause") DEVICE_IPAD_HORIZONTAL_PIN: IconName = IconName("device-ipad-horizontal-pin") DEVICE_IPAD_HORIZONTAL_PLUS: IconName = IconName("device-ipad-horizontal-plus") DEVICE_IPAD_HORIZONTAL_QUESTION: IconName = IconName( "device-ipad-horizontal-question" ) DEVICE_IPAD_HORIZONTAL_SEARCH: IconName = IconName("device-ipad-horizontal-search") DEVICE_IPAD_HORIZONTAL_SHARE: IconName = IconName("device-ipad-horizontal-share") DEVICE_IPAD_HORIZONTAL_STAR: IconName = IconName("device-ipad-horizontal-star") DEVICE_IPAD_HORIZONTAL_UP: IconName = IconName("device-ipad-horizontal-up") DEVICE_IPAD_HORIZONTAL_X: IconName = IconName("device-ipad-horizontal-x") DEVICE_IPAD_MINUS: IconName = IconName("device-ipad-minus") DEVICE_IPAD_OFF: IconName = IconName("device-ipad-off") DEVICE_IPAD_PAUSE: IconName = IconName("device-ipad-pause") DEVICE_IPAD_PIN: IconName = IconName("device-ipad-pin") DEVICE_IPAD_PLUS: IconName = IconName("device-ipad-plus") DEVICE_IPAD_QUESTION: IconName = IconName("device-ipad-question") DEVICE_IPAD_SEARCH: IconName = IconName("device-ipad-search") DEVICE_IPAD_SHARE: IconName = IconName("device-ipad-share") DEVICE_IPAD_STAR: IconName = IconName("device-ipad-star") DEVICE_IPAD_UP: IconName = IconName("device-ipad-up") DEVICE_IPAD_X: IconName = IconName("device-ipad-x") DEVICE_LANDLINE_PHONE: IconName = IconName("device-landline-phone") DEVICE_LAPTOP: IconName = IconName("device-laptop") DEVICE_LAPTOP_OFF: IconName = IconName("device-laptop-off") DEVICE_MOBILE: IconName = IconName("device-mobile") DEVICE_MOBILE_BOLT: IconName = IconName("device-mobile-bolt") DEVICE_MOBILE_CANCEL: IconName = IconName("device-mobile-cancel") DEVICE_MOBILE_CHARGING: IconName = IconName("device-mobile-charging") DEVICE_MOBILE_CHECK: IconName = IconName("device-mobile-check") DEVICE_MOBILE_CODE: IconName = IconName("device-mobile-code") DEVICE_MOBILE_COG: IconName = IconName("device-mobile-cog") DEVICE_MOBILE_DOLLAR: IconName = IconName("device-mobile-dollar") DEVICE_MOBILE_DOWN: IconName = IconName("device-mobile-down") DEVICE_MOBILE_EXCLAMATION: IconName = IconName("device-mobile-exclamation") DEVICE_MOBILE_FILLED: IconName = IconName("device-mobile-filled") DEVICE_MOBILE_HEART: IconName = IconName("device-mobile-heart") DEVICE_MOBILE_MESSAGE: IconName = IconName("device-mobile-message") DEVICE_MOBILE_MINUS: IconName = IconName("device-mobile-minus") DEVICE_MOBILE_OFF: IconName = IconName("device-mobile-off") DEVICE_MOBILE_PAUSE: IconName = IconName("device-mobile-pause") DEVICE_MOBILE_PIN: IconName = IconName("device-mobile-pin") DEVICE_MOBILE_PLUS: IconName = IconName("device-mobile-plus") DEVICE_MOBILE_QUESTION: IconName = IconName("device-mobile-question") DEVICE_MOBILE_ROTATED: IconName = IconName("device-mobile-rotated") DEVICE_MOBILE_SEARCH: IconName = IconName("device-mobile-search") DEVICE_MOBILE_SHARE: IconName = IconName("device-mobile-share") DEVICE_MOBILE_STAR: IconName = IconName("device-mobile-star") DEVICE_MOBILE_UP: IconName = IconName("device-mobile-up") DEVICE_MOBILE_VIBRATION: IconName = IconName("device-mobile-vibration") DEVICE_MOBILE_X: IconName = IconName("device-mobile-x") DEVICE_NINTENDO: IconName = IconName("device-nintendo") DEVICE_NINTENDO_OFF: IconName = IconName("device-nintendo-off") DEVICE_REMOTE: IconName = IconName("device-remote") DEVICE_SD_CARD: IconName = IconName("device-sd-card") DEVICE_SIM: IconName = IconName("device-sim") DEVICE_SIM_1: IconName = IconName("device-sim-1") DEVICE_SIM_2: IconName = IconName("device-sim-2") DEVICE_SIM_3: IconName = IconName("device-sim-3") DEVICE_SPEAKER: IconName = IconName("device-speaker") DEVICE_SPEAKER_OFF: IconName = IconName("device-speaker-off") DEVICE_TABLET: IconName = IconName("device-tablet") DEVICE_TABLET_BOLT: IconName = IconName("device-tablet-bolt") DEVICE_TABLET_CANCEL: IconName = IconName("device-tablet-cancel") DEVICE_TABLET_CHECK: IconName = IconName("device-tablet-check") DEVICE_TABLET_CODE: IconName = IconName("device-tablet-code") DEVICE_TABLET_COG: IconName = IconName("device-tablet-cog") DEVICE_TABLET_DOLLAR: IconName = IconName("device-tablet-dollar") DEVICE_TABLET_DOWN: IconName = IconName("device-tablet-down") DEVICE_TABLET_EXCLAMATION: IconName = IconName("device-tablet-exclamation") DEVICE_TABLET_FILLED: IconName = IconName("device-tablet-filled") DEVICE_TABLET_HEART: IconName = IconName("device-tablet-heart") DEVICE_TABLET_MINUS: IconName = IconName("device-tablet-minus") DEVICE_TABLET_OFF: IconName = IconName("device-tablet-off") DEVICE_TABLET_PAUSE: IconName = IconName("device-tablet-pause") DEVICE_TABLET_PIN: IconName = IconName("device-tablet-pin") DEVICE_TABLET_PLUS: IconName = IconName("device-tablet-plus") DEVICE_TABLET_QUESTION: IconName = IconName("device-tablet-question") DEVICE_TABLET_SEARCH: IconName = IconName("device-tablet-search") DEVICE_TABLET_SHARE: IconName = IconName("device-tablet-share") DEVICE_TABLET_STAR: IconName = IconName("device-tablet-star") DEVICE_TABLET_UP: IconName = IconName("device-tablet-up") DEVICE_TABLET_X: IconName = IconName("device-tablet-x") DEVICE_TV: IconName = IconName("device-tv") DEVICE_TV_OFF: IconName = IconName("device-tv-off") DEVICE_TV_OLD: IconName = IconName("device-tv-old") DEVICE_VISION_PRO: IconName = IconName("device-vision-pro") DEVICE_WATCH: IconName = IconName("device-watch") DEVICE_WATCH_BOLT: IconName = IconName("device-watch-bolt") DEVICE_WATCH_CANCEL: IconName = IconName("device-watch-cancel") DEVICE_WATCH_CHECK: IconName = IconName("device-watch-check") DEVICE_WATCH_CODE: IconName = IconName("device-watch-code") DEVICE_WATCH_COG: IconName = IconName("device-watch-cog") DEVICE_WATCH_DOLLAR: IconName = IconName("device-watch-dollar") DEVICE_WATCH_DOWN: IconName = IconName("device-watch-down") DEVICE_WATCH_EXCLAMATION: IconName = IconName("device-watch-exclamation") DEVICE_WATCH_HEART: IconName = IconName("device-watch-heart") DEVICE_WATCH_MINUS: IconName = IconName("device-watch-minus") DEVICE_WATCH_OFF: IconName = IconName("device-watch-off") DEVICE_WATCH_PAUSE: IconName = IconName("device-watch-pause") DEVICE_WATCH_PIN: IconName = IconName("device-watch-pin") DEVICE_WATCH_PLUS: IconName = IconName("device-watch-plus") DEVICE_WATCH_QUESTION: IconName = IconName("device-watch-question") DEVICE_WATCH_SEARCH: IconName = IconName("device-watch-search") DEVICE_WATCH_SHARE: IconName = IconName("device-watch-share") DEVICE_WATCH_STAR: IconName = IconName("device-watch-star") DEVICE_WATCH_STATS: IconName = IconName("device-watch-stats") DEVICE_WATCH_STATS_2: IconName = IconName("device-watch-stats-2") DEVICE_WATCH_UP: IconName = IconName("device-watch-up") DEVICE_WATCH_X: IconName = IconName("device-watch-x") DEVICES: IconName = IconName("devices") DEVICES_2: IconName = IconName("devices-2") DEVICES_BOLT: IconName = IconName("devices-bolt") DEVICES_CANCEL: IconName = IconName("devices-cancel") DEVICES_CHECK: IconName = IconName("devices-check") DEVICES_CODE: IconName = IconName("devices-code") DEVICES_COG: IconName = IconName("devices-cog") DEVICES_DOLLAR: IconName = IconName("devices-dollar") DEVICES_DOWN: IconName = IconName("devices-down") DEVICES_EXCLAMATION: IconName = IconName("devices-exclamation") DEVICES_HEART: IconName = IconName("devices-heart") DEVICES_MINUS: IconName = IconName("devices-minus") DEVICES_OFF: IconName = IconName("devices-off") DEVICES_PAUSE: IconName = IconName("devices-pause") DEVICES_PC: IconName = IconName("devices-pc") DEVICES_PC_OFF: IconName = IconName("devices-pc-off") DEVICES_PIN: IconName = IconName("devices-pin") DEVICES_PLUS: IconName = IconName("devices-plus") DEVICES_QUESTION: IconName = IconName("devices-question") DEVICES_SEARCH: IconName = IconName("devices-search") DEVICES_SHARE: IconName = IconName("devices-share") DEVICES_STAR: IconName = IconName("devices-star") DEVICES_UP: IconName = IconName("devices-up") DEVICES_X: IconName = IconName("devices-x") DIABOLO: IconName = IconName("diabolo") DIABOLO_OFF: IconName = IconName("diabolo-off") DIABOLO_PLUS: IconName = IconName("diabolo-plus") DIALPAD: IconName = IconName("dialpad") DIALPAD_FILLED: IconName = IconName("dialpad-filled") DIALPAD_OFF: IconName = IconName("dialpad-off") DIAMOND: IconName = IconName("diamond") DIAMOND_FILLED: IconName = IconName("diamond-filled") DIAMOND_OFF: IconName = IconName("diamond-off") DIAMONDS: IconName = IconName("diamonds") DIAMONDS_FILLED: IconName = IconName("diamonds-filled") DICE: IconName = IconName("dice") DICE_1: IconName = IconName("dice-1") DICE_1_FILLED: IconName = IconName("dice-1-filled") DICE_2: IconName = IconName("dice-2") DICE_2_FILLED: IconName = IconName("dice-2-filled") DICE_3: IconName = IconName("dice-3") DICE_3_FILLED: IconName = IconName("dice-3-filled") DICE_4: IconName = IconName("dice-4") DICE_4_FILLED: IconName = IconName("dice-4-filled") DICE_5: IconName = IconName("dice-5") DICE_5_FILLED: IconName = IconName("dice-5-filled") DICE_6: IconName = IconName("dice-6") DICE_6_FILLED: IconName = IconName("dice-6-filled") DICE_FILLED: IconName = IconName("dice-filled") DIMENSIONS: IconName = IconName("dimensions") DIRECTION: IconName = IconName("direction") DIRECTION_HORIZONTAL: IconName = IconName("direction-horizontal") DIRECTION_SIGN: IconName = IconName("direction-sign") DIRECTION_SIGN_FILLED: IconName = IconName("direction-sign-filled") DIRECTION_SIGN_OFF: IconName = IconName("direction-sign-off") DIRECTIONS: IconName = IconName("directions") DIRECTIONS_OFF: IconName = IconName("directions-off") DISABLED: IconName = IconName("disabled") DISABLED_2: IconName = IconName("disabled-2") DISABLED_OFF: IconName = IconName("disabled-off") DISC: IconName = IconName("disc") DISC_GOLF: IconName = IconName("disc-golf") DISC_OFF: IconName = IconName("disc-off") DISCOUNT: IconName = IconName("discount") DISCOUNT_2: IconName = IconName("discount-2") DISCOUNT_2_OFF: IconName = IconName("discount-2-off") DISCOUNT_CHECK: IconName = IconName("discount-check") DISCOUNT_CHECK_FILLED: IconName = IconName("discount-check-filled") DISCOUNT_OFF: IconName = IconName("discount-off") DIVIDE: IconName = IconName("divide") DNA: IconName = IconName("dna") DNA_2: IconName = IconName("dna-2") DNA_2_OFF: IconName = IconName("dna-2-off") DNA_OFF: IconName = IconName("dna-off") DOG: IconName = IconName("dog") DOG_BOWL: IconName = IconName("dog-bowl") DOOR: IconName = IconName("door") DOOR_ENTER: IconName = IconName("door-enter") DOOR_EXIT: IconName = IconName("door-exit") DOOR_OFF: IconName = IconName("door-off") DOTS: IconName = IconName("dots") DOTS_CIRCLE_HORIZONTAL: IconName = IconName("dots-circle-horizontal") DOTS_DIAGONAL: IconName = IconName("dots-diagonal") DOTS_DIAGONAL_2: IconName = IconName("dots-diagonal-2") DOTS_VERTICAL: IconName = IconName("dots-vertical") DOWNLOAD: IconName = IconName("download") DOWNLOAD_OFF: IconName = IconName("download-off") DRAG_DROP: IconName = IconName("drag-drop") DRAG_DROP_2: IconName = IconName("drag-drop-2") DRONE: IconName = IconName("drone") DRONE_OFF: IconName = IconName("drone-off") DROP_CIRCLE: IconName = IconName("drop-circle") DROPLET: IconName = IconName("droplet") DROPLET_BOLT: IconName = IconName("droplet-bolt") DROPLET_CANCEL: IconName = IconName("droplet-cancel") DROPLET_CHECK: IconName = IconName("droplet-check") DROPLET_CODE: IconName = IconName("droplet-code") DROPLET_COG: IconName = IconName("droplet-cog") DROPLET_DOLLAR: IconName = IconName("droplet-dollar") DROPLET_DOWN: IconName = IconName("droplet-down") DROPLET_EXCLAMATION: IconName = IconName("droplet-exclamation") DROPLET_FILLED: IconName = IconName("droplet-filled") DROPLET_FILLED_2: IconName = IconName("droplet-filled-2") DROPLET_HALF: IconName = IconName("droplet-half") DROPLET_HALF_2: IconName = IconName("droplet-half-2") DROPLET_HALF_FILLED: IconName = IconName("droplet-half-filled") DROPLET_HEART: IconName = IconName("droplet-heart") DROPLET_MINUS: IconName = IconName("droplet-minus") DROPLET_OFF: IconName = IconName("droplet-off") DROPLET_PAUSE: IconName = IconName("droplet-pause") DROPLET_PIN: IconName = IconName("droplet-pin") DROPLET_PLUS: IconName = IconName("droplet-plus") DROPLET_QUESTION: IconName = IconName("droplet-question") DROPLET_SEARCH: IconName = IconName("droplet-search") DROPLET_SHARE: IconName = IconName("droplet-share") DROPLET_STAR: IconName = IconName("droplet-star") DROPLET_UP: IconName = IconName("droplet-up") DROPLET_X: IconName = IconName("droplet-x") DUAL_SCREEN: IconName = IconName("dual-screen") E_PASSPORT: IconName = IconName("e-passport") EAR: IconName = IconName("ear") EAR_OFF: IconName = IconName("ear-off") EASE_IN: IconName = IconName("ease-in") EASE_IN_CONTROL_POINT: IconName = IconName("ease-in-control-point") EASE_IN_OUT: IconName = IconName("ease-in-out") EASE_IN_OUT_CONTROL_POINTS: IconName = IconName("ease-in-out-control-points") EASE_OUT: IconName = IconName("ease-out") EASE_OUT_CONTROL_POINT: IconName = IconName("ease-out-control-point") EDIT: IconName = IconName("edit") EDIT_CIRCLE: IconName = IconName("edit-circle") EDIT_CIRCLE_OFF: IconName = IconName("edit-circle-off") EDIT_OFF: IconName = IconName("edit-off") EGG: IconName = IconName("egg") EGG_CRACKED: IconName = IconName("egg-cracked") EGG_FILLED: IconName = IconName("egg-filled") EGG_FRIED: IconName = IconName("egg-fried") EGG_OFF: IconName = IconName("egg-off") EGGS: IconName = IconName("eggs") ELEVATOR: IconName = IconName("elevator") ELEVATOR_OFF: IconName = IconName("elevator-off") EMERGENCY_BED: IconName = IconName("emergency-bed") EMPATHIZE: IconName = IconName("empathize") EMPATHIZE_OFF: IconName = IconName("empathize-off") EMPHASIS: IconName = IconName("emphasis") ENGINE: IconName = IconName("engine") ENGINE_OFF: IconName = IconName("engine-off") EQUAL: IconName = IconName("equal") EQUAL_DOUBLE: IconName = IconName("equal-double") EQUAL_NOT: IconName = IconName("equal-not") ERASER: IconName = IconName("eraser") ERASER_OFF: IconName = IconName("eraser-off") ERROR_404: IconName = IconName("error-404") ERROR_404_OFF: IconName = IconName("error-404-off") EXCHANGE: IconName = IconName("exchange") EXCHANGE_OFF: IconName = IconName("exchange-off") EXCLAMATION_CIRCLE: IconName = IconName("exclamation-circle") EXCLAMATION_MARK: IconName = IconName("exclamation-mark") EXCLAMATION_MARK_OFF: IconName = IconName("exclamation-mark-off") EXPLICIT: IconName = IconName("explicit") EXPLICIT_OFF: IconName = IconName("explicit-off") EXPOSURE: IconName = IconName("exposure") EXPOSURE_0: IconName = IconName("exposure-0") EXPOSURE_MINUS_1: IconName = IconName("exposure-minus-1") EXPOSURE_MINUS_2: IconName = IconName("exposure-minus-2") EXPOSURE_OFF: IconName = IconName("exposure-off") EXPOSURE_PLUS_1: IconName = IconName("exposure-plus-1") EXPOSURE_PLUS_2: IconName = IconName("exposure-plus-2") EXTERNAL_LINK: IconName = IconName("external-link") EXTERNAL_LINK_OFF: IconName = IconName("external-link-off") EYE: IconName = IconName("eye") EYE_CHECK: IconName = IconName("eye-check") EYE_CLOSED: IconName = IconName("eye-closed") EYE_COG: IconName = IconName("eye-cog") EYE_EDIT: IconName = IconName("eye-edit") EYE_EXCLAMATION: IconName = IconName("eye-exclamation") EYE_FILLED: IconName = IconName("eye-filled") EYE_HEART: IconName = IconName("eye-heart") EYE_OFF: IconName = IconName("eye-off") EYE_TABLE: IconName = IconName("eye-table") EYE_X: IconName = IconName("eye-x") EYEGLASS: IconName = IconName("eyeglass") EYEGLASS_2: IconName = IconName("eyeglass-2") EYEGLASS_OFF: IconName = IconName("eyeglass-off") FACE_ID: IconName = IconName("face-id") FACE_ID_ERROR: IconName = IconName("face-id-error") FACE_MASK: IconName = IconName("face-mask") FACE_MASK_OFF: IconName = IconName("face-mask-off") FALL: IconName = IconName("fall") FEATHER: IconName = IconName("feather") FEATHER_OFF: IconName = IconName("feather-off") FENCE: IconName = IconName("fence") FENCE_OFF: IconName = IconName("fence-off") FIDGET_SPINNER: IconName = IconName("fidget-spinner") FILE: IconName = IconName("file") FILE_3D: IconName = IconName("file-3d") FILE_ALERT: IconName = IconName("file-alert") FILE_ANALYTICS: IconName = IconName("file-analytics") FILE_ARROW_LEFT: IconName = IconName("file-arrow-left") FILE_ARROW_RIGHT: IconName = IconName("file-arrow-right") FILE_BARCODE: IconName = IconName("file-barcode") FILE_BROKEN: IconName = IconName("file-broken") FILE_CERTIFICATE: IconName = IconName("file-certificate") FILE_CHART: IconName = IconName("file-chart") FILE_CHECK: IconName = IconName("file-check") FILE_CODE: IconName = IconName("file-code") FILE_CODE_2: IconName = IconName("file-code-2") FILE_CV: IconName = IconName("file-cv") FILE_DATABASE: IconName = IconName("file-database") FILE_DELTA: IconName = IconName("file-delta") FILE_DESCRIPTION: IconName = IconName("file-description") FILE_DIFF: IconName = IconName("file-diff") FILE_DIGIT: IconName = IconName("file-digit") FILE_DISLIKE: IconName = IconName("file-dislike") FILE_DOLLAR: IconName = IconName("file-dollar") FILE_DOTS: IconName = IconName("file-dots") FILE_DOWNLOAD: IconName = IconName("file-download") FILE_EURO: IconName = IconName("file-euro") FILE_EXPORT: IconName = IconName("file-export") FILE_FILLED: IconName = IconName("file-filled") FILE_FUNCTION: IconName = IconName("file-function") FILE_HORIZONTAL: IconName = IconName("file-horizontal") FILE_IMPORT: IconName = IconName("file-import") FILE_INFINITY: IconName = IconName("file-infinity") FILE_INFO: IconName = IconName("file-info") FILE_INVOICE: IconName = IconName("file-invoice") FILE_LAMBDA: IconName = IconName("file-lambda") FILE_LIKE: IconName = IconName("file-like") FILE_MINUS: IconName = IconName("file-minus") FILE_MUSIC: IconName = IconName("file-music") FILE_OFF: IconName = IconName("file-off") FILE_ORIENTATION: IconName = IconName("file-orientation") FILE_PENCIL: IconName = IconName("file-pencil") FILE_PERCENT: IconName = IconName("file-percent") FILE_PHONE: IconName = IconName("file-phone") FILE_PLUS: IconName = IconName("file-plus") FILE_POWER: IconName = IconName("file-power") FILE_REPORT: IconName = IconName("file-report") FILE_RSS: IconName = IconName("file-rss") FILE_SCISSORS: IconName = IconName("file-scissors") FILE_SEARCH: IconName = IconName("file-search") FILE_SETTINGS: IconName = IconName("file-settings") FILE_SHREDDER: IconName = IconName("file-shredder") FILE_SIGNAL: IconName = IconName("file-signal") FILE_SPREADSHEET: IconName = IconName("file-spreadsheet") FILE_STACK: IconName = IconName("file-stack") FILE_STAR: IconName = IconName("file-star") FILE_SYMLINK: IconName = IconName("file-symlink") FILE_TEXT: IconName = IconName("file-text") FILE_TEXT_AI: IconName = IconName("file-text-ai") FILE_TIME: IconName = IconName("file-time") FILE_TYPOGRAPHY: IconName = IconName("file-typography") FILE_UNKNOWN: IconName = IconName("file-unknown") FILE_UPLOAD: IconName = IconName("file-upload") FILE_VECTOR: IconName = IconName("file-vector") FILE_X: IconName = IconName("file-x") FILE_X_FILLED: IconName = IconName("file-x-filled") FILE_ZIP: IconName = IconName("file-zip") FILES: IconName = IconName("files") FILES_OFF: IconName = IconName("files-off") FILTER: IconName = IconName("filter") FILTER_COG: IconName = IconName("filter-cog") FILTER_DOLLAR: IconName = IconName("filter-dollar") FILTER_EDIT: IconName = IconName("filter-edit") FILTER_MINUS: IconName = IconName("filter-minus") FILTER_OFF: IconName = IconName("filter-off") FILTER_PLUS: IconName = IconName("filter-plus") FILTER_STAR: IconName = IconName("filter-star") FILTER_X: IconName = IconName("filter-x") FILTERS: IconName = IconName("filters") FINGERPRINT: IconName = IconName("fingerprint") FINGERPRINT_OFF: IconName = IconName("fingerprint-off") FIRE_EXTINGUISHER: IconName = IconName("fire-extinguisher") FIRE_HYDRANT: IconName = IconName("fire-hydrant") FIRE_HYDRANT_OFF: IconName = IconName("fire-hydrant-off") FIRETRUCK: IconName = IconName("firetruck") FIRST_AID_KIT: IconName = IconName("first-aid-kit") FIRST_AID_KIT_OFF: IconName = IconName("first-aid-kit-off") FISH: IconName = IconName("fish") FISH_BONE: IconName = IconName("fish-bone") FISH_CHRISTIANITY: IconName = IconName("fish-christianity") FISH_HOOK: IconName = IconName("fish-hook") FISH_HOOK_OFF: IconName = IconName("fish-hook-off") FISH_OFF: IconName = IconName("fish-off") FLAG: IconName = IconName("flag") FLAG_2: IconName = IconName("flag-2") FLAG_2_FILLED: IconName = IconName("flag-2-filled") FLAG_2_OFF: IconName = IconName("flag-2-off") FLAG_3: IconName = IconName("flag-3") FLAG_3_FILLED: IconName = IconName("flag-3-filled") FLAG_FILLED: IconName = IconName("flag-filled") FLAG_OFF: IconName = IconName("flag-off") FLAME: IconName = IconName("flame") FLAME_OFF: IconName = IconName("flame-off") FLARE: IconName = IconName("flare") FLASK: IconName = IconName("flask") FLASK_2: IconName = IconName("flask-2") FLASK_2_OFF: IconName = IconName("flask-2-off") FLASK_OFF: IconName = IconName("flask-off") FLIP_FLOPS: IconName = IconName("flip-flops") FLIP_HORIZONTAL: IconName = IconName("flip-horizontal") FLIP_VERTICAL: IconName = IconName("flip-vertical") FLOAT_CENTER: IconName = IconName("float-center") FLOAT_LEFT: IconName = IconName("float-left") FLOAT_NONE: IconName = IconName("float-none") FLOAT_RIGHT: IconName = IconName("float-right") FLOWER: IconName = IconName("flower") FLOWER_OFF: IconName = IconName("flower-off") FOCUS: IconName = IconName("focus") FOCUS_2: IconName = IconName("focus-2") FOCUS_AUTO: IconName = IconName("focus-auto") FOCUS_CENTERED: IconName = IconName("focus-centered") FOLD: IconName = IconName("fold") FOLD_DOWN: IconName = IconName("fold-down") FOLD_UP: IconName = IconName("fold-up") FOLDER: IconName = IconName("folder") FOLDER_BOLT: IconName = IconName("folder-bolt") FOLDER_CANCEL: IconName = IconName("folder-cancel") FOLDER_CHECK: IconName = IconName("folder-check") FOLDER_CODE: IconName = IconName("folder-code") FOLDER_COG: IconName = IconName("folder-cog") FOLDER_DOLLAR: IconName = IconName("folder-dollar") FOLDER_DOWN: IconName = IconName("folder-down") FOLDER_EXCLAMATION: IconName = IconName("folder-exclamation") FOLDER_FILLED: IconName = IconName("folder-filled") FOLDER_HEART: IconName = IconName("folder-heart") FOLDER_MINUS: IconName = IconName("folder-minus") FOLDER_OFF: IconName = IconName("folder-off") FOLDER_OPEN: IconName = IconName("folder-open") FOLDER_PAUSE: IconName = IconName("folder-pause") FOLDER_PIN: IconName = IconName("folder-pin") FOLDER_PLUS: IconName = IconName("folder-plus") FOLDER_QUESTION: IconName = IconName("folder-question") FOLDER_SEARCH: IconName = IconName("folder-search") FOLDER_SHARE: IconName = IconName("folder-share") FOLDER_STAR: IconName = IconName("folder-star") FOLDER_SYMLINK: IconName = IconName("folder-symlink") FOLDER_UP: IconName = IconName("folder-up") FOLDER_X: IconName = IconName("folder-x") FOLDERS: IconName = IconName("folders") FOLDERS_OFF: IconName = IconName("folders-off") FORBID: IconName = IconName("forbid") FORBID_2: IconName = IconName("forbid-2") FORKLIFT: IconName = IconName("forklift") FORMS: IconName = IconName("forms") FOUNTAIN: IconName = IconName("fountain") FOUNTAIN_OFF: IconName = IconName("fountain-off") FRAME: IconName = IconName("frame") FRAME_OFF: IconName = IconName("frame-off") FREE_RIGHTS: IconName = IconName("free-rights") FREEZE_COLUMN: IconName = IconName("freeze-column") FREEZE_ROW: IconName = IconName("freeze-row") FREEZE_ROW_COLUMN: IconName = IconName("freeze-row-column") FRIDGE: IconName = IconName("fridge") FRIDGE_OFF: IconName = IconName("fridge-off") FRIENDS: IconName = IconName("friends") FRIENDS_OFF: IconName = IconName("friends-off") FRUSTUM: IconName = IconName("frustum") FRUSTUM_OFF: IconName = IconName("frustum-off") FRUSTUM_PLUS: IconName = IconName("frustum-plus") FUNCTION: IconName = IconName("function") FUNCTION_OFF: IconName = IconName("function-off") GARDEN_CART: IconName = IconName("garden-cart") GARDEN_CART_OFF: IconName = IconName("garden-cart-off") GAS_STATION: IconName = IconName("gas-station") GAS_STATION_OFF: IconName = IconName("gas-station-off") GAUGE: IconName = IconName("gauge") GAUGE_OFF: IconName = IconName("gauge-off") GAVEL: IconName = IconName("gavel") GENDER_AGENDER: IconName = IconName("gender-agender") GENDER_ANDROGYNE: IconName = IconName("gender-androgyne") GENDER_BIGENDER: IconName = IconName("gender-bigender") GENDER_DEMIBOY: IconName = IconName("gender-demiboy") GENDER_DEMIGIRL: IconName = IconName("gender-demigirl") GENDER_EPICENE: IconName = IconName("gender-epicene") GENDER_FEMALE: IconName = IconName("gender-female") GENDER_FEMME: IconName = IconName("gender-femme") GENDER_GENDERFLUID: IconName = IconName("gender-genderfluid") GENDER_GENDERLESS: IconName = IconName("gender-genderless") GENDER_GENDERQUEER: IconName = IconName("gender-genderqueer") GENDER_HERMAPHRODITE: IconName = IconName("gender-hermaphrodite") GENDER_INTERGENDER: IconName = IconName("gender-intergender") GENDER_MALE: IconName = IconName("gender-male") GENDER_NEUTROIS: IconName = IconName("gender-neutrois") GENDER_THIRD: IconName = IconName("gender-third") GENDER_TRANSGENDER: IconName = IconName("gender-transgender") GENDER_TRASVESTI: IconName = IconName("gender-trasvesti") GEOMETRY: IconName = IconName("geometry") GHOST: IconName = IconName("ghost") GHOST_2: IconName = IconName("ghost-2") GHOST_2_FILLED: IconName = IconName("ghost-2-filled") GHOST_FILLED: IconName = IconName("ghost-filled") GHOST_OFF: IconName = IconName("ghost-off") GIF: IconName = IconName("gif") GIFT: IconName = IconName("gift") GIFT_CARD: IconName = IconName("gift-card") GIFT_OFF: IconName = IconName("gift-off") GIT_BRANCH: IconName = IconName("git-branch") GIT_BRANCH_DELETED: IconName = IconName("git-branch-deleted") GIT_CHERRY_PICK: IconName = IconName("git-cherry-pick") GIT_COMMIT: IconName = IconName("git-commit") GIT_COMPARE: IconName = IconName("git-compare") GIT_FORK: IconName = IconName("git-fork") GIT_MERGE: IconName = IconName("git-merge") GIT_PULL_REQUEST: IconName = IconName("git-pull-request") GIT_PULL_REQUEST_CLOSED: IconName = IconName("git-pull-request-closed") GIT_PULL_REQUEST_DRAFT: IconName = IconName("git-pull-request-draft") GIZMO: IconName = IconName("gizmo") GLASS: IconName = IconName("glass") GLASS_FULL: IconName = IconName("glass-full") GLASS_OFF: IconName = IconName("glass-off") GLOBE: IconName = IconName("globe") GLOBE_OFF: IconName = IconName("globe-off") GO_GAME: IconName = IconName("go-game") GOLF: IconName = IconName("golf") GOLF_OFF: IconName = IconName("golf-off") GPS: IconName = IconName("gps") GRADIENTER: IconName = IconName("gradienter") GRAIN: IconName = IconName("grain") GRAPH: IconName = IconName("graph") GRAPH_OFF: IconName = IconName("graph-off") GRAVE: IconName = IconName("grave") GRAVE_2: IconName = IconName("grave-2") GRID_DOTS: IconName = IconName("grid-dots") GRID_PATTERN: IconName = IconName("grid-pattern") GRILL: IconName = IconName("grill") GRILL_FORK: IconName = IconName("grill-fork") GRILL_OFF: IconName = IconName("grill-off") GRILL_SPATULA: IconName = IconName("grill-spatula") GRIP_HORIZONTAL: IconName = IconName("grip-horizontal") GRIP_VERTICAL: IconName = IconName("grip-vertical") GROWTH: IconName = IconName("growth") GUITAR_PICK: IconName = IconName("guitar-pick") GUITAR_PICK_FILLED: IconName = IconName("guitar-pick-filled") H_1: IconName = IconName("h-1") H_2: IconName = IconName("h-2") H_3: IconName = IconName("h-3") H_4: IconName = IconName("h-4") H_5: IconName = IconName("h-5") H_6: IconName = IconName("h-6") HAMMER: IconName = IconName("hammer") HAMMER_OFF: IconName = IconName("hammer-off") HAND_CLICK: IconName = IconName("hand-click") HAND_FINGER: IconName = IconName("hand-finger") HAND_FINGER_OFF: IconName = IconName("hand-finger-off") HAND_GRAB: IconName = IconName("hand-grab") HAND_LITTLE_FINGER: IconName = IconName("hand-little-finger") HAND_MIDDLE_FINGER: IconName = IconName("hand-middle-finger") HAND_MOVE: IconName = IconName("hand-move") HAND_OFF: IconName = IconName("hand-off") HAND_RING_FINGER: IconName = IconName("hand-ring-finger") HAND_ROCK: IconName = IconName("hand-rock") HAND_SANITIZER: IconName = IconName("hand-sanitizer") HAND_STOP: IconName = IconName("hand-stop") HAND_THREE_FINGERS: IconName = IconName("hand-three-fingers") HAND_TWO_FINGERS: IconName = IconName("hand-two-fingers") HANGER: IconName = IconName("hanger") HANGER_2: IconName = IconName("hanger-2") HANGER_OFF: IconName = IconName("hanger-off") HASH: IconName = IconName("hash") HAZE: IconName = IconName("haze") HAZE_MOON: IconName = IconName("haze-moon") HDR: IconName = IconName("hdr") HEADING: IconName = IconName("heading") HEADING_OFF: IconName = IconName("heading-off") HEADPHONES: IconName = IconName("headphones") HEADPHONES_FILLED: IconName = IconName("headphones-filled") HEADPHONES_OFF: IconName = IconName("headphones-off") HEADSET: IconName = IconName("headset") HEADSET_OFF: IconName = IconName("headset-off") HEALTH_RECOGNITION: IconName = IconName("health-recognition") HEART: IconName = IconName("heart") HEART_BROKEN: IconName = IconName("heart-broken") HEART_FILLED: IconName = IconName("heart-filled") HEART_HANDSHAKE: IconName = IconName("heart-handshake") HEART_MINUS: IconName = IconName("heart-minus") HEART_OFF: IconName = IconName("heart-off") HEART_PLUS: IconName = IconName("heart-plus") HEART_RATE_MONITOR: IconName = IconName("heart-rate-monitor") HEARTBEAT: IconName = IconName("heartbeat") HEARTS: IconName = IconName("hearts") HEARTS_OFF: IconName = IconName("hearts-off") HELICOPTER: IconName = IconName("helicopter") HELICOPTER_LANDING: IconName = IconName("helicopter-landing") HELMET: IconName = IconName("helmet") HELMET_OFF: IconName = IconName("helmet-off") HELP: IconName = IconName("help") HELP_CIRCLE: IconName = IconName("help-circle") HELP_CIRCLE_FILLED: IconName = IconName("help-circle-filled") HELP_HEXAGON: IconName = IconName("help-hexagon") HELP_HEXAGON_FILLED: IconName = IconName("help-hexagon-filled") HELP_OCTAGON: IconName = IconName("help-octagon") HELP_OCTAGON_FILLED: IconName = IconName("help-octagon-filled") HELP_OFF: IconName = IconName("help-off") HELP_SMALL: IconName = IconName("help-small") HELP_SQUARE: IconName = IconName("help-square") HELP_SQUARE_FILLED: IconName = IconName("help-square-filled") HELP_SQUARE_ROUNDED: IconName = IconName("help-square-rounded") HELP_SQUARE_ROUNDED_FILLED: IconName = IconName("help-square-rounded-filled") HELP_TRIANGLE: IconName = IconName("help-triangle") HELP_TRIANGLE_FILLED: IconName = IconName("help-triangle-filled") HEMISPHERE: IconName = IconName("hemisphere") HEMISPHERE_OFF: IconName = IconName("hemisphere-off") HEMISPHERE_PLUS: IconName = IconName("hemisphere-plus") HEXAGON: IconName = IconName("hexagon") HEXAGON_0_FILLED: IconName = IconName("hexagon-0-filled") HEXAGON_1_FILLED: IconName = IconName("hexagon-1-filled") HEXAGON_2_FILLED: IconName = IconName("hexagon-2-filled") HEXAGON_3_FILLED: IconName = IconName("hexagon-3-filled") HEXAGON_3D: IconName = IconName("hexagon-3d") HEXAGON_4_FILLED: IconName = IconName("hexagon-4-filled") HEXAGON_5_FILLED: IconName = IconName("hexagon-5-filled") HEXAGON_6_FILLED: IconName = IconName("hexagon-6-filled") HEXAGON_7_FILLED: IconName = IconName("hexagon-7-filled") HEXAGON_8_FILLED: IconName = IconName("hexagon-8-filled") HEXAGON_9_FILLED: IconName = IconName("hexagon-9-filled") HEXAGON_FILLED: IconName = IconName("hexagon-filled") HEXAGON_LETTER_A: IconName = IconName("hexagon-letter-a") HEXAGON_LETTER_B: IconName = IconName("hexagon-letter-b") HEXAGON_LETTER_C: IconName = IconName("hexagon-letter-c") HEXAGON_LETTER_D: IconName = IconName("hexagon-letter-d") HEXAGON_LETTER_E: IconName = IconName("hexagon-letter-e") HEXAGON_LETTER_F: IconName = IconName("hexagon-letter-f") HEXAGON_LETTER_G: IconName = IconName("hexagon-letter-g") HEXAGON_LETTER_H: IconName = IconName("hexagon-letter-h") HEXAGON_LETTER_I: IconName = IconName("hexagon-letter-i") HEXAGON_LETTER_J: IconName = IconName("hexagon-letter-j") HEXAGON_LETTER_K: IconName = IconName("hexagon-letter-k") HEXAGON_LETTER_L: IconName = IconName("hexagon-letter-l") HEXAGON_LETTER_M: IconName = IconName("hexagon-letter-m") HEXAGON_LETTER_N: IconName = IconName("hexagon-letter-n") HEXAGON_LETTER_O: IconName = IconName("hexagon-letter-o") HEXAGON_LETTER_P: IconName = IconName("hexagon-letter-p") HEXAGON_LETTER_Q: IconName = IconName("hexagon-letter-q") HEXAGON_LETTER_R: IconName = IconName("hexagon-letter-r") HEXAGON_LETTER_S: IconName = IconName("hexagon-letter-s") HEXAGON_LETTER_T: IconName = IconName("hexagon-letter-t") HEXAGON_LETTER_U: IconName = IconName("hexagon-letter-u") HEXAGON_LETTER_V: IconName = IconName("hexagon-letter-v") HEXAGON_LETTER_W: IconName = IconName("hexagon-letter-w") HEXAGON_LETTER_X: IconName = IconName("hexagon-letter-x") HEXAGON_LETTER_Y: IconName = IconName("hexagon-letter-y") HEXAGON_LETTER_Z: IconName = IconName("hexagon-letter-z") HEXAGON_NUMBER_0: IconName = IconName("hexagon-number-0") HEXAGON_NUMBER_1: IconName = IconName("hexagon-number-1") HEXAGON_NUMBER_2: IconName = IconName("hexagon-number-2") HEXAGON_NUMBER_3: IconName = IconName("hexagon-number-3") HEXAGON_NUMBER_4: IconName = IconName("hexagon-number-4") HEXAGON_NUMBER_5: IconName = IconName("hexagon-number-5") HEXAGON_NUMBER_6: IconName = IconName("hexagon-number-6") HEXAGON_NUMBER_7: IconName = IconName("hexagon-number-7") HEXAGON_NUMBER_8: IconName = IconName("hexagon-number-8") HEXAGON_NUMBER_9: IconName = IconName("hexagon-number-9") HEXAGON_OFF: IconName = IconName("hexagon-off") HEXAGONAL_PRISM: IconName = IconName("hexagonal-prism") HEXAGONAL_PRISM_OFF: IconName = IconName("hexagonal-prism-off") HEXAGONAL_PRISM_PLUS: IconName = IconName("hexagonal-prism-plus") HEXAGONAL_PYRAMID: IconName = IconName("hexagonal-pyramid") HEXAGONAL_PYRAMID_OFF: IconName = IconName("hexagonal-pyramid-off") HEXAGONAL_PYRAMID_PLUS: IconName = IconName("hexagonal-pyramid-plus") HEXAGONS: IconName = IconName("hexagons") HEXAGONS_OFF: IconName = IconName("hexagons-off") HIERARCHY: IconName = IconName("hierarchy") HIERARCHY_2: IconName = IconName("hierarchy-2") HIERARCHY_3: IconName = IconName("hierarchy-3") HIERARCHY_OFF: IconName = IconName("hierarchy-off") HIGHLIGHT: IconName = IconName("highlight") HIGHLIGHT_OFF: IconName = IconName("highlight-off") HISTORY: IconName = IconName("history") HISTORY_OFF: IconName = IconName("history-off") HISTORY_TOGGLE: IconName = IconName("history-toggle") HOME: IconName = IconName("home") HOME_2: IconName = IconName("home-2") HOME_BOLT: IconName = IconName("home-bolt") HOME_CANCEL: IconName = IconName("home-cancel") HOME_CHECK: IconName = IconName("home-check") HOME_COG: IconName = IconName("home-cog") HOME_DOLLAR: IconName = IconName("home-dollar") HOME_DOT: IconName = IconName("home-dot") HOME_DOWN: IconName = IconName("home-down") HOME_ECO: IconName = IconName("home-eco") HOME_EDIT: IconName = IconName("home-edit") HOME_EXCLAMATION: IconName = IconName("home-exclamation") HOME_HAND: IconName = IconName("home-hand") HOME_HEART: IconName = IconName("home-heart") HOME_INFINITY: IconName = IconName("home-infinity") HOME_LINK: IconName = IconName("home-link") HOME_MINUS: IconName = IconName("home-minus") HOME_MOVE: IconName = IconName("home-move") HOME_OFF: IconName = IconName("home-off") HOME_PLUS: IconName = IconName("home-plus") HOME_QUESTION: IconName = IconName("home-question") HOME_RIBBON: IconName = IconName("home-ribbon") HOME_SEARCH: IconName = IconName("home-search") HOME_SHARE: IconName = IconName("home-share") HOME_SHIELD: IconName = IconName("home-shield") HOME_SIGNAL: IconName = IconName("home-signal") HOME_STAR: IconName = IconName("home-star") HOME_STATS: IconName = IconName("home-stats") HOME_UP: IconName = IconName("home-up") HOME_X: IconName = IconName("home-x") HORSE_TOY: IconName = IconName("horse-toy") HOTEL_SERVICE: IconName = IconName("hotel-service") HOURGLASS: IconName = IconName("hourglass") HOURGLASS_EMPTY: IconName = IconName("hourglass-empty") HOURGLASS_FILLED: IconName = IconName("hourglass-filled") HOURGLASS_HIGH: IconName = IconName("hourglass-high") HOURGLASS_LOW: IconName = IconName("hourglass-low") HOURGLASS_OFF: IconName = IconName("hourglass-off") HTML: IconName = IconName("html") HTTP_CONNECT: IconName = IconName("http-connect") HTTP_DELETE: IconName = IconName("http-delete") HTTP_GET: IconName = IconName("http-get") HTTP_HEAD: IconName = IconName("http-head") HTTP_OPTIONS: IconName = IconName("http-options") HTTP_PATCH: IconName = IconName("http-patch") HTTP_POST: IconName = IconName("http-post") HTTP_PUT: IconName = IconName("http-put") HTTP_QUE: IconName = IconName("http-que") HTTP_TRACE: IconName = IconName("http-trace") ICE_CREAM: IconName = IconName("ice-cream") ICE_CREAM_2: IconName = IconName("ice-cream-2") ICE_CREAM_OFF: IconName = IconName("ice-cream-off") ICE_SKATING: IconName = IconName("ice-skating") ICONS: IconName = IconName("icons") ICONS_OFF: IconName = IconName("icons-off") ID: IconName = IconName("id") ID_BADGE: IconName = IconName("id-badge") ID_BADGE_2: IconName = IconName("id-badge-2") ID_BADGE_OFF: IconName = IconName("id-badge-off") ID_OFF: IconName = IconName("id-off") INBOX: IconName = IconName("inbox") INBOX_OFF: IconName = IconName("inbox-off") INDENT_DECREASE: IconName = IconName("indent-decrease") INDENT_INCREASE: IconName = IconName("indent-increase") INFINITY: IconName = IconName("infinity") INFINITY_OFF: IconName = IconName("infinity-off") INFO_CIRCLE: IconName = IconName("info-circle") INFO_CIRCLE_FILLED: IconName = IconName("info-circle-filled") INFO_HEXAGON: IconName = IconName("info-hexagon") INFO_HEXAGON_FILLED: IconName = IconName("info-hexagon-filled") INFO_OCTAGON: IconName = IconName("info-octagon") INFO_OCTAGON_FILLED: IconName = IconName("info-octagon-filled") INFO_SMALL: IconName = IconName("info-small") INFO_SQUARE: IconName = IconName("info-square") INFO_SQUARE_FILLED: IconName = IconName("info-square-filled") INFO_SQUARE_ROUNDED: IconName = IconName("info-square-rounded") INFO_SQUARE_ROUNDED_FILLED: IconName = IconName("info-square-rounded-filled") INFO_TRIANGLE: IconName = IconName("info-triangle") INFO_TRIANGLE_FILLED: IconName = IconName("info-triangle-filled") INNER_SHADOW_BOTTOM: IconName = IconName("inner-shadow-bottom") INNER_SHADOW_BOTTOM_FILLED: IconName = IconName("inner-shadow-bottom-filled") INNER_SHADOW_BOTTOM_LEFT: IconName = IconName("inner-shadow-bottom-left") INNER_SHADOW_BOTTOM_LEFT_FILLED: IconName = IconName( "inner-shadow-bottom-left-filled" ) INNER_SHADOW_BOTTOM_RIGHT: IconName = IconName("inner-shadow-bottom-right") INNER_SHADOW_BOTTOM_RIGHT_FILLED: IconName = IconName( "inner-shadow-bottom-right-filled" ) INNER_SHADOW_LEFT: IconName = IconName("inner-shadow-left") INNER_SHADOW_LEFT_FILLED: IconName = IconName("inner-shadow-left-filled") INNER_SHADOW_RIGHT: IconName = IconName("inner-shadow-right") INNER_SHADOW_RIGHT_FILLED: IconName = IconName("inner-shadow-right-filled") INNER_SHADOW_TOP: IconName = IconName("inner-shadow-top") INNER_SHADOW_TOP_FILLED: IconName = IconName("inner-shadow-top-filled") INNER_SHADOW_TOP_LEFT: IconName = IconName("inner-shadow-top-left") INNER_SHADOW_TOP_LEFT_FILLED: IconName = IconName("inner-shadow-top-left-filled") INNER_SHADOW_TOP_RIGHT: IconName = IconName("inner-shadow-top-right") INNER_SHADOW_TOP_RIGHT_FILLED: IconName = IconName("inner-shadow-top-right-filled") INPUT_SEARCH: IconName = IconName("input-search") IRONING: IconName = IconName("ironing") IRONING_1: IconName = IconName("ironing-1") IRONING_2: IconName = IconName("ironing-2") IRONING_3: IconName = IconName("ironing-3") IRONING_OFF: IconName = IconName("ironing-off") IRONING_STEAM: IconName = IconName("ironing-steam") IRONING_STEAM_OFF: IconName = IconName("ironing-steam-off") IRREGULAR_POLYHEDRON: IconName = IconName("irregular-polyhedron") IRREGULAR_POLYHEDRON_OFF: IconName = IconName("irregular-polyhedron-off") IRREGULAR_POLYHEDRON_PLUS: IconName = IconName("irregular-polyhedron-plus") ITALIC: IconName = IconName("italic") JACKET: IconName = IconName("jacket") JETPACK: IconName = IconName("jetpack") JEWISH_STAR: IconName = IconName("jewish-star") JEWISH_STAR_FILLED: IconName = IconName("jewish-star-filled") JPG: IconName = IconName("jpg") JSON: IconName = IconName("json") JUMP_ROPE: IconName = IconName("jump-rope") KARATE: IconName = IconName("karate") KAYAK: IconName = IconName("kayak") KERING: IconName = IconName("kering") KEY: IconName = IconName("key") KEY_OFF: IconName = IconName("key-off") KEYBOARD: IconName = IconName("keyboard") KEYBOARD_HIDE: IconName = IconName("keyboard-hide") KEYBOARD_OFF: IconName = IconName("keyboard-off") KEYBOARD_SHOW: IconName = IconName("keyboard-show") KEYFRAME: IconName = IconName("keyframe") KEYFRAME_ALIGN_CENTER: IconName = IconName("keyframe-align-center") KEYFRAME_ALIGN_HORIZONTAL: IconName = IconName("keyframe-align-horizontal") KEYFRAME_ALIGN_VERTICAL: IconName = IconName("keyframe-align-vertical") KEYFRAMES: IconName = IconName("keyframes") LADDER: IconName = IconName("ladder") LADDER_OFF: IconName = IconName("ladder-off") LAMBDA: IconName = IconName("lambda") LAMP: IconName = IconName("lamp") LAMP_2: IconName = IconName("lamp-2") LAMP_OFF: IconName = IconName("lamp-off") LANE: IconName = IconName("lane") LANGUAGE: IconName = IconName("language") LANGUAGE_HIRAGANA: IconName = IconName("language-hiragana") LANGUAGE_KATAKANA: IconName = IconName("language-katakana") LANGUAGE_OFF: IconName = IconName("language-off") LASSO: IconName = IconName("lasso") LASSO_OFF: IconName = IconName("lasso-off") LASSO_POLYGON: IconName = IconName("lasso-polygon") LAYERS_DIFFERENCE: IconName = IconName("layers-difference") LAYERS_INTERSECT: IconName = IconName("layers-intersect") LAYERS_INTERSECT_2: IconName = IconName("layers-intersect-2") LAYERS_LINKED: IconName = IconName("layers-linked") LAYERS_OFF: IconName = IconName("layers-off") LAYERS_SUBTRACT: IconName = IconName("layers-subtract") LAYERS_UNION: IconName = IconName("layers-union") LAYOUT: IconName = IconName("layout") LAYOUT_2: IconName = IconName("layout-2") LAYOUT_ALIGN_BOTTOM: IconName = IconName("layout-align-bottom") LAYOUT_ALIGN_CENTER: IconName = IconName("layout-align-center") LAYOUT_ALIGN_LEFT: IconName = IconName("layout-align-left") LAYOUT_ALIGN_MIDDLE: IconName = IconName("layout-align-middle") LAYOUT_ALIGN_RIGHT: IconName = IconName("layout-align-right") LAYOUT_ALIGN_TOP: IconName = IconName("layout-align-top") LAYOUT_BOARD: IconName = IconName("layout-board") LAYOUT_BOARD_SPLIT: IconName = IconName("layout-board-split") LAYOUT_BOTTOMBAR: IconName = IconName("layout-bottombar") LAYOUT_BOTTOMBAR_COLLAPSE: IconName = IconName("layout-bottombar-collapse") LAYOUT_BOTTOMBAR_EXPAND: IconName = IconName("layout-bottombar-expand") LAYOUT_CARDS: IconName = IconName("layout-cards") LAYOUT_COLLAGE: IconName = IconName("layout-collage") LAYOUT_COLUMNS: IconName = IconName("layout-columns") LAYOUT_DASHBOARD: IconName = IconName("layout-dashboard") LAYOUT_DISTRIBUTE_HORIZONTAL: IconName = IconName("layout-distribute-horizontal") LAYOUT_DISTRIBUTE_VERTICAL: IconName = IconName("layout-distribute-vertical") LAYOUT_GRID: IconName = IconName("layout-grid") LAYOUT_GRID_ADD: IconName = IconName("layout-grid-add") LAYOUT_GRID_REMOVE: IconName = IconName("layout-grid-remove") LAYOUT_KANBAN: IconName = IconName("layout-kanban") LAYOUT_LIST: IconName = IconName("layout-list") LAYOUT_NAVBAR: IconName = IconName("layout-navbar") LAYOUT_NAVBAR_COLLAPSE: IconName = IconName("layout-navbar-collapse") LAYOUT_NAVBAR_EXPAND: IconName = IconName("layout-navbar-expand") LAYOUT_OFF: IconName = IconName("layout-off") LAYOUT_ROWS: IconName = IconName("layout-rows") LAYOUT_SIDEBAR: IconName = IconName("layout-sidebar") LAYOUT_SIDEBAR_LEFT_COLLAPSE: IconName = IconName("layout-sidebar-left-collapse") LAYOUT_SIDEBAR_LEFT_EXPAND: IconName = IconName("layout-sidebar-left-expand") LAYOUT_SIDEBAR_RIGHT: IconName = IconName("layout-sidebar-right") LAYOUT_SIDEBAR_RIGHT_COLLAPSE: IconName = IconName("layout-sidebar-right-collapse") LAYOUT_SIDEBAR_RIGHT_EXPAND: IconName = IconName("layout-sidebar-right-expand") LEAF: IconName = IconName("leaf") LEAF_OFF: IconName = IconName("leaf-off") LEGO: IconName = IconName("lego") LEGO_OFF: IconName = IconName("lego-off") LEMON: IconName = IconName("lemon") LEMON_2: IconName = IconName("lemon-2") LETTER_A: IconName = IconName("letter-a") LETTER_B: IconName = IconName("letter-b") LETTER_C: IconName = IconName("letter-c") LETTER_CASE: IconName = IconName("letter-case") LETTER_CASE_LOWER: IconName = IconName("letter-case-lower") LETTER_CASE_TOGGLE: IconName = IconName("letter-case-toggle") LETTER_CASE_UPPER: IconName = IconName("letter-case-upper") LETTER_D: IconName = IconName("letter-d") LETTER_E: IconName = IconName("letter-e") LETTER_F: IconName = IconName("letter-f") LETTER_G: IconName = IconName("letter-g") LETTER_H: IconName = IconName("letter-h") LETTER_I: IconName = IconName("letter-i") LETTER_J: IconName = IconName("letter-j") LETTER_K: IconName = IconName("letter-k") LETTER_L: IconName = IconName("letter-l") LETTER_M: IconName = IconName("letter-m") LETTER_N: IconName = IconName("letter-n") LETTER_O: IconName = IconName("letter-o") LETTER_P: IconName = IconName("letter-p") LETTER_Q: IconName = IconName("letter-q") LETTER_R: IconName = IconName("letter-r") LETTER_S: IconName = IconName("letter-s") LETTER_SPACING: IconName = IconName("letter-spacing") LETTER_T: IconName = IconName("letter-t") LETTER_U: IconName = IconName("letter-u") LETTER_V: IconName = IconName("letter-v") LETTER_W: IconName = IconName("letter-w") LETTER_X: IconName = IconName("letter-x") LETTER_Y: IconName = IconName("letter-y") LETTER_Z: IconName = IconName("letter-z") LICENSE: IconName = IconName("license") LICENSE_OFF: IconName = IconName("license-off") LIFEBUOY: IconName = IconName("lifebuoy") LIFEBUOY_OFF: IconName = IconName("lifebuoy-off") LIGHTER: IconName = IconName("lighter") LINE: IconName = IconName("line") LINE_DASHED: IconName = IconName("line-dashed") LINE_DOTTED: IconName = IconName("line-dotted") LINE_HEIGHT: IconName = IconName("line-height") LINK: IconName = IconName("link") LINK_OFF: IconName = IconName("link-off") LIST: IconName = IconName("list") LIST_CHECK: IconName = IconName("list-check") LIST_DETAILS: IconName = IconName("list-details") LIST_NUMBERS: IconName = IconName("list-numbers") LIST_SEARCH: IconName = IconName("list-search") LIST_TREE: IconName = IconName("list-tree") LIVE_PHOTO: IconName = IconName("live-photo") LIVE_PHOTO_OFF: IconName = IconName("live-photo-off") LIVE_VIEW: IconName = IconName("live-view") LOAD_BALANCER: IconName = IconName("load-balancer") LOADER: IconName = IconName("loader") LOADER_2: IconName = IconName("loader-2") LOADER_3: IconName = IconName("loader-3") LOADER_QUARTER: IconName = IconName("loader-quarter") LOCATION: IconName = IconName("location") LOCATION_BROKEN: IconName = IconName("location-broken") LOCATION_FILLED: IconName = IconName("location-filled") LOCATION_OFF: IconName = IconName("location-off") LOCK: IconName = IconName("lock") LOCK_ACCESS: IconName = IconName("lock-access") LOCK_ACCESS_OFF: IconName = IconName("lock-access-off") LOCK_BOLT: IconName = IconName("lock-bolt") LOCK_CANCEL: IconName = IconName("lock-cancel") LOCK_CHECK: IconName = IconName("lock-check") LOCK_CODE: IconName = IconName("lock-code") LOCK_COG: IconName = IconName("lock-cog") LOCK_DOLLAR: IconName = IconName("lock-dollar") LOCK_DOWN: IconName = IconName("lock-down") LOCK_EXCLAMATION: IconName = IconName("lock-exclamation") LOCK_HEART: IconName = IconName("lock-heart") LOCK_MINUS: IconName = IconName("lock-minus") LOCK_OFF: IconName = IconName("lock-off") LOCK_OPEN: IconName = IconName("lock-open") LOCK_OPEN_OFF: IconName = IconName("lock-open-off") LOCK_PAUSE: IconName = IconName("lock-pause") LOCK_PIN: IconName = IconName("lock-pin") LOCK_PLUS: IconName = IconName("lock-plus") LOCK_QUESTION: IconName = IconName("lock-question") LOCK_SEARCH: IconName = IconName("lock-search") LOCK_SHARE: IconName = IconName("lock-share") LOCK_SQUARE: IconName = IconName("lock-square") LOCK_SQUARE_ROUNDED: IconName = IconName("lock-square-rounded") LOCK_SQUARE_ROUNDED_FILLED: IconName = IconName("lock-square-rounded-filled") LOCK_STAR: IconName = IconName("lock-star") LOCK_UP: IconName = IconName("lock-up") LOCK_X: IconName = IconName("lock-x") LOGIC_AND: IconName = IconName("logic-and") LOGIC_BUFFER: IconName = IconName("logic-buffer") LOGIC_NAND: IconName = IconName("logic-nand") LOGIC_NOR: IconName = IconName("logic-nor") LOGIC_NOT: IconName = IconName("logic-not") LOGIC_OR: IconName = IconName("logic-or") LOGIC_XNOR: IconName = IconName("logic-xnor") LOGIC_XOR: IconName = IconName("logic-xor") LOGIN: IconName = IconName("login") LOGOUT: IconName = IconName("logout") LOGOUT_2: IconName = IconName("logout-2") LOLLIPOP: IconName = IconName("lollipop") LOLLIPOP_OFF: IconName = IconName("lollipop-off") LUGGAGE: IconName = IconName("luggage") LUGGAGE_OFF: IconName = IconName("luggage-off") LUNGS: IconName = IconName("lungs") LUNGS_OFF: IconName = IconName("lungs-off") MACRO: IconName = IconName("macro") MACRO_OFF: IconName = IconName("macro-off") MAGNET: IconName = IconName("magnet") MAGNET_OFF: IconName = IconName("magnet-off") MAIL: IconName = IconName("mail") MAIL_AI: IconName = IconName("mail-ai") MAIL_BOLT: IconName = IconName("mail-bolt") MAIL_CANCEL: IconName = IconName("mail-cancel") MAIL_CHECK: IconName = IconName("mail-check") MAIL_CODE: IconName = IconName("mail-code") MAIL_COG: IconName = IconName("mail-cog") MAIL_DOLLAR: IconName = IconName("mail-dollar") MAIL_DOWN: IconName = IconName("mail-down") MAIL_EXCLAMATION: IconName = IconName("mail-exclamation") MAIL_FAST: IconName = IconName("mail-fast") MAIL_FILLED: IconName = IconName("mail-filled") MAIL_FORWARD: IconName = IconName("mail-forward") MAIL_HEART: IconName = IconName("mail-heart") MAIL_MINUS: IconName = IconName("mail-minus") MAIL_OFF: IconName = IconName("mail-off") MAIL_OPENED: IconName = IconName("mail-opened") MAIL_OPENED_FILLED: IconName = IconName("mail-opened-filled") MAIL_PAUSE: IconName = IconName("mail-pause") MAIL_PIN: IconName = IconName("mail-pin") MAIL_PLUS: IconName = IconName("mail-plus") MAIL_QUESTION: IconName = IconName("mail-question") MAIL_SEARCH: IconName = IconName("mail-search") MAIL_SHARE: IconName = IconName("mail-share") MAIL_STAR: IconName = IconName("mail-star") MAIL_UP: IconName = IconName("mail-up") MAIL_X: IconName = IconName("mail-x") MAILBOX: IconName = IconName("mailbox") MAILBOX_OFF: IconName = IconName("mailbox-off") MAN: IconName = IconName("man") MANUAL_GEARBOX: IconName = IconName("manual-gearbox") MAP: IconName = IconName("map") MAP_2: IconName = IconName("map-2") MAP_OFF: IconName = IconName("map-off") MAP_PIN: IconName = IconName("map-pin") MAP_PIN_BOLT: IconName = IconName("map-pin-bolt") MAP_PIN_CANCEL: IconName = IconName("map-pin-cancel") MAP_PIN_CHECK: IconName = IconName("map-pin-check") MAP_PIN_CODE: IconName = IconName("map-pin-code") MAP_PIN_COG: IconName = IconName("map-pin-cog") MAP_PIN_DOLLAR: IconName = IconName("map-pin-dollar") MAP_PIN_DOWN: IconName = IconName("map-pin-down") MAP_PIN_EXCLAMATION: IconName = IconName("map-pin-exclamation") MAP_PIN_FILLED: IconName = IconName("map-pin-filled") MAP_PIN_HEART: IconName = IconName("map-pin-heart") MAP_PIN_MINUS: IconName = IconName("map-pin-minus") MAP_PIN_OFF: IconName = IconName("map-pin-off") MAP_PIN_PAUSE: IconName = IconName("map-pin-pause") MAP_PIN_PIN: IconName = IconName("map-pin-pin") MAP_PIN_PLUS: IconName = IconName("map-pin-plus") MAP_PIN_QUESTION: IconName = IconName("map-pin-question") MAP_PIN_SEARCH: IconName = IconName("map-pin-search") MAP_PIN_SHARE: IconName = IconName("map-pin-share") MAP_PIN_STAR: IconName = IconName("map-pin-star") MAP_PIN_UP: IconName = IconName("map-pin-up") MAP_PIN_X: IconName = IconName("map-pin-x") MAP_PINS: IconName = IconName("map-pins") MAP_SEARCH: IconName = IconName("map-search") MARKDOWN: IconName = IconName("markdown") MARKDOWN_OFF: IconName = IconName("markdown-off") MARQUEE: IconName = IconName("marquee") MARQUEE_2: IconName = IconName("marquee-2") MARQUEE_OFF: IconName = IconName("marquee-off") MARS: IconName = IconName("mars") MASK: IconName = IconName("mask") MASK_OFF: IconName = IconName("mask-off") MASKS_THEATER: IconName = IconName("masks-theater") MASKS_THEATER_OFF: IconName = IconName("masks-theater-off") MASSAGE: IconName = IconName("massage") MATCHSTICK: IconName = IconName("matchstick") MATH: IconName = IconName("math") MATH_1_DIVIDE_2: IconName = IconName("math-1-divide-2") MATH_1_DIVIDE_3: IconName = IconName("math-1-divide-3") MATH_AVG: IconName = IconName("math-avg") MATH_EQUAL_GREATER: IconName = IconName("math-equal-greater") MATH_EQUAL_LOWER: IconName = IconName("math-equal-lower") MATH_FUNCTION: IconName = IconName("math-function") MATH_FUNCTION_OFF: IconName = IconName("math-function-off") MATH_FUNCTION_Y: IconName = IconName("math-function-y") MATH_GREATER: IconName = IconName("math-greater") MATH_INTEGRAL: IconName = IconName("math-integral") MATH_INTEGRAL_X: IconName = IconName("math-integral-x") MATH_INTEGRALS: IconName = IconName("math-integrals") MATH_LOWER: IconName = IconName("math-lower") MATH_MAX: IconName = IconName("math-max") MATH_MIN: IconName = IconName("math-min") MATH_NOT: IconName = IconName("math-not") MATH_OFF: IconName = IconName("math-off") MATH_PI: IconName = IconName("math-pi") MATH_PI_DIVIDE_2: IconName = IconName("math-pi-divide-2") MATH_SYMBOLS: IconName = IconName("math-symbols") MATH_X_DIVIDE_2: IconName = IconName("math-x-divide-2") MATH_X_DIVIDE_Y: IconName = IconName("math-x-divide-y") MATH_X_DIVIDE_Y_2: IconName = IconName("math-x-divide-y-2") MATH_X_MINUS_X: IconName = IconName("math-x-minus-x") MATH_X_MINUS_Y: IconName = IconName("math-x-minus-y") MATH_X_PLUS_X: IconName = IconName("math-x-plus-x") MATH_X_PLUS_Y: IconName = IconName("math-x-plus-y") MATH_XY: IconName = IconName("math-xy") MATH_Y_MINUS_Y: IconName = IconName("math-y-minus-y") MATH_Y_PLUS_Y: IconName = IconName("math-y-plus-y") MAXIMIZE: IconName = IconName("maximize") MAXIMIZE_OFF: IconName = IconName("maximize-off") MEAT: IconName = IconName("meat") MEAT_OFF: IconName = IconName("meat-off") MEDAL: IconName = IconName("medal") MEDAL_2: IconName = IconName("medal-2") MEDICAL_CROSS: IconName = IconName("medical-cross") MEDICAL_CROSS_CIRCLE: IconName = IconName("medical-cross-circle") MEDICAL_CROSS_FILLED: IconName = IconName("medical-cross-filled") MEDICAL_CROSS_OFF: IconName = IconName("medical-cross-off") MEDICINE_SYRUP: IconName = IconName("medicine-syrup") MEEPLE: IconName = IconName("meeple") MENORAH: IconName = IconName("menorah") MENU: IconName = IconName("menu") MENU_2: IconName = IconName("menu-2") MENU_DEEP: IconName = IconName("menu-deep") MENU_ORDER: IconName = IconName("menu-order") MESSAGE: IconName = IconName("message") MESSAGE_2: IconName = IconName("message-2") MESSAGE_2_BOLT: IconName = IconName("message-2-bolt") MESSAGE_2_CANCEL: IconName = IconName("message-2-cancel") MESSAGE_2_CHECK: IconName = IconName("message-2-check") MESSAGE_2_CODE: IconName = IconName("message-2-code") MESSAGE_2_COG: IconName = IconName("message-2-cog") MESSAGE_2_DOLLAR: IconName = IconName("message-2-dollar") MESSAGE_2_DOWN: IconName = IconName("message-2-down") MESSAGE_2_EXCLAMATION: IconName = IconName("message-2-exclamation") MESSAGE_2_HEART: IconName = IconName("message-2-heart") MESSAGE_2_MINUS: IconName = IconName("message-2-minus") MESSAGE_2_OFF: IconName = IconName("message-2-off") MESSAGE_2_PAUSE: IconName = IconName("message-2-pause") MESSAGE_2_PIN: IconName = IconName("message-2-pin") MESSAGE_2_PLUS: IconName = IconName("message-2-plus") MESSAGE_2_QUESTION: IconName = IconName("message-2-question") MESSAGE_2_SEARCH: IconName = IconName("message-2-search") MESSAGE_2_SHARE: IconName = IconName("message-2-share") MESSAGE_2_STAR: IconName = IconName("message-2-star") MESSAGE_2_UP: IconName = IconName("message-2-up") MESSAGE_2_X: IconName = IconName("message-2-x") MESSAGE_BOLT: IconName = IconName("message-bolt") MESSAGE_CANCEL: IconName = IconName("message-cancel") MESSAGE_CHATBOT: IconName = IconName("message-chatbot") MESSAGE_CHECK: IconName = IconName("message-check") MESSAGE_CIRCLE: IconName = IconName("message-circle") MESSAGE_CIRCLE_2: IconName = IconName("message-circle-2") MESSAGE_CIRCLE_2_FILLED: IconName = IconName("message-circle-2-filled") MESSAGE_CIRCLE_BOLT: IconName = IconName("message-circle-bolt") MESSAGE_CIRCLE_CANCEL: IconName = IconName("message-circle-cancel") MESSAGE_CIRCLE_CHECK: IconName = IconName("message-circle-check") MESSAGE_CIRCLE_CODE: IconName = IconName("message-circle-code") MESSAGE_CIRCLE_COG: IconName = IconName("message-circle-cog") MESSAGE_CIRCLE_DOLLAR: IconName = IconName("message-circle-dollar") MESSAGE_CIRCLE_DOWN: IconName = IconName("message-circle-down") MESSAGE_CIRCLE_EXCLAMATION: IconName = IconName("message-circle-exclamation") MESSAGE_CIRCLE_HEART: IconName = IconName("message-circle-heart") MESSAGE_CIRCLE_MINUS: IconName = IconName("message-circle-minus") MESSAGE_CIRCLE_OFF: IconName = IconName("message-circle-off") MESSAGE_CIRCLE_PAUSE: IconName = IconName("message-circle-pause") MESSAGE_CIRCLE_PIN: IconName = IconName("message-circle-pin") MESSAGE_CIRCLE_PLUS: IconName = IconName("message-circle-plus") MESSAGE_CIRCLE_QUESTION: IconName = IconName("message-circle-question") MESSAGE_CIRCLE_SEARCH: IconName = IconName("message-circle-search") MESSAGE_CIRCLE_SHARE: IconName = IconName("message-circle-share") MESSAGE_CIRCLE_STAR: IconName = IconName("message-circle-star") MESSAGE_CIRCLE_UP: IconName = IconName("message-circle-up") MESSAGE_CIRCLE_X: IconName = IconName("message-circle-x") MESSAGE_CODE: IconName = IconName("message-code") MESSAGE_COG: IconName = IconName("message-cog") MESSAGE_DOLLAR: IconName = IconName("message-dollar") MESSAGE_DOTS: IconName = IconName("message-dots") MESSAGE_DOWN: IconName = IconName("message-down") MESSAGE_EXCLAMATION: IconName = IconName("message-exclamation") MESSAGE_FORWARD: IconName = IconName("message-forward") MESSAGE_HEART: IconName = IconName("message-heart") MESSAGE_LANGUAGE: IconName = IconName("message-language") MESSAGE_MINUS: IconName = IconName("message-minus") MESSAGE_OFF: IconName = IconName("message-off") MESSAGE_PAUSE: IconName = IconName("message-pause") MESSAGE_PIN: IconName = IconName("message-pin") MESSAGE_PLUS: IconName = IconName("message-plus") MESSAGE_QUESTION: IconName = IconName("message-question") MESSAGE_REPORT: IconName = IconName("message-report") MESSAGE_SEARCH: IconName = IconName("message-search") MESSAGE_SHARE: IconName = IconName("message-share") MESSAGE_STAR: IconName = IconName("message-star") MESSAGE_UP: IconName = IconName("message-up") MESSAGE_X: IconName = IconName("message-x") MESSAGES: IconName = IconName("messages") MESSAGES_OFF: IconName = IconName("messages-off") METEOR: IconName = IconName("meteor") METEOR_OFF: IconName = IconName("meteor-off") MICHELIN_BIB_GOURMAND: IconName = IconName("michelin-bib-gourmand") MICHELIN_STAR: IconName = IconName("michelin-star") MICHELIN_STAR_GREEN: IconName = IconName("michelin-star-green") MICKEY: IconName = IconName("mickey") MICKEY_FILLED: IconName = IconName("mickey-filled") MICROPHONE: IconName = IconName("microphone") MICROPHONE_2: IconName = IconName("microphone-2") MICROPHONE_2_OFF: IconName = IconName("microphone-2-off") MICROPHONE_OFF: IconName = IconName("microphone-off") MICROSCOPE: IconName = IconName("microscope") MICROSCOPE_OFF: IconName = IconName("microscope-off") MICROWAVE: IconName = IconName("microwave") MICROWAVE_OFF: IconName = IconName("microwave-off") MILITARY_AWARD: IconName = IconName("military-award") MILITARY_RANK: IconName = IconName("military-rank") MILK: IconName = IconName("milk") MILK_OFF: IconName = IconName("milk-off") MILKSHAKE: IconName = IconName("milkshake") MINIMIZE: IconName = IconName("minimize") MINUS: IconName = IconName("minus") MINUS_VERTICAL: IconName = IconName("minus-vertical") MIST: IconName = IconName("mist") MIST_OFF: IconName = IconName("mist-off") MOBILEDATA: IconName = IconName("mobiledata") MOBILEDATA_OFF: IconName = IconName("mobiledata-off") MONEYBAG: IconName = IconName("moneybag") MOOD_ANGRY: IconName = IconName("mood-angry") MOOD_ANNOYED: IconName = IconName("mood-annoyed") MOOD_ANNOYED_2: IconName = IconName("mood-annoyed-2") MOOD_BOY: IconName = IconName("mood-boy") MOOD_CHECK: IconName = IconName("mood-check") MOOD_COG: IconName = IconName("mood-cog") MOOD_CONFUZED: IconName = IconName("mood-confuzed") MOOD_CONFUZED_FILLED: IconName = IconName("mood-confuzed-filled") MOOD_CRAZY_HAPPY: IconName = IconName("mood-crazy-happy") MOOD_CRY: IconName = IconName("mood-cry") MOOD_DOLLAR: IconName = IconName("mood-dollar") MOOD_EDIT: IconName = IconName("mood-edit") MOOD_EMPTY: IconName = IconName("mood-empty") MOOD_EMPTY_FILLED: IconName = IconName("mood-empty-filled") MOOD_HAPPY: IconName = IconName("mood-happy") MOOD_HAPPY_FILLED: IconName = IconName("mood-happy-filled") MOOD_HEART: IconName = IconName("mood-heart") MOOD_KID: IconName = IconName("mood-kid") MOOD_KID_FILLED: IconName = IconName("mood-kid-filled") MOOD_LOOK_LEFT: IconName = IconName("mood-look-left") MOOD_LOOK_RIGHT: IconName = IconName("mood-look-right") MOOD_MINUS: IconName = IconName("mood-minus") MOOD_NERD: IconName = IconName("mood-nerd") MOOD_NERVOUS: IconName = IconName("mood-nervous") MOOD_NEUTRAL: IconName = IconName("mood-neutral") MOOD_NEUTRAL_FILLED: IconName = IconName("mood-neutral-filled") MOOD_OFF: IconName = IconName("mood-off") MOOD_PIN: IconName = IconName("mood-pin") MOOD_PLUS: IconName = IconName("mood-plus") MOOD_SAD: IconName = IconName("mood-sad") MOOD_SAD_2: IconName = IconName("mood-sad-2") MOOD_SAD_DIZZY: IconName = IconName("mood-sad-dizzy") MOOD_SAD_FILLED: IconName = IconName("mood-sad-filled") MOOD_SAD_SQUINT: IconName = IconName("mood-sad-squint") MOOD_SEARCH: IconName = IconName("mood-search") MOOD_SHARE: IconName = IconName("mood-share") MOOD_SICK: IconName = IconName("mood-sick") MOOD_SILENCE: IconName = IconName("mood-silence") MOOD_SING: IconName = IconName("mood-sing") MOOD_SMILE: IconName = IconName("mood-smile") MOOD_SMILE_BEAM: IconName = IconName("mood-smile-beam") MOOD_SMILE_DIZZY: IconName = IconName("mood-smile-dizzy") MOOD_SMILE_FILLED: IconName = IconName("mood-smile-filled") MOOD_SUPRISED: IconName = IconName("mood-suprised") MOOD_TONGUE: IconName = IconName("mood-tongue") MOOD_TONGUE_WINK: IconName = IconName("mood-tongue-wink") MOOD_TONGUE_WINK_2: IconName = IconName("mood-tongue-wink-2") MOOD_UNAMUSED: IconName = IconName("mood-unamused") MOOD_UP: IconName = IconName("mood-up") MOOD_WINK: IconName = IconName("mood-wink") MOOD_WINK_2: IconName = IconName("mood-wink-2") MOOD_WRRR: IconName = IconName("mood-wrrr") MOOD_X: IconName = IconName("mood-x") MOOD_XD: IconName = IconName("mood-xd") MOON: IconName = IconName("moon") MOON_2: IconName = IconName("moon-2") MOON_FILLED: IconName = IconName("moon-filled") MOON_OFF: IconName = IconName("moon-off") MOON_STARS: IconName = IconName("moon-stars") MOPED: IconName = IconName("moped") MOTORBIKE: IconName = IconName("motorbike") MOUNTAIN: IconName = IconName("mountain") MOUNTAIN_OFF: IconName = IconName("mountain-off") MOUSE: IconName = IconName("mouse") MOUSE_2: IconName = IconName("mouse-2") MOUSE_OFF: IconName = IconName("mouse-off") MOUSTACHE: IconName = IconName("moustache") MOVIE: IconName = IconName("movie") MOVIE_OFF: IconName = IconName("movie-off") MUG: IconName = IconName("mug") MUG_OFF: IconName = IconName("mug-off") MULTIPLIER_0_5X: IconName = IconName("multiplier-0-5x") MULTIPLIER_1_5X: IconName = IconName("multiplier-1-5x") MULTIPLIER_1X: IconName = IconName("multiplier-1x") MULTIPLIER_2X: IconName = IconName("multiplier-2x") MUSHROOM: IconName = IconName("mushroom") MUSHROOM_FILLED: IconName = IconName("mushroom-filled") MUSHROOM_OFF: IconName = IconName("mushroom-off") MUSIC: IconName = IconName("music") MUSIC_OFF: IconName = IconName("music-off") NAVIGATION: IconName = IconName("navigation") NAVIGATION_FILLED: IconName = IconName("navigation-filled") NAVIGATION_NORTH: IconName = IconName("navigation-north") NAVIGATION_OFF: IconName = IconName("navigation-off") NEEDLE: IconName = IconName("needle") NEEDLE_THREAD: IconName = IconName("needle-thread") NETWORK: IconName = IconName("network") NETWORK_OFF: IconName = IconName("network-off") NEW_SECTION: IconName = IconName("new-section") NEWS: IconName = IconName("news") NEWS_OFF: IconName = IconName("news-off") NFC: IconName = IconName("nfc") NFC_OFF: IconName = IconName("nfc-off") NO_COPYRIGHT: IconName = IconName("no-copyright") NO_CREATIVE_COMMONS: IconName = IconName("no-creative-commons") NO_DERIVATIVES: IconName = IconName("no-derivatives") NORTH_STAR: IconName = IconName("north-star") NOTE: IconName = IconName("note") NOTE_OFF: IconName = IconName("note-off") NOTEBOOK: IconName = IconName("notebook") NOTEBOOK_OFF: IconName = IconName("notebook-off") NOTES: IconName = IconName("notes") NOTES_OFF: IconName = IconName("notes-off") NOTIFICATION: IconName = IconName("notification") NOTIFICATION_OFF: IconName = IconName("notification-off") NUMBER: IconName = IconName("number") NUMBER_0: IconName = IconName("number-0") NUMBER_1: IconName = IconName("number-1") NUMBER_2: IconName = IconName("number-2") NUMBER_3: IconName = IconName("number-3") NUMBER_4: IconName = IconName("number-4") NUMBER_5: IconName = IconName("number-5") NUMBER_6: IconName = IconName("number-6") NUMBER_7: IconName = IconName("number-7") NUMBER_8: IconName = IconName("number-8") NUMBER_9: IconName = IconName("number-9") NUMBERS: IconName = IconName("numbers") NURSE: IconName = IconName("nurse") OCTAGON: IconName = IconName("octagon") OCTAGON_FILLED: IconName = IconName("octagon-filled") OCTAGON_OFF: IconName = IconName("octagon-off") OCTAHEDRON: IconName = IconName("octahedron") OCTAHEDRON_OFF: IconName = IconName("octahedron-off") OCTAHEDRON_PLUS: IconName = IconName("octahedron-plus") OLD: IconName = IconName("old") OLYMPICS: IconName = IconName("olympics") OLYMPICS_OFF: IconName = IconName("olympics-off") OM: IconName = IconName("om") OMEGA: IconName = IconName("omega") OUTBOUND: IconName = IconName("outbound") OUTLET: IconName = IconName("outlet") OVAL: IconName = IconName("oval") OVAL_FILLED: IconName = IconName("oval-filled") OVAL_VERTICAL: IconName = IconName("oval-vertical") OVAL_VERTICAL_FILLED: IconName = IconName("oval-vertical-filled") OVERLINE: IconName = IconName("overline") PACKAGE: IconName = IconName("package") PACKAGE_EXPORT: IconName = IconName("package-export") PACKAGE_IMPORT: IconName = IconName("package-import") PACKAGE_OFF: IconName = IconName("package-off") PACKAGES: IconName = IconName("packages") PACMAN: IconName = IconName("pacman") PAGE_BREAK: IconName = IconName("page-break") PAINT: IconName = IconName("paint") PAINT_FILLED: IconName = IconName("paint-filled") PAINT_OFF: IconName = IconName("paint-off") PALETTE: IconName = IconName("palette") PALETTE_OFF: IconName = IconName("palette-off") PANORAMA_HORIZONTAL: IconName = IconName("panorama-horizontal") PANORAMA_HORIZONTAL_OFF: IconName = IconName("panorama-horizontal-off") PANORAMA_VERTICAL: IconName = IconName("panorama-vertical") PANORAMA_VERTICAL_OFF: IconName = IconName("panorama-vertical-off") PAPER_BAG: IconName = IconName("paper-bag") PAPER_BAG_OFF: IconName = IconName("paper-bag-off") PAPERCLIP: IconName = IconName("paperclip") PARACHUTE: IconName = IconName("parachute") PARACHUTE_OFF: IconName = IconName("parachute-off") PARENTHESES: IconName = IconName("parentheses") PARENTHESES_OFF: IconName = IconName("parentheses-off") PARKING: IconName = IconName("parking") PARKING_OFF: IconName = IconName("parking-off") PASSWORD: IconName = IconName("password") PAW: IconName = IconName("paw") PAW_FILLED: IconName = IconName("paw-filled") PAW_OFF: IconName = IconName("paw-off") PDF: IconName = IconName("pdf") PEACE: IconName = IconName("peace") PENCIL: IconName = IconName("pencil") PENCIL_MINUS: IconName = IconName("pencil-minus") PENCIL_OFF: IconName = IconName("pencil-off") PENCIL_PLUS: IconName = IconName("pencil-plus") PENNANT: IconName = IconName("pennant") PENNANT_2: IconName = IconName("pennant-2") PENNANT_2_FILLED: IconName = IconName("pennant-2-filled") PENNANT_FILLED: IconName = IconName("pennant-filled") PENNANT_OFF: IconName = IconName("pennant-off") PENTAGON: IconName = IconName("pentagon") PENTAGON_FILLED: IconName = IconName("pentagon-filled") PENTAGON_OFF: IconName = IconName("pentagon-off") PENTAGRAM: IconName = IconName("pentagram") PEPPER: IconName = IconName("pepper") PEPPER_OFF: IconName = IconName("pepper-off") PERCENTAGE: IconName = IconName("percentage") PERFUME: IconName = IconName("perfume") PERSPECTIVE: IconName = IconName("perspective") PERSPECTIVE_OFF: IconName = IconName("perspective-off") PHONE: IconName = IconName("phone") PHONE_CALL: IconName = IconName("phone-call") PHONE_CALLING: IconName = IconName("phone-calling") PHONE_CHECK: IconName = IconName("phone-check") PHONE_FILLED: IconName = IconName("phone-filled") PHONE_INCOMING: IconName = IconName("phone-incoming") PHONE_OFF: IconName = IconName("phone-off") PHONE_OUTGOING: IconName = IconName("phone-outgoing") PHONE_PAUSE: IconName = IconName("phone-pause") PHONE_PLUS: IconName = IconName("phone-plus") PHONE_X: IconName = IconName("phone-x") PHOTO: IconName = IconName("photo") PHOTO_AI: IconName = IconName("photo-ai") PHOTO_BOLT: IconName = IconName("photo-bolt") PHOTO_CANCEL: IconName = IconName("photo-cancel") PHOTO_CHECK: IconName = IconName("photo-check") PHOTO_CODE: IconName = IconName("photo-code") PHOTO_COG: IconName = IconName("photo-cog") PHOTO_DOLLAR: IconName = IconName("photo-dollar") PHOTO_DOWN: IconName = IconName("photo-down") PHOTO_EDIT: IconName = IconName("photo-edit") PHOTO_EXCLAMATION: IconName = IconName("photo-exclamation") PHOTO_FILLED: IconName = IconName("photo-filled") PHOTO_HEART: IconName = IconName("photo-heart") PHOTO_MINUS: IconName = IconName("photo-minus") PHOTO_OFF: IconName = IconName("photo-off") PHOTO_PAUSE: IconName = IconName("photo-pause") PHOTO_PIN: IconName = IconName("photo-pin") PHOTO_PLUS: IconName = IconName("photo-plus") PHOTO_QUESTION: IconName = IconName("photo-question") PHOTO_SEARCH: IconName = IconName("photo-search") PHOTO_SENSOR: IconName = IconName("photo-sensor") PHOTO_SENSOR_2: IconName = IconName("photo-sensor-2") PHOTO_SENSOR_3: IconName = IconName("photo-sensor-3") PHOTO_SHARE: IconName = IconName("photo-share") PHOTO_SHIELD: IconName = IconName("photo-shield") PHOTO_STAR: IconName = IconName("photo-star") PHOTO_UP: IconName = IconName("photo-up") PHOTO_X: IconName = IconName("photo-x") PHYSOTHERAPIST: IconName = IconName("physotherapist") PIANO: IconName = IconName("piano") PICK: IconName = IconName("pick") PICTURE_IN_PICTURE: IconName = IconName("picture-in-picture") PICTURE_IN_PICTURE_OFF: IconName = IconName("picture-in-picture-off") PICTURE_IN_PICTURE_ON: IconName = IconName("picture-in-picture-on") PICTURE_IN_PICTURE_TOP: IconName = IconName("picture-in-picture-top") PIG: IconName = IconName("pig") PIG_MONEY: IconName = IconName("pig-money") PIG_OFF: IconName = IconName("pig-off") PILCROW: IconName = IconName("pilcrow") PILL: IconName = IconName("pill") PILL_OFF: IconName = IconName("pill-off") PILLS: IconName = IconName("pills") PIN: IconName = IconName("pin") PIN_FILLED: IconName = IconName("pin-filled") PING_PONG: IconName = IconName("ping-pong") PINNED: IconName = IconName("pinned") PINNED_FILLED: IconName = IconName("pinned-filled") PINNED_OFF: IconName = IconName("pinned-off") PIZZA: IconName = IconName("pizza") PIZZA_OFF: IconName = IconName("pizza-off") PLACEHOLDER: IconName = IconName("placeholder") PLANE: IconName = IconName("plane") PLANE_ARRIVAL: IconName = IconName("plane-arrival") PLANE_DEPARTURE: IconName = IconName("plane-departure") PLANE_INFLIGHT: IconName = IconName("plane-inflight") PLANE_OFF: IconName = IconName("plane-off") PLANE_TILT: IconName = IconName("plane-tilt") PLANET: IconName = IconName("planet") PLANET_OFF: IconName = IconName("planet-off") PLANT: IconName = IconName("plant") PLANT_2: IconName = IconName("plant-2") PLANT_2_OFF: IconName = IconName("plant-2-off") PLANT_OFF: IconName = IconName("plant-off") PLAY_BASKETBALL: IconName = IconName("play-basketball") PLAY_CARD: IconName = IconName("play-card") PLAY_CARD_OFF: IconName = IconName("play-card-off") PLAY_FOOTBALL: IconName = IconName("play-football") PLAY_HANDBALL: IconName = IconName("play-handball") PLAY_VOLLEYBALL: IconName = IconName("play-volleyball") PLAYER_EJECT: IconName = IconName("player-eject") PLAYER_EJECT_FILLED: IconName = IconName("player-eject-filled") PLAYER_PAUSE: IconName = IconName("player-pause") PLAYER_PAUSE_FILLED: IconName = IconName("player-pause-filled") PLAYER_PLAY: IconName = IconName("player-play") PLAYER_PLAY_FILLED: IconName = IconName("player-play-filled") PLAYER_RECORD: IconName = IconName("player-record") PLAYER_RECORD_FILLED: IconName = IconName("player-record-filled") PLAYER_SKIP_BACK: IconName = IconName("player-skip-back") PLAYER_SKIP_BACK_FILLED: IconName = IconName("player-skip-back-filled") PLAYER_SKIP_FORWARD: IconName = IconName("player-skip-forward") PLAYER_SKIP_FORWARD_FILLED: IconName = IconName("player-skip-forward-filled") PLAYER_STOP: IconName = IconName("player-stop") PLAYER_STOP_FILLED: IconName = IconName("player-stop-filled") PLAYER_TRACK_NEXT: IconName = IconName("player-track-next") PLAYER_TRACK_NEXT_FILLED: IconName = IconName("player-track-next-filled") PLAYER_TRACK_PREV: IconName = IconName("player-track-prev") PLAYER_TRACK_PREV_FILLED: IconName = IconName("player-track-prev-filled") PLAYLIST: IconName = IconName("playlist") PLAYLIST_ADD: IconName = IconName("playlist-add") PLAYLIST_OFF: IconName = IconName("playlist-off") PLAYLIST_X: IconName = IconName("playlist-x") PLAYSTATION_CIRCLE: IconName = IconName("playstation-circle") PLAYSTATION_SQUARE: IconName = IconName("playstation-square") PLAYSTATION_TRIANGLE: IconName = IconName("playstation-triangle") PLAYSTATION_X: IconName = IconName("playstation-x") PLUG: IconName = IconName("plug") PLUG_CONNECTED: IconName = IconName("plug-connected") PLUG_CONNECTED_X: IconName = IconName("plug-connected-x") PLUG_OFF: IconName = IconName("plug-off") PLUG_X: IconName = IconName("plug-x") PLUS: IconName = IconName("plus") PLUS_EQUAL: IconName = IconName("plus-equal") PLUS_MINUS: IconName = IconName("plus-minus") PNG: IconName = IconName("png") PODIUM: IconName = IconName("podium") PODIUM_OFF: IconName = IconName("podium-off") POINT: IconName = IconName("point") POINT_FILLED: IconName = IconName("point-filled") POINT_OFF: IconName = IconName("point-off") POINTER: IconName = IconName("pointer") POINTER_BOLT: IconName = IconName("pointer-bolt") POINTER_CANCEL: IconName = IconName("pointer-cancel") POINTER_CHECK: IconName = IconName("pointer-check") POINTER_CODE: IconName = IconName("pointer-code") POINTER_COG: IconName = IconName("pointer-cog") POINTER_DOLLAR: IconName = IconName("pointer-dollar") POINTER_DOWN: IconName = IconName("pointer-down") POINTER_EXCLAMATION: IconName = IconName("pointer-exclamation") POINTER_HEART: IconName = IconName("pointer-heart") POINTER_MINUS: IconName = IconName("pointer-minus") POINTER_OFF: IconName = IconName("pointer-off") POINTER_PAUSE: IconName = IconName("pointer-pause") POINTER_PIN: IconName = IconName("pointer-pin") POINTER_PLUS: IconName = IconName("pointer-plus") POINTER_QUESTION: IconName = IconName("pointer-question") POINTER_SEARCH: IconName = IconName("pointer-search") POINTER_SHARE: IconName = IconName("pointer-share") POINTER_STAR: IconName = IconName("pointer-star") POINTER_UP: IconName = IconName("pointer-up") POINTER_X: IconName = IconName("pointer-x") POKEBALL: IconName = IconName("pokeball") POKEBALL_OFF: IconName = IconName("pokeball-off") POKER_CHIP: IconName = IconName("poker-chip") POLAROID: IconName = IconName("polaroid") POLAROID_FILLED: IconName = IconName("polaroid-filled") POLYGON: IconName = IconName("polygon") POLYGON_OFF: IconName = IconName("polygon-off") POO: IconName = IconName("poo") POOL: IconName = IconName("pool") POOL_OFF: IconName = IconName("pool-off") POWER: IconName = IconName("power") PRAY: IconName = IconName("pray") PREMIUM_RIGHTS: IconName = IconName("premium-rights") PRESCRIPTION: IconName = IconName("prescription") PRESENTATION: IconName = IconName("presentation") PRESENTATION_ANALYTICS: IconName = IconName("presentation-analytics") PRESENTATION_OFF: IconName = IconName("presentation-off") PRINTER: IconName = IconName("printer") PRINTER_OFF: IconName = IconName("printer-off") PRISM: IconName = IconName("prism") PRISM_OFF: IconName = IconName("prism-off") PRISM_PLUS: IconName = IconName("prism-plus") PRISON: IconName = IconName("prison") PROGRESS: IconName = IconName("progress") PROGRESS_ALERT: IconName = IconName("progress-alert") PROGRESS_BOLT: IconName = IconName("progress-bolt") PROGRESS_CHECK: IconName = IconName("progress-check") PROGRESS_DOWN: IconName = IconName("progress-down") PROGRESS_HELP: IconName = IconName("progress-help") PROGRESS_X: IconName = IconName("progress-x") PROMPT: IconName = IconName("prompt") PROPELLER: IconName = IconName("propeller") PROPELLER_OFF: IconName = IconName("propeller-off") PUMPKIN_SCARY: IconName = IconName("pumpkin-scary") PUZZLE: IconName = IconName("puzzle") PUZZLE_2: IconName = IconName("puzzle-2") PUZZLE_FILLED: IconName = IconName("puzzle-filled") PUZZLE_OFF: IconName = IconName("puzzle-off") PYRAMID: IconName = IconName("pyramid") PYRAMID_OFF: IconName = IconName("pyramid-off") PYRAMID_PLUS: IconName = IconName("pyramid-plus") QRCODE: IconName = IconName("qrcode") QRCODE_OFF: IconName = IconName("qrcode-off") QUESTION_MARK: IconName = IconName("question-mark") QUOTE: IconName = IconName("quote") QUOTE_OFF: IconName = IconName("quote-off") RADAR: IconName = IconName("radar") RADAR_2: IconName = IconName("radar-2") RADAR_OFF: IconName = IconName("radar-off") RADIO: IconName = IconName("radio") RADIO_OFF: IconName = IconName("radio-off") RADIOACTIVE: IconName = IconName("radioactive") RADIOACTIVE_FILLED: IconName = IconName("radioactive-filled") RADIOACTIVE_OFF: IconName = IconName("radioactive-off") RADIUS_BOTTOM_LEFT: IconName = IconName("radius-bottom-left") RADIUS_BOTTOM_RIGHT: IconName = IconName("radius-bottom-right") RADIUS_TOP_LEFT: IconName = IconName("radius-top-left") RADIUS_TOP_RIGHT: IconName = IconName("radius-top-right") RAINBOW: IconName = IconName("rainbow") RAINBOW_OFF: IconName = IconName("rainbow-off") RATING_12_PLUS: IconName = IconName("rating-12-plus") RATING_14_PLUS: IconName = IconName("rating-14-plus") RATING_16_PLUS: IconName = IconName("rating-16-plus") RATING_18_PLUS: IconName = IconName("rating-18-plus") RATING_21_PLUS: IconName = IconName("rating-21-plus") RAZOR: IconName = IconName("razor") RAZOR_ELECTRIC: IconName = IconName("razor-electric") RECEIPT: IconName = IconName("receipt") RECEIPT_2: IconName = IconName("receipt-2") RECEIPT_OFF: IconName = IconName("receipt-off") RECEIPT_REFUND: IconName = IconName("receipt-refund") RECEIPT_TAX: IconName = IconName("receipt-tax") RECHARGING: IconName = IconName("recharging") RECORD_MAIL: IconName = IconName("record-mail") RECORD_MAIL_OFF: IconName = IconName("record-mail-off") RECTANGLE: IconName = IconName("rectangle") RECTANGLE_FILLED: IconName = IconName("rectangle-filled") RECTANGLE_ROUNDED_BOTTOM: IconName = IconName("rectangle-rounded-bottom") RECTANGLE_ROUNDED_TOP: IconName = IconName("rectangle-rounded-top") RECTANGLE_VERTICAL: IconName = IconName("rectangle-vertical") RECTANGLE_VERTICAL_FILLED: IconName = IconName("rectangle-vertical-filled") RECTANGULAR_PRISM: IconName = IconName("rectangular-prism") RECTANGULAR_PRISM_OFF: IconName = IconName("rectangular-prism-off") RECTANGULAR_PRISM_PLUS: IconName = IconName("rectangular-prism-plus") RECYCLE: IconName = IconName("recycle") RECYCLE_OFF: IconName = IconName("recycle-off") REFRESH: IconName = IconName("refresh") REFRESH_ALERT: IconName = IconName("refresh-alert") REFRESH_DOT: IconName = IconName("refresh-dot") REFRESH_OFF: IconName = IconName("refresh-off") REGEX: IconName = IconName("regex") REGEX_OFF: IconName = IconName("regex-off") REGISTERED: IconName = IconName("registered") RELATION_MANY_TO_MANY: IconName = IconName("relation-many-to-many") RELATION_ONE_TO_MANY: IconName = IconName("relation-one-to-many") RELATION_ONE_TO_ONE: IconName = IconName("relation-one-to-one") RELOAD: IconName = IconName("reload") REPEAT: IconName = IconName("repeat") REPEAT_OFF: IconName = IconName("repeat-off") REPEAT_ONCE: IconName = IconName("repeat-once") REPLACE: IconName = IconName("replace") REPLACE_FILLED: IconName = IconName("replace-filled") REPLACE_OFF: IconName = IconName("replace-off") REPORT: IconName = IconName("report") REPORT_ANALYTICS: IconName = IconName("report-analytics") REPORT_MEDICAL: IconName = IconName("report-medical") REPORT_MONEY: IconName = IconName("report-money") REPORT_OFF: IconName = IconName("report-off") REPORT_SEARCH: IconName = IconName("report-search") RESERVED_LINE: IconName = IconName("reserved-line") RESIZE: IconName = IconName("resize") RESTORE: IconName = IconName("restore") REWIND_BACKWARD_10: IconName = IconName("rewind-backward-10") REWIND_BACKWARD_15: IconName = IconName("rewind-backward-15") REWIND_BACKWARD_20: IconName = IconName("rewind-backward-20") REWIND_BACKWARD_30: IconName = IconName("rewind-backward-30") REWIND_BACKWARD_40: IconName = IconName("rewind-backward-40") REWIND_BACKWARD_5: IconName = IconName("rewind-backward-5") REWIND_BACKWARD_50: IconName = IconName("rewind-backward-50") REWIND_BACKWARD_60: IconName = IconName("rewind-backward-60") REWIND_FORWARD_10: IconName = IconName("rewind-forward-10") REWIND_FORWARD_15: IconName = IconName("rewind-forward-15") REWIND_FORWARD_20: IconName = IconName("rewind-forward-20") REWIND_FORWARD_30: IconName = IconName("rewind-forward-30") REWIND_FORWARD_40: IconName = IconName("rewind-forward-40") REWIND_FORWARD_5: IconName = IconName("rewind-forward-5") REWIND_FORWARD_50: IconName = IconName("rewind-forward-50") REWIND_FORWARD_60: IconName = IconName("rewind-forward-60") RIBBON_HEALTH: IconName = IconName("ribbon-health") RINGS: IconName = IconName("rings") RIPPLE: IconName = IconName("ripple") RIPPLE_OFF: IconName = IconName("ripple-off") ROAD: IconName = IconName("road") ROAD_OFF: IconName = IconName("road-off") ROAD_SIGN: IconName = IconName("road-sign") ROBOT: IconName = IconName("robot") ROBOT_OFF: IconName = IconName("robot-off") ROCKET: IconName = IconName("rocket") ROCKET_OFF: IconName = IconName("rocket-off") ROLLER_SKATING: IconName = IconName("roller-skating") ROLLERCOASTER: IconName = IconName("rollercoaster") ROLLERCOASTER_OFF: IconName = IconName("rollercoaster-off") ROSETTE: IconName = IconName("rosette") ROSETTE_FILLED: IconName = IconName("rosette-filled") ROSETTE_NUMBER_0: IconName = IconName("rosette-number-0") ROSETTE_NUMBER_1: IconName = IconName("rosette-number-1") ROSETTE_NUMBER_2: IconName = IconName("rosette-number-2") ROSETTE_NUMBER_3: IconName = IconName("rosette-number-3") ROSETTE_NUMBER_4: IconName = IconName("rosette-number-4") ROSETTE_NUMBER_5: IconName = IconName("rosette-number-5") ROSETTE_NUMBER_6: IconName = IconName("rosette-number-6") ROSETTE_NUMBER_7: IconName = IconName("rosette-number-7") ROSETTE_NUMBER_8: IconName = IconName("rosette-number-8") ROSETTE_NUMBER_9: IconName = IconName("rosette-number-9") ROTATE: IconName = IconName("rotate") ROTATE_2: IconName = IconName("rotate-2") ROTATE_360: IconName = IconName("rotate-360") ROTATE_CLOCKWISE: IconName = IconName("rotate-clockwise") ROTATE_CLOCKWISE_2: IconName = IconName("rotate-clockwise-2") ROTATE_DOT: IconName = IconName("rotate-dot") ROTATE_RECTANGLE: IconName = IconName("rotate-rectangle") ROUTE: IconName = IconName("route") ROUTE_2: IconName = IconName("route-2") ROUTE_OFF: IconName = IconName("route-off") ROUTER: IconName = IconName("router") ROUTER_OFF: IconName = IconName("router-off") ROW_INSERT_BOTTOM: IconName = IconName("row-insert-bottom") ROW_INSERT_TOP: IconName = IconName("row-insert-top") ROW_REMOVE: IconName = IconName("row-remove") RSS: IconName = IconName("rss") RUBBER_STAMP: IconName = IconName("rubber-stamp") RUBBER_STAMP_OFF: IconName = IconName("rubber-stamp-off") RULER: IconName = IconName("ruler") RULER_2: IconName = IconName("ruler-2") RULER_2_OFF: IconName = IconName("ruler-2-off") RULER_3: IconName = IconName("ruler-3") RULER_MEASURE: IconName = IconName("ruler-measure") RULER_OFF: IconName = IconName("ruler-off") RUN: IconName = IconName("run") S_TURN_DOWN: IconName = IconName("s-turn-down") S_TURN_LEFT: IconName = IconName("s-turn-left") S_TURN_RIGHT: IconName = IconName("s-turn-right") S_TURN_UP: IconName = IconName("s-turn-up") SAILBOAT: IconName = IconName("sailboat") SAILBOAT_2: IconName = IconName("sailboat-2") SAILBOAT_OFF: IconName = IconName("sailboat-off") SALAD: IconName = IconName("salad") SALT: IconName = IconName("salt") SATELLITE: IconName = IconName("satellite") SATELLITE_OFF: IconName = IconName("satellite-off") SAUSAGE: IconName = IconName("sausage") SCALE: IconName = IconName("scale") SCALE_OFF: IconName = IconName("scale-off") SCALE_OUTLINE: IconName = IconName("scale-outline") SCALE_OUTLINE_OFF: IconName = IconName("scale-outline-off") SCAN: IconName = IconName("scan") SCAN_EYE: IconName = IconName("scan-eye") SCHEMA: IconName = IconName("schema") SCHEMA_OFF: IconName = IconName("schema-off") SCHOOL: IconName = IconName("school") SCHOOL_BELL: IconName = IconName("school-bell") SCHOOL_OFF: IconName = IconName("school-off") SCISSORS: IconName = IconName("scissors") SCISSORS_OFF: IconName = IconName("scissors-off") SCOOTER: IconName = IconName("scooter") SCOOTER_ELECTRIC: IconName = IconName("scooter-electric") SCOREBOARD: IconName = IconName("scoreboard") SCREEN_SHARE: IconName = IconName("screen-share") SCREEN_SHARE_OFF: IconName = IconName("screen-share-off") SCREENSHOT: IconName = IconName("screenshot") SCRIBBLE: IconName = IconName("scribble") SCRIBBLE_OFF: IconName = IconName("scribble-off") SCRIPT: IconName = IconName("script") SCRIPT_MINUS: IconName = IconName("script-minus") SCRIPT_PLUS: IconName = IconName("script-plus") SCRIPT_X: IconName = IconName("script-x") SCUBA_MASK: IconName = IconName("scuba-mask") SCUBA_MASK_OFF: IconName = IconName("scuba-mask-off") SDK: IconName = IconName("sdk") SEARCH: IconName = IconName("search") SEARCH_OFF: IconName = IconName("search-off") SECTION: IconName = IconName("section") SECTION_SIGN: IconName = IconName("section-sign") SEEDING: IconName = IconName("seeding") SEEDING_OFF: IconName = IconName("seeding-off") SELECT: IconName = IconName("select") SELECT_ALL: IconName = IconName("select-all") SELECTOR: IconName = IconName("selector") SEND: IconName = IconName("send") SEND_OFF: IconName = IconName("send-off") SEO: IconName = IconName("seo") SEPARATOR: IconName = IconName("separator") SEPARATOR_HORIZONTAL: IconName = IconName("separator-horizontal") SEPARATOR_VERTICAL: IconName = IconName("separator-vertical") SERVER: IconName = IconName("server") SERVER_2: IconName = IconName("server-2") SERVER_BOLT: IconName = IconName("server-bolt") SERVER_COG: IconName = IconName("server-cog") SERVER_OFF: IconName = IconName("server-off") SERVICEMARK: IconName = IconName("servicemark") SETTINGS: IconName = IconName("settings") SETTINGS_2: IconName = IconName("settings-2") SETTINGS_AUTOMATION: IconName = IconName("settings-automation") SETTINGS_BOLT: IconName = IconName("settings-bolt") SETTINGS_CANCEL: IconName = IconName("settings-cancel") SETTINGS_CHECK: IconName = IconName("settings-check") SETTINGS_CODE: IconName = IconName("settings-code") SETTINGS_COG: IconName = IconName("settings-cog") SETTINGS_DOLLAR: IconName = IconName("settings-dollar") SETTINGS_DOWN: IconName = IconName("settings-down") SETTINGS_EXCLAMATION: IconName = IconName("settings-exclamation") SETTINGS_FILLED: IconName = IconName("settings-filled") SETTINGS_HEART: IconName = IconName("settings-heart") SETTINGS_MINUS: IconName = IconName("settings-minus") SETTINGS_OFF: IconName = IconName("settings-off") SETTINGS_PAUSE: IconName = IconName("settings-pause") SETTINGS_PIN: IconName = IconName("settings-pin") SETTINGS_PLUS: IconName = IconName("settings-plus") SETTINGS_QUESTION: IconName = IconName("settings-question") SETTINGS_SEARCH: IconName = IconName("settings-search") SETTINGS_SHARE: IconName = IconName("settings-share") SETTINGS_STAR: IconName = IconName("settings-star") SETTINGS_UP: IconName = IconName("settings-up") SETTINGS_X: IconName = IconName("settings-x") SHADOW: IconName = IconName("shadow") SHADOW_OFF: IconName = IconName("shadow-off") SHAPE: IconName = IconName("shape") SHAPE_2: IconName = IconName("shape-2") SHAPE_3: IconName = IconName("shape-3") SHAPE_OFF: IconName = IconName("shape-off") SHARE: IconName = IconName("share") SHARE_2: IconName = IconName("share-2") SHARE_3: IconName = IconName("share-3") SHARE_OFF: IconName = IconName("share-off") SHI_JUMPING: IconName = IconName("shi-jumping") SHIELD: IconName = IconName("shield") SHIELD_BOLT: IconName = IconName("shield-bolt") SHIELD_CANCEL: IconName = IconName("shield-cancel") SHIELD_CHECK: IconName = IconName("shield-check") SHIELD_CHECK_FILLED: IconName = IconName("shield-check-filled") SHIELD_CHECKERED: IconName = IconName("shield-checkered") SHIELD_CHECKERED_FILLED: IconName = IconName("shield-checkered-filled") SHIELD_CHEVRON: IconName = IconName("shield-chevron") SHIELD_CODE: IconName = IconName("shield-code") SHIELD_COG: IconName = IconName("shield-cog") SHIELD_DOLLAR: IconName = IconName("shield-dollar") SHIELD_DOWN: IconName = IconName("shield-down") SHIELD_EXCLAMATION: IconName = IconName("shield-exclamation") SHIELD_FILLED: IconName = IconName("shield-filled") SHIELD_HALF: IconName = IconName("shield-half") SHIELD_HALF_FILLED: IconName = IconName("shield-half-filled") SHIELD_HEART: IconName = IconName("shield-heart") SHIELD_LOCK: IconName = IconName("shield-lock") SHIELD_LOCK_FILLED: IconName = IconName("shield-lock-filled") SHIELD_MINUS: IconName = IconName("shield-minus") SHIELD_OFF: IconName = IconName("shield-off") SHIELD_PAUSE: IconName = IconName("shield-pause") SHIELD_PIN: IconName = IconName("shield-pin") SHIELD_PLUS: IconName = IconName("shield-plus") SHIELD_QUESTION: IconName = IconName("shield-question") SHIELD_SEARCH: IconName = IconName("shield-search") SHIELD_SHARE: IconName = IconName("shield-share") SHIELD_STAR: IconName = IconName("shield-star") SHIELD_UP: IconName = IconName("shield-up") SHIELD_X: IconName = IconName("shield-x") SHIP: IconName = IconName("ship") SHIP_OFF: IconName = IconName("ship-off") SHIRT: IconName = IconName("shirt") SHIRT_FILLED: IconName = IconName("shirt-filled") SHIRT_OFF: IconName = IconName("shirt-off") SHIRT_SPORT: IconName = IconName("shirt-sport") SHOE: IconName = IconName("shoe") SHOE_OFF: IconName = IconName("shoe-off") SHOPPING_BAG: IconName = IconName("shopping-bag") SHOPPING_CART: IconName = IconName("shopping-cart") SHOPPING_CART_DISCOUNT: IconName = IconName("shopping-cart-discount") SHOPPING_CART_OFF: IconName = IconName("shopping-cart-off") SHOPPING_CART_PLUS: IconName = IconName("shopping-cart-plus") SHOPPING_CART_X: IconName = IconName("shopping-cart-x") SHOVEL: IconName = IconName("shovel") SHREDDER: IconName = IconName("shredder") SIGN_LEFT: IconName = IconName("sign-left") SIGN_LEFT_FILLED: IconName = IconName("sign-left-filled") SIGN_RIGHT: IconName = IconName("sign-right") SIGN_RIGHT_FILLED: IconName = IconName("sign-right-filled") SIGNAL_2G: IconName = IconName("signal-2g") SIGNAL_3G: IconName = IconName("signal-3g") SIGNAL_4G: IconName = IconName("signal-4g") SIGNAL_4G_PLUS: IconName = IconName("signal-4g-plus") SIGNAL_5G: IconName = IconName("signal-5g") SIGNAL_6G: IconName = IconName("signal-6g") SIGNAL_E: IconName = IconName("signal-e") SIGNAL_G: IconName = IconName("signal-g") SIGNAL_H: IconName = IconName("signal-h") SIGNAL_H_PLUS: IconName = IconName("signal-h-plus") SIGNAL_LTE: IconName = IconName("signal-lte") SIGNATURE: IconName = IconName("signature") SIGNATURE_OFF: IconName = IconName("signature-off") SITEMAP: IconName = IconName("sitemap") SITEMAP_OFF: IconName = IconName("sitemap-off") SKATEBOARD: IconName = IconName("skateboard") SKATEBOARD_OFF: IconName = IconName("skateboard-off") SKATEBOARDING: IconName = IconName("skateboarding") SKULL: IconName = IconName("skull") SLASH: IconName = IconName("slash") SLASHES: IconName = IconName("slashes") SLEIGH: IconName = IconName("sleigh") SLICE: IconName = IconName("slice") SLIDESHOW: IconName = IconName("slideshow") SMART_HOME: IconName = IconName("smart-home") SMART_HOME_OFF: IconName = IconName("smart-home-off") SMOKING: IconName = IconName("smoking") SMOKING_NO: IconName = IconName("smoking-no") SNOWFLAKE: IconName = IconName("snowflake") SNOWFLAKE_OFF: IconName = IconName("snowflake-off") SNOWMAN: IconName = IconName("snowman") SOCCER_FIELD: IconName = IconName("soccer-field") SOCIAL: IconName = IconName("social") SOCIAL_OFF: IconName = IconName("social-off") SOCK: IconName = IconName("sock") SOFA: IconName = IconName("sofa") SOFA_OFF: IconName = IconName("sofa-off") SOLAR_PANEL: IconName = IconName("solar-panel") SOLAR_PANEL_2: IconName = IconName("solar-panel-2") SORT_0_9: IconName = IconName("sort-0-9") SORT_9_0: IconName = IconName("sort-9-0") SORT_A_Z: IconName = IconName("sort-a-z") SORT_ASCENDING: IconName = IconName("sort-ascending") SORT_ASCENDING_2: IconName = IconName("sort-ascending-2") SORT_ASCENDING_LETTERS: IconName = IconName("sort-ascending-letters") SORT_ASCENDING_NUMBERS: IconName = IconName("sort-ascending-numbers") SORT_DESCENDING: IconName = IconName("sort-descending") SORT_DESCENDING_2: IconName = IconName("sort-descending-2") SORT_DESCENDING_LETTERS: IconName = IconName("sort-descending-letters") SORT_DESCENDING_NUMBERS: IconName = IconName("sort-descending-numbers") SORT_Z_A: IconName = IconName("sort-z-a") SOS: IconName = IconName("sos") SOUP: IconName = IconName("soup") SOUP_OFF: IconName = IconName("soup-off") SOURCE_CODE: IconName = IconName("source-code") SPACE: IconName = IconName("space") SPACE_OFF: IconName = IconName("space-off") SPACING_HORIZONTAL: IconName = IconName("spacing-horizontal") SPACING_VERTICAL: IconName = IconName("spacing-vertical") SPADE: IconName = IconName("spade") SPADE_FILLED: IconName = IconName("spade-filled") SPARKLES: IconName = IconName("sparkles") SPEAKERPHONE: IconName = IconName("speakerphone") SPEEDBOAT: IconName = IconName("speedboat") SPHERE: IconName = IconName("sphere") SPHERE_OFF: IconName = IconName("sphere-off") SPHERE_PLUS: IconName = IconName("sphere-plus") SPIDER: IconName = IconName("spider") SPIRAL: IconName = IconName("spiral") SPIRAL_OFF: IconName = IconName("spiral-off") SPORT_BILLARD: IconName = IconName("sport-billard") SPRAY: IconName = IconName("spray") SPY: IconName = IconName("spy") SPY_OFF: IconName = IconName("spy-off") SQL: IconName = IconName("sql") SQUARE: IconName = IconName("square") SQUARE_0_FILLED: IconName = IconName("square-0-filled") SQUARE_1_FILLED: IconName = IconName("square-1-filled") SQUARE_2_FILLED: IconName = IconName("square-2-filled") SQUARE_3_FILLED: IconName = IconName("square-3-filled") SQUARE_4_FILLED: IconName = IconName("square-4-filled") SQUARE_5_FILLED: IconName = IconName("square-5-filled") SQUARE_6_FILLED: IconName = IconName("square-6-filled") SQUARE_7_FILLED: IconName = IconName("square-7-filled") SQUARE_8_FILLED: IconName = IconName("square-8-filled") SQUARE_9_FILLED: IconName = IconName("square-9-filled") SQUARE_ARROW_DOWN: IconName = IconName("square-arrow-down") SQUARE_ARROW_LEFT: IconName = IconName("square-arrow-left") SQUARE_ARROW_RIGHT: IconName = IconName("square-arrow-right") SQUARE_ARROW_UP: IconName = IconName("square-arrow-up") SQUARE_ASTERISK: IconName = IconName("square-asterisk") SQUARE_CHECK: IconName = IconName("square-check") SQUARE_CHECK_FILLED: IconName = IconName("square-check-filled") SQUARE_CHEVRON_DOWN: IconName = IconName("square-chevron-down") SQUARE_CHEVRON_LEFT: IconName = IconName("square-chevron-left") SQUARE_CHEVRON_RIGHT: IconName = IconName("square-chevron-right") SQUARE_CHEVRON_UP: IconName = IconName("square-chevron-up") SQUARE_CHEVRONS_DOWN: IconName = IconName("square-chevrons-down") SQUARE_CHEVRONS_LEFT: IconName = IconName("square-chevrons-left") SQUARE_CHEVRONS_RIGHT: IconName = IconName("square-chevrons-right") SQUARE_CHEVRONS_UP: IconName = IconName("square-chevrons-up") SQUARE_DOT: IconName = IconName("square-dot") SQUARE_F0: IconName = IconName("square-f0") SQUARE_F0_FILLED: IconName = IconName("square-f0-filled") SQUARE_F1: IconName = IconName("square-f1") SQUARE_F1_FILLED: IconName = IconName("square-f1-filled") SQUARE_F2: IconName = IconName("square-f2") SQUARE_F2_FILLED: IconName = IconName("square-f2-filled") SQUARE_F3: IconName = IconName("square-f3") SQUARE_F3_FILLED: IconName = IconName("square-f3-filled") SQUARE_F4: IconName = IconName("square-f4") SQUARE_F4_FILLED: IconName = IconName("square-f4-filled") SQUARE_F5: IconName = IconName("square-f5") SQUARE_F5_FILLED: IconName = IconName("square-f5-filled") SQUARE_F6: IconName = IconName("square-f6") SQUARE_F6_FILLED: IconName = IconName("square-f6-filled") SQUARE_F7: IconName = IconName("square-f7") SQUARE_F7_FILLED: IconName = IconName("square-f7-filled") SQUARE_F8: IconName = IconName("square-f8") SQUARE_F8_FILLED: IconName = IconName("square-f8-filled") SQUARE_F9: IconName = IconName("square-f9") SQUARE_F9_FILLED: IconName = IconName("square-f9-filled") SQUARE_FORBID: IconName = IconName("square-forbid") SQUARE_FORBID_2: IconName = IconName("square-forbid-2") SQUARE_HALF: IconName = IconName("square-half") SQUARE_KEY: IconName = IconName("square-key") SQUARE_LETTER_A: IconName = IconName("square-letter-a") SQUARE_LETTER_B: IconName = IconName("square-letter-b") SQUARE_LETTER_C: IconName = IconName("square-letter-c") SQUARE_LETTER_D: IconName = IconName("square-letter-d") SQUARE_LETTER_E: IconName = IconName("square-letter-e") SQUARE_LETTER_F: IconName = IconName("square-letter-f") SQUARE_LETTER_G: IconName = IconName("square-letter-g") SQUARE_LETTER_H: IconName = IconName("square-letter-h") SQUARE_LETTER_I: IconName = IconName("square-letter-i") SQUARE_LETTER_J: IconName = IconName("square-letter-j") SQUARE_LETTER_K: IconName = IconName("square-letter-k") SQUARE_LETTER_L: IconName = IconName("square-letter-l") SQUARE_LETTER_M: IconName = IconName("square-letter-m") SQUARE_LETTER_N: IconName = IconName("square-letter-n") SQUARE_LETTER_O: IconName = IconName("square-letter-o") SQUARE_LETTER_P: IconName = IconName("square-letter-p") SQUARE_LETTER_Q: IconName = IconName("square-letter-q") SQUARE_LETTER_R: IconName = IconName("square-letter-r") SQUARE_LETTER_S: IconName = IconName("square-letter-s") SQUARE_LETTER_T: IconName = IconName("square-letter-t") SQUARE_LETTER_U: IconName = IconName("square-letter-u") SQUARE_LETTER_V: IconName = IconName("square-letter-v") SQUARE_LETTER_W: IconName = IconName("square-letter-w") SQUARE_LETTER_X: IconName = IconName("square-letter-x") SQUARE_LETTER_Y: IconName = IconName("square-letter-y") SQUARE_LETTER_Z: IconName = IconName("square-letter-z") SQUARE_MINUS: IconName = IconName("square-minus") SQUARE_NUMBER_0: IconName = IconName("square-number-0") SQUARE_NUMBER_1: IconName = IconName("square-number-1") SQUARE_NUMBER_2: IconName = IconName("square-number-2") SQUARE_NUMBER_3: IconName = IconName("square-number-3") SQUARE_NUMBER_4: IconName = IconName("square-number-4") SQUARE_NUMBER_5: IconName = IconName("square-number-5") SQUARE_NUMBER_6: IconName = IconName("square-number-6") SQUARE_NUMBER_7: IconName = IconName("square-number-7") SQUARE_NUMBER_8: IconName = IconName("square-number-8") SQUARE_NUMBER_9: IconName = IconName("square-number-9") SQUARE_OFF: IconName = IconName("square-off") SQUARE_PLUS: IconName = IconName("square-plus") SQUARE_ROOT: IconName = IconName("square-root") SQUARE_ROOT_2: IconName = IconName("square-root-2") SQUARE_ROTATED: IconName = IconName("square-rotated") SQUARE_ROTATED_FILLED: IconName = IconName("square-rotated-filled") SQUARE_ROTATED_FORBID: IconName = IconName("square-rotated-forbid") SQUARE_ROTATED_FORBID_2: IconName = IconName("square-rotated-forbid-2") SQUARE_ROTATED_OFF: IconName = IconName("square-rotated-off") SQUARE_ROUNDED: IconName = IconName("square-rounded") SQUARE_ROUNDED_ARROW_DOWN: IconName = IconName("square-rounded-arrow-down") SQUARE_ROUNDED_ARROW_DOWN_FILLED: IconName = IconName( "square-rounded-arrow-down-filled" ) SQUARE_ROUNDED_ARROW_LEFT: IconName = IconName("square-rounded-arrow-left") SQUARE_ROUNDED_ARROW_LEFT_FILLED: IconName = IconName( "square-rounded-arrow-left-filled" ) SQUARE_ROUNDED_ARROW_RIGHT: IconName = IconName("square-rounded-arrow-right") SQUARE_ROUNDED_ARROW_RIGHT_FILLED: IconName = IconName( "square-rounded-arrow-right-filled" ) SQUARE_ROUNDED_ARROW_UP: IconName = IconName("square-rounded-arrow-up") SQUARE_ROUNDED_ARROW_UP_FILLED: IconName = IconName( "square-rounded-arrow-up-filled" ) SQUARE_ROUNDED_CHECK: IconName = IconName("square-rounded-check") SQUARE_ROUNDED_CHECK_FILLED: IconName = IconName("square-rounded-check-filled") SQUARE_ROUNDED_CHEVRON_DOWN: IconName = IconName("square-rounded-chevron-down") SQUARE_ROUNDED_CHEVRON_DOWN_FILLED: IconName = IconName( "square-rounded-chevron-down-filled" ) SQUARE_ROUNDED_CHEVRON_LEFT: IconName = IconName("square-rounded-chevron-left") SQUARE_ROUNDED_CHEVRON_LEFT_FILLED: IconName = IconName( "square-rounded-chevron-left-filled" ) SQUARE_ROUNDED_CHEVRON_RIGHT: IconName = IconName("square-rounded-chevron-right") SQUARE_ROUNDED_CHEVRON_RIGHT_FILLED: IconName = IconName( "square-rounded-chevron-right-filled" ) SQUARE_ROUNDED_CHEVRON_UP: IconName = IconName("square-rounded-chevron-up") SQUARE_ROUNDED_CHEVRON_UP_FILLED: IconName = IconName( "square-rounded-chevron-up-filled" ) SQUARE_ROUNDED_CHEVRONS_DOWN: IconName = IconName("square-rounded-chevrons-down") SQUARE_ROUNDED_CHEVRONS_DOWN_FILLED: IconName = IconName( "square-rounded-chevrons-down-filled" ) SQUARE_ROUNDED_CHEVRONS_LEFT: IconName = IconName("square-rounded-chevrons-left") SQUARE_ROUNDED_CHEVRONS_LEFT_FILLED: IconName = IconName( "square-rounded-chevrons-left-filled" ) SQUARE_ROUNDED_CHEVRONS_RIGHT: IconName = IconName("square-rounded-chevrons-right") SQUARE_ROUNDED_CHEVRONS_RIGHT_FILLED: IconName = IconName( "square-rounded-chevrons-right-filled" ) SQUARE_ROUNDED_CHEVRONS_UP: IconName = IconName("square-rounded-chevrons-up") SQUARE_ROUNDED_CHEVRONS_UP_FILLED: IconName = IconName( "square-rounded-chevrons-up-filled" ) SQUARE_ROUNDED_FILLED: IconName = IconName("square-rounded-filled") SQUARE_ROUNDED_LETTER_A: IconName = IconName("square-rounded-letter-a") SQUARE_ROUNDED_LETTER_B: IconName = IconName("square-rounded-letter-b") SQUARE_ROUNDED_LETTER_C: IconName = IconName("square-rounded-letter-c") SQUARE_ROUNDED_LETTER_D: IconName = IconName("square-rounded-letter-d") SQUARE_ROUNDED_LETTER_E: IconName = IconName("square-rounded-letter-e") SQUARE_ROUNDED_LETTER_F: IconName = IconName("square-rounded-letter-f") SQUARE_ROUNDED_LETTER_G: IconName = IconName("square-rounded-letter-g") SQUARE_ROUNDED_LETTER_H: IconName = IconName("square-rounded-letter-h") SQUARE_ROUNDED_LETTER_I: IconName = IconName("square-rounded-letter-i") SQUARE_ROUNDED_LETTER_J: IconName = IconName("square-rounded-letter-j") SQUARE_ROUNDED_LETTER_K: IconName = IconName("square-rounded-letter-k") SQUARE_ROUNDED_LETTER_L: IconName = IconName("square-rounded-letter-l") SQUARE_ROUNDED_LETTER_M: IconName = IconName("square-rounded-letter-m") SQUARE_ROUNDED_LETTER_N: IconName = IconName("square-rounded-letter-n") SQUARE_ROUNDED_LETTER_O: IconName = IconName("square-rounded-letter-o") SQUARE_ROUNDED_LETTER_P: IconName = IconName("square-rounded-letter-p") SQUARE_ROUNDED_LETTER_Q: IconName = IconName("square-rounded-letter-q") SQUARE_ROUNDED_LETTER_R: IconName = IconName("square-rounded-letter-r") SQUARE_ROUNDED_LETTER_S: IconName = IconName("square-rounded-letter-s") SQUARE_ROUNDED_LETTER_T: IconName = IconName("square-rounded-letter-t") SQUARE_ROUNDED_LETTER_U: IconName = IconName("square-rounded-letter-u") SQUARE_ROUNDED_LETTER_V: IconName = IconName("square-rounded-letter-v") SQUARE_ROUNDED_LETTER_W: IconName = IconName("square-rounded-letter-w") SQUARE_ROUNDED_LETTER_X: IconName = IconName("square-rounded-letter-x") SQUARE_ROUNDED_LETTER_Y: IconName = IconName("square-rounded-letter-y") SQUARE_ROUNDED_LETTER_Z: IconName = IconName("square-rounded-letter-z") SQUARE_ROUNDED_MINUS: IconName = IconName("square-rounded-minus") SQUARE_ROUNDED_NUMBER_0: IconName = IconName("square-rounded-number-0") SQUARE_ROUNDED_NUMBER_0_FILLED: IconName = IconName( "square-rounded-number-0-filled" ) SQUARE_ROUNDED_NUMBER_1: IconName = IconName("square-rounded-number-1") SQUARE_ROUNDED_NUMBER_1_FILLED: IconName = IconName( "square-rounded-number-1-filled" ) SQUARE_ROUNDED_NUMBER_2: IconName = IconName("square-rounded-number-2") SQUARE_ROUNDED_NUMBER_2_FILLED: IconName = IconName( "square-rounded-number-2-filled" ) SQUARE_ROUNDED_NUMBER_3: IconName = IconName("square-rounded-number-3") SQUARE_ROUNDED_NUMBER_3_FILLED: IconName = IconName( "square-rounded-number-3-filled" ) SQUARE_ROUNDED_NUMBER_4: IconName = IconName("square-rounded-number-4") SQUARE_ROUNDED_NUMBER_4_FILLED: IconName = IconName( "square-rounded-number-4-filled" ) SQUARE_ROUNDED_NUMBER_5: IconName = IconName("square-rounded-number-5") SQUARE_ROUNDED_NUMBER_5_FILLED: IconName = IconName( "square-rounded-number-5-filled" ) SQUARE_ROUNDED_NUMBER_6: IconName = IconName("square-rounded-number-6") SQUARE_ROUNDED_NUMBER_6_FILLED: IconName = IconName( "square-rounded-number-6-filled" ) SQUARE_ROUNDED_NUMBER_7: IconName = IconName("square-rounded-number-7") SQUARE_ROUNDED_NUMBER_7_FILLED: IconName = IconName( "square-rounded-number-7-filled" ) SQUARE_ROUNDED_NUMBER_8: IconName = IconName("square-rounded-number-8") SQUARE_ROUNDED_NUMBER_8_FILLED: IconName = IconName( "square-rounded-number-8-filled" ) SQUARE_ROUNDED_NUMBER_9: IconName = IconName("square-rounded-number-9") SQUARE_ROUNDED_NUMBER_9_FILLED: IconName = IconName( "square-rounded-number-9-filled" ) SQUARE_ROUNDED_PLUS: IconName = IconName("square-rounded-plus") SQUARE_ROUNDED_PLUS_FILLED: IconName = IconName("square-rounded-plus-filled") SQUARE_ROUNDED_X: IconName = IconName("square-rounded-x") SQUARE_ROUNDED_X_FILLED: IconName = IconName("square-rounded-x-filled") SQUARE_TOGGLE: IconName = IconName("square-toggle") SQUARE_TOGGLE_HORIZONTAL: IconName = IconName("square-toggle-horizontal") SQUARE_X: IconName = IconName("square-x") SQUARES_DIAGONAL: IconName = IconName("squares-diagonal") SQUARES_FILLED: IconName = IconName("squares-filled") STACK: IconName = IconName("stack") STACK_2: IconName = IconName("stack-2") STACK_3: IconName = IconName("stack-3") STACK_POP: IconName = IconName("stack-pop") STACK_PUSH: IconName = IconName("stack-push") STAIRS: IconName = IconName("stairs") STAIRS_DOWN: IconName = IconName("stairs-down") STAIRS_UP: IconName = IconName("stairs-up") STAR: IconName = IconName("star") STAR_FILLED: IconName = IconName("star-filled") STAR_HALF: IconName = IconName("star-half") STAR_HALF_FILLED: IconName = IconName("star-half-filled") STAR_OFF: IconName = IconName("star-off") STARS: IconName = IconName("stars") STARS_FILLED: IconName = IconName("stars-filled") STARS_OFF: IconName = IconName("stars-off") STATUS_CHANGE: IconName = IconName("status-change") STEAM: IconName = IconName("steam") STEERING_WHEEL: IconName = IconName("steering-wheel") STEERING_WHEEL_OFF: IconName = IconName("steering-wheel-off") STEP_INTO: IconName = IconName("step-into") STEP_OUT: IconName = IconName("step-out") STEREO_GLASSES: IconName = IconName("stereo-glasses") STETHOSCOPE: IconName = IconName("stethoscope") STETHOSCOPE_OFF: IconName = IconName("stethoscope-off") STICKER: IconName = IconName("sticker") STORM: IconName = IconName("storm") STORM_OFF: IconName = IconName("storm-off") STRETCHING: IconName = IconName("stretching") STRETCHING_2: IconName = IconName("stretching-2") STRIKETHROUGH: IconName = IconName("strikethrough") SUBMARINE: IconName = IconName("submarine") SUBSCRIPT: IconName = IconName("subscript") SUBTASK: IconName = IconName("subtask") SUM: IconName = IconName("sum") SUM_OFF: IconName = IconName("sum-off") SUN: IconName = IconName("sun") SUN_FILLED: IconName = IconName("sun-filled") SUN_HIGH: IconName = IconName("sun-high") SUN_LOW: IconName = IconName("sun-low") SUN_MOON: IconName = IconName("sun-moon") SUN_OFF: IconName = IconName("sun-off") SUN_WIND: IconName = IconName("sun-wind") SUNGLASSES: IconName = IconName("sunglasses") SUNRISE: IconName = IconName("sunrise") SUNSET: IconName = IconName("sunset") SUNSET_2: IconName = IconName("sunset-2") SUPERSCRIPT: IconName = IconName("superscript") SVG: IconName = IconName("svg") SWIMMING: IconName = IconName("swimming") SWIPE: IconName = IconName("swipe") SWITCH: IconName = IconName("switch") SWITCH_2: IconName = IconName("switch-2") SWITCH_3: IconName = IconName("switch-3") SWITCH_HORIZONTAL: IconName = IconName("switch-horizontal") SWITCH_VERTICAL: IconName = IconName("switch-vertical") SWORD: IconName = IconName("sword") SWORD_OFF: IconName = IconName("sword-off") SWORDS: IconName = IconName("swords") TABLE: IconName = IconName("table") TABLE_ALIAS: IconName = IconName("table-alias") TABLE_COLUMN: IconName = IconName("table-column") TABLE_DOWN: IconName = IconName("table-down") TABLE_EXPORT: IconName = IconName("table-export") TABLE_FILLED: IconName = IconName("table-filled") TABLE_HEART: IconName = IconName("table-heart") TABLE_IMPORT: IconName = IconName("table-import") TABLE_MINUS: IconName = IconName("table-minus") TABLE_OFF: IconName = IconName("table-off") TABLE_OPTIONS: IconName = IconName("table-options") TABLE_PLUS: IconName = IconName("table-plus") TABLE_ROW: IconName = IconName("table-row") TABLE_SHARE: IconName = IconName("table-share") TABLE_SHORTCUT: IconName = IconName("table-shortcut") TAG: IconName = IconName("tag") TAG_OFF: IconName = IconName("tag-off") TAGS: IconName = IconName("tags") TAGS_OFF: IconName = IconName("tags-off") TALLYMARK_1: IconName = IconName("tallymark-1") TALLYMARK_2: IconName = IconName("tallymark-2") TALLYMARK_3: IconName = IconName("tallymark-3") TALLYMARK_4: IconName = IconName("tallymark-4") TALLYMARKS: IconName = IconName("tallymarks") TANK: IconName = IconName("tank") TARGET: IconName = IconName("target") TARGET_ARROW: IconName = IconName("target-arrow") TARGET_OFF: IconName = IconName("target-off") TEAPOT: IconName = IconName("teapot") TELESCOPE: IconName = IconName("telescope") TELESCOPE_OFF: IconName = IconName("telescope-off") TEMPERATURE: IconName = IconName("temperature") TEMPERATURE_CELSIUS: IconName = IconName("temperature-celsius") TEMPERATURE_FAHRENHEIT: IconName = IconName("temperature-fahrenheit") TEMPERATURE_MINUS: IconName = IconName("temperature-minus") TEMPERATURE_OFF: IconName = IconName("temperature-off") TEMPERATURE_PLUS: IconName = IconName("temperature-plus") TEMPLATE: IconName = IconName("template") TEMPLATE_OFF: IconName = IconName("template-off") TENT: IconName = IconName("tent") TENT_OFF: IconName = IconName("tent-off") TERMINAL: IconName = IconName("terminal") TERMINAL_2: IconName = IconName("terminal-2") TEST_PIPE: IconName = IconName("test-pipe") TEST_PIPE_2: IconName = IconName("test-pipe-2") TEST_PIPE_OFF: IconName = IconName("test-pipe-off") TEX: IconName = IconName("tex") TEXT_CAPTION: IconName = IconName("text-caption") TEXT_COLOR: IconName = IconName("text-color") TEXT_DECREASE: IconName = IconName("text-decrease") TEXT_DIRECTION_LTR: IconName = IconName("text-direction-ltr") TEXT_DIRECTION_RTL: IconName = IconName("text-direction-rtl") TEXT_INCREASE: IconName = IconName("text-increase") TEXT_ORIENTATION: IconName = IconName("text-orientation") TEXT_PLUS: IconName = IconName("text-plus") TEXT_RECOGNITION: IconName = IconName("text-recognition") TEXT_RESIZE: IconName = IconName("text-resize") TEXT_SIZE: IconName = IconName("text-size") TEXT_SPELLCHECK: IconName = IconName("text-spellcheck") TEXT_WRAP: IconName = IconName("text-wrap") TEXT_WRAP_DISABLED: IconName = IconName("text-wrap-disabled") TEXTURE: IconName = IconName("texture") THEATER: IconName = IconName("theater") THERMOMETER: IconName = IconName("thermometer") THUMB_DOWN: IconName = IconName("thumb-down") THUMB_DOWN_FILLED: IconName = IconName("thumb-down-filled") THUMB_DOWN_OFF: IconName = IconName("thumb-down-off") THUMB_UP: IconName = IconName("thumb-up") THUMB_UP_FILLED: IconName = IconName("thumb-up-filled") THUMB_UP_OFF: IconName = IconName("thumb-up-off") TIC_TAC: IconName = IconName("tic-tac") TICKET: IconName = IconName("ticket") TICKET_OFF: IconName = IconName("ticket-off") TIE: IconName = IconName("tie") TILDE: IconName = IconName("tilde") TILT_SHIFT: IconName = IconName("tilt-shift") TILT_SHIFT_OFF: IconName = IconName("tilt-shift-off") TIME_DURATION_0: IconName = IconName("time-duration-0") TIME_DURATION_10: IconName = IconName("time-duration-10") TIME_DURATION_15: IconName = IconName("time-duration-15") TIME_DURATION_30: IconName = IconName("time-duration-30") TIME_DURATION_45: IconName = IconName("time-duration-45") TIME_DURATION_5: IconName = IconName("time-duration-5") TIME_DURATION_60: IconName = IconName("time-duration-60") TIME_DURATION_90: IconName = IconName("time-duration-90") TIME_DURATION_OFF: IconName = IconName("time-duration-off") TIMELINE: IconName = IconName("timeline") TIMELINE_EVENT: IconName = IconName("timeline-event") TIMELINE_EVENT_EXCLAMATION: IconName = IconName("timeline-event-exclamation") TIMELINE_EVENT_MINUS: IconName = IconName("timeline-event-minus") TIMELINE_EVENT_PLUS: IconName = IconName("timeline-event-plus") TIMELINE_EVENT_TEXT: IconName = IconName("timeline-event-text") TIMELINE_EVENT_X: IconName = IconName("timeline-event-x") TIR: IconName = IconName("tir") TOGGLE_LEFT: IconName = IconName("toggle-left") TOGGLE_RIGHT: IconName = IconName("toggle-right") TOILET_PAPER: IconName = IconName("toilet-paper") TOILET_PAPER_OFF: IconName = IconName("toilet-paper-off") TOML: IconName = IconName("toml") TOOL: IconName = IconName("tool") TOOLS: IconName = IconName("tools") TOOLS_KITCHEN: IconName = IconName("tools-kitchen") TOOLS_KITCHEN_2: IconName = IconName("tools-kitchen-2") TOOLS_KITCHEN_2_OFF: IconName = IconName("tools-kitchen-2-off") TOOLS_KITCHEN_OFF: IconName = IconName("tools-kitchen-off") TOOLS_OFF: IconName = IconName("tools-off") TOOLTIP: IconName = IconName("tooltip") TOPOLOGY_BUS: IconName = IconName("topology-bus") TOPOLOGY_COMPLEX: IconName = IconName("topology-complex") TOPOLOGY_FULL: IconName = IconName("topology-full") TOPOLOGY_FULL_HIERARCHY: IconName = IconName("topology-full-hierarchy") TOPOLOGY_RING: IconName = IconName("topology-ring") TOPOLOGY_RING_2: IconName = IconName("topology-ring-2") TOPOLOGY_RING_3: IconName = IconName("topology-ring-3") TOPOLOGY_STAR: IconName = IconName("topology-star") TOPOLOGY_STAR_2: IconName = IconName("topology-star-2") TOPOLOGY_STAR_3: IconName = IconName("topology-star-3") TOPOLOGY_STAR_RING: IconName = IconName("topology-star-ring") TOPOLOGY_STAR_RING_2: IconName = IconName("topology-star-ring-2") TOPOLOGY_STAR_RING_3: IconName = IconName("topology-star-ring-3") TORII: IconName = IconName("torii") TORNADO: IconName = IconName("tornado") TOURNAMENT: IconName = IconName("tournament") TOWER: IconName = IconName("tower") TOWER_OFF: IconName = IconName("tower-off") TRACK: IconName = IconName("track") TRACTOR: IconName = IconName("tractor") TRADEMARK: IconName = IconName("trademark") TRAFFIC_CONE: IconName = IconName("traffic-cone") TRAFFIC_CONE_OFF: IconName = IconName("traffic-cone-off") TRAFFIC_LIGHTS: IconName = IconName("traffic-lights") TRAFFIC_LIGHTS_OFF: IconName = IconName("traffic-lights-off") TRAIN: IconName = IconName("train") TRANSFER_IN: IconName = IconName("transfer-in") TRANSFER_OUT: IconName = IconName("transfer-out") TRANSFORM: IconName = IconName("transform") TRANSFORM_FILLED: IconName = IconName("transform-filled") TRANSITION_BOTTOM: IconName = IconName("transition-bottom") TRANSITION_LEFT: IconName = IconName("transition-left") TRANSITION_RIGHT: IconName = IconName("transition-right") TRANSITION_TOP: IconName = IconName("transition-top") TRASH: IconName = IconName("trash") TRASH_FILLED: IconName = IconName("trash-filled") TRASH_OFF: IconName = IconName("trash-off") TRASH_X: IconName = IconName("trash-x") TRASH_X_FILLED: IconName = IconName("trash-x-filled") TREADMILL: IconName = IconName("treadmill") TREE: IconName = IconName("tree") TREES: IconName = IconName("trees") TREKKING: IconName = IconName("trekking") TRENDING_DOWN: IconName = IconName("trending-down") TRENDING_DOWN_2: IconName = IconName("trending-down-2") TRENDING_DOWN_3: IconName = IconName("trending-down-3") TRENDING_UP: IconName = IconName("trending-up") TRENDING_UP_2: IconName = IconName("trending-up-2") TRENDING_UP_3: IconName = IconName("trending-up-3") TRIANGLE: IconName = IconName("triangle") TRIANGLE_FILLED: IconName = IconName("triangle-filled") TRIANGLE_INVERTED: IconName = IconName("triangle-inverted") TRIANGLE_INVERTED_FILLED: IconName = IconName("triangle-inverted-filled") TRIANGLE_OFF: IconName = IconName("triangle-off") TRIANGLE_SQUARE_CIRCLE: IconName = IconName("triangle-square-circle") TRIANGLES: IconName = IconName("triangles") TRIDENT: IconName = IconName("trident") TROLLEY: IconName = IconName("trolley") TROPHY: IconName = IconName("trophy") TROPHY_FILLED: IconName = IconName("trophy-filled") TROPHY_OFF: IconName = IconName("trophy-off") TROWEL: IconName = IconName("trowel") TRUCK: IconName = IconName("truck") TRUCK_DELIVERY: IconName = IconName("truck-delivery") TRUCK_LOADING: IconName = IconName("truck-loading") TRUCK_OFF: IconName = IconName("truck-off") TRUCK_RETURN: IconName = IconName("truck-return") TXT: IconName = IconName("txt") TYPOGRAPHY: IconName = IconName("typography") TYPOGRAPHY_OFF: IconName = IconName("typography-off") UFO: IconName = IconName("ufo") UFO_OFF: IconName = IconName("ufo-off") UMBRELLA: IconName = IconName("umbrella") UMBRELLA_FILLED: IconName = IconName("umbrella-filled") UMBRELLA_OFF: IconName = IconName("umbrella-off") UNDERLINE: IconName = IconName("underline") UNLINK: IconName = IconName("unlink") UPLOAD: IconName = IconName("upload") URGENT: IconName = IconName("urgent") USB: IconName = IconName("usb") USER: IconName = IconName("user") USER_BOLT: IconName = IconName("user-bolt") USER_CANCEL: IconName = IconName("user-cancel") USER_CHECK: IconName = IconName("user-check") USER_CIRCLE: IconName = IconName("user-circle") USER_CODE: IconName = IconName("user-code") USER_COG: IconName = IconName("user-cog") USER_DOLLAR: IconName = IconName("user-dollar") USER_DOWN: IconName = IconName("user-down") USER_EDIT: IconName = IconName("user-edit") USER_EXCLAMATION: IconName = IconName("user-exclamation") USER_HEART: IconName = IconName("user-heart") USER_MINUS: IconName = IconName("user-minus") USER_OFF: IconName = IconName("user-off") USER_PAUSE: IconName = IconName("user-pause") USER_PIN: IconName = IconName("user-pin") USER_PLUS: IconName = IconName("user-plus") USER_QUESTION: IconName = IconName("user-question") USER_SEARCH: IconName = IconName("user-search") USER_SHARE: IconName = IconName("user-share") USER_SHIELD: IconName = IconName("user-shield") USER_STAR: IconName = IconName("user-star") USER_UP: IconName = IconName("user-up") USER_X: IconName = IconName("user-x") USERS: IconName = IconName("users") USERS_GROUP: IconName = IconName("users-group") USERS_MINUS: IconName = IconName("users-minus") USERS_PLUS: IconName = IconName("users-plus") UV_INDEX: IconName = IconName("uv-index") UX_CIRCLE: IconName = IconName("ux-circle") VACCINE: IconName = IconName("vaccine") VACCINE_BOTTLE: IconName = IconName("vaccine-bottle") VACCINE_BOTTLE_OFF: IconName = IconName("vaccine-bottle-off") VACCINE_OFF: IconName = IconName("vaccine-off") VACUUM_CLEANER: IconName = IconName("vacuum-cleaner") VARIABLE: IconName = IconName("variable") VARIABLE_MINUS: IconName = IconName("variable-minus") VARIABLE_OFF: IconName = IconName("variable-off") VARIABLE_PLUS: IconName = IconName("variable-plus") VECTOR: IconName = IconName("vector") VECTOR_BEZIER: IconName = IconName("vector-bezier") VECTOR_BEZIER_2: IconName = IconName("vector-bezier-2") VECTOR_BEZIER_ARC: IconName = IconName("vector-bezier-arc") VECTOR_BEZIER_CIRCLE: IconName = IconName("vector-bezier-circle") VECTOR_OFF: IconName = IconName("vector-off") VECTOR_SPLINE: IconName = IconName("vector-spline") VECTOR_TRIANGLE: IconName = IconName("vector-triangle") VECTOR_TRIANGLE_OFF: IconName = IconName("vector-triangle-off") VENUS: IconName = IconName("venus") VERSIONS: IconName = IconName("versions") VERSIONS_FILLED: IconName = IconName("versions-filled") VERSIONS_OFF: IconName = IconName("versions-off") VIDEO: IconName = IconName("video") VIDEO_MINUS: IconName = IconName("video-minus") VIDEO_OFF: IconName = IconName("video-off") VIDEO_PLUS: IconName = IconName("video-plus") VIEW_360: IconName = IconName("view-360") VIEW_360_OFF: IconName = IconName("view-360-off") VIEWFINDER: IconName = IconName("viewfinder") VIEWFINDER_OFF: IconName = IconName("viewfinder-off") VIEWPORT_NARROW: IconName = IconName("viewport-narrow") VIEWPORT_WIDE: IconName = IconName("viewport-wide") VINYL: IconName = IconName("vinyl") VIP: IconName = IconName("vip") VIP_OFF: IconName = IconName("vip-off") VIRUS: IconName = IconName("virus") VIRUS_OFF: IconName = IconName("virus-off") VIRUS_SEARCH: IconName = IconName("virus-search") VOCABULARY: IconName = IconName("vocabulary") VOCABULARY_OFF: IconName = IconName("vocabulary-off") VOLCANO: IconName = IconName("volcano") VOLUME: IconName = IconName("volume") VOLUME_2: IconName = IconName("volume-2") VOLUME_3: IconName = IconName("volume-3") VOLUME_OFF: IconName = IconName("volume-off") WALK: IconName = IconName("walk") WALL: IconName = IconName("wall") WALL_OFF: IconName = IconName("wall-off") WALLET: IconName = IconName("wallet") WALLET_OFF: IconName = IconName("wallet-off") WALLPAPER: IconName = IconName("wallpaper") WALLPAPER_OFF: IconName = IconName("wallpaper-off") WAND: IconName = IconName("wand") WAND_OFF: IconName = IconName("wand-off") WASH: IconName = IconName("wash") WASH_DRY: IconName = IconName("wash-dry") WASH_DRY_1: IconName = IconName("wash-dry-1") WASH_DRY_2: IconName = IconName("wash-dry-2") WASH_DRY_3: IconName = IconName("wash-dry-3") WASH_DRY_A: IconName = IconName("wash-dry-a") WASH_DRY_DIP: IconName = IconName("wash-dry-dip") WASH_DRY_F: IconName = IconName("wash-dry-f") WASH_DRY_FLAT: IconName = IconName("wash-dry-flat") WASH_DRY_HANG: IconName = IconName("wash-dry-hang") WASH_DRY_OFF: IconName = IconName("wash-dry-off") WASH_DRY_P: IconName = IconName("wash-dry-p") WASH_DRY_SHADE: IconName = IconName("wash-dry-shade") WASH_DRY_W: IconName = IconName("wash-dry-w") WASH_DRYCLEAN: IconName = IconName("wash-dryclean") WASH_DRYCLEAN_OFF: IconName = IconName("wash-dryclean-off") WASH_ECO: IconName = IconName("wash-eco") WASH_GENTLE: IconName = IconName("wash-gentle") WASH_HAND: IconName = IconName("wash-hand") WASH_MACHINE: IconName = IconName("wash-machine") WASH_OFF: IconName = IconName("wash-off") WASH_PRESS: IconName = IconName("wash-press") WASH_TEMPERATURE_1: IconName = IconName("wash-temperature-1") WASH_TEMPERATURE_2: IconName = IconName("wash-temperature-2") WASH_TEMPERATURE_3: IconName = IconName("wash-temperature-3") WASH_TEMPERATURE_4: IconName = IconName("wash-temperature-4") WASH_TEMPERATURE_5: IconName = IconName("wash-temperature-5") WASH_TEMPERATURE_6: IconName = IconName("wash-temperature-6") WASH_TUMBLE_DRY: IconName = IconName("wash-tumble-dry") WASH_TUMBLE_OFF: IconName = IconName("wash-tumble-off") WATERPOLO: IconName = IconName("waterpolo") WAVE_SAW_TOOL: IconName = IconName("wave-saw-tool") WAVE_SINE: IconName = IconName("wave-sine") WAVE_SQUARE: IconName = IconName("wave-square") WEBHOOK: IconName = IconName("webhook") WEBHOOK_OFF: IconName = IconName("webhook-off") WEIGHT: IconName = IconName("weight") WHEELCHAIR: IconName = IconName("wheelchair") WHEELCHAIR_OFF: IconName = IconName("wheelchair-off") WHIRL: IconName = IconName("whirl") WIFI: IconName = IconName("wifi") WIFI_0: IconName = IconName("wifi-0") WIFI_1: IconName = IconName("wifi-1") WIFI_2: IconName = IconName("wifi-2") WIFI_OFF: IconName = IconName("wifi-off") WIND: IconName = IconName("wind") WIND_OFF: IconName = IconName("wind-off") WINDMILL: IconName = IconName("windmill") WINDMILL_FILLED: IconName = IconName("windmill-filled") WINDMILL_OFF: IconName = IconName("windmill-off") WINDOW: IconName = IconName("window") WINDOW_MAXIMIZE: IconName = IconName("window-maximize") WINDOW_MINIMIZE: IconName = IconName("window-minimize") WINDOW_OFF: IconName = IconName("window-off") WINDSOCK: IconName = IconName("windsock") WIPER: IconName = IconName("wiper") WIPER_WASH: IconName = IconName("wiper-wash") WOMAN: IconName = IconName("woman") WOOD: IconName = IconName("wood") WORLD: IconName = IconName("world") WORLD_BOLT: IconName = IconName("world-bolt") WORLD_CANCEL: IconName = IconName("world-cancel") WORLD_CHECK: IconName = IconName("world-check") WORLD_CODE: IconName = IconName("world-code") WORLD_COG: IconName = IconName("world-cog") WORLD_DOLLAR: IconName = IconName("world-dollar") WORLD_DOWN: IconName = IconName("world-down") WORLD_DOWNLOAD: IconName = IconName("world-download") WORLD_EXCLAMATION: IconName = IconName("world-exclamation") WORLD_HEART: IconName = IconName("world-heart") WORLD_LATITUDE: IconName = IconName("world-latitude") WORLD_LONGITUDE: IconName = IconName("world-longitude") WORLD_MINUS: IconName = IconName("world-minus") WORLD_OFF: IconName = IconName("world-off") WORLD_PAUSE: IconName = IconName("world-pause") WORLD_PIN: IconName = IconName("world-pin") WORLD_PLUS: IconName = IconName("world-plus") WORLD_QUESTION: IconName = IconName("world-question") WORLD_SEARCH: IconName = IconName("world-search") WORLD_SHARE: IconName = IconName("world-share") WORLD_STAR: IconName = IconName("world-star") WORLD_UP: IconName = IconName("world-up") WORLD_UPLOAD: IconName = IconName("world-upload") WORLD_WWW: IconName = IconName("world-www") WORLD_X: IconName = IconName("world-x") WRECKING_BALL: IconName = IconName("wrecking-ball") WRITING: IconName = IconName("writing") WRITING_OFF: IconName = IconName("writing-off") WRITING_SIGN: IconName = IconName("writing-sign") WRITING_SIGN_OFF: IconName = IconName("writing-sign-off") X: IconName = IconName("x") XBOX_A: IconName = IconName("xbox-a") XBOX_B: IconName = IconName("xbox-b") XBOX_X: IconName = IconName("xbox-x") XBOX_Y: IconName = IconName("xbox-y") XD: IconName = IconName("xd") YIN_YANG: IconName = IconName("yin-yang") YIN_YANG_FILLED: IconName = IconName("yin-yang-filled") YOGA: IconName = IconName("yoga") ZEPPELIN: IconName = IconName("zeppelin") ZEPPELIN_OFF: IconName = IconName("zeppelin-off") ZIP: IconName = IconName("zip") ZODIAC_AQUARIUS: IconName = IconName("zodiac-aquarius") ZODIAC_ARIES: IconName = IconName("zodiac-aries") ZODIAC_CANCER: IconName = IconName("zodiac-cancer") ZODIAC_CAPRICORN: IconName = IconName("zodiac-capricorn") ZODIAC_GEMINI: IconName = IconName("zodiac-gemini") ZODIAC_LEO: IconName = IconName("zodiac-leo") ZODIAC_LIBRA: IconName = IconName("zodiac-libra") ZODIAC_PISCES: IconName = IconName("zodiac-pisces") ZODIAC_SAGITTARIUS: IconName = IconName("zodiac-sagittarius") ZODIAC_SCORPIO: IconName = IconName("zodiac-scorpio") ZODIAC_TAURUS: IconName = IconName("zodiac-taurus") ZODIAC_VIRGO: IconName = IconName("zodiac-virgo") ZOOM_CANCEL: IconName = IconName("zoom-cancel") ZOOM_CHECK: IconName = IconName("zoom-check") ZOOM_CHECK_FILLED: IconName = IconName("zoom-check-filled") ZOOM_CODE: IconName = IconName("zoom-code") ZOOM_EXCLAMATION: IconName = IconName("zoom-exclamation") ZOOM_FILLED: IconName = IconName("zoom-filled") ZOOM_IN: IconName = IconName("zoom-in") ZOOM_IN_AREA: IconName = IconName("zoom-in-area") ZOOM_IN_AREA_FILLED: IconName = IconName("zoom-in-area-filled") ZOOM_IN_FILLED: IconName = IconName("zoom-in-filled") ZOOM_MONEY: IconName = IconName("zoom-money") ZOOM_OUT: IconName = IconName("zoom-out") ZOOM_OUT_AREA: IconName = IconName("zoom-out-area") ZOOM_OUT_FILLED: IconName = IconName("zoom-out-filled") ZOOM_PAN: IconName = IconName("zoom-pan") ZOOM_QUESTION: IconName = IconName("zoom-question") ZOOM_REPLACE: IconName = IconName("zoom-replace") ZOOM_RESET: IconName = IconName("zoom-reset") ZZZ: IconName = IconName("zzz") ZZZ_OFF: IconName = IconName("zzz-off") ================================================ FILE: viser/src/viser/_icons_generate_enum.py ================================================ """Helper script for dumping Tabler icon names into a Literal type.""" import tarfile from pathlib import Path HERE_DIR = Path(__file__).absolute().parent ICON_DIR = HERE_DIR / "_icons" def enum_name_from_icon(name: str) -> str: """Capitalize an icon name for use as an enum name.""" name = name.upper() name = name.replace("-", "_") if name[0].isdigit(): name = "ICON_" + name return name if __name__ == "__main__": with tarfile.open(ICON_DIR / "tabler-icons.tar") as tar: icon_names = sorted([name.partition(".svg")[0] for name in tar.getnames()]) # Generate stub file. This is used by type checkers. (HERE_DIR / "_icons_enum.pyi").write_text( "\n".join( [ "# Automatically generated by `_icons_generate_enum.py`", "# See https://tabler-icons.io/", "import enum", "from typing import NewType", "", "IconName = NewType('IconName', str)", '"""Name of an icon. Should be generated via `viser.Icon.*`."""', "", "class Icon:", ' """\'Enum\' class for referencing Tabler icons.', "", " We don't subclass enum.Enum for performance reasons -- importing an enum with", " thousands of names can result in import times in the hundreds of milliseconds.", ' """', "", ] + [ # Prefix all icon names with ICON_, since some of them start with # numbers and can't directly be used as Python names. f" {enum_name_from_icon(icon)}: IconName = IconName('{icon}')" for icon in icon_names ] ) ) # Generate source. This is used at runtime + by Sphinx for documentation. (HERE_DIR / "_icons_enum.py").write_text( "\n".join( [ "# Automatically generated by `_icons_generate_enum.py`", "# See https://tabler-icons.io/", "from typing import NewType", "", "IconName = NewType('IconName', str)", '"""Name of an icon. Should be generated via `viser.Icon.*`."""', "", "", "class _IconStringConverter(type):", " def __getattr__(self, __name: str) -> IconName:", ' if not __name.startswith("_"):', ' return IconName(__name.lower().replace("_", "-"))', " else:", " raise AttributeError()", "", "", "class Icon(metaclass=_IconStringConverter):", ' """\'Enum\' class for referencing Tabler icons.', "", " We don't subclass enum.Enum for performance reasons -- importing an enum with", " thousands of names can result in import times in the hundreds of milliseconds.", "", " Attributes:", ] + [ # Prefix all icon names with ICON_, since some of them start with # numbers and can't directly be used as Python names. f" {enum_name_from_icon(icon)} (IconName): The :code:`{icon}` icon." for icon in icon_names ] + [' """'] ) ) ================================================ FILE: viser/src/viser/_messages.py ================================================ """Message type definitions. For synchronization with the TypeScript definitions, see `_typescript_interface_gen.py.`""" from __future__ import annotations import dataclasses import uuid from typing import ( Any, Callable, ClassVar, Dict, Optional, Tuple, Type, TypeVar, Union, ) import numpy as onp import numpy.typing as onpt from typing_extensions import Annotated, Literal, NotRequired, TypedDict, override from . import infra, theme GuiSliderMark = TypedDict("GuiSliderMark", {"value": float, "label": NotRequired[str]}) Color = Literal[ "dark", "gray", "red", "pink", "grape", "violet", "indigo", "blue", "cyan", "green", "lime", "yellow", "orange", "teal", ] class Message(infra.Message): _tags: ClassVar[Tuple[str, ...]] = tuple() @override def redundancy_key(self) -> str: """Returns a unique key for this message, used for detecting redundant messages. For example: if we send 1000 GuiSetValue messages for the same GUI element, we should only keep the latest messages. """ parts = [type(self).__name__] # Scene node manipulation messages all have a "name" field. node_name = getattr(self, "name", None) if node_name is not None: parts.append(node_name) # GUI and notification messages all have an "id" field. node_name = getattr(self, "id", None) if node_name is not None: parts.append(node_name) return "_".join(parts) T = TypeVar("T", bound=Type[Message]) def tag_class(tag: str) -> Callable[[T], T]: """Decorator for tagging a class with a `type` field.""" def wrapper(cls: T) -> T: cls._tags = (cls._tags or ()) + (tag,) return cls return wrapper @dataclasses.dataclass class RunJavascriptMessage(Message): """Message for running some arbitrary Javascript on the client. We use this to set up the Plotly.js package, via the plotly.min.js source code.""" source: str @override def redundancy_key(self) -> str: # Never cull these messages. return str(uuid.uuid4()) @dataclasses.dataclass class NotificationMessage(Message): """Notification message.""" mode: Literal["show", "update"] id: str title: str body: str loading: bool with_close_button: bool auto_close: Union[int, Literal[False]] color: Optional[Color] @dataclasses.dataclass class RemoveNotificationMessage(Message): """Remove a specific notification.""" id: str @dataclasses.dataclass class ViewerCameraMessage(Message): """Message for a posed viewer camera. Pose is in the form T_world_camera, OpenCV convention, +Z forward.""" wxyz: Tuple[float, float, float, float] position: Tuple[float, float, float] fov: float aspect: float look_at: Tuple[float, float, float] up_direction: Tuple[float, float, float] # The list of scene pointer events supported by the viser frontend. ScenePointerEventType = Literal["click", "rect-select"] @dataclasses.dataclass class ScenePointerMessage(Message): """Message for a raycast-like pointer in the scene. origin is the viewing camera position, in world coordinates. direction is the vector if a ray is projected from the camera through the clicked pixel, """ # Later we can add `double_click`, `move`, `down`, `up`, etc. event_type: ScenePointerEventType ray_origin: Optional[Tuple[float, float, float]] ray_direction: Optional[Tuple[float, float, float]] screen_pos: Tuple[Tuple[float, float], ...] @dataclasses.dataclass class ScenePointerEnableMessage(Message): """Message to enable/disable scene click events.""" enable: bool event_type: ScenePointerEventType @override def redundancy_key(self) -> str: return ( type(self).__name__ + "-" + self.event_type + "-" + str(self.enable).lower() ) @dataclasses.dataclass class CameraFrustumMessage(Message): """Variant of CameraMessage used for visualizing camera frustums. OpenCV convention, +Z forward.""" name: str fov: float aspect: float scale: float color: int thickness: float image_media_type: Optional[Literal["image/jpeg", "image/png"]] image_binary: Optional[bytes] @dataclasses.dataclass class GlbMessage(Message): """GlTF message.""" name: str glb_data: bytes scale: float @dataclasses.dataclass class FrameMessage(Message): """Coordinate frame message.""" name: str show_axes: bool axes_length: float axes_radius: float origin_radius: float @dataclasses.dataclass class BatchedAxesMessage(Message): """Batched axes message. Positions and orientations should follow a `T_parent_local` convention, which corresponds to the R matrix and t vector in `p_parent = [R | t] p_local`.""" name: str wxyzs_batched: onpt.NDArray[onp.float32] positions_batched: onpt.NDArray[onp.float32] axes_length: float axes_radius: float @dataclasses.dataclass class GridMessage(Message): """Grid message. Helpful for visualizing things like ground planes.""" name: str width: float height: float width_segments: int height_segments: int plane: Literal["xz", "xy", "yx", "yz", "zx", "zy"] cell_color: int cell_thickness: float cell_size: float section_color: int section_thickness: float section_size: float @dataclasses.dataclass class LabelMessage(Message): """Add a 2D label to the scene.""" name: str text: str @dataclasses.dataclass class Gui3DMessage(Message): """Add a 3D gui element to the scene.""" order: float name: str container_id: str @dataclasses.dataclass class PointCloudMessage(Message): """Point cloud message. Positions are internally canonicalized to float32, colors to uint8. Float color inputs should be in the range [0,1], int color inputs should be in the range [0,255].""" name: str points: onpt.NDArray[onp.float32] colors: onpt.NDArray[onp.uint8] point_size: float point_ball_norm: float def __post_init__(self): # Check shapes. assert self.points.shape == self.colors.shape assert self.points.shape[-1] == 3 # Check dtypes. assert self.points.dtype == onp.float32 assert self.colors.dtype == onp.uint8 @dataclasses.dataclass class MeshBoneMessage(Message): """Message for a bone of a skinned mesh.""" name: str @dataclasses.dataclass class MeshMessage(Message): """Mesh message. Vertices are internally canonicalized to float32, faces to uint32.""" name: str vertices: onpt.NDArray[onp.float32] faces: onpt.NDArray[onp.uint32] color: Optional[int] vertex_colors: Optional[onpt.NDArray[onp.uint8]] wireframe: bool opacity: Optional[float] flat_shading: bool side: Literal["front", "back", "double"] material: Literal["standard", "toon3", "toon5"] def __post_init__(self): # Check shapes. assert self.vertices.shape[-1] == 3 assert self.faces.shape[-1] == 3 @dataclasses.dataclass class SkinnedMeshMessage(MeshMessage): """Mesh message. Vertices are internally canonicalized to float32, faces to uint32.""" bone_wxyzs: Tuple[Tuple[float, float, float, float], ...] bone_positions: Tuple[Tuple[float, float, float], ...] skin_indices: onpt.NDArray[onp.uint16] skin_weights: onpt.NDArray[onp.float32] def __post_init__(self): # Check shapes. assert self.vertices.shape[-1] == 3 assert self.faces.shape[-1] == 3 assert self.skin_weights is not None assert ( self.skin_indices.shape == self.skin_weights.shape == (self.vertices.shape[0], 4) ) @dataclasses.dataclass class SetBoneOrientationMessage(Message): """Server -> client message to set a skinned mesh bone's orientation. As with all other messages, transforms take the `T_parent_local` convention.""" name: str bone_index: int wxyz: Tuple[float, float, float, float] @override def redundancy_key(self) -> str: return type(self).__name__ + "-" + self.name + "-" + str(self.bone_index) @dataclasses.dataclass class SetBonePositionMessage(Message): """Server -> client message to set a skinned mesh bone's position. As with all other messages, transforms take the `T_parent_local` convention.""" name: str bone_index: int position: Tuple[float, float, float] @override def redundancy_key(self) -> str: return type(self).__name__ + "-" + self.name + "-" + str(self.bone_index) @dataclasses.dataclass class TransformControlsMessage(Message): """Message for transform gizmos.""" name: str scale: float line_width: float fixed: bool auto_transform: bool active_axes: Tuple[bool, bool, bool] disable_axes: bool disable_sliders: bool disable_rotations: bool translation_limits: Tuple[ Tuple[float, float], Tuple[float, float], Tuple[float, float] ] rotation_limits: Tuple[ Tuple[float, float], Tuple[float, float], Tuple[float, float] ] depth_test: bool opacity: float @dataclasses.dataclass class SetCameraPositionMessage(Message): """Server -> client message to set the camera's position.""" position: Tuple[float, float, float] @dataclasses.dataclass class SetCameraUpDirectionMessage(Message): """Server -> client message to set the camera's up direction.""" position: Tuple[float, float, float] @dataclasses.dataclass class SetCameraLookAtMessage(Message): """Server -> client message to set the camera's look-at point.""" look_at: Tuple[float, float, float] @dataclasses.dataclass class SetCameraFovMessage(Message): """Server -> client message to set the camera's field of view.""" fov: float @dataclasses.dataclass class SetOrientationMessage(Message): """Server -> client message to set a scene node's orientation. As with all other messages, transforms take the `T_parent_local` convention.""" name: str wxyz: Tuple[float, float, float, float] @dataclasses.dataclass class SetPositionMessage(Message): """Server -> client message to set a scene node's position. As with all other messages, transforms take the `T_parent_local` convention.""" name: str position: Tuple[float, float, float] @dataclasses.dataclass class TransformControlsUpdateMessage(Message): """Client -> server message when a transform control is updated. As with all other messages, transforms take the `T_parent_local` convention.""" name: str wxyz: Tuple[float, float, float, float] position: Tuple[float, float, float] @dataclasses.dataclass class BackgroundImageMessage(Message): """Message for rendering a background image.""" media_type: Literal["image/jpeg", "image/png"] rgb_bytes: bytes depth_bytes: Optional[bytes] @dataclasses.dataclass class ImageMessage(Message): """Message for rendering 2D images.""" name: str media_type: Literal["image/jpeg", "image/png"] data: bytes render_width: float render_height: float @dataclasses.dataclass class RemoveSceneNodeMessage(Message): """Remove a particular node from the scene.""" name: str @dataclasses.dataclass class SetSceneNodeVisibilityMessage(Message): """Set the visibility of a particular node in the scene.""" name: str visible: bool @dataclasses.dataclass class SetSceneNodeClickableMessage(Message): """Set the clickability of a particular node in the scene.""" name: str clickable: bool @dataclasses.dataclass class SceneNodeClickMessage(Message): """Message for clicked objects.""" name: str instance_index: Optional[int] """Instance index. Currently only used for batched axes.""" ray_origin: Tuple[float, float, float] ray_direction: Tuple[float, float, float] screen_pos: Tuple[float, float] @dataclasses.dataclass class ResetSceneMessage(Message): """Reset scene.""" @dataclasses.dataclass class ResetGuiMessage(Message): """Reset GUI.""" @tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddFolderMessage(Message): order: float id: str label: str container_id: str expand_by_default: bool visible: bool @tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddMarkdownMessage(Message): order: float id: str markdown: str container_id: str visible: bool @tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddProgressBarMessage(Message): order: float id: str value: float animated: bool color: Optional[Color] container_id: str visible: bool @tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddPlotlyMessage(Message): order: float id: str plotly_json_str: str aspect: float container_id: str visible: bool @tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddTabGroupMessage(Message): order: float id: str container_id: str tab_labels: Tuple[str, ...] tab_icons_html: Tuple[Union[str, None], ...] tab_container_ids: Tuple[str, ...] visible: bool @dataclasses.dataclass class _GuiAddInputBase(Message): """Base message type containing fields commonly used by GUI inputs.""" order: float id: str label: str container_id: str hint: Optional[str] value: Any visible: bool disabled: bool @dataclasses.dataclass class GuiModalMessage(Message): order: float id: str title: str @dataclasses.dataclass class GuiCloseModalMessage(Message): id: str @tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddButtonMessage(_GuiAddInputBase): # All GUI elements currently need an `value` field. # This makes our job on the frontend easier. value: bool color: Optional[Color] icon_html: Optional[str] @tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddUploadButtonMessage(_GuiAddInputBase): color: Optional[Color] icon_html: Optional[str] mime_type: str @tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddSliderMessage(_GuiAddInputBase): min: float max: float step: Optional[float] value: float precision: int marks: Optional[Tuple[GuiSliderMark, ...]] = None @tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddMultiSliderMessage(_GuiAddInputBase): min: float max: float step: Optional[float] min_range: Optional[float] precision: int fixed_endpoints: bool = False marks: Optional[Tuple[GuiSliderMark, ...]] = None @tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddNumberMessage(_GuiAddInputBase): value: float precision: int step: float min: Optional[float] max: Optional[float] @tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddRgbMessage(_GuiAddInputBase): value: Tuple[int, int, int] @tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddRgbaMessage(_GuiAddInputBase): value: Tuple[int, int, int, int] @tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddCheckboxMessage(_GuiAddInputBase): value: bool @tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddVector2Message(_GuiAddInputBase): value: Tuple[float, float] min: Optional[Tuple[float, float]] max: Optional[Tuple[float, float]] step: float precision: int @tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddVector3Message(_GuiAddInputBase): value: Tuple[float, float, float] min: Optional[Tuple[float, float, float]] max: Optional[Tuple[float, float, float]] step: float precision: int @tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddTextMessage(_GuiAddInputBase): value: str @tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddDropdownMessage(_GuiAddInputBase): value: str options: Tuple[str, ...] @tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddButtonGroupMessage(_GuiAddInputBase): value: str options: Tuple[str, ...] @dataclasses.dataclass class GuiRemoveMessage(Message): """Sent server->client to remove a GUI element.""" id: str @dataclasses.dataclass class GuiUpdateMessage(Message): """Sent client<->server when any property of a GUI component is changed.""" id: str updates: Annotated[ Dict[str, Any], infra.TypeScriptAnnotationOverride("Partial"), ] """Mapping from property name to new value.""" @override def redundancy_key(self) -> str: return ( type(self).__name__ + "-" + self.id + "-" + ",".join(list(self.updates.keys())) ) @dataclasses.dataclass class ThemeConfigurationMessage(Message): """Message from server->client to configure parts of the GUI.""" titlebar_content: Optional[theme.TitlebarConfig] control_layout: Literal["floating", "collapsible", "fixed"] control_width: Literal["small", "medium", "large"] show_logo: bool show_share_button: bool dark_mode: bool colors: Optional[Tuple[str, str, str, str, str, str, str, str, str, str]] @dataclasses.dataclass class CatmullRomSplineMessage(Message): """Message from server->client carrying Catmull-Rom spline information.""" name: str positions: Tuple[Tuple[float, float, float], ...] curve_type: Literal["centripetal", "chordal", "catmullrom"] tension: float closed: bool line_width: float color: int segments: Optional[int] @dataclasses.dataclass class CubicBezierSplineMessage(Message): """Message from server->client carrying Cubic Bezier spline information.""" name: str positions: Tuple[Tuple[float, float, float], ...] control_points: Tuple[Tuple[float, float, float], ...] line_width: float color: int segments: Optional[int] @dataclasses.dataclass class GaussianSplatsMessage(Message): """Message from server->client carrying splattable Gaussians.""" name: str # Memory layout is borrowed from: # https://github.com/antimatter15/splat buffer: onpt.NDArray[onp.uint32] """Our buffer will contain: - x as f32 - y as f32 - z as f32 - (unused) - cov1 (f16), cov2 (f16) - cov3 (f16), cov4 (f16) - cov5 (f16), cov6 (f16) - rgba (int32) Where cov1-6 are the upper triangular elements of the covariance matrix.""" @dataclasses.dataclass class GetRenderRequestMessage(Message): """Message from server->client requesting a render of the current viewport.""" format: Literal["image/jpeg", "image/png"] height: int width: int quality: int @dataclasses.dataclass class GetRenderResponseMessage(Message): """Message from client->server carrying a render.""" payload: bytes @dataclasses.dataclass class FileTransferStart(Message): """Signal that a file is about to be sent.""" source_component_id: Optional[str] """Origin GUI component, used for client->server file uploads.""" transfer_uuid: str filename: str mime_type: str part_count: int size_bytes: int @override def redundancy_key(self) -> str: return type(self).__name__ + "-" + self.transfer_uuid @dataclasses.dataclass class FileTransferPart(Message): """Send a file for clients to download or upload files from client.""" # TODO: it would make sense to rename all "id" instances to "uuid" for GUI component ids. source_component_id: Optional[str] transfer_uuid: str part: int content: bytes @override def redundancy_key(self) -> str: return type(self).__name__ + "-" + self.transfer_uuid + "-" + str(self.part) @dataclasses.dataclass class FileTransferPartAck(Message): """Send a file for clients to download or upload files from client.""" source_component_id: Optional[str] transfer_uuid: str transferred_bytes: int total_bytes: int @override def redundancy_key(self) -> str: return ( type(self).__name__ + "-" + self.transfer_uuid + "-" + str(self.transferred_bytes) ) @dataclasses.dataclass class ShareUrlRequest(Message): """Message from client->server to connect to the share URL server.""" @dataclasses.dataclass class ShareUrlUpdated(Message): """Message from server->client to indicate that the share URL has been updated.""" share_url: Optional[str] @dataclasses.dataclass class ShareUrlDisconnect(Message): """Message from client->server to disconnect from the share URL server.""" @dataclasses.dataclass class SetGuiPanelLabelMessage(Message): """Message from server->client to set the label of the GUI panel.""" label: Optional[str] ================================================ FILE: viser/src/viser/_notification_handle.py ================================================ from __future__ import annotations import dataclasses from typing import Literal from ._gui_api import Color from ._messages import NotificationMessage, RemoveNotificationMessage from .infra._infra import WebsockClientConnection @dataclasses.dataclass class _NotificationHandleState: websock_interface: WebsockClientConnection id: str title: str body: str loading: bool with_close_button: bool auto_close: int | Literal[False] color: Color | None @dataclasses.dataclass class NotificationHandle: """Handle for a notification in our visualizer.""" _impl: _NotificationHandleState def _sync_with_client(self, first: bool = False) -> None: m = NotificationMessage( "show" if first else "update", self._impl.id, self._impl.title, self._impl.body, self._impl.loading, self._impl.with_close_button, self._impl.auto_close, self._impl.color, ) self._impl.websock_interface.queue_message(m) @property def title(self) -> str: """Title to display on the notification.""" return self._impl.title @title.setter def title(self, title: str) -> None: if title == self._impl.title: return self._impl.title = title self._sync_with_client() @property def body(self) -> str: """Message to display on the notification body.""" return self._impl.body @body.setter def body(self, body: str) -> None: if body == self._impl.body: return self._impl.body = body self._sync_with_client() @property def loading(self) -> bool: """Whether the notification shows loading icon.""" return self._impl.loading @loading.setter def loading(self, loading: bool) -> None: if loading == self._impl.loading: return self._impl.loading = loading self._sync_with_client() @property def with_close_button(self) -> bool: """Whether the notification can be manually closed.""" return self._impl.with_close_button @with_close_button.setter def with_close_button(self, with_close_button: bool) -> None: if with_close_button == self._impl.with_close_button: return self._impl.with_close_button = with_close_button self._sync_with_client() @property def auto_close(self) -> int | Literal[False]: """Time in ms before the notification automatically closes; otherwise False such that the notification never closes on its own.""" return self._impl.auto_close @auto_close.setter def auto_close(self, auto_close: int | Literal[False]) -> None: if auto_close == self._impl.auto_close: return self._impl.auto_close = auto_close self._sync_with_client() @property def color(self) -> Color | None: """Color of the notification.""" return self._impl.color @color.setter def color(self, color: Color | None) -> None: if color == self._impl.color: return self._impl.color = color self._sync_with_client() def remove(self) -> None: self._impl.websock_interface.queue_message( RemoveNotificationMessage(self._impl.id) ) ================================================ FILE: viser/src/viser/_scene_api.py ================================================ from __future__ import annotations import io import time import warnings from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Callable, Tuple, TypeVar, Union, cast, get_args import imageio.v3 as iio import numpy as onp import numpy.typing as onpt from typing_extensions import Literal, ParamSpec, TypeAlias, assert_never from . import _messages from . import transforms as tf from ._scene_handles import ( BatchedAxesHandle, BoneState, CameraFrustumHandle, FrameHandle, GaussianSplatHandle, GlbHandle, Gui3dContainerHandle, ImageHandle, LabelHandle, MeshHandle, MeshSkinnedBoneHandle, MeshSkinnedHandle, PointCloudHandle, SceneNodeHandle, SceneNodePointerEvent, ScenePointerEvent, TransformControlsHandle, _SceneNodeHandleState, _TransformControlsState, ) if TYPE_CHECKING: import trimesh from ._viser import ClientHandle, ViserServer from .infra import ClientId P = ParamSpec("P") def _colors_to_uint8(colors: onp.ndarray) -> onpt.NDArray[onp.uint8]: """Convert intensity values to uint8. We assume the range [0,1] for floats, and [0,255] for integers. Accepts any shape.""" if colors.dtype != onp.uint8: if onp.issubdtype(colors.dtype, onp.floating): colors = onp.clip(colors * 255.0, 0, 255).astype(onp.uint8) if onp.issubdtype(colors.dtype, onp.integer): colors = onp.clip(colors, 0, 255).astype(onp.uint8) return colors RgbTupleOrArray: TypeAlias = Union[ Tuple[int, int, int], Tuple[float, float, float], onp.ndarray ] def _encode_rgb(rgb: RgbTupleOrArray) -> int: if isinstance(rgb, onp.ndarray): assert rgb.shape == (3,) rgb_fixed = tuple( value if onp.issubdtype(type(value), onp.integer) else int(value * 255) for value in rgb ) assert len(rgb_fixed) == 3 return int(rgb_fixed[0] * (256**2) + rgb_fixed[1] * 256 + rgb_fixed[2]) def _encode_image_binary( image: onp.ndarray, format: Literal["png", "jpeg"], jpeg_quality: int | None = None, ) -> tuple[Literal["image/png", "image/jpeg"], bytes]: media_type: Literal["image/png", "image/jpeg"] image = _colors_to_uint8(image) with io.BytesIO() as data_buffer: if format == "png": media_type = "image/png" iio.imwrite(data_buffer, image, extension=".png") elif format == "jpeg": media_type = "image/jpeg" iio.imwrite( data_buffer, image[..., :3], # Strip alpha. extension=".jpeg", quality=75 if jpeg_quality is None else jpeg_quality, ) else: assert_never(format) binary = data_buffer.getvalue() return media_type, binary TVector = TypeVar("TVector", bound=tuple) def cast_vector(vector: TVector | onp.ndarray, length: int) -> TVector: if not isinstance(vector, tuple): assert cast(onp.ndarray, vector).shape == ( length, ), f"Expected vector of shape {(length,)}, but got {vector.shape} instead" return cast(TVector, tuple(map(float, vector))) class SceneApi: """Interface for adding 3D primitives to the scene. Used by both our global server object, for sharing the same GUI elements with all clients, and by individual client handles.""" def __init__( self, owner: ViserServer | ClientHandle, # Who do I belong to? thread_executor: ThreadPoolExecutor, ) -> None: from ._viser import ViserServer self._owner = owner """Entity that owns this API.""" self._websock_interface = ( owner._websock_server if isinstance(owner, ViserServer) else owner._websock_connection ) """Interface for sending and listening to messages.""" self.world_axes: FrameHandle = FrameHandle( _SceneNodeHandleState( "/WorldAxes", self, wxyz=onp.array([1.0, 0.0, 0.0, 0.0]), position=onp.zeros(3), ) ) """Handle for the world axes, which are created by default.""" # Hide world axes on initialization. if isinstance(owner, ViserServer): self.world_axes.visible = False self._handle_from_transform_controls_name: dict[ str, TransformControlsHandle ] = {} self._handle_from_node_name: dict[str, SceneNodeHandle] = {} self._scene_pointer_cb: Callable[[ScenePointerEvent], None] | None = None self._scene_pointer_done_cb: Callable[[], None] = lambda: None self._scene_pointer_event_type: _messages.ScenePointerEventType | None = None self._websock_interface.register_handler( _messages.TransformControlsUpdateMessage, self._handle_transform_controls_updates, ) self._websock_interface.register_handler( _messages.SceneNodeClickMessage, self._handle_node_click_updates, ) self._websock_interface.register_handler( _messages.ScenePointerMessage, self._handle_scene_pointer_updates, ) self._thread_executor = thread_executor def set_up_direction( self, direction: Literal["+x", "+y", "+z", "-x", "-y", "-z"] | tuple[float, float, float] | onp.ndarray, ) -> None: """Set the global up direction of the scene. By default we follow +Z-up (similar to Blender, 3DS Max, ROS, etc), the most common alternative is +Y (OpenGL, Maya, etc). Args: direction: New up direction. Can either be a string (one of +x, +y, +z, -x, -y, -z) or a length-3 direction vector. """ if isinstance(direction, str): direction = { "+x": (1, 0, 0), "+y": (0, 1, 0), "+z": (0, 0, 1), "-x": (-1, 0, 0), "-y": (0, -1, 0), "-z": (0, 0, -1), }[direction] assert not isinstance(direction, str) default_three_up = onp.array([0.0, 1.0, 0.0]) direction = onp.asarray(direction) def rotate_between(before: onp.ndarray, after: onp.ndarray) -> tf.SO3: assert before.shape == after.shape == (3,) before = before / onp.linalg.norm(before) after = after / onp.linalg.norm(after) angle = onp.arccos(onp.clip(onp.dot(before, after), -1, 1)) axis = onp.cross(before, after) if onp.allclose(axis, onp.zeros(3), rtol=1e-3, atol=1e-5): unit_vector = onp.arange(3) == onp.argmin(onp.abs(before)) axis = onp.cross(before, unit_vector) axis = axis / onp.linalg.norm(axis) return tf.SO3.exp(angle * axis) R_threeworld_world = rotate_between(direction, default_three_up) # Rotate the world frame such that: # If we set +Y to up, +X and +Z should face the camera. # If we set +Z to up, +X and +Y should face the camera. # In App.tsx, the camera is initialized at [-3, 3, -3] in the threejs # coordinate frame. desired_fwd = onp.array([-1.0, 0.0, -1.0]) / onp.sqrt(2.0) current_fwd = R_threeworld_world @ (onp.ones(3) / onp.sqrt(3.0)) current_fwd = current_fwd * onp.array([1.0, 0.0, 1.0]) current_fwd = current_fwd / onp.linalg.norm(current_fwd) R_threeworld_world = ( tf.SO3.from_y_radians( # Rotate around the null space / up direction. onp.arctan2( onp.cross(current_fwd, desired_fwd)[1], onp.dot(current_fwd, desired_fwd), ), ) @ R_threeworld_world ) if not onp.any(onp.isnan(R_threeworld_world.wxyz)): # Set the orientation of the root node. self._websock_interface.queue_message( _messages.SetOrientationMessage( "", cast_vector(R_threeworld_world.wxyz, 4) ) ) def set_global_visibility(self, visible: bool) -> None: """Set visibility for all scene nodes. If set to False, all scene nodes will be hidden. This can be useful when we've called :meth:`SceneApi.set_background_image()`, and want to hide everything except for the background. Args: visible: Whether or not all scene nodes should be visible. """ self._websock_interface.queue_message( _messages.SetSceneNodeVisibilityMessage("", visible) ) def add_glb( self, name: str, glb_data: bytes, scale=1.0, wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> GlbHandle: """Add a general 3D asset via binary glTF (GLB). For glTF files, it's often simpler to use `trimesh.load()` with `.add_mesh_trimesh()`. This will call `.add_glb()` under the hood. For glTF features not supported by trimesh, glTF to GLB conversion can also be done programatically with libraries like `pygltflib`. Args: name: A scene tree name. Names in the format of /parent/child can be used to define a kinematic tree. glb_data: A binary payload. scale: A scale for resizing the GLB asset. wxyz: Quaternion rotation to parent frame from local frame (R_pl). position: Translation to parent frame from local frame (t_pl). visible: Whether or not this scene node is initially visible. Returns: Handle for manipulating scene node. """ self._websock_interface.queue_message( _messages.GlbMessage(name, glb_data, scale) ) return GlbHandle._make(self, name, wxyz, position, visible) def add_spline_catmull_rom( self, name: str, positions: tuple[tuple[float, float, float], ...] | onp.ndarray, curve_type: Literal["centripetal", "chordal", "catmullrom"] = "centripetal", tension: float = 0.5, closed: bool = False, line_width: float = 1, color: RgbTupleOrArray = (20, 20, 20), segments: int | None = None, wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> SceneNodeHandle: """Add a spline to the scene using Catmull-Rom interpolation. This method creates a spline based on a set of positions and interpolates them using the Catmull-Rom algorithm. This can be used to create smooth curves. Args: name: A scene tree name. Names in the format of /parent/child can be used to define a kinematic tree. positions: A tuple of 3D positions (x, y, z) defining the spline's path. curve_type: Type of the curve ('centripetal', 'chordal', 'catmullrom'). tension: Tension of the curve. Affects the tightness of the curve. closed: Boolean indicating if the spline is closed (forms a loop). line_width: Width of the spline line. color: Color of the spline as an RGB tuple. segments: Number of segments to divide the spline into. wxyz: Quaternion rotation to parent frame from local frame (R_pl). position: Translation to parent frame from local frame (t_pl). visible: Whether or not this scene node is initially visible. Returns: Handle for manipulating scene node. """ if isinstance(positions, onp.ndarray): assert len(positions.shape) == 2 and positions.shape[1] == 3 positions = tuple(map(tuple, positions)) # type: ignore assert len(positions[0]) == 3 assert isinstance(positions, tuple) self._websock_interface.queue_message( _messages.CatmullRomSplineMessage( name, positions, curve_type, tension, closed, line_width, _encode_rgb(color), segments=segments, ) ) return SceneNodeHandle._make(self, name, wxyz, position, visible) def add_spline_cubic_bezier( self, name: str, positions: tuple[tuple[float, float, float], ...] | onp.ndarray, control_points: tuple[tuple[float, float, float], ...] | onp.ndarray, line_width: float = 1, color: RgbTupleOrArray = (20, 20, 20), segments: int | None = None, wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> SceneNodeHandle: """Add a spline to the scene using Cubic Bezier interpolation. This method allows for the creation of a cubic Bezier spline based on given positions and control points. It is useful for creating complex, smooth, curving shapes. Args: name: A scene tree name. Names in the format of /parent/child can be used to define a kinematic tree. positions: A tuple of 3D positions (x, y, z) defining the spline's key points. control_points: A tuple of control points for Bezier curve shaping. line_width: Width of the spline line. color: Color of the spline as an RGB tuple. segments: Number of segments to divide the spline into. wxyz: Quaternion rotation to parent frame from local frame (R_pl). position: Translation to parent frame from local frame (t_pl). visible: Whether or not this scene node is initially visible. Returns: Handle for manipulating scene node. """ if isinstance(positions, onp.ndarray): assert len(positions.shape) == 2 and positions.shape[1] == 3 positions = tuple(map(tuple, positions)) # type: ignore if isinstance(control_points, onp.ndarray): assert len(control_points.shape) == 2 and control_points.shape[1] == 3 control_points = tuple(map(tuple, control_points)) # type: ignore assert isinstance(positions, tuple) assert isinstance(control_points, tuple) assert len(control_points) == (2 * len(positions) - 2) self._websock_interface.queue_message( _messages.CubicBezierSplineMessage( name, positions, control_points, line_width, _encode_rgb(color), segments=segments, ) ) return SceneNodeHandle._make(self, name, wxyz, position, visible) def add_camera_frustum( self, name: str, fov: float, aspect: float, scale: float = 0.3, color: RgbTupleOrArray = (20, 20, 20), image: onp.ndarray | None = None, format: Literal["png", "jpeg"] = "jpeg", jpeg_quality: int | None = None, wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, thickness: float = 1.0, ) -> CameraFrustumHandle: """Add a camera frustum to the scene for visualization. This method adds a frustum representation, typically used to visualize the field of view of a camera. It's helpful for understanding the perspective and coverage of a camera in the 3D space. Like all cameras in the viser Python API, frustums follow the OpenCV [+Z forward, +X right, +Y down] convention. fov is vertical in radians; aspect is width over height Args: name: A scene tree name. Names in the format of /parent/child can be used to define a kinematic tree. fov: Field of view of the camera (in radians). aspect: Aspect ratio of the camera (width over height). scale: Scale factor for the size of the frustum. color: Color of the frustum as an RGB tuple. image: Optional image to be displayed on the frustum. format: Format of the provided image ('png' or 'jpeg'). jpeg_quality: Quality of the jpeg image (if jpeg format is used). wxyz: Quaternion rotation to parent frame from local frame (R_pl). position: Translation to parent frame from local frame (t_pl). visible: Whether or not this scene node is initially visible. Returns: Handle for manipulating scene node. """ if image is not None: media_type, binary = _encode_image_binary( image, format, jpeg_quality=jpeg_quality ) else: media_type = None binary = None self._websock_interface.queue_message( _messages.CameraFrustumMessage( name=name, fov=fov, aspect=aspect, scale=scale, thickness=thickness, # (255, 255, 255) => 0xffffff, etc color=_encode_rgb(color), image_media_type=media_type, image_binary=binary, ) ) return CameraFrustumHandle._make(self, name, wxyz, position, visible) def add_frame( self, name: str, show_axes: bool = True, axes_length: float = 0.5, axes_radius: float = 0.025, origin_radius: float | None = None, wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> FrameHandle: """Add a coordinate frame to the scene. This method is used for adding a visual representation of a coordinate frame, which can help in understanding the orientation and position of objects in 3D space. For cases where we want to visualize many coordinate frames, like trajectories containing thousands or tens of thousands of frames, batching and calling :meth:`add_batched_axes()` may be a better choice than calling :meth:`add_frame()` in a loop. Args: name: A scene tree name. Names in the format of /parent/child can be used to define a kinematic tree. show_axes: Boolean to indicate whether to show the frame as a set of axes + origin sphere. axes_length: Length of each axis. axes_radius: Radius of each axis. origin_radius: Radius of the origin sphere. If not set, defaults to `2 * axes_radius`. wxyz: Quaternion rotation to parent frame from local frame (R_pl). position: Translation to parent frame from local frame (t_pl). visible: Whether or not this scene node is initially visible. Returns: Handle for manipulating scene node. """ if origin_radius is None: origin_radius = axes_radius * 2 self._websock_interface.queue_message( _messages.FrameMessage( name=name, show_axes=show_axes, axes_length=axes_length, axes_radius=axes_radius, origin_radius=origin_radius, ) ) return FrameHandle._make(self, name, wxyz, position, visible) def add_batched_axes( self, name: str, batched_wxyzs: tuple[tuple[float, float, float, float], ...] | onp.ndarray, batched_positions: tuple[tuple[float, float, float], ...] | onp.ndarray, axes_length: float = 0.5, axes_radius: float = 0.025, wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> BatchedAxesHandle: """Visualize batched sets of coordinate frame axes. The functionality of :meth:`add_batched_axes()` overlaps significantly with :meth:`add_frame()` when `show_axes=True`. The primary difference is that :meth:`add_batched_axes()` supports multiple axes via the `wxyzs_batched` (shape Nx4) and `positions_batched` (shape Nx3) arguments. Axes that are batched and rendered via a single call to `add_batched_axes()` are instanced on the client; this will be much faster to render than `add_frame()` called in a loop. Args: name: A scene tree name. Names in the format of /parent/child can be used to define a kinematic tree. batched_wxyzs: Float array of shape (N,4). batched_positions: Float array of shape (N,3). axes_length: Length of each axis. axes_radius: Radius of each axis. wxyz: Quaternion rotation to parent frame from local frame (R_pl). This will be applied to all axes. position: Translation to parent frame from local frame (t_pl). This will be applied to all axes. visible: Whether or not this scene node is initially visible. Returns: Handle for manipulating scene node. """ batched_wxyzs = onp.asarray(batched_wxyzs) batched_positions = onp.asarray(batched_positions) num_axes = batched_wxyzs.shape[0] assert batched_wxyzs.shape == (num_axes, 4) assert batched_positions.shape == (num_axes, 3) self._websock_interface.queue_message( _messages.BatchedAxesMessage( name=name, wxyzs_batched=batched_wxyzs.astype(onp.float32), positions_batched=batched_positions.astype(onp.float32), axes_length=axes_length, axes_radius=axes_radius, ) ) return BatchedAxesHandle._make(self, name, wxyz, position, visible) def add_grid( self, name: str, width: float = 10.0, height: float = 10.0, width_segments: int = 10, height_segments: int = 10, plane: Literal["xz", "xy", "yx", "yz", "zx", "zy"] = "xy", cell_color: RgbTupleOrArray = (200, 200, 200), cell_thickness: float = 1.0, cell_size: float = 0.5, section_color: RgbTupleOrArray = (140, 140, 140), section_thickness: float = 1.0, section_size: float = 1.0, wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> SceneNodeHandle: """Add a 2D grid to the scene. This can be useful as a size, orientation, or ground plane reference. Args: name: Name of the grid. width: Width of the grid. height: Height of the grid. width_segments: Number of segments along the width. height_segments: Number of segments along the height. plane: The plane in which the grid is oriented (e.g., 'xy', 'yz'). cell_color: Color of the grid cells as an RGB tuple. cell_thickness: Thickness of the grid lines. cell_size: Size of each cell in the grid. section_color: Color of the grid sections as an RGB tuple. section_thickness: Thickness of the section lines. section_size: Size of each section in the grid. wxyz: Quaternion rotation to parent frame from local frame (R_pl). position: Translation to parent frame from local frame (t_pl). visible: Whether or not this scene node is initially visible. Returns: Handle for manipulating scene node. """ self._websock_interface.queue_message( _messages.GridMessage( name=name, width=width, height=height, width_segments=width_segments, height_segments=height_segments, plane=plane, cell_color=_encode_rgb(cell_color), cell_thickness=cell_thickness, cell_size=cell_size, section_color=_encode_rgb(section_color), section_thickness=section_thickness, section_size=section_size, ) ) return SceneNodeHandle._make(self, name, wxyz, position, visible) def add_label( self, name: str, text: str, wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> LabelHandle: """Add a 2D label to the scene. This method creates a text label in the 3D scene, which can be used to annotate or provide information about specific points or objects. Args: name: Name of the label. text: Text content of the label. wxyz: Quaternion rotation to parent frame from local frame (R_pl). position: Translation to parent frame from local frame (t_pl). visible: Whether or not this scene node is initially visible. Returns: Handle for manipulating scene node. """ self._websock_interface.queue_message(_messages.LabelMessage(name, text)) return LabelHandle._make(self, name, wxyz, position, visible=visible) def add_point_cloud( self, name: str, points: onp.ndarray, colors: onp.ndarray | tuple[float, float, float], point_size: float = 0.1, point_shape: Literal[ "square", "diamond", "circle", "rounded", "sparkle" ] = "square", wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> PointCloudHandle: """Add a point cloud to the scene. Args: name: Name of scene node. Determines location in kinematic tree. points: Location of points. Should have shape (N, 3). colors: Colors of points. Should have shape (N, 3) or (3,). point_size: Size of each point. point_shape: Shape to draw each point. wxyz: Quaternion rotation to parent frame from local frame (R_pl). position: Translation to parent frame from local frame (t_pl). visible: Whether or not this scene node is initially visible. Returns: Handle for manipulating scene node. """ colors_cast = _colors_to_uint8(onp.asarray(colors)) assert ( len(points.shape) == 2 and points.shape[-1] == 3 ), "Shape of points should be (N, 3)." assert colors_cast.shape in { points.shape, (3,), }, "Shape of colors should be (N, 3) or (3,)." if colors_cast.shape == (3,): colors_cast = onp.tile(colors_cast[None, :], reps=(points.shape[0], 1)) self._websock_interface.queue_message( _messages.PointCloudMessage( name=name, points=points.astype(onp.float32), colors=colors_cast, point_size=point_size, point_ball_norm={ "square": float("inf"), "diamond": 1.0, "circle": 2.0, "rounded": 3.0, "sparkle": 0.6, }[point_shape], ) ) return PointCloudHandle._make(self, name, wxyz, position, visible) def add_mesh_skinned( self, name: str, vertices: onp.ndarray, faces: onp.ndarray, bone_wxyzs: tuple[tuple[float, float, float, float], ...] | onp.ndarray, bone_positions: tuple[tuple[float, float, float], ...] | onp.ndarray, skin_weights: onp.ndarray, color: RgbTupleOrArray = (90, 200, 255), wireframe: bool = False, opacity: float | None = None, material: Literal["standard", "toon3", "toon5"] = "standard", flat_shading: bool = False, side: Literal["front", "back", "double"] = "front", wxyz: Tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: Tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> MeshSkinnedHandle: """Add a skinned mesh to the scene, which we can deform using a set of bone transformations. Args: name: A scene tree name. Names in the format of /parent/child can be used to define a kinematic tree. vertices: A numpy array of vertex positions. Should have shape (V, 3). faces: A numpy array of faces, where each face is represented by indices of vertices. Should have shape (F,) bone_wxyzs: Nested tuple or array of initial bone orientations. bone_positions: Nested tuple or array of initial bone positions. skin_weights: A numpy array of skin weights. Should have shape (V, B) where B is the number of bones. Only the top 4 bone weights for each vertex will be used. color: Color of the mesh as an RGB tuple. wireframe: Boolean indicating if the mesh should be rendered as a wireframe. opacity: Opacity of the mesh. None means opaque. material: Material type of the mesh ('standard', 'toon3', 'toon5'). This argument is ignored when wireframe=True. flat_shading: Whether to do flat shading. This argument is ignored when wireframe=True. side: Side of the surface to render ('front', 'back', 'double'). wxyz: Quaternion rotation to parent frame from local frame (R_pl). position: Translation from parent frame to local frame (t_pl). visible: Whether or not this mesh is initially visible. Returns: Handle for manipulating scene node. """ if wireframe and material != "standard": warnings.warn( f"Invalid combination of {wireframe=} and {material=}. Material argument will be ignored.", stacklevel=2, ) if wireframe and flat_shading: warnings.warn( f"Invalid combination of {wireframe=} and {flat_shading=}. Flat shading argument will be ignored.", stacklevel=2, ) num_bones = len(bone_wxyzs) assert skin_weights.shape == (vertices.shape[0], num_bones) # Take the four biggest indices. top4_skin_indices = onp.argsort(skin_weights, axis=-1)[:, -4:] top4_skin_weights = skin_weights[ onp.arange(vertices.shape[0])[:, None], top4_skin_indices ] assert ( top4_skin_weights.shape == top4_skin_indices.shape == (vertices.shape[0], 4) ) bone_wxyzs = onp.asarray(bone_wxyzs) bone_positions = onp.asarray(bone_positions) assert bone_wxyzs.shape == (num_bones, 4) assert bone_positions.shape == (num_bones, 3) self._websock_interface.queue_message( _messages.SkinnedMeshMessage( name, vertices.astype(onp.float32), faces.astype(onp.uint32), # (255, 255, 255) => 0xffffff, etc color=_encode_rgb(color), vertex_colors=None, wireframe=wireframe, opacity=opacity, flat_shading=flat_shading, side=side, material=material, bone_wxyzs=tuple( ( float(wxyz[0]), float(wxyz[1]), float(wxyz[2]), float(wxyz[3]), ) for wxyz in bone_wxyzs.astype(onp.float32) ), bone_positions=tuple( (float(xyz[0]), float(xyz[1]), float(xyz[2])) for xyz in bone_positions.astype(onp.float32) ), skin_indices=top4_skin_indices.astype(onp.uint16), skin_weights=top4_skin_weights.astype(onp.float32), ) ) handle = MeshHandle._make(self, name, wxyz, position, visible) return MeshSkinnedHandle( handle._impl, bones=tuple( MeshSkinnedBoneHandle( _impl=BoneState( name=name, websock_interface=self._websock_interface, bone_index=i, wxyz=bone_wxyzs[i], position=bone_positions[i], ) ) for i in range(num_bones) ), ) def add_mesh_simple( self, name: str, vertices: onp.ndarray, faces: onp.ndarray, color: RgbTupleOrArray = (90, 200, 255), wireframe: bool = False, opacity: float | None = None, material: Literal["standard", "toon3", "toon5"] = "standard", flat_shading: bool = False, side: Literal["front", "back", "double"] = "front", wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> MeshHandle: """Add a mesh to the scene. Args: name: A scene tree name. Names in the format of /parent/child can be used to define a kinematic tree. vertices: A numpy array of vertex positions. Should have shape (V, 3). faces: A numpy array of faces, where each face is represented by indices of vertices. Should have shape (F,) color: Color of the mesh as an RGB tuple. wireframe: Boolean indicating if the mesh should be rendered as a wireframe. opacity: Opacity of the mesh. None means opaque. material: Material type of the mesh ('standard', 'toon3', 'toon5'). This argument is ignored when wireframe=True. flat_shading: Whether to do flat shading. This argument is ignored when wireframe=True. side: Side of the surface to render ('front', 'back', 'double'). wxyz: Quaternion rotation to parent frame from local frame (R_pl). position: Translation from parent frame to local frame (t_pl). visible: Whether or not this mesh is initially visible. Returns: Handle for manipulating scene node. """ if wireframe and material != "standard": warnings.warn( f"Invalid combination of {wireframe=} and {material=}. Material argument will be ignored.", stacklevel=2, ) if wireframe and flat_shading: warnings.warn( f"Invalid combination of {wireframe=} and {flat_shading=}. Flat shading argument will be ignored.", stacklevel=2, ) self._websock_interface.queue_message( _messages.MeshMessage( name, vertices.astype(onp.float32), faces.astype(onp.uint32), # (255, 255, 255) => 0xffffff, etc color=_encode_rgb(color), vertex_colors=None, wireframe=wireframe, opacity=opacity, flat_shading=flat_shading, side=side, material=material, ) ) return MeshHandle._make(self, name, wxyz, position, visible) def add_mesh_trimesh( self, name: str, mesh: trimesh.Trimesh, scale: float = 1.0, wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> GlbHandle: """Add a trimesh mesh to the scene. Internally calls `self.add_glb()`. Args: name: A scene tree name. Names in the format of /parent/child can be used to define a kinematic tree. mesh: A trimesh mesh object. scale: A scale for resizing the mesh. wxyz: Quaternion rotation to parent frame from local frame (R_pl). position: Translation to parent frame from local frame (t_pl). visible: Whether or not this scene node is initially visible. Returns: Handle for manipulating scene node. """ with io.BytesIO() as data_buffer: mesh.export(data_buffer, file_type="glb") glb_data = data_buffer.getvalue() return self.add_glb( name, glb_data=glb_data, scale=scale, wxyz=wxyz, position=position, visible=visible, ) def _add_gaussian_splats( self, name: str, centers: onp.ndarray, covariances: onp.ndarray, rgbs: onp.ndarray, opacities: onp.ndarray, wxyz: Tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: Tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> GaussianSplatHandle: """Add a model to render using Gaussian Splatting. **Work-in-progress.** This feature is experimental and still under development. It may be changed or removed. Arguments: name: Scene node name. centers: Centers of Gaussians. (N, 3). covariances: Second moment for each Gaussian. (N, 3, 3). rgbs: Color for each Gaussian. (N, 3). opacities: Opacity for each Gaussian. (N, 1). wxyz: R_parent_local transformation. position: t_parent_local transformation. visibile: Initial visibility of scene node. Returns: Scene node handle. """ num_gaussians = centers.shape[0] assert centers.shape == (num_gaussians, 3) assert rgbs.shape == (num_gaussians, 3) assert opacities.shape == (num_gaussians, 1) assert covariances.shape == (num_gaussians, 3, 3) # Get cholesky factor of covariance. This helps retain precision when # we convert to float16. cov_cholesky_triu = ( onp.linalg.cholesky(covariances.astype(onp.float64) + onp.ones(3) * 1e-7) .swapaxes(-1, -2) # tril => triu .reshape((-1, 9))[:, onp.array([0, 1, 2, 4, 5, 8])] ) buffer = onp.concatenate( [ # First texelFetch. # - xyz (96 bits): centers. centers.astype(onp.float32).view(onp.uint8), # - w (32 bits): this is reserved for use by the renderer. onp.zeros((num_gaussians, 4), dtype=onp.uint8), # Second texelFetch. # - xyz (96 bits): upper-triangular Cholesky factor of covariance. cov_cholesky_triu.astype(onp.float16).copy().view(onp.uint8), # - w (32 bits): rgba. _colors_to_uint8(rgbs), _colors_to_uint8(opacities), ], axis=-1, ).view(onp.uint32) assert buffer.shape == (num_gaussians, 8) self._websock_interface.queue_message( _messages.GaussianSplatsMessage( name=name, buffer=buffer, ) ) node_handle = GaussianSplatHandle._make(self, name, wxyz, position, visible) return node_handle def add_box( self, name: str, color: RgbTupleOrArray, dimensions: tuple[float, float, float] | onp.ndarray = (1.0, 1.0, 1.0), wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> MeshHandle: """Add a box to the scene. Args: name: A scene tree name. Names in the format of /parent/child can be used to define a kinematic tree. color: Color of the box as an RGB tuple. dimensions: Dimensions of the box (x, y, z). wxyz: Quaternion rotation to parent frame from local frame (R_pl). position: Translation from parent frame to local frame (t_pl). visible: Whether or not this box is initially visible. Returns: Handle for manipulating scene node. """ import trimesh.creation mesh = trimesh.creation.box(dimensions) return self.add_mesh_simple( name=name, vertices=mesh.vertices, faces=mesh.faces, color=color, flat_shading=True, position=position, wxyz=wxyz, visible=visible, ) def add_icosphere( self, name: str, radius: float, color: RgbTupleOrArray, subdivisions: int = 3, wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> MeshHandle: """Add an icosphere to the scene. Args: name: A scene tree name. Names in the format of /parent/child can be used to define a kinematic tree. radius: Radius of the icosphere. color: Color of the icosphere as an RGB tuple. subdivisions: Number of subdivisions to use when creating the icosphere. wxyz: Quaternion rotation to parent frame from local frame (R_pl). position: Translation from parent frame to local frame (t_pl). visible: Whether or not this icosphere is initially visible. Returns: Handle for manipulating scene node. """ import trimesh.creation mesh = trimesh.creation.icosphere(subdivisions=subdivisions, radius=radius) # We use add_mesh_simple() because it lets us do smooth shading; # add_mesh_trimesh() currently does not. return self.add_mesh_simple( name=name, vertices=mesh.vertices, faces=mesh.faces, color=color, flat_shading=False, position=position, wxyz=wxyz, visible=visible, ) def set_background_image( self, image: onp.ndarray, format: Literal["png", "jpeg"] = "jpeg", jpeg_quality: int | None = None, depth: onp.ndarray | None = None, ) -> None: """Set a background image for the scene, optionally with depth compositing. Args: image: The image to set as the background. Should have shape (H, W, 3). format: Format to transport and display the image using ('png' or 'jpeg'). jpeg_quality: Quality of the jpeg image (if jpeg format is used). depth: Optional depth image to use to composite background with scene elements. """ media_type, rgb_bytes = _encode_image_binary( image, format, jpeg_quality=jpeg_quality ) # Encode depth if provided. We use a 3-channel PNG to represent a fixed point # depth at each pixel. depth_bytes = None if depth is not None: # Convert to fixed-point. # We'll support from 0 -> (2^24 - 1) / 100_000. # # This translates to a range of [0, 167.77215], with a precision of 1e-5. assert len(depth.shape) == 2 or ( len(depth.shape) == 3 and depth.shape[2] == 1 ), "Depth should have shape (H,W) or (H,W,1)." depth = onp.clip(depth * 100_000, 0, 2**24 - 1).astype(onp.uint32) assert depth is not None # Appease mypy. intdepth: onp.ndarray = depth.reshape((*depth.shape[:2], 1)).view(onp.uint8) assert intdepth.shape == (*depth.shape[:2], 4) with io.BytesIO() as data_buffer: iio.imwrite(data_buffer, intdepth[:, :, :3], extension=".png") depth_bytes = data_buffer.getvalue() self._websock_interface.queue_message( _messages.BackgroundImageMessage( media_type=media_type, rgb_bytes=rgb_bytes, depth_bytes=depth_bytes, ) ) def add_image( self, name: str, image: onp.ndarray, render_width: float, render_height: float, format: Literal["png", "jpeg"] = "jpeg", jpeg_quality: int | None = None, wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> ImageHandle: """Add a 2D image to the scene. Args: name: A scene tree name. Names in the format of /parent/child can be used to define a kinematic tree. image: A numpy array representing the image. render_width: Width at which the image should be rendered in the scene. render_height: Height at which the image should be rendered in the scene. format: Format to transport and display the image using ('png' or 'jpeg'). jpeg_quality: Quality of the jpeg image (if jpeg format is used). wxyz: Quaternion rotation to parent frame from local frame (R_pl). position: Translation from parent frame to local frame (t_pl). visible: Whether or not this image is initially visible. Returns: Handle for manipulating scene node. """ media_type, binary = _encode_image_binary( image, format, jpeg_quality=jpeg_quality ) self._websock_interface.queue_message( _messages.ImageMessage( name=name, media_type=media_type, data=binary, render_width=render_width, render_height=render_height, ) ) return ImageHandle._make(self, name, wxyz, position, visible) def add_transform_controls( self, name: str, scale: float = 1.0, line_width: float = 2.5, fixed: bool = False, auto_transform: bool = True, active_axes: tuple[bool, bool, bool] = (True, True, True), disable_axes: bool = False, disable_sliders: bool = False, disable_rotations: bool = False, translation_limits: tuple[ tuple[float, float], tuple[float, float], tuple[float, float] ] = ((-1000.0, 1000.0), (-1000.0, 1000.0), (-1000.0, 1000.0)), rotation_limits: tuple[ tuple[float, float], tuple[float, float], tuple[float, float] ] = ((-1000.0, 1000.0), (-1000.0, 1000.0), (-1000.0, 1000.0)), depth_test: bool = True, opacity: float = 1.0, wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> TransformControlsHandle: """Add a transform gizmo for interacting with the scene. This method adds a transform control (gizmo) to the scene, allowing for interactive manipulation of objects in terms of their position, rotation, and scale. Args: name: A scene tree name. Names in the format of /parent/child can be used to define a kinematic tree. scale: Scale of the transform controls. line_width: Width of the lines used in the gizmo. fixed: Boolean indicating if the gizmo should be fixed in position. auto_transform: Whether the transform should be applied automatically. active_axes: tuple of booleans indicating active axes. disable_axes: Boolean to disable axes interaction. disable_sliders: Boolean to disable slider interaction. disable_rotations: Boolean to disable rotation interaction. translation_limits: Limits for translation. rotation_limits: Limits for rotation. depth_test: Boolean indicating if depth testing should be used when rendering. opacity: Opacity of the gizmo. wxyz: Quaternion rotation to parent frame from local frame (R_pl). position: Translation from parent frame to local frame (t_pl). visible: Whether or not this gizmo is initially visible. Returns: Handle for manipulating (and reading state of) scene node. """ self._websock_interface.queue_message( _messages.TransformControlsMessage( name=name, scale=scale, line_width=line_width, fixed=fixed, auto_transform=auto_transform, active_axes=active_axes, disable_axes=disable_axes, disable_sliders=disable_sliders, disable_rotations=disable_rotations, translation_limits=translation_limits, rotation_limits=rotation_limits, depth_test=depth_test, opacity=opacity, ) ) def sync_cb(client_id: ClientId, state: TransformControlsHandle) -> None: message_orientation = _messages.SetOrientationMessage( name=name, wxyz=tuple(map(float, state._impl.wxyz)), # type: ignore ) message_orientation.excluded_self_client = client_id self._websock_interface.queue_message(message_orientation) message_position = _messages.SetPositionMessage( name=name, position=tuple(map(float, state._impl.position)), # type: ignore ) message_position.excluded_self_client = client_id self._websock_interface.queue_message(message_position) node_handle = SceneNodeHandle._make(self, name, wxyz, position, visible) state_aux = _TransformControlsState( last_updated=time.time(), update_cb=[], sync_cb=sync_cb, ) handle = TransformControlsHandle(node_handle._impl, state_aux) self._handle_from_transform_controls_name[name] = handle return handle def reset(self) -> None: """Reset the scene.""" self._websock_interface.queue_message(_messages.ResetSceneMessage()) def _get_client_handle(self, client_id: ClientId) -> ClientHandle: """Private helper for getting a client handle from its ID.""" # Avoid circular imports. from ._viser import ViserServer # Implementation-wise, note that MessageApi is never directly instantiated. # Instead, it serves as a mixin/base class for either ViserServer, which # maintains a registry of connected clients, or ClientHandle, which should # only ever be dealing with its own client_id. if isinstance(self._owner, ViserServer): # TODO: there's a potential race condition here when the client disconnects. # This probably applies to multiple other parts of the code, we should # revisit all of the cases where we index into connected_clients. return self._owner._connected_clients[client_id] else: assert client_id == self._owner.client_id return self._owner def _handle_transform_controls_updates( self, client_id: ClientId, message: _messages.TransformControlsUpdateMessage ) -> None: """Callback for handling transform gizmo messages.""" handle = self._handle_from_transform_controls_name.get(message.name, None) if handle is None: return # Update state. wxyz = onp.array(message.wxyz) position = onp.array(message.position) with self._owner.atomic(): handle._impl.wxyz = wxyz handle._impl.position = position handle._impl_aux.last_updated = time.time() # Trigger callbacks. for cb in handle._impl_aux.update_cb: cb(handle) if handle._impl_aux.sync_cb is not None: handle._impl_aux.sync_cb(client_id, handle) def _handle_node_click_updates( self, client_id: ClientId, message: _messages.SceneNodeClickMessage ) -> None: """Callback for handling click messages.""" handle = self._handle_from_node_name.get(message.name, None) if handle is None or handle._impl.click_cb is None: return for cb in handle._impl.click_cb: event = SceneNodePointerEvent( client=self._get_client_handle(client_id), client_id=client_id, event="click", target=handle, ray_origin=message.ray_origin, ray_direction=message.ray_direction, screen_pos=message.screen_pos, instance_index=message.instance_index, ) cb(event) # type: ignore def _handle_scene_pointer_updates( self, client_id: ClientId, message: _messages.ScenePointerMessage ): """Callback for handling click messages.""" event = ScenePointerEvent( client=self._get_client_handle(client_id), client_id=client_id, event_type=message.event_type, ray_origin=message.ray_origin, ray_direction=message.ray_direction, screen_pos=message.screen_pos, ) # Call the callback if it exists, and the after-run callback. if self._scene_pointer_cb is None: return self._scene_pointer_cb(event) def on_pointer_event( self, event_type: Literal["click", "rect-select"] ) -> Callable[ [Callable[[ScenePointerEvent], None]], Callable[[ScenePointerEvent], None] ]: """Add a callback for scene pointer events. Args: event_type: event to listen to. """ # Ensure the event type is valid. assert event_type in get_args(_messages.ScenePointerEventType) from ._viser import ClientHandle, ViserServer def cleanup_previous_event(target: ViserServer | ClientHandle): # If the server or client does not have a scene pointer callback, return. if target.scene._scene_pointer_cb is None: return # Remove callback. target.scene.remove_pointer_callback() def decorator( func: Callable[[ScenePointerEvent], None], ) -> Callable[[ScenePointerEvent], None]: # Check if another scene pointer event was previously registered. # If so, we need to clear the previous event and register the new one. cleanup_previous_event(self._owner) # If called on the server handle, remove all clients' callbacks. if isinstance(self._owner, ViserServer): for client in self._owner.get_clients().values(): cleanup_previous_event(client) # If called on the client handle, and server handle has a callback, remove the server's callback. # (If the server has a callback, none of the clients should have callbacks.) elif isinstance(self._owner, ClientHandle): server = self._owner._viser_server cleanup_previous_event(server) self._scene_pointer_cb = func self._scene_pointer_event_type = event_type self._websock_interface.queue_message( _messages.ScenePointerEnableMessage(enable=True, event_type=event_type) ) return func return decorator def on_pointer_callback_removed( self, func: Callable[[], None], ) -> Callable[[], None]: """Add a callback to run automatically when the callback for a scene pointer event is removed. This will be triggered exactly once, either manually (via :meth:`remove_pointer_callback()`) or automatically (if the scene pointer event is overridden with another call to :meth:`on_pointer_event()`). Args: func: Callback for when scene pointer events are removed. """ self._scene_pointer_done_cb = func return func def remove_pointer_callback( self, ) -> None: """Remove the currently attached scene pointer event. This will trigger any callback attached to `.on_scene_pointer_removed()`.""" if self._scene_pointer_cb is None: warnings.warn( "No scene pointer callback exists for this server/client, ignoring.", stacklevel=2, ) return # Notify client that the listener has been removed. event_type = self._scene_pointer_event_type assert event_type is not None self._websock_interface.queue_message( _messages.ScenePointerEnableMessage(enable=False, event_type=event_type) ) self._owner.flush() # Run cleanup callback. self._scene_pointer_done_cb() # Reset the callback and event type, on the python side. self._scene_pointer_cb = None self._scene_pointer_done_cb = lambda: None self._scene_pointer_event_type = None def add_3d_gui_container( self, name: str, wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> Gui3dContainerHandle: """Add a 3D gui container to the scene. The returned container handle can be used as a context to place GUI elements into the 3D scene. Args: name: A scene tree name. Names in the format of /parent/child can be used to define a kinematic tree. wxyz: Quaternion rotation to parent frame from local frame (R_pl). position: Translation to parent frame from local frame (t_pl). visible: Whether or not this scene node is initially visible. Returns: Handle for manipulating scene node. Can be used as a context to place GUI elements inside of the container. """ # Avoids circular import. from ._gui_api import _make_unique_id # New name to make the type checker happy; ViserServer and ClientHandle inherit # from both GuiApi and MessageApi. The pattern below is unideal. gui_api = self._owner.gui # Remove the 3D GUI container if it already exists. This will make sure # contained GUI elements are removed, preventing potential memory leaks. if name in self._handle_from_node_name: self._handle_from_node_name[name].remove() container_id = _make_unique_id() self._websock_interface.queue_message( _messages.Gui3DMessage( order=time.time(), name=name, container_id=container_id, ) ) node_handle = SceneNodeHandle._make(self, name, wxyz, position, visible=visible) return Gui3dContainerHandle(node_handle._impl, gui_api, container_id) ================================================ FILE: viser/src/viser/_scene_handles.py ================================================ from __future__ import annotations import dataclasses from typing import TYPE_CHECKING, Callable, Generic, Literal, TypeVar import numpy as onp from . import _messages from .infra._infra import WebsockClientConnection, WebsockServer if TYPE_CHECKING: from ._gui_api import GuiApi from ._gui_handles import SupportsRemoveProtocol from ._scene_api import SceneApi from ._viser import ClientHandle from .infra import ClientId @dataclasses.dataclass(frozen=True) class ScenePointerEvent: """Event passed to pointer callbacks for the scene (currently only clicks).""" client: ClientHandle """Client that triggered this event.""" client_id: int """ID of client that triggered this event.""" event_type: _messages.ScenePointerEventType """Type of event that was triggered. Currently we only support clicks and box selections.""" ray_origin: tuple[float, float, float] | None """Origin of 3D ray corresponding to this click, in world coordinates.""" ray_direction: tuple[float, float, float] | None """Direction of 3D ray corresponding to this click, in world coordinates.""" screen_pos: tuple[tuple[float, float], ...] """Screen position of the click on the screen (OpenCV image coordinates, 0 to 1). (0, 0) is the upper-left corner, (1, 1) is the bottom-right corner. For a box selection, this includes the min- and max- corners of the box.""" @property def event(self): """Deprecated. Use `event_type` instead.""" return self.event_type TSceneNodeHandle = TypeVar("TSceneNodeHandle", bound="SceneNodeHandle") @dataclasses.dataclass class _SceneNodeHandleState: name: str api: SceneApi wxyz: onp.ndarray = dataclasses.field( default_factory=lambda: onp.array([1.0, 0.0, 0.0, 0.0]) ) position: onp.ndarray = dataclasses.field( default_factory=lambda: onp.array([0.0, 0.0, 0.0]) ) visible: bool = True # TODO: we should remove SceneNodeHandle as an argument here. click_cb: list[Callable[[SceneNodePointerEvent[SceneNodeHandle]], None]] | None = ( None ) @dataclasses.dataclass class SceneNodeHandle: """Handle base class for interacting with scene nodes.""" _impl: _SceneNodeHandleState @classmethod def _make( cls: type[TSceneNodeHandle], api: SceneApi, name: str, wxyz: tuple[float, float, float, float] | onp.ndarray, position: tuple[float, float, float] | onp.ndarray, visible: bool, ) -> TSceneNodeHandle: out = cls(_SceneNodeHandleState(name, api)) api._handle_from_node_name[name] = out out.wxyz = wxyz out.position = position # Toggle visibility to make sure we send a # SetSceneNodeVisibilityMessage to the client. out._impl.visible = not visible out.visible = visible return out @property def wxyz(self) -> onp.ndarray: """Orientation of the scene node. This is the quaternion representation of the R in `p_parent = [R | t] p_local`. Synchronized to clients automatically when assigned. """ return self._impl.wxyz @wxyz.setter def wxyz(self, wxyz: tuple[float, float, float, float] | onp.ndarray) -> None: from ._scene_api import cast_vector wxyz_cast = cast_vector(wxyz, 4) self._impl.wxyz = onp.asarray(wxyz) self._impl.api._websock_interface.queue_message( _messages.SetOrientationMessage(self._impl.name, wxyz_cast) ) @property def position(self) -> onp.ndarray: """Position of the scene node. This is equivalent to the t in `p_parent = [R | t] p_local`. Synchronized to clients automatically when assigned. """ return self._impl.position @position.setter def position(self, position: tuple[float, float, float] | onp.ndarray) -> None: from ._scene_api import cast_vector position_cast = cast_vector(position, 3) self._impl.position = onp.asarray(position) self._impl.api._websock_interface.queue_message( _messages.SetPositionMessage(self._impl.name, position_cast) ) @property def visible(self) -> bool: """Whether the scene node is visible or not. Synchronized to clients automatically when assigned.""" return self._impl.visible @visible.setter def visible(self, visible: bool) -> None: if visible == self._impl.visible: return self._impl.api._websock_interface.queue_message( _messages.SetSceneNodeVisibilityMessage(self._impl.name, visible) ) self._impl.visible = visible def remove(self) -> None: """Remove the node from the scene.""" self._impl.api._websock_interface.queue_message( _messages.RemoveSceneNodeMessage(self._impl.name) ) @dataclasses.dataclass(frozen=True) class SceneNodePointerEvent(Generic[TSceneNodeHandle]): """Event passed to pointer callbacks for scene nodes (currently only clicks).""" client: ClientHandle """Client that triggered this event.""" client_id: int """ID of client that triggered this event.""" event: Literal["click"] """Type of event that was triggered. Currently we only support clicks.""" target: TSceneNodeHandle """Scene node that was clicked.""" ray_origin: tuple[float, float, float] """Origin of 3D ray corresponding to this click, in world coordinates.""" ray_direction: tuple[float, float, float] """Direction of 3D ray corresponding to this click, in world coordinates.""" screen_pos: tuple[float, float] """Screen position of the click on the screen (OpenCV image coordinates, 0 to 1). (0, 0) is the upper-left corner, (1, 1) is the bottom-right corner.""" instance_index: int | None """Instance ID of the clicked object, if applicable. Currently this is `None` for all objects except for the output of :meth:`SceneApi.add_batched_axes()`.""" @dataclasses.dataclass class _ClickableSceneNodeHandle(SceneNodeHandle): def on_click( self: TSceneNodeHandle, func: Callable[[SceneNodePointerEvent[TSceneNodeHandle]], None], ) -> Callable[[SceneNodePointerEvent[TSceneNodeHandle]], None]: """Attach a callback for when a scene node is clicked.""" self._impl.api._websock_interface.queue_message( _messages.SetSceneNodeClickableMessage(self._impl.name, True) ) if self._impl.click_cb is None: self._impl.click_cb = [] self._impl.click_cb.append(func) # type: ignore return func @dataclasses.dataclass class CameraFrustumHandle(_ClickableSceneNodeHandle): """Handle for camera frustums.""" @dataclasses.dataclass class PointCloudHandle(SceneNodeHandle): """Handle for point clouds. Does not support click events.""" @dataclasses.dataclass class BatchedAxesHandle(_ClickableSceneNodeHandle): """Handle for batched coordinate frames.""" @dataclasses.dataclass class FrameHandle(_ClickableSceneNodeHandle): """Handle for coordinate frames.""" @dataclasses.dataclass class MeshHandle(_ClickableSceneNodeHandle): """Handle for mesh objects.""" @dataclasses.dataclass class GaussianSplatHandle(_ClickableSceneNodeHandle): """Handle for Gaussian splatting objects. **Work-in-progress.** Gaussian rendering is still under development. """ @dataclasses.dataclass class MeshSkinnedHandle(_ClickableSceneNodeHandle): """Handle for skinned mesh objects.""" bones: tuple[MeshSkinnedBoneHandle, ...] """Bones of the skinned mesh. These handles can be used for reading and writing poses, which are defined relative to the mesh root.""" @dataclasses.dataclass class BoneState: name: str websock_interface: WebsockServer | WebsockClientConnection bone_index: int wxyz: onp.ndarray position: onp.ndarray @dataclasses.dataclass class MeshSkinnedBoneHandle: """Handle for reading and writing the poses of bones in a skinned mesh.""" _impl: BoneState @property def wxyz(self) -> onp.ndarray: """Orientation of the bone. This is the quaternion representation of the R in `p_parent = [R | t] p_local`. Synchronized to clients automatically when assigned. """ return self._impl.wxyz @wxyz.setter def wxyz(self, wxyz: tuple[float, float, float, float] | onp.ndarray) -> None: from ._scene_api import cast_vector wxyz_cast = cast_vector(wxyz, 4) self._impl.wxyz = onp.asarray(wxyz) self._impl.websock_interface.queue_message( _messages.SetBoneOrientationMessage( self._impl.name, self._impl.bone_index, wxyz_cast ) ) @property def position(self) -> onp.ndarray: """Position of the bone. This is equivalent to the t in `p_parent = [R | t] p_local`. Synchronized to clients automatically when assigned. """ return self._impl.position @position.setter def position(self, position: tuple[float, float, float] | onp.ndarray) -> None: from ._scene_api import cast_vector position_cast = cast_vector(position, 3) self._impl.position = onp.asarray(position) self._impl.websock_interface.queue_message( _messages.SetBonePositionMessage( self._impl.name, self._impl.bone_index, position_cast ) ) @dataclasses.dataclass class GlbHandle(_ClickableSceneNodeHandle): """Handle for GLB objects.""" @dataclasses.dataclass class ImageHandle(_ClickableSceneNodeHandle): """Handle for 2D images, rendered in 3D.""" @dataclasses.dataclass class LabelHandle(SceneNodeHandle): """Handle for 2D label objects. Does not support click events.""" @dataclasses.dataclass class _TransformControlsState: last_updated: float update_cb: list[Callable[[TransformControlsHandle], None]] sync_cb: None | Callable[[ClientId, TransformControlsHandle], None] = None @dataclasses.dataclass class TransformControlsHandle(_ClickableSceneNodeHandle): """Handle for interacting with transform control gizmos.""" _impl_aux: _TransformControlsState @property def update_timestamp(self) -> float: return self._impl_aux.last_updated def on_update( self, func: Callable[[TransformControlsHandle], None] ) -> Callable[[TransformControlsHandle], None]: """Attach a callback for when the gizmo is moved.""" self._impl_aux.update_cb.append(func) return func @dataclasses.dataclass class Gui3dContainerHandle(SceneNodeHandle): """Use as a context to place GUI elements into a 3D GUI container.""" _gui_api: GuiApi _container_id: str _container_id_restore: str | None = None _children: dict[str, SupportsRemoveProtocol] = dataclasses.field( default_factory=dict ) def __enter__(self) -> Gui3dContainerHandle: self._container_id_restore = self._gui_api._get_container_id() self._gui_api._set_container_id(self._container_id) return self def __exit__(self, *args) -> None: del args assert self._container_id_restore is not None self._gui_api._set_container_id(self._container_id_restore) self._container_id_restore = None def __post_init__(self) -> None: self._gui_api._container_handle_from_id[self._container_id] = self def remove(self) -> None: """Permanently remove this GUI container from the visualizer.""" # Call scene node remove. super().remove() # Clean up contained GUI elements. for child in tuple(self._children.values()): child.remove() self._gui_api._container_handle_from_id.pop(self._container_id) ================================================ FILE: viser/src/viser/_tunnel.py ================================================ from __future__ import annotations import asyncio import multiprocessing as mp import threading from functools import lru_cache from multiprocessing.managers import DictProxy from pathlib import Path from typing import Callable, Literal import rich @lru_cache def _is_multiprocess_ok() -> bool: import __main__ if hasattr(__main__, "__file__"): src = Path(__main__.__file__).read_text() return "\nif __name__" in src and "__main__" in src else: return True class ViserTunnel: """Tunneling utility for internal use. This is chaotic academic software, and we'd appreciate if you refrained from red-teaming it. :) """ def __init__(self, share_domain: str, local_port: int) -> None: self._share_domain = share_domain self._local_port = local_port # Heuristic for `if __name__ == "__main__"` check. self._multiprocess_ok = _is_multiprocess_ok() if not self._multiprocess_ok: rich.print( "[bold](viser)[/bold] No `if __name__ == '__main__'` check found; creating share URL tunnel in a thread" ) self._process: mp.Process | None = None self._thread: threading.Thread | None = None self._event_loop: asyncio.AbstractEventLoop | None = None self._shared_state: DictProxy | dict if self._multiprocess_ok: manager = mp.Manager() self._connect_event = manager.Event() self._disconnect_event = manager.Event() self._close_event = None # Only used for threads. For processes, we just kill the tunnel process. self._shared_state = manager.dict() else: self._connect_event = threading.Event() self._disconnect_event = threading.Event() self._close_event = asyncio.Event() self._shared_state = {} self._shared_state["status"] = "ready" self._shared_state["url"] = None def on_disconnect(self, callback: Callable[[], None]) -> None: def call_on_disconnect() -> None: try: self._disconnect_event.wait() except EOFError: return callback() threading.Thread(target=call_on_disconnect, daemon=True).start() def on_connect(self, callback: Callable[[int], None]) -> None: """Establish the tunnel connection. Returns URL if tunnel succeeds, otherwise None.""" assert self._process is None self._shared_state["status"] = "connecting" def wait_job() -> None: try: self._connect_event.wait() except EOFError: return callback(self._shared_state["max_conn_count"]) threading.Thread(target=wait_job, daemon=True).start() # Note that this will generally require an __name__ == "__main__" check # on the origin script. if self._multiprocess_ok: self._process = mp.Process( target=_connect_job, daemon=True, args=( self._connect_event, self._disconnect_event, self._close_event, self._share_domain, self._local_port, self._shared_state, None, ), ) self._process.start() else: self._thread = threading.Thread( target=_connect_job, daemon=True, args=( self._connect_event, self._disconnect_event, self._close_event, self._share_domain, self._local_port, self._shared_state, self, ), ) self._thread.start() def get_url(self) -> str | None: """Get tunnel URL. None if not connected (or connection failed).""" return self._shared_state["url"] def get_status( self, ) -> Literal["ready", "connecting", "failed", "connected", "closed"]: return self._shared_state["status"] def close(self) -> None: """Close the tunnel.""" if self._process is not None: self._process.kill() self._process.join() self._disconnect_event.set() if self._thread is not None: assert self._event_loop is not None @self._event_loop.call_soon_threadsafe def _() -> None: assert self._close_event is not None self._close_event.set() self._thread.join() self._disconnect_event.set() def _connect_job( connect_event: threading.Event, disconnect_event: threading.Event, close_event: asyncio.Event | None, # Only for threads. share_domain: str, local_port: int, shared_state: DictProxy | dict, event_loop_target: ViserTunnel | None, # Only for threads. ) -> None: event_loop = asyncio.new_event_loop() asyncio.set_event_loop(event_loop) if event_loop_target is not None: event_loop_target._event_loop = event_loop if close_event is None: close_event = asyncio.Event() try: event_loop.run_until_complete( _make_tunnel( connect_event, disconnect_event, close_event, share_domain, local_port, shared_state, ) ) event_loop.close() except KeyboardInterrupt: event_loop.call_soon_threadsafe(close_event.set) tasks = asyncio.all_tasks(event_loop) event_loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) event_loop.close() async def _make_tunnel( connect_event: threading.Event, disconnect_event: threading.Event, close_event: asyncio.Event | None, share_domain: str, local_port: int, shared_state: DictProxy | dict, ) -> None: share_domain = "share.viser.studio" import requests try: response = requests.request( "GET", url=f"https://{share_domain}/?request_forward", headers={"Content-Type": "application/json"}, ) if response.status_code != 200: shared_state["status"] = "failed" return except requests.exceptions.ConnectionError: shared_state["status"] = "failed" return except Exception as e: shared_state["status"] = "failed" raise e res = response.json() shared_state["url"] = res["url"] shared_state["max_conn_count"] = res["max_conn_count"] shared_state["status"] = "connected" connect_event.set() await asyncio.gather( *[ asyncio.create_task( _simple_proxy( "127.0.0.1", local_port, share_domain, res["port"], close_event if close_event is not None else asyncio.Event(), ) ) for _ in range(res["max_conn_count"]) ] ) shared_state["url"] = None shared_state["status"] = "closed" disconnect_event.set() async def _simple_proxy( local_host: str, local_port: int, remote_host: str, remote_port: int, close_event: asyncio.Event, ) -> None: """Establish a connection to the tunnel server.""" async def close_writer(writer: asyncio.StreamWriter) -> None: """Utility for closing a writer and waiting until done, while suppressing errors from broken connections.""" try: if not writer.is_closing(): writer.close() await writer.wait_closed() except ConnectionError: pass async def relay(r: asyncio.StreamReader, w: asyncio.StreamWriter) -> None: """Simple data passthrough from one stream to another.""" try: while True: data = await r.read(4096) if len(data) == 0: # Done! break w.write(data) await w.drain() except Exception: pass finally: await close_writer(w) while True: local_w = None remote_w = None try: local_r, local_w = await asyncio.open_connection(local_host, local_port) remote_r, remote_w = await asyncio.open_connection(remote_host, remote_port) await asyncio.wait( [ asyncio.gather( asyncio.create_task(relay(local_r, remote_w)), asyncio.create_task(relay(remote_r, local_w)), ), asyncio.create_task(close_event.wait()), ], return_when=asyncio.FIRST_COMPLETED, ) except Exception: pass finally: # Be extra sure that connections are closed. if local_w is not None: await close_writer(local_w) if remote_w is not None: await close_writer(remote_w) if close_event.is_set(): break # Throttle connection attempts. await asyncio.sleep(0.1) ================================================ FILE: viser/src/viser/_viser.py ================================================ from __future__ import annotations import dataclasses import io import mimetypes import threading import time import warnings from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, ContextManager import imageio.v3 as iio import numpy as onp import numpy.typing as npt import rich from rich import box, style from rich.panel import Panel from rich.table import Table from typing_extensions import Literal from . import _client_autobuild, _messages, infra from . import transforms as tf from ._gui_api import Color, GuiApi, _make_unique_id from ._notification_handle import NotificationHandle, _NotificationHandleState from ._scene_api import SceneApi, cast_vector from ._tunnel import ViserTunnel from .infra._infra import RecordHandle class _BackwardsCompatibilityShim: """Shims for backward compatibility with viser API from version `<=0.1.30`.""" def __getattr__(self, name: str) -> Any: fixed_name = { # Map from old method names (viser v0.1.*) to new methods names. "reset_scene": "reset", "set_global_scene_node_visibility": "set_global_visibility", "on_scene_pointer": "on_pointer_event", "on_scene_pointer_removed": "on_pointer_callback_removed", "remove_scene_pointer_callback": "remove_pointer_callback", "add_mesh": "add_mesh_simple", }.get(name, name) if hasattr(self.scene, fixed_name): warnings.warn( f"{type(self).__name__}.{name} has been deprecated, use {type(self).__name__}.scene.{fixed_name} instead. Alternatively, pin to `viser<0.2.0`.", category=DeprecationWarning, stacklevel=2, ) return object.__getattribute__(self.scene, fixed_name) fixed_name = name.replace("add_gui_", "add_").replace("set_gui_", "set_") if hasattr(self.gui, fixed_name): warnings.warn( f"{type(self).__name__}.{name} has been deprecated, use {type(self).__name__}.gui.{fixed_name} instead. Alternatively, pin to `viser<0.2.0`.", category=DeprecationWarning, stacklevel=2, ) return object.__getattribute__(self.gui, fixed_name) raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'" ) @dataclasses.dataclass class _CameraHandleState: """Information about a client's camera state.""" client: ClientHandle wxyz: npt.NDArray[onp.float64] position: npt.NDArray[onp.float64] fov: float aspect: float look_at: npt.NDArray[onp.float64] up_direction: npt.NDArray[onp.float64] update_timestamp: float camera_cb: list[Callable[[CameraHandle], None]] class CameraHandle: """A handle for reading and writing the camera state of a particular client. Typically accessed via :attr:`ClientHandle.camera`.""" def __init__(self, client: ClientHandle) -> None: self._state = _CameraHandleState( client, wxyz=onp.zeros(4), position=onp.zeros(3), fov=0.0, aspect=0.0, look_at=onp.zeros(3), up_direction=onp.zeros(3), update_timestamp=0.0, camera_cb=[], ) @property def client(self) -> ClientHandle: """Client that this camera corresponds to.""" return self._state.client @property def wxyz(self) -> npt.NDArray[onp.float64]: """Corresponds to the R in `P_world = [R | t] p_camera`. Synchronized automatically when assigned.""" assert self._state.update_timestamp != 0.0 return self._state.wxyz # Note: asymmetric properties are supported in Pyright, but not yet in mypy. # - https://github.com/python/mypy/issues/3004 # - https://github.com/python/mypy/pull/11643 @wxyz.setter def wxyz(self, wxyz: tuple[float, float, float, float] | onp.ndarray) -> None: R_world_camera = tf.SO3(onp.asarray(wxyz)).as_matrix() look_distance = onp.linalg.norm(self.look_at - self.position) # We're following OpenCV conventions: look_direction is +Z, up_direction is -Y, # right_direction is +X. look_direction = R_world_camera[:, 2] up_direction = -R_world_camera[:, 1] right_direction = R_world_camera[:, 0] # Minimize our impact on the orbit controls by keeping the new up direction as # close to the old one as possible. projected_up_direction = ( self.up_direction - float(self.up_direction @ right_direction) * right_direction ) up_cosine = float(up_direction @ projected_up_direction) if abs(up_cosine) < 0.05: projected_up_direction = up_direction elif up_cosine < 0.0: projected_up_direction = up_direction new_look_at = look_direction * look_distance + self.position # Update lookat and up direction. self.look_at = new_look_at self.up_direction = projected_up_direction # The internal camera orientation should be set in the look_at / # up_direction setters. We can uncomment this assert to check this. # assert onp.allclose(self._state.wxyz, wxyz) or onp.allclose( # self._state.wxyz, -wxyz # ) @property def position(self) -> npt.NDArray[onp.float64]: """Corresponds to the t in `P_world = [R | t] p_camera`. Synchronized automatically when assigned. The `look_at` point and `up_direction` vectors are maintained when updating `position`, which means that updates to `position` will often also affect `wxyz`. """ assert self._state.update_timestamp != 0.0 return self._state.position @position.setter def position(self, position: tuple[float, float, float] | onp.ndarray) -> None: offset = onp.asarray(position) - onp.array(self.position) # type: ignore self._state.position = onp.asarray(position) self.look_at = onp.array(self.look_at) + offset self._state.update_timestamp = time.time() self._state.client._websock_connection.queue_message( _messages.SetCameraPositionMessage(cast_vector(position, 3)) ) def _update_wxyz(self) -> None: """Compute and update the camera orientation from the internal look_at, position, and up vectors.""" z = self._state.look_at - self._state.position z /= onp.linalg.norm(z) y = tf.SO3.exp(z * onp.pi) @ self._state.up_direction y = y - onp.dot(z, y) * z y /= onp.linalg.norm(y) x = onp.cross(y, z) self._state.wxyz = tf.SO3.from_matrix(onp.stack([x, y, z], axis=1)).wxyz @property def fov(self) -> float: """Vertical field of view of the camera, in radians. Synchronized automatically when assigned.""" assert self._state.update_timestamp != 0.0 return self._state.fov @fov.setter def fov(self, fov: float) -> None: self._state.fov = fov self._state.update_timestamp = time.time() self._state.client._websock_connection.queue_message( _messages.SetCameraFovMessage(fov) ) @property def aspect(self) -> float: """Canvas width divided by height. Not assignable.""" assert self._state.update_timestamp != 0.0 return self._state.aspect @property def update_timestamp(self) -> float: assert self._state.update_timestamp != 0.0 return self._state.update_timestamp @property def look_at(self) -> npt.NDArray[onp.float64]: """Look at point for the camera. Synchronized automatically when set.""" assert self._state.update_timestamp != 0.0 return self._state.look_at @look_at.setter def look_at(self, look_at: tuple[float, float, float] | onp.ndarray) -> None: self._state.look_at = onp.asarray(look_at) self._state.update_timestamp = time.time() self._update_wxyz() self._state.client._websock_connection.queue_message( _messages.SetCameraLookAtMessage(cast_vector(look_at, 3)) ) @property def up_direction(self) -> npt.NDArray[onp.float64]: """Up direction for the camera. Synchronized automatically when set.""" assert self._state.update_timestamp != 0.0 return self._state.up_direction @up_direction.setter def up_direction( self, up_direction: tuple[float, float, float] | onp.ndarray ) -> None: self._state.up_direction = onp.asarray(up_direction) self._update_wxyz() self._state.update_timestamp = time.time() self._state.client._websock_connection.queue_message( _messages.SetCameraUpDirectionMessage(cast_vector(up_direction, 3)) ) def on_update( self, callback: Callable[[CameraHandle], None] ) -> Callable[[CameraHandle], None]: """Attach a callback to run when a new camera message is received.""" self._state.camera_cb.append(callback) return callback def get_render( self, height: int, width: int, transport_format: Literal["png", "jpeg"] = "jpeg" ) -> onp.ndarray: """Request a render from a client, block until it's done and received, then return it as a numpy array. Args: height: Height of rendered image. Should be <= the browser height. width: Width of rendered image. Should be <= the browser width. transport_format: Image transport format. JPEG will return a lossy (H, W, 3) RGB array. PNG will return a lossless (H, W, 4) RGBA array, but can cause memory issues on the frontend if called too quickly for higher-resolution images. """ # Listen for a render reseponse message, which should contain the rendered # image. render_ready_event = threading.Event() out: onp.ndarray | None = None connection = self.client._websock_connection def got_render_cb( client_id: int, message: _messages.GetRenderResponseMessage ) -> None: del client_id connection.unregister_handler( _messages.GetRenderResponseMessage, got_render_cb ) nonlocal out out = iio.imread( io.BytesIO(message.payload), extension=f".{transport_format}", ) render_ready_event.set() connection.register_handler(_messages.GetRenderResponseMessage, got_render_cb) self.client._websock_connection.queue_message( _messages.GetRenderRequestMessage( "image/jpeg" if transport_format == "jpeg" else "image/png", height=height, width=width, # Only used for JPEG. The main reason to use a lower quality version # value is (unfortunately) to make life easier for the Javascript # garbage collector. quality=80, ) ) render_ready_event.wait() assert out is not None return out # Don't inherit from _BackwardsCompatibilityShim during type checking, because # this will unnecessarily suppress type errors. (from the overriding of # __getattr__). class ClientHandle(_BackwardsCompatibilityShim if not TYPE_CHECKING else object): """A handle is created for each client that connects to a server. Handles can be used to communicate with just one client, as well as for reading and writing of camera state. Similar to :class:`ViserServer`, client handles also expose scene and GUI interfaces at :attr:`ClientHandle.scene` and :attr:`ClientHandle.gui`. If these are used, for example via a client's :meth:`SceneApi.add_point_cloud()` method, created elements are local to only one specific client. """ def __init__( self, conn: infra.WebsockClientConnection, server: ViserServer ) -> None: # Private attributes. self._websock_connection = conn self._viser_server = server # Public attributes. self.scene: SceneApi = SceneApi( self, thread_executor=server._websock_server._thread_executor ) """Handle for interacting with the 3D scene.""" self.gui: GuiApi = GuiApi( self, thread_executor=server._websock_server._thread_executor ) """Handle for interacting with the GUI.""" self.client_id: int = conn.client_id """Unique ID for this client.""" self.camera: CameraHandle = CameraHandle(self) """Handle for reading from and manipulating the client's viewport camera.""" def flush(self) -> None: """Flush the outgoing message buffer. Any buffered messages will immediately be sent. (by default they are windowed)""" self._viser_server._websock_server.flush_client(self.client_id) def atomic(self) -> ContextManager[None]: """Returns a context where: all outgoing messages are grouped and applied by clients atomically. This should be treated as a soft constraint that's helpful for things like animations, or when we want position and orientation updates to happen synchronously. Returns: Context manager. """ return self._websock_connection.atomic() def send_file_download( self, filename: str, content: bytes, chunk_size: int = 1024 * 1024 ) -> None: """Send a file for a client or clients to download. Args: filename: Name of the file to send. Used to infer MIME type. content: Content of the file. chunk_size: Number of bytes to send at a time. """ mime_type = mimetypes.guess_type(filename, strict=False)[0] if mime_type is None: mime_type = "application/octet-stream" parts = [ content[i * chunk_size : (i + 1) * chunk_size] for i in range(int(onp.ceil(len(content) / chunk_size))) ] uuid = _make_unique_id() self._websock_connection.queue_message( _messages.FileTransferStart( source_component_id=None, transfer_uuid=uuid, filename=filename, mime_type=mime_type, part_count=len(parts), size_bytes=len(content), ) ) for i, part in enumerate(parts): self._websock_connection.queue_message( _messages.FileTransferPart( None, transfer_uuid=uuid, part=i, content=part, ) ) self.flush() def add_notification( self, title: str, body: str, loading: bool = False, with_close_button: bool = True, auto_close: int | Literal[False] = False, color: Color | None = None, ) -> NotificationHandle: """Add a notification to the client's interface. This method creates a new notification that will be displayed at the top left corner of the client's viewer. Notifications are useful for providing alerts or status updates to users. Args: title: Title to display on the notification. body: Message to display on the notification body. loading: Whether the notification shows loading icon. with_close_button: Whether the notification can be manually closed. auto_close: Time in ms before the notification automatically closes; otherwise False such that the notification never closes on its own. Returns: A handle that can be used to interact with the GUI element. """ handle = NotificationHandle( _NotificationHandleState( websock_interface=self._websock_connection, id=_make_unique_id(), title=title, body=body, loading=loading, with_close_button=with_close_button, auto_close=auto_close, color=color, ) ) handle._sync_with_client(first=True) return handle class ViserServer(_BackwardsCompatibilityShim if not TYPE_CHECKING else object): """:class:`ViserServer` is the main class for working with viser. On instantiation, it (a) launches a thread with a web server and (b) provides a high-level API for interactive 3D visualization. **Core API.** Clients can connect via a web browser, and will be shown two components: a 3D scene and a 2D GUI panel. Methods belonging to :attr:`ViserServer.scene` can be used to add 3D primitives to the scene. Methods belonging to :attr:`ViserServer.gui` can be used to add 2D GUI elements. **Shared state.** Elements added to the server object, for example via a server's :meth:`SceneApi.add_point_cloud` or :meth:`GuiApi.add_button`, will have state that's shared and synchronized automatically between all connected clients. To show elements that are local to a single client, see :attr:`ClientHandle.scene` and :attr:`ClientHandle.gui`. Args: host: Host to bind server to. port: Port to bind server to. label: Label shown at the top of the GUI panel. """ # Hide deprecated arguments from docstring and type checkers. def __init__( self, host: str = "0.0.0.0", port: int = 8080, label: str | None = None, verbose: bool = True, **_deprecated_kwargs, ): # Create server. server = infra.WebsockServer( host=host, port=port, message_class=_messages.Message, http_server_root=Path(__file__).absolute().parent / "client" / "build", verbose=verbose, client_api_version=1, ) self._websock_server = server _client_autobuild.ensure_client_is_built() self._connection = server self._connected_clients: dict[int, ClientHandle] = {} self._client_lock = threading.Lock() self._client_connect_cb: list[Callable[[ClientHandle], None]] = [] self._client_disconnect_cb: list[Callable[[ClientHandle], None]] = [] # For new clients, register and add a handler for camera messages. @server.on_client_connect def _(conn: infra.WebsockClientConnection) -> None: client = ClientHandle(conn, server=self) first = True def handle_camera_message( client_id: infra.ClientId, message: _messages.ViewerCameraMessage ) -> None: nonlocal first assert client_id == client.client_id # Update the client's camera. with client.atomic(): client.camera._state = _CameraHandleState( client, onp.array(message.wxyz), onp.array(message.position), message.fov, message.aspect, onp.array(message.look_at), onp.array(message.up_direction), time.time(), camera_cb=client.camera._state.camera_cb, ) # We consider a client to be connected after the first camera message is # received. if first: first = False with self._client_lock: self._connected_clients[conn.client_id] = client for cb in self._client_connect_cb: cb(client) for camera_cb in client.camera._state.camera_cb: camera_cb(client.camera) conn.register_handler(_messages.ViewerCameraMessage, handle_camera_message) # Remove clients when they disconnect. @server.on_client_disconnect def _(conn: infra.WebsockClientConnection) -> None: with self._client_lock: if conn.client_id not in self._connected_clients: return handle = self._connected_clients.pop(conn.client_id) for cb in self._client_disconnect_cb: cb(handle) # Start the server. server.start() self.scene: SceneApi = SceneApi(self, thread_executor=server._thread_executor) """Handle for interacting with the 3D scene.""" self.gui: GuiApi = GuiApi(self, thread_executor=server._thread_executor) """Handle for interacting with the GUI.""" server.register_handler( _messages.ShareUrlDisconnect, lambda client_id, msg: self.disconnect_share_url(), ) server.register_handler( _messages.ShareUrlRequest, lambda client_id, msg: self.request_share_url() ) # Form status print. port = server._port # Port may have changed. http_url = f"http://{host}:{port}" ws_url = f"ws://{host}:{port}" table = Table( title=None, show_header=False, box=box.MINIMAL, title_style=style.Style(bold=True), ) table.add_row("HTTP", http_url) table.add_row("Websocket", ws_url) rich.print(Panel(table, title="[bold]viser[/bold]", expand=False)) self._share_tunnel: ViserTunnel | None = None # Create share tunnel if requested. # This is deprecated: we should use get_share_url() instead. share = _deprecated_kwargs.get("share", False) if share: self.request_share_url() self.scene.reset() self.gui.reset() self.gui.set_panel_label(label) def get_host(self) -> str: """Returns the host address of the Viser server. Returns: Host address as string. """ return self._websock_server._host def get_port(self) -> int: """Returns the port of the Viser server. This could be different from the originally requested one. Returns: Port as integer. """ return self._websock_server._port def request_share_url(self, verbose: bool = True) -> str | None: """Request a share URL for the Viser server, which allows for public access. On the first call, will block until a connecting with the share URL server is established. Afterwards, the URL will be returned directly. This is an experimental feature that relies on an external server; it shouldn't be relied on for critical applications. Returns: Share URL as string, or None if connection fails or is closed. """ if self._share_tunnel is not None: # Tunnel already exists. while self._share_tunnel.get_status() in ("ready", "connecting"): time.sleep(0.05) return self._share_tunnel.get_url() else: # Create a new tunnel!. if verbose: rich.print("[bold](viser)[/bold] Share URL requested!") connect_event = threading.Event() self._share_tunnel = ViserTunnel( "share.viser.studio", self._websock_server._port ) @self._share_tunnel.on_disconnect def _() -> None: rich.print("[bold](viser)[/bold] Disconnected from share URL") self._share_tunnel = None self._websock_server.unsafe_send_message( _messages.ShareUrlUpdated(None) ) @self._share_tunnel.on_connect def _(max_clients: int) -> None: assert self._share_tunnel is not None share_url = self._share_tunnel.get_url() if verbose: if share_url is None: rich.print("[bold](viser)[/bold] Could not generate share URL") else: rich.print( f"[bold](viser)[/bold] Generated share URL (expires in 24 hours, max {max_clients} clients): {share_url}" ) self._websock_server.unsafe_send_message( _messages.ShareUrlUpdated(share_url) ) connect_event.set() connect_event.wait() url = self._share_tunnel.get_url() return url def disconnect_share_url(self) -> None: """Disconnect from the share URL server.""" if self._share_tunnel is not None: self._share_tunnel.close() else: rich.print( "[bold](viser)[/bold] Tried to disconnect from share URL, but already disconnected" ) def stop(self) -> None: """Stop the Viser server and associated threads and tunnels.""" self._websock_server.stop() if self._share_tunnel is not None: self._share_tunnel.close() def get_clients(self) -> dict[int, ClientHandle]: """Creates and returns a copy of the mapping from connected client IDs to handles. Returns: Dictionary of clients. """ with self._client_lock: return self._connected_clients.copy() def on_client_connect( self, cb: Callable[[ClientHandle], None] ) -> Callable[[ClientHandle], None]: """Attach a callback to run for newly connected clients.""" with self._client_lock: clients = self._connected_clients.copy().values() self._client_connect_cb.append(cb) # Trigger callback on any already-connected clients. # If we have: # # server = viser.ViserServer() # server.on_client_connect(...) # # This makes sure that the the callback is applied to any clients that # connect between the two lines. for client in clients: cb(client) return cb def on_client_disconnect( self, cb: Callable[[ClientHandle], None] ) -> Callable[[ClientHandle], None]: """Attach a callback to run when clients disconnect.""" self._client_disconnect_cb.append(cb) return cb def flush(self) -> None: """Flush the outgoing message buffer. Any buffered messages will immediately be sent. (by default they are windowed)""" self._websock_server.flush() def atomic(self) -> ContextManager[None]: """Returns a context where: all outgoing messages are grouped and applied by clients atomically. This should be treated as a soft constraint that's helpful for things like animations, or when we want position and orientation updates to happen synchronously. Returns: Context manager. """ return self._websock_server.atomic() def send_file_download( self, filename: str, content: bytes, chunk_size: int = 1024 * 1024 ) -> None: """Send a file for a client or clients to download. Args: filename: Name of the file to send. Used to infer MIME type. content: Content of the file. chunk_size: Number of bytes to send at a time. """ for client in self.get_clients().values(): client.send_file_download(filename, content, chunk_size) def _start_scene_recording(self) -> RecordHandle: """Start recording outgoing messages for playback or embedding. Includes only the scene. **Work-in-progress.** This API may be changed or removed. """ recorder = self._websock_server.start_recording( # Don't record GUI messages. This feels brittle. filter=lambda message: "Gui" not in type(message).__name__ ) # Insert current scene state. for message in self._websock_server._broadcast_buffer.message_from_id.values(): recorder._insert_message(message) return recorder ================================================ FILE: viser/src/viser/client/.eslintrc.js ================================================ module.exports = { settings: { react: { version: "detect", // React version. "detect" automatically picks the version you have installed. // You can also use `16.0`, `16.3`, etc, if you want to override the detected value. // It will default to "latest" and warn if missing, and to "detect" in the future }, }, env: { browser: true, es2021: true, }, extends: [ "eslint:recommended", "plugin:react/recommended", "plugin:@typescript-eslint/recommended", ], overrides: [], parser: "@typescript-eslint/parser", parserOptions: { ecmaVersion: "latest", sourceType: "module", }, plugins: ["react", "@typescript-eslint", "react-refresh"], ignorePatterns: ["build/", ".eslintrc.js"], rules: { // https://github.com/jsx-eslint/eslint-plugin-react/issues/3423 "react/no-unknown-property": "off", "no-constant-condition": "off", // Suppress errors for missing 'import React' in files. "react/react-in-jsx-scope": "off", "@typescript-eslint/ban-ts-comment": "off", "@typescript-eslint/no-explicit-any": "off", "@typescript-eslint/no-non-null-assertion": "off", "react/prop-types": [ "error", { skipUndeclared: true, }, ], "react-refresh/only-export-components": "warn", }, }; ================================================ FILE: viser/src/viser/client/.gitignore ================================================ # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. # dependencies /node_modules /.pnp .pnp.js # testing /coverage # misc .DS_Store .env.local .env.development.local .env.test.local .env.production.local npm-debug.log* yarn-debug.log* yarn-error.log* ================================================ FILE: viser/src/viser/client/index.html ================================================ Viser
================================================ FILE: viser/src/viser/client/package.json ================================================ { "name": "viser", "version": "0.1.0", "private": true, "dependencies": { "@mantine/core": "^7.6.2", "@mantine/dates": "^7.6.2", "@mantine/hooks": "^7.6.2", "@mantine/notifications": "^7.6.2", "@mantine/vanilla-extract": "^7.6.2", "@mdx-js/mdx": "^3.0.1", "@mdx-js/react": "^3.0.1", "@msgpack/msgpack": "^3.0.0-beta2", "@react-three/drei": "^9.64.0", "@react-three/fiber": "^8.12.0", "@tabler/icons-react": "^3.1.0", "@types/node": "^20.11.30", "@types/react": "^18.0.33", "@types/react-dom": "^18.0.11", "@types/three": "^0.162.0", "@vanilla-extract/css": "^1.14.1", "@vitejs/plugin-react": "^4.0.1", "await-lock": "^2.2.2", "clsx": "^2.1.0", "colortranslator": "^4.1.0", "dayjs": "^1.11.10", "detect-browser": "^5.3.0", "fflate": "^0.8.2", "hold-event": "^1.1.0", "immer": "^10.0.4", "its-fine": "^1.2.5", "mantine-react-table": "^2.0.0-beta.0", "postcss": "^8.4.38", "prettier": "^3.0.3", "react": "^18.2.0", "react-dom": "^18.2.0", "react-error-boundary": "^4.0.10", "react-qr-code": "^2.0.12", "react-router-dom": "^6.10.0", "rehype-color-chips": "^0.1.3", "remark-gfm": "^4.0.0", "three": "^0.162.0", "vite": "^5.2.6", "vite-plugin-svgr": "^4.2.0", "vite-tsconfig-paths": "^4.2.0", "web-vitals": "^3.3.1", "ws": "^8.13.0", "zustand": "^4.3.7" }, "scripts": { "start": "vite --host", "build": "tsc && vite build", "serve": "vite preview" }, "eslintConfig": { "extends": [ "react-app" ] }, "browserslist": { "production": [ "last 2 chrome versions", "last 2 firefox versions", "last 2 safari versions" ], "development": [ "last 1 chrome version", "last 1 firefox version", "last 1 safari version" ] }, "devDependencies": { "@types/msgpack": "0.0.34", "@types/uuid": "^9.0.8", "@types/wicg-file-system-access": "^2023.10.5", "@typescript-eslint/eslint-plugin": "^7.4.0", "@typescript-eslint/parser": "^7.4.0", "@typescript-eslint/typescript-estree": "^7.4.0", "@vanilla-extract/vite-plugin": "^4.0.6", "browserslist-to-esbuild": "^2.1.1", "eslint": "^8.43.0", "eslint-plugin-react": "^7.32.2", "eslint-plugin-react-refresh": "^0.4.1", "postcss-preset-mantine": "^1.13.0", "typescript": "^5.0.4", "vite-plugin-eslint": "^1.8.1" } } ================================================ FILE: viser/src/viser/client/postcss.config.cjs ================================================ module.exports = { plugins: { "postcss-preset-mantine": {}, "postcss-simple-vars": { variables: { "mantine-breakpoint-xs": "36em", "mantine-breakpoint-sm": "48em", "mantine-breakpoint-md": "62em", "mantine-breakpoint-lg": "75em", "mantine-breakpoint-xl": "88em", }, }, }, }; ================================================ FILE: viser/src/viser/client/public/manifest.json ================================================ { "short_name": "Viser", "name": "Viser", "icons": [ { "src": "favicon.svg", "sizes": "any", "type": "image/x-icon" } ], "start_url": ".", "display": "standalone", "theme_color": "#000000", "background_color": "#ffffff" } ================================================ FILE: viser/src/viser/client/src/App.css.ts ================================================ import { globalStyle } from "@vanilla-extract/css"; globalStyle(".mantine-ScrollArea-scrollbar", { zIndex: 100, }); ================================================ FILE: viser/src/viser/client/src/App.tsx ================================================ // @refresh reset import "@mantine/core/styles.css"; import "@mantine/notifications/styles.css"; import "./App.css"; import { Notifications } from "@mantine/notifications"; import { CameraControls, Environment, PerformanceMonitor, Stats, } from "@react-three/drei"; import * as THREE from "three"; import { Canvas, useThree, useFrame } from "@react-three/fiber"; import { SynchronizedCameraControls } from "./CameraControls"; import { Anchor, Box, ColorSchemeScript, Image, MantineProvider, Modal, Tooltip, createTheme, useMantineTheme, } from "@mantine/core"; import React, { useEffect } from "react"; import { SceneNodeThreeObject, UseSceneTree } from "./SceneTree"; import "./index.css"; import ControlPanel from "./ControlPanel/ControlPanel"; import { UseGui, useGuiState } from "./ControlPanel/GuiState"; import { searchParamKey } from "./SearchParamsUtils"; import { WebsocketMessageProducer } from "./WebsocketInterface"; import { Titlebar } from "./Titlebar"; import { ViserModal } from "./Modal"; import { useSceneTreeState } from "./SceneTreeState"; import { GetRenderRequestMessage, Message } from "./WebsocketMessages"; import { useThrottledMessageSender } from "./WebsocketFunctions"; import { useDisclosure } from "@mantine/hooks"; import { rayToViserCoords } from "./WorldTransformUtils"; import { ndcFromPointerXy, opencvXyFromPointerXy } from "./ClickUtils"; import { theme } from "./AppTheme"; import { FrameSynchronizedMessageHandler } from "./MessageHandler"; import { PlaybackFromFile } from "./FilePlayback"; import { SplatRenderContext } from "./Splatting/GaussianSplats"; import { BrowserWarning } from "./BrowserWarning"; export type ViewerContextContents = { messageSource: "websocket" | "file_playback"; // Zustand hooks. useSceneTree: UseSceneTree; useGui: UseGui; // Useful references. // TODO: there's really no reason these all need to be their own ref objects. // We could have just one ref to a global mutable struct. sendMessageRef: React.MutableRefObject<(message: Message) => void>; canvasRef: React.MutableRefObject; sceneRef: React.MutableRefObject; cameraRef: React.MutableRefObject; backgroundMaterialRef: React.MutableRefObject; cameraControlRef: React.MutableRefObject; sendCameraRef: React.MutableRefObject<(() => void) | null>; resetCameraViewRef: React.MutableRefObject<(() => void) | null>; // Scene node attributes. // This is intentionally placed outside of the Zustand state to reduce overhead. nodeAttributesFromName: React.MutableRefObject<{ [name: string]: | undefined | { poseUpdateState?: "updated" | "needsUpdate" | "waitForMakeObject"; wxyz?: [number, number, number, number]; position?: [number, number, number]; visibility?: boolean; // Visibility state from the server. overrideVisibility?: boolean; // Override from the GUI. }; }>; nodeRefFromName: React.MutableRefObject<{ [name: string]: undefined | THREE.Object3D; }>; messageQueueRef: React.MutableRefObject; // Requested a render. getRenderRequestState: React.MutableRefObject< "ready" | "triggered" | "pause" | "in_progress" >; getRenderRequest: React.MutableRefObject; // Track click drag events. scenePointerInfo: React.MutableRefObject<{ enabled: false | "click" | "rect-select"; // Enable box events. dragStart: [number, number]; // First mouse position. dragEnd: [number, number]; // Final mouse position. isDragging: boolean; }>; // 2D canvas for drawing -- can be used to give feedback on cursor movement, or more. canvas2dRef: React.MutableRefObject; // Poses for bones in skinned meshes. skinnedMeshState: React.MutableRefObject<{ [name: string]: { initialized: boolean; poses: { wxyz: [number, number, number, number]; position: [number, number, number]; }[]; }; }>; }; export const ViewerContext = React.createContext( null, ); THREE.ColorManagement.enabled = true; function ViewerRoot() { // What websocket server should we connect to? function getDefaultServerFromUrl() { // https://localhost:8080/ => ws://localhost:8080 // https://localhost:8080/?server=some_url => ws://localhost:8080 let server = window.location.href; server = server.replace("http://", "ws://"); server = server.replace("https://", "wss://"); server = server.split("?")[0]; if (server.endsWith("/")) server = server.slice(0, -1); return server; } const servers = new URLSearchParams(window.location.search).getAll( searchParamKey, ); const initialServer = servers.length >= 1 ? servers[0] : getDefaultServerFromUrl(); // Playback mode for embedding viser. const searchParams = new URLSearchParams(window.location.search); const playbackPath = searchParams.get("playbackPath"); const darkMode = searchParams.get("darkMode") !== null; const showStats = searchParams.get("showStats") !== null; // Values that can be globally accessed by components in a viewer. const nodeRefFromName = React.useRef<{ [name: string]: undefined | THREE.Object3D; }>({}); const viewer: ViewerContextContents = { messageSource: playbackPath === null ? "websocket" : "file_playback", useSceneTree: useSceneTreeState(nodeRefFromName), useGui: useGuiState(initialServer), sendMessageRef: React.useRef( playbackPath == null ? (message) => console.log( `Tried to send ${message.type} but websocket is not connected!`, ) : () => null, ), canvasRef: React.useRef(null), sceneRef: React.useRef(null), cameraRef: React.useRef(null), backgroundMaterialRef: React.useRef(null), cameraControlRef: React.useRef(null), sendCameraRef: React.useRef(null), resetCameraViewRef: React.useRef(null), // Scene node attributes that aren't placed in the zustand state for performance reasons. nodeAttributesFromName: React.useRef({ "": { wxyz: (() => { const quat = new THREE.Quaternion().setFromEuler( new THREE.Euler(Math.PI / 2, Math.PI, -Math.PI / 2), ); return [quat.w, quat.x, quat.y, quat.z]; })(), }, }), nodeRefFromName: nodeRefFromName, messageQueueRef: React.useRef([]), getRenderRequestState: React.useRef("ready"), getRenderRequest: React.useRef(null), scenePointerInfo: React.useRef({ enabled: false, dragStart: [0, 0], dragEnd: [0, 0], isDragging: false, }), canvas2dRef: React.useRef(null), skinnedMeshState: React.useRef({}), }; // Set dark default if specified in URL. if (darkMode) viewer.useGui.getState().theme.dark_mode = darkMode; return ( {viewer.messageSource === "websocket" ? ( ) : null} {viewer.messageSource === "file_playback" ? ( ) : null} {showStats ? : null} ); } function ViewerContents({ children }: { children: React.ReactNode }) { const viewer = React.useContext(ViewerContext)!; const darkMode = viewer.useGui((state) => state.theme.dark_mode); const colors = viewer.useGui((state) => state.theme.colors); const controlLayout = viewer.useGui((state) => state.theme.control_layout); return ( <> {children} ({ backgroundColor: darkMode ? theme.colors.dark[9] : "#fff", flexGrow: 1, overflow: "hidden", height: "100%", })} > {viewer.useGui((state) => state.theme.show_logo) && viewer.messageSource == "websocket" ? ( ) : null} {viewer.messageSource == "websocket" ? ( ) : null} ); } function ViewerCanvas({ children }: { children: React.ReactNode }) { const viewer = React.useContext(ViewerContext)!; const sendClickThrottled = useThrottledMessageSender(20); const theme = useMantineTheme(); const initDistanceScale = parseFloat( new URLSearchParams(window.location.search).get("initDistanceScale") ?? "1.0", ); return ( { const pointerInfo = viewer.scenePointerInfo.current!; // Only handle pointer events if enabled. if (pointerInfo.enabled === false) return; // Keep track of the first click position. const canvasBbox = viewer.canvasRef.current!.getBoundingClientRect(); pointerInfo.dragStart = [ e.clientX - canvasBbox.left, e.clientY - canvasBbox.top, ]; pointerInfo.dragEnd = pointerInfo.dragStart; // Check if pointer position is in bounds. if (ndcFromPointerXy(viewer, pointerInfo.dragEnd) === null) return; // Only allow one drag event at a time. if (pointerInfo.isDragging) return; pointerInfo.isDragging = true; // Disable camera controls -- we don't want the camera to move while we're dragging. viewer.cameraControlRef.current!.enabled = false; const ctx = viewer.canvas2dRef.current!.getContext("2d")!; ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height); }} onPointerMove={(e) => { const pointerInfo = viewer.scenePointerInfo.current!; // Only handle if click events are enabled, and if pointer is down (i.e., dragging). if (pointerInfo.enabled === false || !pointerInfo.isDragging) return; // Check if pointer position is in boudns. const canvasBbox = viewer.canvasRef.current!.getBoundingClientRect(); const pointerXy: [number, number] = [ e.clientX - canvasBbox.left, e.clientY - canvasBbox.top, ]; if (ndcFromPointerXy(viewer, pointerXy) === null) return; // Check if mouse position has changed sufficiently from last position. // Uses 3px as a threshood, similar to drag detection in // `SceneNodeClickMessage` from `SceneTree.tsx`. pointerInfo.dragEnd = pointerXy; if ( Math.abs(pointerInfo.dragEnd[0] - pointerInfo.dragStart[0]) <= 3 && Math.abs(pointerInfo.dragEnd[1] - pointerInfo.dragStart[1]) <= 3 ) return; // If we're listening for scene box events, draw the box on the 2D canvas for user feedback. if (pointerInfo.enabled === "rect-select") { const ctx = viewer.canvas2dRef.current!.getContext("2d")!; ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height); ctx.beginPath(); ctx.fillStyle = theme.primaryColor; ctx.strokeStyle = "blue"; ctx.globalAlpha = 0.2; ctx.fillRect( pointerInfo.dragStart[0], pointerInfo.dragStart[1], pointerInfo.dragEnd[0] - pointerInfo.dragStart[0], pointerInfo.dragEnd[1] - pointerInfo.dragStart[1], ); ctx.globalAlpha = 1.0; ctx.stroke(); } }} onPointerUp={() => { const pointerInfo = viewer.scenePointerInfo.current!; // Re-enable camera controls! Was disabled in `onPointerDown`, to allow // for mouse drag w/o camera movement. viewer.cameraControlRef.current!.enabled = true; // Only handle if click events are enabled, and if pointer was down (i.e., dragging). if (pointerInfo.enabled === false || !pointerInfo.isDragging) return; const ctx = viewer.canvas2dRef.current!.getContext("2d")!; ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height); // If there's only one pointer, send a click message. // The message will return origin/direction lists of length 1. if (pointerInfo.enabled === "click") { const raycaster = new THREE.Raycaster(); // Raycaster expects NDC coordinates, so we convert the click event to NDC. const mouseVector = ndcFromPointerXy(viewer, pointerInfo.dragEnd); if (mouseVector === null) return; raycaster.setFromCamera(mouseVector, viewer.cameraRef.current!); const ray = rayToViserCoords(viewer, raycaster.ray); // Send OpenCV image coordinates to the server (normalized). const mouseVectorOpenCV = opencvXyFromPointerXy( viewer, pointerInfo.dragEnd, ); sendClickThrottled({ type: "ScenePointerMessage", event_type: "click", ray_origin: [ray.origin.x, ray.origin.y, ray.origin.z], ray_direction: [ray.direction.x, ray.direction.y, ray.direction.z], screen_pos: [[mouseVectorOpenCV.x, mouseVectorOpenCV.y]], }); } else if (pointerInfo.enabled === "rect-select") { // If the ScenePointerEvent had mouse drag movement, we will send a "box" message: // Use the first and last mouse positions to create a box. // Again, click should be in openCV image coordinates (normalized). const firstMouseVector = opencvXyFromPointerXy( viewer, pointerInfo.dragStart, ); const lastMouseVector = opencvXyFromPointerXy( viewer, pointerInfo.dragEnd, ); const x_min = Math.min(firstMouseVector.x, lastMouseVector.x); const x_max = Math.max(firstMouseVector.x, lastMouseVector.x); const y_min = Math.min(firstMouseVector.y, lastMouseVector.y); const y_max = Math.max(firstMouseVector.y, lastMouseVector.y); // Send the upper-left and lower-right corners of the box. const screenBoxList: [number, number][] = [ [x_min, y_min], [x_max, y_max], ]; sendClickThrottled({ type: "ScenePointerMessage", event_type: "rect-select", ray_origin: null, ray_direction: null, screen_pos: screenBoxList, }); } // Release drag lock. pointerInfo.isDragging = false; }} > {children} ); } function AdaptiveDpr() { const setDpr = useThree((state) => state.setDpr); return ( { const max = Math.min(refreshrate * 0.9, 85); const min = Math.max(max * 0.5, 38); return [min, max]; }} onChange={({ factor, fps, refreshrate }) => { const dpr = window.devicePixelRatio * (0.2 + 0.8 * factor); console.log( `[Performance] Setting DPR to ${dpr}; FPS=${fps}/${refreshrate}`, ); setDpr(dpr); }} /> ); } /* HTML Canvas, for drawing 2D. */ function Viewer2DCanvas() { const viewer = React.useContext(ViewerContext)!; useEffect(() => { // Create a resize observer to resize the CSS canvas when the window is resized. const resizeObserver = new ResizeObserver((entries) => { const { width, height } = entries[0].contentRect; canvas.width = width; canvas.height = height; }); // Observe the canvas. const canvas = viewer.canvas2dRef.current!; resizeObserver.observe(canvas); // Cleanup return () => resizeObserver.disconnect(); }); return ( ); } /* Background image with support for depth compositing. */ function BackgroundImage() { // Create a fragment shader that composites depth using depth and rgb const vertShader = ` varying vec2 vUv; void main() { vUv = uv; gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0); } `.trim(); const fragShader = ` #include precision highp float; precision highp int; varying vec2 vUv; uniform sampler2D colorMap; uniform sampler2D depthMap; uniform float cameraNear; uniform float cameraFar; uniform bool enabled; uniform bool hasDepth; float readDepth(sampler2D depthMap, vec2 coord) { vec4 rgbPacked = texture(depthMap, coord); // For the k-th channel, coefficients are calculated as: 255 * 1e-5 * 2^(8 * k). // Note that: [0, 255] channels are scaled to [0, 1], and we multiply by 1e5 on the server side. float depth = rgbPacked.r * 0.00255 + rgbPacked.g * 0.6528 + rgbPacked.b * 167.1168; return depth; } void main() { if (!enabled) { // discard the pixel if we're not enabled discard; } vec4 color = texture(colorMap, vUv); gl_FragColor = vec4(color.rgb, 1.0); float bufDepth; if(hasDepth){ float depth = readDepth(depthMap, vUv); bufDepth = viewZToPerspectiveDepth(-depth, cameraNear, cameraFar); } else { // If no depth enabled, set depth to 1.0 (infinity) to treat it like a background image. bufDepth = 1.0; } gl_FragDepth = bufDepth; }`.trim(); // initialize the rgb texture with all white and depth at infinity const backgroundMaterial = new THREE.ShaderMaterial({ fragmentShader: fragShader, vertexShader: vertShader, uniforms: { enabled: { value: false }, depthMap: { value: null }, colorMap: { value: null }, cameraNear: { value: null }, cameraFar: { value: null }, hasDepth: { value: false }, }, }); const { backgroundMaterialRef } = React.useContext(ViewerContext)!; backgroundMaterialRef.current = backgroundMaterial; const backgroundMesh = React.useRef(null); useFrame(({ camera }) => { // Logic ahead relies on perspective camera assumption. if (!(camera instanceof THREE.PerspectiveCamera)) { console.error( "Camera is not a perspective camera, cannot render background image", ); return; } // Update the position of the mesh based on the camera position. const lookdir = camera.getWorldDirection(new THREE.Vector3()); backgroundMesh.current!.position.set( camera.position.x, camera.position.y, camera.position.z, ); backgroundMesh.current!.position.addScaledVector(lookdir, 1.0); backgroundMesh.current!.quaternion.copy(camera.quaternion); // Resize the mesh based on focal length. const f = camera.getFocalLength(); backgroundMesh.current!.scale.set( camera.getFilmWidth() / f, camera.getFilmHeight() / f, 1.0, ); // Set near/far uniforms. backgroundMaterial.uniforms.cameraNear.value = camera.near; backgroundMaterial.uniforms.cameraFar.value = camera.far; }); return ( ); } /** Component for helping us set the scene reference. */ function SceneContextSetter() { const { sceneRef, cameraRef } = React.useContext(ViewerContext)!; sceneRef.current = useThree((state) => state.scene); cameraRef.current = useThree( (state) => state.camera as THREE.PerspectiveCamera, ); return <>; } export function Root() { return (
); } /** Logo. When clicked, opens an info modal. */ function ViserLogo() { const [aboutModalOpened, { open: openAbout, close: closeAbout }] = useDisclosure(false); return ( <>

Viser is a 3D visualization toolkit developed at UC Berkeley.

Nerfstudio   •   GitHub   •   Documentation

); } ================================================ FILE: viser/src/viser/client/src/AppTheme.ts ================================================ import { Checkbox, ColorInput, Select, TextInput, NumberInput, Paper, ActionIcon, Button, createTheme, } from "@mantine/core"; import { themeToVars } from "@mantine/vanilla-extract"; export const theme = createTheme({ fontFamily: "Inter", autoContrast: true, components: { Checkbox: Checkbox.extend({ defaultProps: { radius: "xs", }, }), ColorInput: ColorInput.extend({ defaultProps: { radius: "xs", }, }), Select: Select.extend({ defaultProps: { radius: "sm", }, }), TextInput: TextInput.extend({ defaultProps: { radius: "xs", }, }), NumberInput: NumberInput.extend({ defaultProps: { radius: "xs", }, }), Paper: Paper.extend({ defaultProps: { radius: "xs", shadow: "0", }, }), ActionIcon: ActionIcon.extend({ defaultProps: { variant: "subtle", color: "gray", radius: "xs", }, }), Button: Button.extend({ defaultProps: { radius: "xs", fw: 450, }, }), }, }); export const vars = themeToVars(theme); ================================================ FILE: viser/src/viser/client/src/BrowserWarning.tsx ================================================ import { notifications } from "@mantine/notifications"; import { detect } from "detect-browser"; import { useEffect } from "react"; export function BrowserWarning() { useEffect(() => { const browser = detect(); // Browser version are based loosely on support for SIMD, OffscreenCanvas. // // https://caniuse.com/?search=simd // https://caniuse.com/?search=OffscreenCanvas if (browser === null || browser.version === null) { console.log("Failed to detect browser"); notifications.show({ title: "Could not detect browser version", message: "Your browser version could not be detected. It may not be supported.", autoClose: false, color: "red", }); } else { const version = parseFloat(browser.version); console.log(`Detected ${browser.name} version ${version}`); if ( (browser.name === "chrome" && version < 91) || (browser.name === "edge" && version < 91) || (browser.name === "firefox" && version < 89) || (browser.name === "opera" && version < 77) || (browser.name === "safari" && version < 16.4) ) notifications.show({ title: "Unsuppported browser", message: `Your browser (${ browser.name.slice(0, 1).toUpperCase() + browser.name.slice(1) }/${ browser.version }) is outdated, which may cause problems. Consider updating.`, autoClose: false, color: "red", }); } }); return null; } ================================================ FILE: viser/src/viser/client/src/CameraControls.tsx ================================================ import { ViewerContext } from "./App"; import { CameraControls } from "@react-three/drei"; import { useThree } from "@react-three/fiber"; import * as holdEvent from "hold-event"; import React, { useContext, useRef } from "react"; import { PerspectiveCamera } from "three"; import * as THREE from "three"; import { computeT_threeworld_world } from "./WorldTransformUtils"; import { useThrottledMessageSender } from "./WebsocketFunctions"; export function SynchronizedCameraControls() { const viewer = useContext(ViewerContext)!; const camera = useThree((state) => state.camera as PerspectiveCamera); const sendCameraThrottled = useThrottledMessageSender(20); // Helper for resetting camera poses. const initialCameraRef = useRef<{ camera: PerspectiveCamera; lookAt: THREE.Vector3; } | null>(null); viewer.resetCameraViewRef.current = () => { viewer.cameraControlRef.current!.setLookAt( initialCameraRef.current!.camera.position.x, initialCameraRef.current!.camera.position.y, initialCameraRef.current!.camera.position.z, initialCameraRef.current!.lookAt.x, initialCameraRef.current!.lookAt.y, initialCameraRef.current!.lookAt.z, true, ); viewer.cameraRef.current!.up.set( initialCameraRef.current!.camera.up.x, initialCameraRef.current!.camera.up.y, initialCameraRef.current!.camera.up.z, ); viewer.cameraControlRef.current!.updateCameraUp(); }; // Callback for sending cameras. // It makes the code more chaotic, but we preallocate a bunch of things to // minimize garbage collection! const R_threecam_cam = new THREE.Quaternion().setFromEuler( new THREE.Euler(Math.PI, 0.0, 0.0), ); const R_world_threeworld = new THREE.Quaternion(); const tmpMatrix4 = new THREE.Matrix4(); const lookAt = new THREE.Vector3(); const R_world_camera = new THREE.Quaternion(); const t_world_camera = new THREE.Vector3(); const scale = new THREE.Vector3(); const sendCamera = React.useCallback(() => { const three_camera = camera; const camera_control = viewer.cameraControlRef.current; if (camera_control === null) { // Camera controls not yet ready, let's re-try later. setTimeout(sendCamera, 10); return; } // We put Z up to match the scene tree, and convert threejs camera convention // to the OpenCV one. const T_world_threeworld = computeT_threeworld_world(viewer).invert(); const T_world_camera = T_world_threeworld.clone() .multiply( tmpMatrix4 .makeRotationFromQuaternion(three_camera.quaternion) .setPosition(three_camera.position), ) .multiply(tmpMatrix4.makeRotationFromQuaternion(R_threecam_cam)); R_world_threeworld.setFromRotationMatrix(T_world_threeworld); camera_control.getTarget(lookAt).applyQuaternion(R_world_threeworld); const up = three_camera.up.clone().applyQuaternion(R_world_threeworld); //Store initial camera values if (initialCameraRef.current === null) { initialCameraRef.current = { camera: three_camera.clone(), lookAt: camera_control.getTarget(new THREE.Vector3()), }; } T_world_camera.decompose(t_world_camera, R_world_camera, scale); sendCameraThrottled({ type: "ViewerCameraMessage", wxyz: [ R_world_camera.w, R_world_camera.x, R_world_camera.y, R_world_camera.z, ], position: t_world_camera.toArray(), aspect: three_camera.aspect, fov: (three_camera.fov * Math.PI) / 180.0, look_at: [lookAt.x, lookAt.y, lookAt.z], up_direction: [up.x, up.y, up.z], }); }, [camera, sendCameraThrottled]); // Send camera for new connections. // We add a small delay to give the server time to add a callback. const connected = viewer.useGui((state) => state.websocketConnected); const initHeightOffset = parseFloat( new URLSearchParams(window.location.search).get("initHeightOffset") ?? "0.05", ); // Add the height offset to the initial camera position React.useEffect(() => { const cameraControls = viewer.cameraControlRef.current!; const lookAt = new THREE.Vector3(); cameraControls.getTarget(lookAt); viewer.cameraControlRef.current!.setLookAt( camera.position.x, camera.position.y + (camera.position.y / 0.1) * initHeightOffset, // init_scale: camera.position.y / 0.1 camera.position.z, lookAt.x, lookAt.y + (camera.position.y / 0.1) * initHeightOffset, lookAt.z, true, ); viewer.sendCameraRef.current = sendCamera; if (!connected) return; setTimeout(() => sendCamera(), 50); }, [connected, sendCamera]); // Send camera for 3D viewport changes. const canvas = viewer.canvasRef.current!; // R3F canvas. React.useEffect(() => { // Create a resize observer to resize the CSS canvas when the window is resized. const resizeObserver = new ResizeObserver(() => { sendCamera(); }); resizeObserver.observe(canvas); // Cleanup. return () => resizeObserver.disconnect(); }, [canvas]); // Keyboard controls. React.useEffect(() => { const cameraControls = viewer.cameraControlRef.current!; const wKey = new holdEvent.KeyboardKeyHold("KeyW", 20); const aKey = new holdEvent.KeyboardKeyHold("KeyA", 20); const sKey = new holdEvent.KeyboardKeyHold("KeyS", 20); const dKey = new holdEvent.KeyboardKeyHold("KeyD", 20); const qKey = new holdEvent.KeyboardKeyHold("KeyQ", 20); const eKey = new holdEvent.KeyboardKeyHold("KeyE", 20); // TODO: these event listeners are currently never removed, even if this // component gets unmounted. aKey.addEventListener("holding", (event) => { cameraControls.truck(-0.0002 * event?.deltaTime, 0, true); }); dKey.addEventListener("holding", (event) => { cameraControls.truck(0.0002 * event?.deltaTime, 0, true); }); wKey.addEventListener("holding", (event) => { cameraControls.forward(0.0002 * event?.deltaTime, true); }); sKey.addEventListener("holding", (event) => { cameraControls.forward(-0.0002 * event?.deltaTime, true); }); qKey.addEventListener("holding", (event) => { cameraControls.elevate(0.0002 * event?.deltaTime, true); }); eKey.addEventListener("holding", (event) => { cameraControls.elevate(-0.0002 * event?.deltaTime, true); }); const leftKey = new holdEvent.KeyboardKeyHold("ArrowLeft", 20); const rightKey = new holdEvent.KeyboardKeyHold("ArrowRight", 20); const upKey = new holdEvent.KeyboardKeyHold("ArrowUp", 20); const downKey = new holdEvent.KeyboardKeyHold("ArrowDown", 20); leftKey.addEventListener("holding", (event) => { cameraControls.rotate( -0.05 * THREE.MathUtils.DEG2RAD * event?.deltaTime, 0, true, ); }); rightKey.addEventListener("holding", (event) => { cameraControls.rotate( 0.05 * THREE.MathUtils.DEG2RAD * event?.deltaTime, 0, true, ); }); upKey.addEventListener("holding", (event) => { cameraControls.rotate( 0, -0.05 * THREE.MathUtils.DEG2RAD * event?.deltaTime, true, ); }); downKey.addEventListener("holding", (event) => { cameraControls.rotate( 0, 0.05 * THREE.MathUtils.DEG2RAD * event?.deltaTime, true, ); }); // TODO: we currently don't remove any event listeners. This is a bit messy // because KeyboardKeyHold attaches listeners directly to the // document/window; it's unclear if we can remove these. return () => { return; }; }, [CameraControls]); return ( ); } ================================================ FILE: viser/src/viser/client/src/ClickUtils.tsx ================================================ import * as THREE from "three"; import { ViewerContextContents } from "./App"; /** Turn a click event into a normalized device coordinate (NDC) vector. * Normalizes click coordinates to be between -1 and 1, with (0, 0) being the center of the screen. * * Returns null if input is not valid. */ export function ndcFromPointerXy( viewer: ViewerContextContents, xy: [number, number], ): THREE.Vector2 | null { const mouseVector = new THREE.Vector2(); mouseVector.x = 2 * ((xy[0] + 0.5) / viewer.canvasRef.current!.clientWidth) - 1; mouseVector.y = 1 - 2 * ((xy[1] + 0.5) / viewer.canvasRef.current!.clientHeight); return mouseVector.x < 1 && mouseVector.x > -1 && mouseVector.y < 1 && mouseVector.y > -1 ? mouseVector : null; } /** Turn a click event to normalized OpenCV coordinate (NDC) vector. * Normalizes click coordinates to be between (0, 0) as upper-left corner, * and (1, 1) as lower-right corner, with (0.5, 0.5) being the center of the screen. * Uses offsetX/Y, and clientWidth/Height to get the coordinates. */ export function opencvXyFromPointerXy( viewer: ViewerContextContents, xy: [number, number], ): THREE.Vector2 { const mouseVector = new THREE.Vector2(); mouseVector.x = (xy[0] + 0.5) / viewer.canvasRef.current!.clientWidth; mouseVector.y = (xy[1] + 0.5) / viewer.canvasRef.current!.clientHeight; return mouseVector; } ================================================ FILE: viser/src/viser/client/src/ControlPanel/BottomPanel.tsx ================================================ import { Box, Collapse, Paper, useMantineColorScheme } from "@mantine/core"; import React from "react"; import { useDisclosure } from "@mantine/hooks"; const BottomPanelContext = React.createContext; expanded: boolean; toggleExpanded: () => void; }>(null); /** A bottom panel is used to display the controls on mobile devices. */ export default function BottomPanel({ children, }: { children: string | React.ReactNode; }) { const panelWrapperRef = React.useRef(null); const [expanded, { toggle: toggleExpanded }] = useDisclosure(true); return ( ({ borderTopWidth: "1px", borderTopStyle: "solid", borderColor: useMantineColorScheme().colorScheme == "dark" ? theme.colors.dark[4] : theme.colors.gray[3], boxSizing: "border-box", width: "100%", zIndex: 10, position: "fixed", bottom: 0, left: 0, margin: 0, overflow: "scroll", minHeight: "3.5em", maxHeight: "60%", transition: "height 0.3s linear", })} ref={panelWrapperRef} > {children} ); } BottomPanel.Handle = function BottomPanelHandle({ children, }: { children: string | React.ReactNode; }) { const panelContext = React.useContext(BottomPanelContext)!; return ( ({ borderBottomWidth: panelContext.expanded ? "1px" : undefined, borderBottomStyle: "solid", borderColor: useMantineColorScheme().colorScheme == "dark" ? theme.colors.dark[4] : theme.colors.gray[3], cursor: "pointer", position: "relative", fontWeight: 400, userSelect: "none", display: "flex", alignItems: "center", padding: "0 0.8em", height: "3.5em", })} onClick={() => { panelContext.toggleExpanded(); }} > {children} ); }; /** Contents of a panel. */ BottomPanel.Contents = function BottomPanelContents({ children, }: { children: string | React.ReactNode; }) { const panelContext = React.useContext(BottomPanelContext)!; return {children}; }; /** Hides contents when panel is collapsed. */ BottomPanel.HideWhenCollapsed = function BottomPanelHideWhenCollapsed({ children, }: { children: React.ReactNode; }) { const expanded = React.useContext(BottomPanelContext)?.expanded ?? true; return expanded ? children : null; }; ================================================ FILE: viser/src/viser/client/src/ControlPanel/ControlPanel.tsx ================================================ import { useDisclosure, useMediaQuery } from "@mantine/hooks"; import GeneratedGuiContainer from "./Generated"; import { ViewerContext } from "../App"; import QRCode from "react-qr-code"; import ServerControls from "./ServerControls"; import { ActionIcon, Anchor, Box, Button, Collapse, CopyButton, Flex, Loader, Modal, Stack, Text, TextInput, Tooltip, Transition, useMantineColorScheme, useMantineTheme, } from "@mantine/core"; import { IconAdjustments, IconCloudCheck, IconArrowBack, IconShare, IconCopy, IconCheck, IconPlugConnectedX, IconQrcode, IconQrcodeOff, } from "@tabler/icons-react"; import React from "react"; import BottomPanel from "./BottomPanel"; import FloatingPanel from "./FloatingPanel"; import { ThemeConfigurationMessage } from "../WebsocketMessages"; import SidebarPanel from "./SidebarPanel"; // Must match constant in Python. const ROOT_CONTAINER_ID = "root"; export default function ControlPanel(props: { control_layout: ThemeConfigurationMessage["control_layout"]; }) { const theme = useMantineTheme(); const useMobileView = useMediaQuery(`(max-width: ${theme.breakpoints.xs})`); // TODO: will result in unnecessary re-renders. const viewer = React.useContext(ViewerContext)!; const showGenerated = viewer.useGui( (state) => "root" in state.guiIdSetFromContainerId, ); const [showSettings, { toggle }] = useDisclosure(false); const controlWidthString = viewer.useGui( (state) => state.theme.control_width, ); const controlWidth = ( controlWidthString == "small" ? "16em" : controlWidthString == "medium" ? "20em" : controlWidthString == "large" ? "24em" : null )!; const generatedServerToggleButton = ( { evt.stopPropagation(); toggle(); }} style={{ display: showGenerated ? undefined : "none", transform: "translateY(0.05em)", }} > {showSettings ? ( ) : ( )} ); const panelContents = ( <> ); if (useMobileView) { /* Mobile layout. */ return ( {generatedServerToggleButton} {panelContents} ); } else if (props.control_layout === "floating") { /* Floating layout. */ return ( {generatedServerToggleButton} {panelContents} ); } else { /* Sidebar view. */ return ( {generatedServerToggleButton} {panelContents} ); } } /* Icon and label telling us the current status of the websocket connection. */ function ConnectionStatus() { const { useGui } = React.useContext(ViewerContext)!; const connected = useGui((state) => state.websocketConnected); const label = useGui((state) => state.label); return ( <>
{/* Spacer. */} {(styles) => ( )} {(styles) => ( )} {label !== "" ? label : connected ? "Connected" : "Connecting..."} ); } function ShareButton() { const viewer = React.useContext(ViewerContext)!; const connected = viewer.useGui((state) => state.websocketConnected); const shareUrl = viewer.useGui((state) => state.shareUrl); const setShareUrl = viewer.useGui((state) => state.setShareUrl); const [doingSomething, setDoingSomething] = React.useState(false); const [shareModalOpened, { open: openShareModal, close: closeShareModal }] = useDisclosure(false); const [showQrCode, { toggle: toggleShowQrcode }] = useDisclosure(); // Turn off loader when share URL is set. React.useEffect(() => { if (shareUrl !== null) { setDoingSomething(false); } }, [shareUrl]); React.useEffect(() => { if (!connected && shareModalOpened) closeShareModal(); }, [connected, shareModalOpened]); if (viewer.useGui((state) => state.theme).show_share_button === false) return null; const colorScheme = useMantineColorScheme().colorScheme; return ( <> { evt.stopPropagation(); openShareModal(); }} style={{ transform: "translateY(0.05em)", }} disabled={!connected} > evt.stopPropagation()} onMouseDown={(evt) => evt.stopPropagation()} onMouseMove={(evt) => evt.stopPropagation()} onMouseUp={(evt) => evt.stopPropagation()} styles={{ title: { fontWeight: 600 } }} > {shareUrl === null ? ( <> {/* (val === null ? null : setPlaybackSpeed(val))} radius="xs" data={["0.5x", "1x", "2x", "4x", "8x"]} styles={{ wrapper: { width: "3.25em" }, }} comboboxProps={{ zIndex: 5, width: "5.25em" }} /> ); } } ================================================ FILE: viser/src/viser/client/src/Markdown.tsx ================================================ import React from "react"; import * as runtime from "react/jsx-runtime"; import * as provider from "@mdx-js/react"; import { evaluate } from "@mdx-js/mdx"; import { type MDXComponents } from "mdx/types"; import { ReactNode, useEffect, useState } from "react"; import remarkGfm from "remark-gfm"; import rehypeColorChips from "rehype-color-chips"; import { Anchor, Blockquote, Code, Image, List, ListProps, Table, Text, Title, TitleOrder, } from "@mantine/core"; import { visit } from "unist-util-visit"; import { Transformer } from "unified"; import { Root } from "hast"; // Custom Rehype to clean up code blocks (Mantine makes these annoying to style) // Adds "block" to any code non-inline code block, which gets directly passed into // the Mantine Code component. function rehypeCodeblock(): void | Transformer { return (tree) => { visit(tree, "element", (node, _i, parent) => { if (node.tagName !== "code") return; if (parent && parent.type === "element" && parent.tagName === "pre") { node.properties = { block: true, ...node.properties }; } }); }; } // Custom classes to pipe MDX into Mantine Components // Some of them separate the children into a separate prop since Mantine requires a child // and MDX always makes children optional, so destructuring props doesn't work function MdxText(props: React.ComponentPropsWithoutRef) { return ; } function MdxAnchor(props: React.ComponentPropsWithoutRef) { return ; } function MdxTitle( props: React.ComponentPropsWithoutRef, order: TitleOrder, ) { return ; } function MdxList( props: Omit, "children" | "type">, children: React.ComponentPropsWithoutRef["children"], type: ListProps["type"], ) { // Account for GFM Checkboxes if (props.className == "contains-task-list") { return ( {children} ); } return ( {children} ); } function MdxListItem( props: Omit, "children">, children: React.ComponentPropsWithoutRef["children"], ) { return {children}; } // A possible improvement is to use Mantine Prism to add code highlighting support. function MdxCode( props: Omit, "children">, children: React.ComponentPropsWithoutRef["children"], ) { return {children}; } function MdxBlockquote( props: React.ComponentPropsWithoutRef, ) { return
; } function MdxCite( props: React.DetailedHTMLProps< React.HTMLAttributes, HTMLElement >, ) { return ( ); } function MdxTable(props: React.ComponentPropsWithoutRef) { return ; } function MdxImage(props: React.ComponentPropsWithoutRef) { return ; } const components: MDXComponents = { p: (props) => MdxText(props), a: (props) => MdxAnchor(props), h1: (props) => MdxTitle(props, 1), h2: (props) => MdxTitle(props, 2), h3: (props) => MdxTitle(props, 3), h4: (props) => MdxTitle(props, 4), h5: (props) => MdxTitle(props, 5), h6: (props) => MdxTitle(props, 6), ul: (props) => MdxList(props, props.children ?? "", "unordered"), ol: (props) => MdxList(props, props.children ?? "", "ordered"), li: (props) => MdxListItem(props, props.children ?? ""), code: (props) => MdxCode(props, props.children ?? ""), pre: (props) => <>{props.children}, blockquote: (props) => MdxBlockquote(props), Cite: (props) => MdxCite(props), table: (props) => MdxTable(props), img: (props) => MdxImage(props), "*": () => <>, }; async function parseMarkdown(markdown: string) { // @ts-ignore (necessary since JSX runtime isn't properly typed according to the internet) const { default: Content } = await evaluate(markdown, { ...runtime, ...provider, development: false, remarkPlugins: [remarkGfm], rehypePlugins: [rehypeCodeblock, rehypeColorChips], }); return Content; } /** * Parses and renders markdown on the client. This is generally a bad practice. * NOTE: Only run on markdown you trust. * It might be worth looking into sandboxing all markdown so that it can't run JS. */ export default function Markdown(props: { children?: string }) { const [child, setChild] = useState(null); useEffect(() => { try { parseMarkdown(props.children ?? "").then((Content) => { setChild(); }); } catch { setChild(Error Parsing Markdown...); } }, [props.children]); return child; } ================================================ FILE: viser/src/viser/client/src/MessageHandler.tsx ================================================ import { CatmullRomLine, CubicBezierLine, Grid, Html } from "@react-three/drei"; import { useContextBridge } from "its-fine"; import { notifications } from "@mantine/notifications"; import React, { useContext } from "react"; import * as THREE from "three"; import { TextureLoader } from "three"; import { ViewerContext } from "./App"; import { SceneNode } from "./SceneTree"; import { CameraFrustum, CoordinateFrame, InstancedAxes, GlbAsset, OutlinesIfHovered, PointCloud, } from "./ThreeAssets"; import { FileTransferPart, FileTransferStart, Message, } from "./WebsocketMessages"; import { PivotControls } from "@react-three/drei"; import { isTexture, makeThrottledMessageSender } from "./WebsocketFunctions"; import { isGuiConfig } from "./ControlPanel/GuiState"; import { useFrame } from "@react-three/fiber"; import GeneratedGuiContainer from "./ControlPanel/Generated"; import { Paper, Progress } from "@mantine/core"; import { IconCheck } from "@tabler/icons-react"; import { computeT_threeworld_world } from "./WorldTransformUtils"; import { SplatObject } from "./Splatting/GaussianSplats"; /** Convert raw RGB color buffers to linear color buffers. **/ function threeColorBufferFromUint8Buffer(colors: ArrayBuffer) { return new THREE.Float32BufferAttribute( new Float32Array(new Uint8Array(colors)).map((value) => { value = value / 255.0; if (value <= 0.04045) { return value / 12.92; } else { return Math.pow((value + 0.055) / 1.055, 2.4); } }), 3, ); } /** Returns a handler for all incoming messages. */ function useMessageHandler() { const viewer = useContext(ViewerContext)!; const ContextBridge = useContextBridge(); // We could reduce the redundancy here if we wanted to. // https://github.com/nerfstudio-project/viser/issues/39 const removeSceneNode = viewer.useSceneTree((state) => state.removeSceneNode); const resetScene = viewer.useSceneTree((state) => state.resetScene); const addSceneNode = viewer.useSceneTree((state) => state.addSceneNode); const resetGui = viewer.useGui((state) => state.resetGui); const setTheme = viewer.useGui((state) => state.setTheme); const setShareUrl = viewer.useGui((state) => state.setShareUrl); const addGui = viewer.useGui((state) => state.addGui); const addModal = viewer.useGui((state) => state.addModal); const removeModal = viewer.useGui((state) => state.removeModal); const removeGui = viewer.useGui((state) => state.removeGui); const updateGuiProps = viewer.useGui((state) => state.updateGuiProps); const setClickable = viewer.useSceneTree((state) => state.setClickable); const updateUploadState = viewer.useGui((state) => state.updateUploadState); // Same as addSceneNode, but make a parent in the form of a dummy coordinate // frame if it doesn't exist yet. function addSceneNodeMakeParents(node: SceneNode) { // Make sure scene node is in attributes. const attrs = viewer.nodeAttributesFromName.current; attrs[node.name] = { overrideVisibility: attrs[node.name]?.overrideVisibility, }; // Don't update the pose of the object until we've made a new one! attrs[node.name]!.poseUpdateState = "waitForMakeObject"; // Make sure parents exists. const nodeFromName = viewer.useSceneTree.getState().nodeFromName; const parentName = node.name.split("/").slice(0, -1).join("/"); if (!(parentName in nodeFromName)) { addSceneNodeMakeParents( new SceneNode(parentName, (ref) => ( )), ); } addSceneNode(node); } const fileDownloadHandler = useFileDownloadHandler(); // Return message handler. return (message: Message) => { if (isGuiConfig(message)) { addGui(message); return; } switch (message.type) { // Set the share URL. case "ShareUrlUpdated": { setShareUrl(message.share_url); return; } // Request a render. case "GetRenderRequestMessage": { viewer.getRenderRequest.current = message; viewer.getRenderRequestState.current = "triggered"; return; } // Set the GUI panel label. case "SetGuiPanelLabelMessage": { viewer.useGui.setState({ label: message.label ?? "" }); return; } // Configure the theme. case "ThemeConfigurationMessage": { setTheme(message); return; } // Run some arbitrary Javascript. // This is used for plotting, where the Python server will send over a // copy of plotly.min.js for the currently-installed version of plotly. case "RunJavascriptMessage": { eval(message.source); return; } // Add a notification. case "NotificationMessage": { if (message.mode === "show") { notifications.show({ id: message.id, title: message.title, message: message.body, withCloseButton: message.with_close_button, loading: message.loading, autoClose: message.auto_close, color: message.color ?? undefined, }); } else if (message.mode === "update") { notifications.update({ id: message.id, title: message.title, message: message.body, withCloseButton: message.with_close_button, loading: message.loading, autoClose: message.auto_close, color: message.color ?? undefined, }); } return; } // Remove a specific notification. case "RemoveNotificationMessage": { notifications.hide(message.id); return; } // Enable/disable whether scene pointer events are sent. case "ScenePointerEnableMessage": { // Update scene click enable state. viewer.scenePointerInfo.current!.enabled = message.enable ? message.event_type : false; // Update cursor to indicate whether the scene can be clicked. viewer.canvasRef.current!.style.cursor = message.enable ? "pointer" : "auto"; return; } // Add a coordinate frame. case "FrameMessage": { addSceneNodeMakeParents( new SceneNode(message.name, (ref) => ( )), ); return; } // Add axes to visualize. case "BatchedAxesMessage": { addSceneNodeMakeParents( new SceneNode( message.name, (ref) => ( // Minor naming discrepancy: I think "batched" will be clearer to // folks on the Python side, but instanced is somewhat more // precise. ), undefined, undefined, undefined, // Compute click instance index from instance ID. Each visualized // frame has 1 instance for each of 3 line segments. (instanceId) => Math.floor(instanceId! / 3), ), ); return; } case "GridMessage": { addSceneNodeMakeParents( new SceneNode(message.name, (ref) => ( )), ); return; } // Add a point cloud. case "PointCloudMessage": { addSceneNodeMakeParents( new SceneNode(message.name, (ref) => ( val / 255.0, )} /> )), ); return; } case "GuiModalMessage": { addModal(message); return; } case "GuiCloseModalMessage": { removeModal(message.id); return; } // Add mesh case "SkinnedMeshMessage": case "MeshMessage": { const geometry = new THREE.BufferGeometry(); const generateGradientMap = (shades: 3 | 5) => { const texture = new THREE.DataTexture( Uint8Array.from( shades == 3 ? [0, 0, 0, 255, 128, 128, 128, 255, 255, 255, 255, 255] : [ 0, 0, 0, 255, 64, 64, 64, 255, 128, 128, 128, 255, 192, 192, 192, 255, 255, 255, 255, 255, ], ), shades, 1, THREE.RGBAFormat, ); texture.needsUpdate = true; return texture; }; const standardArgs = { color: message.color ?? undefined, vertexColors: message.vertex_colors !== null, wireframe: message.wireframe, transparent: message.opacity !== null, opacity: message.opacity ?? 1.0, // Flat shading only makes sense for non-wireframe materials. flatShading: message.flat_shading && !message.wireframe, side: { front: THREE.FrontSide, back: THREE.BackSide, double: THREE.DoubleSide, }[message.side], }; const assertUnreachable = (x: never): never => { throw new Error(`Should never get here! ${x}`); }; const material = message.material == "standard" || message.wireframe ? new THREE.MeshStandardMaterial(standardArgs) : message.material == "toon3" ? new THREE.MeshToonMaterial({ gradientMap: generateGradientMap(3), ...standardArgs, }) : message.material == "toon5" ? new THREE.MeshToonMaterial({ gradientMap: generateGradientMap(5), ...standardArgs, }) : assertUnreachable(message.material); geometry.setAttribute( "position", new THREE.Float32BufferAttribute( new Float32Array( message.vertices.buffer.slice( message.vertices.byteOffset, message.vertices.byteOffset + message.vertices.byteLength, ), ), 3, ), ); if (message.vertex_colors !== null) { geometry.setAttribute( "color", threeColorBufferFromUint8Buffer(message.vertex_colors), ); } geometry.setIndex( new THREE.Uint32BufferAttribute( new Uint32Array( message.faces.buffer.slice( message.faces.byteOffset, message.faces.byteOffset + message.faces.byteLength, ), ), 1, ), ); geometry.computeVertexNormals(); geometry.computeBoundingSphere(); const cleanupMesh = () => { // TODO: we can switch to the react-three-fiber , // , etc components to avoid manual // disposal. geometry.dispose(); material.dispose(); }; if (message.type === "MeshMessage") // Normal mesh. addSceneNodeMakeParents( new SceneNode( message.name, (ref) => { return ( ); }, cleanupMesh, ), ); else if (message.type === "SkinnedMeshMessage") { // Skinned mesh. const bones: THREE.Bone[] = []; for (let i = 0; i < message.bone_wxyzs!.length; i++) { bones.push(new THREE.Bone()); } const xyzw_quat = new THREE.Quaternion(); const boneInverses: THREE.Matrix4[] = []; viewer.skinnedMeshState.current[message.name] = { initialized: false, poses: [], }; bones.forEach((bone, i) => { const wxyz = message.bone_wxyzs[i]; const position = message.bone_positions[i]; xyzw_quat.set(wxyz[1], wxyz[2], wxyz[3], wxyz[0]); const boneInverse = new THREE.Matrix4(); boneInverse.makeRotationFromQuaternion(xyzw_quat); boneInverse.setPosition(position[0], position[1], position[2]); boneInverse.invert(); boneInverses.push(boneInverse); bone.quaternion.copy(xyzw_quat); bone.position.set(position[0], position[1], position[2]); bone.matrixAutoUpdate = false; bone.matrixWorldAutoUpdate = false; viewer.skinnedMeshState.current[message.name].poses.push({ wxyz: wxyz, position: position, }); }); const skeleton = new THREE.Skeleton(bones, boneInverses); geometry.setAttribute( "skinIndex", new THREE.Uint16BufferAttribute( new Uint16Array( message.skin_indices.buffer.slice( message.skin_indices.byteOffset, message.skin_indices.byteOffset + message.skin_indices.byteLength, ), ), 4, ), ); geometry.setAttribute( "skinWeight", new THREE.Float32BufferAttribute( new Float32Array( message.skin_weights!.buffer.slice( message.skin_weights!.byteOffset, message.skin_weights!.byteOffset + message.skin_weights!.byteLength, ), ), 4, ), ); addSceneNodeMakeParents( new SceneNode( message.name, (ref) => { return ( ); }, () => { delete viewer.skinnedMeshState.current[message.name]; skeleton.dispose(); cleanupMesh(); }, false, // everyFrameCallback: update bone transforms. () => { const parentNode = viewer.nodeRefFromName.current[message.name]; if (parentNode === undefined) return; const state = viewer.skinnedMeshState.current[message.name]; bones.forEach((bone, i) => { if (!state.initialized) { parentNode.add(bone); } const wxyz = state.initialized ? state.poses[i].wxyz : message.bone_wxyzs[i]; const position = state.initialized ? state.poses[i].position : message.bone_positions[i]; xyzw_quat.set(wxyz[1], wxyz[2], wxyz[3], wxyz[0]); bone.matrix.makeRotationFromQuaternion(xyzw_quat); bone.matrix.setPosition( position[0], position[1], position[2], ); bone.updateMatrixWorld(); }); if (!state.initialized) { state.initialized = true; } }, ), ); } return; } // Set the bone poses. case "SetBoneOrientationMessage": { const bonePoses = viewer.skinnedMeshState.current; bonePoses[message.name].poses[message.bone_index].wxyz = message.wxyz; break; } case "SetBonePositionMessage": { const bonePoses = viewer.skinnedMeshState.current; bonePoses[message.name].poses[message.bone_index].position = message.position; break; } // Add a camera frustum. case "CameraFrustumMessage": { let texture = undefined; if ( message.image_media_type !== null && message.image_binary !== null ) { const image_url = URL.createObjectURL( new Blob([message.image_binary]), ); texture = new TextureLoader().load(image_url, () => URL.revokeObjectURL(image_url), ); } addSceneNodeMakeParents( new SceneNode( message.name, (ref) => ( ), () => texture?.dispose(), ), ); return; } case "TransformControlsMessage": { const name = message.name; const sendDragMessage = makeThrottledMessageSender(viewer, 50); addSceneNodeMakeParents( new SceneNode( message.name, (ref) => ( e.stopPropagation()}> { const attrs = viewer.nodeAttributesFromName.current; if (attrs[message.name] === undefined) { attrs[message.name] = {}; } const wxyz = new THREE.Quaternion(); wxyz.setFromRotationMatrix(l); const position = new THREE.Vector3().setFromMatrixPosition( l, ); const nodeAttributes = attrs[message.name]!; nodeAttributes.wxyz = [wxyz.w, wxyz.x, wxyz.y, wxyz.z]; nodeAttributes.position = position.toArray(); sendDragMessage({ type: "TransformControlsUpdateMessage", name: name, wxyz: nodeAttributes.wxyz, position: nodeAttributes.position, }); }} /> ), undefined, true, // unmountWhenInvisible ), ); return; } case "SetCameraLookAtMessage": { const cameraControls = viewer.cameraControlRef.current!; const T_threeworld_world = computeT_threeworld_world(viewer); const target = new THREE.Vector3( message.look_at[0], message.look_at[1], message.look_at[2], ); target.applyMatrix4(T_threeworld_world); cameraControls.setTarget(target.x, target.y, target.z, false); return; } case "SetCameraUpDirectionMessage": { const camera = viewer.cameraRef.current!; const cameraControls = viewer.cameraControlRef.current!; const T_threeworld_world = computeT_threeworld_world(viewer); const updir = new THREE.Vector3( message.position[0], message.position[1], message.position[2], ) .normalize() .applyQuaternion( new THREE.Quaternion().setFromRotationMatrix(T_threeworld_world), ); camera.up.set(updir.x, updir.y, updir.z); // Back up position. const prevPosition = new THREE.Vector3(); cameraControls.getPosition(prevPosition); cameraControls.updateCameraUp(); // Restore position, which can get unexpectedly mutated in updateCameraUp(). cameraControls.setPosition( prevPosition.x, prevPosition.y, prevPosition.z, false, ); return; } case "SetCameraPositionMessage": { const cameraControls = viewer.cameraControlRef.current!; // Set the camera position. Due to the look-at, note that this will // shift the orientation as-well. const position_cmd = new THREE.Vector3( message.position[0], message.position[1], message.position[2], ); const T_threeworld_world = computeT_threeworld_world(viewer); position_cmd.applyMatrix4(T_threeworld_world); cameraControls.setPosition( position_cmd.x, position_cmd.y, position_cmd.z, ); return; } case "SetCameraFovMessage": { const camera = viewer.cameraRef.current!; // tan(fov / 2.0) = 0.5 * film height / focal length // focal length = 0.5 * film height / tan(fov / 2.0) camera.setFocalLength( (0.5 * camera.getFilmHeight()) / Math.tan(message.fov / 2.0), ); viewer.sendCameraRef.current !== null && viewer.sendCameraRef.current(); return; } case "SetOrientationMessage": { const attr = viewer.nodeAttributesFromName.current; if (attr[message.name] === undefined) attr[message.name] = {}; attr[message.name]!.wxyz = message.wxyz; if (attr[message.name]!.poseUpdateState == "updated") attr[message.name]!.poseUpdateState = "needsUpdate"; break; } case "SetPositionMessage": { const attr = viewer.nodeAttributesFromName.current; if (attr[message.name] === undefined) attr[message.name] = {}; attr[message.name]!.position = message.position; if (attr[message.name]!.poseUpdateState == "updated") attr[message.name]!.poseUpdateState = "needsUpdate"; break; } case "SetSceneNodeVisibilityMessage": { const attr = viewer.nodeAttributesFromName.current; if (attr[message.name] === undefined) attr[message.name] = {}; attr[message.name]!.visibility = message.visible; break; } // Add a background image. case "BackgroundImageMessage": { const rgb_url = URL.createObjectURL( new Blob([message.rgb_bytes], { type: message.media_type, }), ); new TextureLoader().load(rgb_url, (texture) => { URL.revokeObjectURL(rgb_url); const oldBackgroundTexture = viewer.backgroundMaterialRef.current!.uniforms.colorMap.value; viewer.backgroundMaterialRef.current!.uniforms.colorMap.value = texture; if (isTexture(oldBackgroundTexture)) oldBackgroundTexture.dispose(); viewer.useGui.setState({ backgroundAvailable: true }); }); viewer.backgroundMaterialRef.current!.uniforms.enabled.value = true; viewer.backgroundMaterialRef.current!.uniforms.hasDepth.value = message.depth_bytes !== null; if (message.depth_bytes !== null) { // If depth is available set the texture const depth_url = URL.createObjectURL( new Blob([message.depth_bytes], { type: message.media_type, }), ); new TextureLoader().load(depth_url, (texture) => { URL.revokeObjectURL(depth_url); const oldDepthTexture = viewer.backgroundMaterialRef.current?.uniforms.depthMap.value; viewer.backgroundMaterialRef.current!.uniforms.depthMap.value = texture; if (isTexture(oldDepthTexture)) oldDepthTexture.dispose(); }); } return; } // Add a 2D label. case "LabelMessage": { addSceneNodeMakeParents( new SceneNode( message.name, (ref) => { // We wrap with because Html doesn't implement THREE.Object3D. return (
{message.text}
); }, undefined, true, ), ); return; } case "Gui3DMessage": { addSceneNodeMakeParents( new SceneNode( message.name, (ref) => { // We wrap with because Html doesn't implement // THREE.Object3D. The initial position is intended to be // off-screen; it will be overwritten with the actual position // after the component is mounted. return ( { evt.stopPropagation(); }} > ); }, undefined, true, ), ); return; } // Add an image. case "ImageMessage": { // This current implementation may flicker when the image is updated, // because the texture is not necessarily done loading before the // component is mounted. We could fix this by passing an `onLoad` // callback into `TextureLoader`, but this would require work because // `addSceneNodeMakeParents` needs to be called immediately: it // overwrites position/wxyz attributes, and we don't want this to // happen after later messages are received. const image_url = URL.createObjectURL( new Blob([message.data], { type: message.media_type, }), ); const texture = new TextureLoader().load( image_url, () => URL.revokeObjectURL(image_url), // Revoke URL on load. ); addSceneNodeMakeParents( new SceneNode( message.name, (ref) => { return ( ); }, () => texture.dispose(), ), ); return; } // Remove a scene node and its children by name. case "RemoveSceneNodeMessage": { console.log("Removing scene node:", message.name); removeSceneNode(message.name); const attrs = viewer.nodeAttributesFromName.current; delete attrs[message.name]; return; } // Set the clickability of a particular scene node. case "SetSceneNodeClickableMessage": { // This setTimeout is totally unnecessary, but can help surface some race // conditions. setTimeout(() => setClickable(message.name, message.clickable), 50); return; } // Reset the entire scene, removing all scene nodes. case "ResetSceneMessage": { resetScene(); const oldBackground = viewer.sceneRef.current?.background; viewer.sceneRef.current!.background = null; if (isTexture(oldBackground)) oldBackground.dispose(); viewer.useGui.setState({ backgroundAvailable: false }); // Disable the depth texture rendering viewer.backgroundMaterialRef.current!.uniforms.enabled.value = false; return; } // Reset the GUI state. case "ResetGuiMessage": { resetGui(); return; } // Update props of a GUI component case "GuiUpdateMessage": { updateGuiProps(message.id, message.updates); return; } // Remove a GUI input. case "GuiRemoveMessage": { removeGui(message.id); return; } // Add a glTF/GLB asset. case "GlbMessage": { addSceneNodeMakeParents( new SceneNode(message.name, (ref) => { return ( ); }), ); return; } case "CatmullRomSplineMessage": { addSceneNodeMakeParents( new SceneNode(message.name, (ref) => { return ( ); }), ); return; } case "CubicBezierSplineMessage": { addSceneNodeMakeParents( new SceneNode(message.name, (ref) => { return ( {[...Array(message.positions.length - 1).keys()].map((i) => ( ))} ); }), ); return; } case "GaussianSplatsMessage": { addSceneNodeMakeParents( new SceneNode(message.name, (ref) => { return ( ); }), ); return; } case "FileTransferStart": case "FileTransferPart": { fileDownloadHandler(message); return; } case "FileTransferPartAck": { updateUploadState({ componentId: message.source_component_id!, uploadedBytes: message.transferred_bytes, totalBytes: message.total_bytes, }); return; } default: { console.log("Received message did not match any known types:", message); return; } } }; } function useFileDownloadHandler() { const downloadStatesRef = React.useRef<{ [uuid: string]: { metadata: FileTransferStart; notificationId: string; parts: Uint8Array[]; bytesDownloaded: number; displayFilesize: string; }; }>({}); return (message: FileTransferStart | FileTransferPart) => { const notificationId = "download-" + message.transfer_uuid; // Create or update download state. switch (message.type) { case "FileTransferStart": { let displaySize = message.size_bytes; const displayUnits = ["B", "K", "M", "G", "T", "P"]; let displayUnitIndex = 0; while ( displaySize >= 100 && displayUnitIndex < displayUnits.length - 1 ) { displaySize /= 1024; displayUnitIndex += 1; } downloadStatesRef.current[message.transfer_uuid] = { metadata: message, notificationId: notificationId, parts: [], bytesDownloaded: 0, displayFilesize: `${displaySize.toFixed(1)}${ displayUnits[displayUnitIndex] }`, }; break; } case "FileTransferPart": { const downloadState = downloadStatesRef.current[message.transfer_uuid]; if (message.part != downloadState.parts.length) { console.error( "A file download message was dropped; this should never happen!", ); } downloadState.parts.push(message.content); downloadState.bytesDownloaded += message.content.length; break; } } // Show notification. const downloadState = downloadStatesRef.current[message.transfer_uuid]; const progressValue = (100.0 * downloadState.bytesDownloaded) / downloadState.metadata.size_bytes; const isDone = downloadState.bytesDownloaded == downloadState.metadata.size_bytes; (downloadState.bytesDownloaded == 0 ? notifications.show : notifications.update)({ title: (isDone ? "Downloaded " : "Downloading ") + `${downloadState.metadata.filename} (${downloadState.displayFilesize})`, message: , id: notificationId, autoClose: isDone, withCloseButton: isDone, loading: !isDone, icon: isDone ? : undefined, }); // If done: download file and clear state. if (isDone) { const link = document.createElement("a"); link.href = window.URL.createObjectURL( new Blob(downloadState.parts, { type: downloadState.metadata.mime_type, }), ); link.download = downloadState.metadata.filename; link.click(); link.remove(); delete downloadStatesRef.current[message.transfer_uuid]; } }; } export function FrameSynchronizedMessageHandler() { const handleMessage = useMessageHandler(); const viewer = useContext(ViewerContext)!; const messageQueueRef = viewer.messageQueueRef; useFrame(() => { // Send a render along if it was requested! if (viewer.getRenderRequestState.current === "triggered") { viewer.getRenderRequestState.current = "pause"; } else if (viewer.getRenderRequestState.current === "pause") { const sourceCanvas = viewer.canvasRef.current!; const targetWidth = viewer.getRenderRequest.current!.width; const targetHeight = viewer.getRenderRequest.current!.height; // We'll save a render to an intermediate canvas with the requested dimensions. const renderBufferCanvas = new OffscreenCanvas(targetWidth, targetHeight); const ctx = renderBufferCanvas.getContext("2d")!; ctx.reset(); // Use a white background for JPEGs, which don't have an alpha channel. if (viewer.getRenderRequest.current?.format === "image/jpeg") { ctx.fillStyle = "white"; ctx.fillRect(0, 0, renderBufferCanvas.width, renderBufferCanvas.height); } // Determine offsets for the source canvas. We'll always center our renders. // https://developer.mozilla.org/en-US/docs/Web/API/CanvasRenderingContext2D/drawImage let sourceWidth = sourceCanvas.width; let sourceHeight = sourceCanvas.height; const sourceAspect = sourceWidth / sourceHeight; const targetAspect = targetWidth / targetHeight; if (sourceAspect > targetAspect) { // The source is wider than the target. // We need to shrink the width. sourceWidth = Math.round(targetAspect * sourceHeight); } else if (sourceAspect < targetAspect) { // The source is narrower than the target. // We need to shrink the height. sourceHeight = Math.round(sourceWidth / targetAspect); } console.log( `Sending render; requested aspect ratio was ${targetAspect} (dimensinos: ${targetWidth}/${targetHeight}), copying from aspect ratio ${ sourceWidth / sourceHeight } (dimensions: ${sourceWidth}/${sourceHeight}).`, ); ctx.drawImage( sourceCanvas, (sourceCanvas.width - sourceWidth) / 2.0, (sourceCanvas.height - sourceHeight) / 2.0, sourceWidth, sourceHeight, 0, 0, targetWidth, targetHeight, ); viewer.getRenderRequestState.current = "in_progress"; // Encode the image, the send it. renderBufferCanvas .convertToBlob({ type: viewer.getRenderRequest.current!.format, quality: viewer.getRenderRequest.current!.quality / 100.0, }) .then(async (blob) => { if (blob === null) { console.error("Render failed"); viewer.getRenderRequestState.current = "ready"; return; } const payload = new Uint8Array(await blob.arrayBuffer()); viewer.sendMessageRef.current({ type: "GetRenderResponseMessage", payload: payload, }); viewer.getRenderRequestState.current = "ready"; }); } // Handle messages, but only if we're not trying to render something. if (viewer.getRenderRequestState.current === "ready") { // Handle messages before every frame. // Place this directly in ws.onmessage can cause race conditions! // // If a render is requested, note that we don't handle any more messages // until the render is done. const requestRenderIndex = messageQueueRef.current.findIndex( (message) => message.type === "GetRenderRequestMessage", ); const numMessages = requestRenderIndex !== -1 ? requestRenderIndex + 1 : messageQueueRef.current.length; const processBatch = messageQueueRef.current.splice(0, numMessages); processBatch.forEach(handleMessage); } }); return null; } ================================================ FILE: viser/src/viser/client/src/Modal.tsx ================================================ import { ViewerContext } from "./App"; import { GuiModalMessage } from "./WebsocketMessages"; import GeneratedGuiContainer from "./ControlPanel/Generated"; import { Modal } from "@mantine/core"; import { useContext } from "react"; export function ViserModal() { const viewer = useContext(ViewerContext)!; const modalList = viewer.useGui((state) => state.modals); const modals = modalList.map((conf, index) => { return ; }); return modals; } function GeneratedModal({ conf, index, }: { conf: GuiModalMessage; index: number; }) { return ( { // To make memory management easier, we should only close modals from // the server. // Otherwise, the client would need to communicate to the server that // the modal was deleted and contained GUI elements were cleared. }} withCloseButton={false} centered zIndex={100 + index} > ); } ================================================ FILE: viser/src/viser/client/src/Outlines.tsx ================================================ /** This is a modified version of drei's component. The primary * change is to add support for ref forwarding. https://github.com/pmndrs/drei * */ import * as THREE from "three"; import * as React from "react"; import { extend, applyProps, ReactThreeFiber, useThree, } from "@react-three/fiber"; import { toCreasedNormals } from "three-stdlib"; import { version } from "@react-three/drei/helpers/constants"; import { shaderMaterial } from "@react-three/drei"; const OutlinesMaterial = /* @__PURE__ */ shaderMaterial( { screenspace: false, color: /* @__PURE__ */ new THREE.Color("black"), opacity: 1, thickness: 0.05, size: /* @__PURE__ */ new THREE.Vector2(), }, `#include #include #include uniform float thickness; uniform float screenspace; uniform vec2 size; void main() { #if defined (USE_SKINNING) #include #include #include #include #include #endif #include #include #include #include vec4 tNormal = vec4(normal, 0.0); vec4 tPosition = vec4(transformed, 1.0); #ifdef USE_INSTANCING tNormal = instanceMatrix * tNormal; tPosition = instanceMatrix * tPosition; #endif if (screenspace == 0.0) { vec3 newPosition = tPosition.xyz + tNormal.xyz * thickness; gl_Position = projectionMatrix * modelViewMatrix * vec4(newPosition, 1.0); } else { vec4 clipPosition = projectionMatrix * modelViewMatrix * tPosition; vec4 clipNormal = projectionMatrix * modelViewMatrix * tNormal; vec2 offset = normalize(clipNormal.xy) * thickness / size * clipPosition.w * 2.0; clipPosition.xy += offset; gl_Position = clipPosition; } }`, `uniform vec3 color; uniform float opacity; void main(){ gl_FragColor = vec4(color, opacity); #include #include <${version >= 154 ? "colorspace_fragment" : "encodings_fragment"}> }`, ); type OutlinesProps = JSX.IntrinsicElements["group"] & { /** Outline color, default: black */ color?: ReactThreeFiber.Color; /** Line thickness is independent of zoom, default: false */ screenspace?: boolean; /** Outline opacity, default: 1 */ opacity?: number; /** Outline transparency, default: false */ transparent?: boolean; /** Outline thickness, default 0.05 */ thickness?: number; /** Geometry crease angle (0 === no crease), default: Math.PI */ angle?: number; toneMapped?: boolean; polygonOffset?: boolean; polygonOffsetFactor?: number; renderOrder?: number; }; export const Outlines = React.forwardRef( function Outlines( { color = "black", opacity = 1, transparent = false, screenspace = false, toneMapped = true, polygonOffset = false, polygonOffsetFactor = 0, renderOrder = 0, thickness = 0.05, angle = Math.PI, ...props }, ref, ) { const localRef = React.useRef(null); const [material] = React.useState( () => new OutlinesMaterial({ side: THREE.BackSide }), ); const gl = useThree((state) => state.gl); const contextSize = gl.getDrawingBufferSize(new THREE.Vector2()); React.useMemo(() => extend({ OutlinesMaterial }), []); const oldAngle = React.useRef(0); const oldGeometry = React.useRef(); React.useLayoutEffect(() => { const group = localRef.current; if (!group) return; const parent = group.parent as THREE.Mesh & THREE.SkinnedMesh & THREE.InstancedMesh; if (parent && parent.geometry) { if ( oldAngle.current !== angle || oldGeometry.current !== parent.geometry ) { oldAngle.current = angle; oldGeometry.current = parent.geometry; // Remove old mesh let mesh = group.children[0] as any; if (mesh) { if (angle) mesh.geometry.dispose(); group.remove(mesh); } if (parent.skeleton) { mesh = new THREE.SkinnedMesh(); mesh.material = material; mesh.bind(parent.skeleton, parent.bindMatrix); group.add(mesh); } else if (parent.isInstancedMesh) { mesh = new THREE.InstancedMesh( parent.geometry, material, parent.count, ); mesh.instanceMatrix = parent.instanceMatrix; group.add(mesh); } else { mesh = new THREE.Mesh(); mesh.material = material; group.add(mesh); } mesh.geometry = angle ? toCreasedNormals(parent.geometry, angle) : parent.geometry; } } }); React.useLayoutEffect(() => { const group = localRef.current; if (!group) return; const mesh = group.children[0] as THREE.Mesh< THREE.BufferGeometry, THREE.Material >; if (mesh) { mesh.renderOrder = renderOrder; applyProps(mesh.material as any, { transparent, thickness, color, opacity, size: contextSize, screenspace, toneMapped, polygonOffset, polygonOffsetFactor, }); } }); React.useEffect(() => { return () => { // Dispose everything on unmount const group = localRef.current; if (!group) return; const mesh = group.children[0] as THREE.Mesh< THREE.BufferGeometry, THREE.Material >; if (mesh) { if (angle) mesh.geometry.dispose(); group.remove(mesh); } }; }, []); return ( { localRef.current = obj; if (typeof ref === "function") ref(obj!); else if (ref) ref.current = obj; }} {...props} /> ); }, ); ================================================ FILE: viser/src/viser/client/src/SceneTree.tsx ================================================ import { useCursor } from "@react-three/drei"; import { createPortal, useFrame } from "@react-three/fiber"; import React from "react"; import * as THREE from "three"; import { ViewerContext } from "./App"; import { useThrottledMessageSender } from "./WebsocketFunctions"; import { Html } from "@react-three/drei"; import { immerable } from "immer"; import { useSceneTreeState } from "./SceneTreeState"; import { ErrorBoundary } from "react-error-boundary"; import { rayToViserCoords } from "./WorldTransformUtils"; import { HoverableContext } from "./ThreeAssets"; import { opencvXyFromPointerXy } from "./ClickUtils"; export type MakeObject = ( ref: React.Ref, ) => React.ReactNode; /** Scenes will consist of nodes, which form a tree. */ export class SceneNode { [immerable] = true; public children: string[]; public clickable: boolean; constructor( public readonly name: string, public readonly makeObject: MakeObject, public readonly cleanup?: () => void, /** unmountWhenInvisible is used to unmount components when they * should be hidden. * * https://github.com/pmndrs/drei/issues/1323 */ public readonly unmountWhenInvisible?: boolean, public readonly everyFrameCallback?: () => void, /** For click events on instanced nodes, like batched axes, we want to keep track of which. */ public readonly computeClickInstanceIndexFromInstanceId?: ( instanceId: number | undefined, ) => number | null, ) { this.children = []; this.clickable = false; } } /** Type corresponding to a zustand-style useSceneTree hook. */ export type UseSceneTree = ReturnType; function SceneNodeThreeChildren(props: { name: string; parent: THREE.Object3D; }) { const viewer = React.useContext(ViewerContext)!; const [children, setChildren] = React.useState( viewer.useSceneTree.getState().nodeFromName[props.name]?.children ?? [], ); React.useEffect(() => { let updateQueued = false; return viewer.useSceneTree.subscribe((state) => { // Do nothing if an update is already queued. if (updateQueued) return; // Do nothing if children haven't changed. const newChildren = state.nodeFromName[props.name]?.children; if ( newChildren === undefined || newChildren === children || // Note that this won't check for elementwise equality! (newChildren.length === 0 && children.length == 0) ) return; // Queue a (throttled) children update. updateQueued = true; setTimeout( () => { updateQueued = false; const newChildren = viewer.useSceneTree.getState().nodeFromName[props.name]!.children!; setChildren(newChildren); }, // Throttle more when we have a lot of children... newChildren.length <= 16 ? 10 : newChildren.length <= 128 ? 50 : 200, ); }); }, []); // Create a group of children inside of the parent object. return createPortal( {children && children.map((child_id) => ( ))} , props.parent, ); } /** Component for updating attributes of a scene node. */ function SceneNodeLabel(props: { name: string }) { const viewer = React.useContext(ViewerContext)!; const labelVisible = viewer.useSceneTree( (state) => state.labelVisibleFromName[props.name], ); return labelVisible ? ( {props.name} ) : null; } /** Component containing the three.js object and children for a particular scene node. */ export function SceneNodeThreeObject(props: { name: string; parent: THREE.Object3D | null; }) { const viewer = React.useContext(ViewerContext)!; const makeObject = viewer.useSceneTree( (state) => state.nodeFromName[props.name]?.makeObject, ); const unmountWhenInvisible = viewer.useSceneTree( (state) => state.nodeFromName[props.name]?.unmountWhenInvisible, ); const everyFrameCallback = viewer.useSceneTree( (state) => state.nodeFromName[props.name]?.everyFrameCallback, ); const computeClickInstanceIndexFromInstanceId = viewer.useSceneTree( (state) => state.nodeFromName[props.name]?.computeClickInstanceIndexFromInstanceId, ); const [unmount, setUnmount] = React.useState(false); const clickable = viewer.useSceneTree((state) => state.nodeFromName[props.name]?.clickable) ?? false; const [obj, setRef] = React.useState(null); // Update global registry of node objects. // This is used for updating bone transforms in skinned meshes. React.useEffect(() => { if (obj !== null) viewer.nodeRefFromName.current[props.name] = obj; }, [obj]); // Create object + children. // // For not-fully-understood reasons, wrapping makeObject with useMemo() fixes // stability issues (eg breaking runtime errors) associated with // PivotControls. const objNode = React.useMemo(() => { if (makeObject === undefined) return null; // Pose will need to be updated. const attrs = viewer.nodeAttributesFromName.current; if (!(props.name in attrs)) { attrs[props.name] = {}; } attrs[props.name]!.poseUpdateState = "needsUpdate"; return makeObject(setRef); }, [makeObject]); const children = obj === null ? null : ( ); // Helper for transient visibility checks. Checks the .visible attribute of // both this object and ancestors. // // This is used for (1) suppressing click events and (2) unmounting when // unmountWhenInvisible is true. The latter is used for components. function isDisplayed() { // We avoid checking obj.visible because obj may be unmounted when // unmountWhenInvisible=true. const attrs = viewer.nodeAttributesFromName.current[props.name]; const visibility = (attrs?.overrideVisibility === undefined ? attrs?.visibility : attrs.overrideVisibility) ?? true; if (visibility === false) return false; if (props.parent === null) return true; // Check visibility of parents + ancestors. let visible = props.parent.visible; if (visible) { props.parent.traverseAncestors((ancestor) => { visible = visible && ancestor.visible; }); } return visible; } // Pose needs to be updated whenever component is remounted. React.useEffect(() => { const attrs = viewer.nodeAttributesFromName.current[props.name]; if (attrs !== undefined) attrs.poseUpdateState = "needsUpdate"; }); // Update attributes on a per-frame basis. Currently does redundant work, // although this shouldn't be a bottleneck. useFrame( () => { const attrs = viewer.nodeAttributesFromName.current[props.name]; everyFrameCallback && everyFrameCallback(); // Unmount when invisible. // Examples: components, PivotControls. // // This is a workaround for situations where just setting `visible` doesn't // work (like ), or to prevent invisible elements from being // interacted with (). // // https://github.com/pmndrs/drei/issues/1323 if (unmountWhenInvisible) { const displayed = isDisplayed(); if (displayed && unmount) { if (obj !== null) obj.visible = false; setUnmount(false); } if (!displayed && !unmount) { setUnmount(true); } } if (obj === null) return; if (attrs === undefined) return; const visibility = (attrs?.overrideVisibility === undefined ? attrs?.visibility : attrs.overrideVisibility) ?? true; obj.visible = visibility; if (attrs.poseUpdateState == "needsUpdate") { attrs.poseUpdateState = "updated"; const wxyz = attrs.wxyz; if (wxyz !== undefined) { obj.quaternion.set(wxyz[1], wxyz[2], wxyz[3], wxyz[0]); } const position = attrs.position; if (position !== undefined) { obj.position.set(position[0], position[1], position[2]); } // Update matrices if necessary. This is necessary for PivotControls. if (!obj.matrixAutoUpdate) obj.updateMatrix(); if (!obj.matrixWorldAutoUpdate) obj.updateMatrixWorld(); } }, // Other useFrame hooks may depend on transforms + visibility. So it's best // to call this hook early. -10000, ); // Clicking logic. const sendClicksThrottled = useThrottledMessageSender(50); const [hovered, setHovered] = React.useState(false); useCursor(hovered); const hoveredRef = React.useRef(false); if (!clickable && hovered) setHovered(false); const dragInfo = React.useRef({ dragging: false, startClientX: 0, startClientY: 0, }); if (objNode === undefined || unmount) { return <>{children}; } else if (clickable) { return ( <> { // This sometimes (but very rarely) catches a race condition when // we remove scene nodes. I would guess it's related to portaling, // but the issue is unnoticeable with ErrorBoundary in-place so not // debugging further for now... console.error( "There was an error rendering a scene node object:", objNode, ); return null; }} > { if (!isDisplayed()) return; e.stopPropagation(); const state = dragInfo.current; const canvasBbox = viewer.canvasRef.current!.getBoundingClientRect(); state.startClientX = e.clientX - canvasBbox.left; state.startClientY = e.clientY - canvasBbox.top; state.dragging = false; }} onPointerMove={(e) => { if (!isDisplayed()) return; e.stopPropagation(); const state = dragInfo.current; const canvasBbox = viewer.canvasRef.current!.getBoundingClientRect(); const deltaX = e.clientX - canvasBbox.left - state.startClientX; const deltaY = e.clientY - canvasBbox.top - state.startClientY; // Minimum motion. if (Math.abs(deltaX) <= 3 && Math.abs(deltaY) <= 3) return; state.dragging = true; }} onPointerUp={(e) => { if (!isDisplayed()) return; e.stopPropagation(); const state = dragInfo.current; if (state.dragging) return; // Convert ray to viser coordinates. const ray = rayToViserCoords(viewer, e.ray); // Send OpenCV image coordinates to the server (normalized). const canvasBbox = viewer.canvasRef.current!.getBoundingClientRect(); const mouseVectorOpenCV = opencvXyFromPointerXy(viewer, [ e.clientX - canvasBbox.left, e.clientY - canvasBbox.top, ]); sendClicksThrottled({ type: "SceneNodeClickMessage", name: props.name, instance_index: computeClickInstanceIndexFromInstanceId === undefined ? null : computeClickInstanceIndexFromInstanceId(e.instanceId), // Note that the threejs up is +Y, but we expose a +Z up. ray_origin: [ray.origin.x, ray.origin.y, ray.origin.z], ray_direction: [ ray.direction.x, ray.direction.y, ray.direction.z, ], screen_pos: [mouseVectorOpenCV.x, mouseVectorOpenCV.y], }); }} onPointerOver={(e) => { if (!isDisplayed()) return; e.stopPropagation(); setHovered(true); hoveredRef.current = true; }} onPointerOut={() => { if (!isDisplayed()) return; setHovered(false); hoveredRef.current = false; }} > {objNode} {children} ); } else { return ( <> {/* This does nothing, but switching between clickable vs not causes strange transform behavior without it. */} {objNode} {children} ); } } ================================================ FILE: viser/src/viser/client/src/SceneTreeState.tsx ================================================ import React from "react"; import { MakeObject, SceneNode } from "./SceneTree"; import { CoordinateFrame } from "./ThreeAssets"; import * as THREE from "three"; import { create } from "zustand"; import { subscribeWithSelector } from "zustand/middleware"; import { immer } from "zustand/middleware/immer"; interface SceneTreeState { nodeFromName: { [name: string]: undefined | SceneNode }; // Putting this into SceneNode makes the scene tree table much harder to implement. labelVisibleFromName: { [name: string]: boolean }; } export interface SceneTreeActions extends SceneTreeState { setClickable(name: string, clickable: boolean): void; addSceneNode(nodes: SceneNode): void; removeSceneNode(name: string): void; resetScene(): void; setLabelVisibility(name: string, labelVisibility: boolean): void; } // Create default scene tree state. const rootAxesTemplate: MakeObject = (ref) => ( ); const rootNodeTemplate = new SceneNode("", (ref) => ( )) as SceneNode; const rootAxesNode = new SceneNode( "/WorldAxes", rootAxesTemplate, ) as SceneNode; rootNodeTemplate.children.push("/WorldAxes"); /** Declare a scene state, and return a hook for accessing it. Note that we put effort into avoiding a global state! */ export function useSceneTreeState( nodeRefFromName: React.MutableRefObject<{ [name: string]: undefined | THREE.Object3D; }>, ) { return React.useState(() => create( subscribeWithSelector( immer((set) => ({ nodeFromName: { "": rootNodeTemplate, "/WorldAxes": rootAxesNode }, labelVisibleFromName: {}, setClickable: (name, clickable) => set((state) => { const node = state.nodeFromName[name]; if (node !== undefined) node.clickable = clickable; }), addSceneNode: (node) => set((state) => { const existingNode = state.nodeFromName[node.name]; if (existingNode !== undefined) { // Node already exists. delete nodeRefFromName.current[node.name]; existingNode.cleanup && existingNode.cleanup(); // Free resources. state.nodeFromName[node.name] = { ...node, children: existingNode.children, }; } else { // Node doesn't exist yet! // TODO: this assumes the parent exists. We could probably merge this with addSceneNodeMakeParents. const parent_name = node.name.split("/").slice(0, -1).join("/"); state.nodeFromName[node.name] = node; state.nodeFromName[parent_name]!.children.push(node.name); } }), removeSceneNode: (name) => set((state) => { if (!(name in state.nodeFromName)) { console.log("Skipping scene node removal for " + name); return; } // Remove this scene node and all children. const removeNames: string[] = []; function findChildrenRecursive(name: string) { removeNames.push(name); state.nodeFromName[name]!.children.forEach( findChildrenRecursive, ); } findChildrenRecursive(name); removeNames.forEach((removeName) => { const node = state.nodeFromName[removeName]!; node.cleanup && node.cleanup(); // Free resources. delete state.nodeFromName[removeName]; delete nodeRefFromName.current[removeName]; }); // Remove node from parent's children list. const parent_name = name.split("/").slice(0, -1).join("/"); state.nodeFromName[parent_name]!.children = state.nodeFromName[ parent_name ]!.children.filter((child_name) => child_name !== name); }), resetScene: () => set((state) => { // For scene resets: we need to retain the object references created for the root and world frame nodes. for (const key of Object.keys(state.nodeFromName)) { if (key !== "" && key !== "/WorldAxes") delete state.nodeFromName[key]; } Object.values(state.nodeFromName).forEach((node) => { // Free resources. if (node === undefined || node.cleanup === undefined) return; node.cleanup(); }); state.nodeFromName[""] = rootNodeTemplate; state.nodeFromName["/WorldAxes"] = rootAxesNode; }), setLabelVisibility: (name, labelVisibility) => set((state) => { state.labelVisibleFromName[name] = labelVisibility; }), })), ), ), )[0]; } ================================================ FILE: viser/src/viser/client/src/SearchParamsUtils.tsx ================================================ /** Utilities for interacting with the URL search parameters. * * This lets us specify the websocket server + port from the URL. */ export const searchParamKey = "websocket"; export function syncSearchParamServer(server: string) { const searchParams = new URLSearchParams(window.location.search); // No need to update the URL bar if the websocket port matches the HTTP port. // So if we navigate to http://localhost:8081, this should by default connect to ws://localhost:8081. const isDefaultServer = window.location.host.includes( server.replace("ws://", "").replace("/", ""), ) || window.location.host.includes( server.replace("wss://", "").replace("/", ""), ); if (isDefaultServer && searchParams.has(searchParamKey)) { searchParams.delete(searchParamKey); } else if (!isDefaultServer) { searchParams.set(searchParamKey, server); } window.history.replaceState( null, "Viser", // We could use URLSearchParams.toString() to build this string, but that // would escape it. We're going to just not escape the string. :) searchParams.size === 0 ? window.location.href.split("?")[0] : "?" + Array.from(searchParams.entries()) .map(([k, v]) => `${k}=${v}`) .join("&"), ); } ================================================ FILE: viser/src/viser/client/src/Splatting/GaussianSplats.tsx ================================================ /** Gaussian splatting implementation for viser. * * This borrows heavily from existing open-source implementations. Particularly * useful references: * - https://github.com/quadjr/aframe-gaussian-splatting * - https://github.com/antimatter15/splat * - https://github.com/pmndrs/drei * - https://github.com/vincent-lecrubier-skydio/react-three-fiber-gaussian-splat * * Usage should look like: * * * * * * * * Where `buffer` contains serialized Gaussian attributes. SplatObjects are * globally sorted by a worker (with some help from WebAssembly + SIMD * intrinsics), and then rendered as a single threejs mesh. Unlike other R3F * implementations that we're aware of, this enables correct compositing * between multiple splat objects. */ import React from "react"; import * as THREE from "three"; import SplatSortWorker from "./SplatSortWorker?worker"; import { useFrame, useThree } from "@react-three/fiber"; import { shaderMaterial } from "@react-three/drei"; import { SorterWorkerIncoming } from "./SplatSortWorker"; import { create } from "zustand"; import { Object3D } from "three"; /**Global splat state.*/ interface SplatState { groupBufferFromId: { [id: string]: Uint32Array }; nodeRefFromId: React.MutableRefObject<{ [name: string]: undefined | Object3D; }>; setBuffer: (id: string, buffer: Uint32Array) => void; removeBuffer: (id: string) => void; } /**Hook for creating global splat state.*/ function useGaussianSplatStore() { const nodeRefFromId = React.useRef({}); return React.useState(() => create((set) => ({ groupBufferFromId: {}, nodeRefFromId: nodeRefFromId, setBuffer: (id, buffer) => { return set((state) => ({ groupBufferFromId: { ...state.groupBufferFromId, [id]: buffer }, })); }, removeBuffer: (id) => { return set((state) => { // eslint-disable-next-line @typescript-eslint/no-unused-vars const { [id]: _, ...buffers } = state.groupBufferFromId; return { groupBufferFromId: buffers }; }); }, })), )[0]; } const GaussianSplatsContext = React.createContext | null>(null); /**Provider for creating splat rendering context.*/ export function SplatRenderContext({ children, }: { children: React.ReactNode; }) { const store = useGaussianSplatStore(); return ( {children} ); } const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( { numGaussians: 0, focal: 100.0, viewport: [640, 480], near: 1.0, far: 100.0, depthTest: true, depthWrite: false, transparent: true, textureBuffer: null, textureT_camera_groups: null, transitionInState: 0.0, }, `precision highp usampler2D; // Most important: ints must be 32-bit. precision mediump float; // Index from the splat sorter. attribute uint sortedIndex; // Buffers for splat data; each Gaussian gets 4 floats and 4 int32s. We just // copy quadjr for this. uniform usampler2D textureBuffer; // We could also use a uniform to store transforms, but this would be more // limiting in terms of the # of groups we can have. uniform sampler2D textureT_camera_groups; // Various other uniforms... uniform uint numGaussians; uniform vec2 focal; uniform vec2 viewport; uniform float near; uniform float far; // Fade in state between [0, 1]. uniform float transitionInState; out vec4 vRgba; out vec2 vPosition; // Function to fetch and construct the i-th transform matrix using texelFetch mat4 getGroupTransform(uint i) { // Calculate the base index for the i-th transform. uint baseIndex = i * 3u; // Fetch the texels that represent the first 3 rows of the transform. We // choose to use row-major here, since it lets us exclude the fourth row of // the matrix. vec4 row0 = texelFetch(textureT_camera_groups, ivec2(baseIndex + 0u, 0), 0); vec4 row1 = texelFetch(textureT_camera_groups, ivec2(baseIndex + 1u, 0), 0); vec4 row2 = texelFetch(textureT_camera_groups, ivec2(baseIndex + 2u, 0), 0); // Construct the mat4 with the fetched rows. mat4 transform = mat4(row0, row1, row2, vec4(0.0, 0.0, 0.0, 1.0)); return transpose(transform); } void main () { // Get position + scale from float buffer. ivec2 texSize = textureSize(textureBuffer, 0); uint texStart = sortedIndex << 1u; ivec2 texPos0 = ivec2(texStart % uint(texSize.x), texStart / uint(texSize.x)); // Fetch from textures. uvec4 floatBufferData = texelFetch(textureBuffer, texPos0, 0); mat4 T_camera_group = getGroupTransform(floatBufferData.w); // Any early return will discard the fragment. gl_Position = vec4(0.0, 0.0, 2.0, 1.0); // Get center wrt camera. modelViewMatrix is T_cam_world. vec3 center = uintBitsToFloat(floatBufferData.xyz); vec4 c_cam = T_camera_group * vec4(center, 1); if (-c_cam.z < near || -c_cam.z > far) return; vec4 pos2d = projectionMatrix * c_cam; float clip = 1.1 * pos2d.w; if (pos2d.x < -clip || pos2d.x > clip || pos2d.y < -clip || pos2d.y > clip) return; // Read covariance terms. ivec2 texPos1 = ivec2((texStart + 1u) % uint(texSize.x), (texStart + 1u) / uint(texSize.x)); uvec4 intBufferData = texelFetch(textureBuffer, texPos1, 0); // Get covariance terms from int buffer. uint rgbaUint32 = intBufferData.w; vec2 chol01 = unpackHalf2x16(intBufferData.x); vec2 chol23 = unpackHalf2x16(intBufferData.y); vec2 chol45 = unpackHalf2x16(intBufferData.z); // Transition in. float startTime = 0.8 * float(sortedIndex) / float(numGaussians); float cov_scale = smoothstep(startTime, startTime + 0.2, transitionInState); // Do the actual splatting. mat3 chol = mat3( chol01.x, chol01.y, chol23.x, 0., chol23.y, chol45.x, 0., 0., chol45.y ); mat3 cov3d = chol * transpose(chol) * cov_scale; mat3 J = mat3( // Matrices are column-major. focal.x / c_cam.z, 0., 0.0, 0., focal.y / c_cam.z, 0.0, -(focal.x * c_cam.x) / (c_cam.z * c_cam.z), -(focal.y * c_cam.y) / (c_cam.z * c_cam.z), 0. ); mat3 A = J * mat3(T_camera_group); mat3 cov_proj = A * cov3d * transpose(A); float diag1 = cov_proj[0][0] + 0.3; float offDiag = cov_proj[0][1]; float diag2 = cov_proj[1][1] + 0.3; // Eigendecomposition. float mid = 0.5 * (diag1 + diag2); float radius = length(vec2((diag1 - diag2) / 2.0, offDiag)); float lambda1 = mid + radius; float lambda2 = mid - radius; if (lambda2 < 0.0) return; vec2 diagonalVector = normalize(vec2(offDiag, lambda1 - diag1)); vec2 v1 = min(sqrt(2.0 * lambda1), 1024.0) * diagonalVector; vec2 v2 = min(sqrt(2.0 * lambda2), 1024.0) * vec2(diagonalVector.y, -diagonalVector.x); vRgba = vec4( float(rgbaUint32 & uint(0xFF)) / 255.0, float((rgbaUint32 >> uint(8)) & uint(0xFF)) / 255.0, float((rgbaUint32 >> uint(16)) & uint(0xFF)) / 255.0, float(rgbaUint32 >> uint(24)) / 255.0 ); // Throw the Gaussian off the screen if it's too close, too far, or too small. float weightedDeterminant = vRgba.a * (diag1 * diag2 - offDiag * offDiag); if (weightedDeterminant < 0.5) return; vPosition = position.xy; gl_Position = vec4( vec2(pos2d) / pos2d.w + position.x * v1 / viewport * 2.0 + position.y * v2 / viewport * 2.0, pos2d.z / pos2d.w, 1.); } `, `precision mediump float; uniform vec2 viewport; uniform vec2 focal; in vec4 vRgba; in vec2 vPosition; void main () { float A = -dot(vPosition, vPosition); if (A < -4.0) discard; float B = exp(A) * vRgba.a; if (B < 0.01) discard; // alphaTest. gl_FragColor = vec4(vRgba.rgb, B); }`, ); export const SplatObject = React.forwardRef< THREE.Group, { buffer: Uint32Array; } >(function SplatObject({ buffer }, ref) { const splatContext = React.useContext(GaussianSplatsContext)!; const setBuffer = splatContext((state) => state.setBuffer); const removeBuffer = splatContext((state) => state.removeBuffer); const nodeRefFromId = splatContext((state) => state.nodeRefFromId); const name = React.useMemo(() => crypto.randomUUID(), [buffer]); const [obj, setRef] = React.useState(null); React.useEffect(() => { if (obj === null) return; setBuffer(name, buffer); if (ref !== null) { if ("current" in ref) { ref.current = obj; } else { ref(obj); } } nodeRefFromId.current[name] = obj; return () => { removeBuffer(name); delete nodeRefFromId.current[name]; }; }, [obj]); return ; }); /** External interface. Component should be added to the root of canvas. */ function SplatRenderer() { const splatContext = React.useContext(GaussianSplatsContext)!; const groupBufferFromId = splatContext((state) => state.groupBufferFromId); const nodeRefFromId = splatContext((state) => state.nodeRefFromId); // Consolidate Gaussian groups into a single buffer. const merged = mergeGaussianGroups(groupBufferFromId); const meshProps = useGaussianMeshProps( merged.gaussianBuffer, merged.numGroups, ); // Create sorting worker. const sortWorker = new SplatSortWorker(); let initializedBufferTexture = false; sortWorker.onmessage = (e) => { // Update rendering order. const sortedIndices = e.data.sortedIndices as Uint32Array; meshProps.sortedIndexAttribute.set(sortedIndices); meshProps.sortedIndexAttribute.needsUpdate = true; // Trigger initial render. if (!initializedBufferTexture) { meshProps.material.uniforms.numGaussians.value = merged.numGaussians; meshProps.textureBuffer.needsUpdate = true; initializedBufferTexture = true; } }; function postToWorker(message: SorterWorkerIncoming) { sortWorker.postMessage(message); } postToWorker({ setBuffer: merged.gaussianBuffer, setGroupIndices: merged.groupIndices, }); // Cleanup. React.useEffect(() => { return () => { meshProps.textureBuffer.dispose(); meshProps.geometry.dispose(); meshProps.material.dispose(); postToWorker({ close: true }); }; }); // Per-frame updates. This is in charge of synchronizing transforms and // triggering sorting. // // We pre-allocate matrices to make life easier for the garbage collector. const meshRef = React.useRef(null); const tmpT_camera_group = new THREE.Matrix4(); const Tz_camera_groups = new Float32Array(merged.numGroups * 4); const prevRowMajorT_camera_groups = meshProps.rowMajorT_camera_groups .slice() .fill(0); const prevVisibles: boolean[] = []; useFrame((state, delta) => { const mesh = meshRef.current; if (mesh === null || sortWorker === null) return; // Update camera parameter uniforms. const dpr = state.viewport.dpr; const fovY = ((state.camera as THREE.PerspectiveCamera).fov * Math.PI) / 180.0; const fovX = 2 * Math.atan(Math.tan(fovY / 2) * state.viewport.aspect); const fy = (dpr * state.size.height) / (2 * Math.tan(fovY / 2)); const fx = (dpr * state.size.width) / (2 * Math.tan(fovX / 2)); if (meshProps.material === undefined) return; const uniforms = meshProps.material.uniforms; uniforms.transitionInState.value = Math.min( uniforms.transitionInState.value + delta * 2.0, 1.0, ); uniforms.focal.value = [fx, fy]; uniforms.near.value = state.camera.near; uniforms.far.value = state.camera.far; uniforms.viewport.value = [state.size.width * dpr, state.size.height * dpr]; // Update group transforms. const T_camera_world = state.camera.matrixWorldInverse; const groupVisibles: boolean[] = []; let visibilitiesChanged = false; for (const [groupIndex, name] of Object.keys(groupBufferFromId).entries()) { const node = nodeRefFromId.current[name]; if (node === undefined) continue; tmpT_camera_group.copy(T_camera_world).multiply(node.matrixWorld); const colMajorElements = tmpT_camera_group.elements; Tz_camera_groups.set( [ colMajorElements[2], colMajorElements[6], colMajorElements[10], colMajorElements[14], ], groupIndex * 4, ); const rowMajorElements = tmpT_camera_group.transpose().elements; meshProps.rowMajorT_camera_groups.set( rowMajorElements.slice(0, 12), groupIndex * 12, ); // Determine visibility. If the parent has unmountWhenInvisible=true, the // first frame after showing a hidden parent can have visible=true with // an incorrect matrixWorld transform. There might be a better fix, but // `prevVisible` is an easy workaround for this. let visibleNow = node.visible && node.parent !== null; if (visibleNow) { node.traverseAncestors((ancestor) => { visibleNow = visibleNow && ancestor.visible; }); } groupVisibles.push(visibleNow && prevVisibles[groupIndex] === true); if (prevVisibles[groupIndex] !== visibleNow) { prevVisibles[groupIndex] = visibleNow; visibilitiesChanged = true; } } const groupsMovedWrtCam = !meshProps.rowMajorT_camera_groups.every( (v, i) => v === prevRowMajorT_camera_groups[i], ); if (groupsMovedWrtCam) { // Gaussians need to be re-sorted. postToWorker({ setTz_camera_groups: Tz_camera_groups, }); } if (groupsMovedWrtCam || visibilitiesChanged) { // If a group is not visible, we'll throw it off the screen with some Big // Numbers. It's important that this only impacts the coordinates used // for the shader and not for the sorter; that way when we "show" a group // of Gaussians the correct rendering order is immediately available. for (const [i, visible] of groupVisibles.entries()) { if (!visible) { meshProps.rowMajorT_camera_groups[i * 12 + 3] = 1e10; meshProps.rowMajorT_camera_groups[i * 12 + 7] = 1e10; meshProps.rowMajorT_camera_groups[i * 12 + 11] = 1e10; } } prevRowMajorT_camera_groups.set(meshProps.rowMajorT_camera_groups); meshProps.textureT_camera_groups.needsUpdate = true; } }, -100 /* This should be called early to reduce group transform artifacts. */); return ( ); } /**Consolidate groups of Gaussians into a single buffer, to make it possible * for them to be sorted globally.*/ function mergeGaussianGroups(groupBufferFromName: { [name: string]: Uint32Array; }) { // Create geometry. Each Gaussian will be rendered as a quad. let totalBufferLength = 0; for (const buffer of Object.values(groupBufferFromName)) { totalBufferLength += buffer.length; } const numGaussians = totalBufferLength / 8; const gaussianBuffer = new Uint32Array(totalBufferLength); const groupIndices = new Uint32Array(numGaussians); let offset = 0; for (const [groupIndex, groupBuffer] of Object.values( groupBufferFromName, ).entries()) { groupIndices.fill( groupIndex, offset / 8, (offset + groupBuffer.length) / 8, ); gaussianBuffer.set(groupBuffer, offset); // Each Gaussian is allocated // - 12 bytes for center x, y, z (float32) // - 4 bytes for group index (uint32); we're filling this in now // // - 12 bytes for covariance (6 terms, float16) // - 4 bytes for RGBA (uint8) for (let i = 0; i < groupBuffer.length; i += 8) { gaussianBuffer[offset + i + 3] = groupIndex; } offset += groupBuffer.length; } const numGroups = Object.keys(groupBufferFromName).length; return { numGaussians, gaussianBuffer, numGroups, groupIndices }; } /**Hook to generate properties for rendering Gaussians via a three.js mesh.*/ function useGaussianMeshProps(gaussianBuffer: Uint32Array, numGroups: number) { const numGaussians = gaussianBuffer.length / 8; const maxTextureSize = useThree((state) => state.gl).capabilities .maxTextureSize; // Create instanced geometry. const geometry = new THREE.InstancedBufferGeometry(); geometry.instanceCount = numGaussians; geometry.setIndex( new THREE.BufferAttribute(new Uint32Array([0, 2, 1, 0, 3, 2]), 1), ); geometry.setAttribute( "position", new THREE.BufferAttribute( new Float32Array([-2, -2, 2, -2, 2, 2, -2, 2]), 2, ), ); // Rendering order for Gaussians. const sortedIndexAttribute = new THREE.InstancedBufferAttribute( new Uint32Array(numGaussians), 1, ); sortedIndexAttribute.setUsage(THREE.DynamicDrawUsage); geometry.setAttribute("sortedIndex", sortedIndexAttribute); // Create texture buffers. const textureWidth = Math.min(numGaussians * 2, maxTextureSize); const textureHeight = Math.ceil((numGaussians * 2) / textureWidth); const bufferPadded = new Uint32Array(textureWidth * textureHeight * 4); bufferPadded.set(gaussianBuffer); const textureBuffer = new THREE.DataTexture( bufferPadded, textureWidth, textureHeight, THREE.RGBAIntegerFormat, THREE.UnsignedIntType, ); textureBuffer.internalFormat = "RGBA32UI"; textureBuffer.needsUpdate = true; const rowMajorT_camera_groups = new Float32Array(numGroups * 12); const textureT_camera_groups = new THREE.DataTexture( rowMajorT_camera_groups, (numGroups * 12) / 4, 1, THREE.RGBAFormat, THREE.FloatType, ); textureT_camera_groups.internalFormat = "RGBA32F"; textureT_camera_groups.needsUpdate = true; const material = new GaussianSplatMaterial({ // @ts-ignore textureBuffer: textureBuffer, textureT_camera_groups: textureT_camera_groups, numGaussians: 0, transitionInState: 0.0, }); return { geometry, material, textureBuffer, sortedIndexAttribute, textureT_camera_groups, rowMajorT_camera_groups, }; } ================================================ FILE: viser/src/viser/client/src/Splatting/SplatSortWorker.ts ================================================ /** Worker for sorting splats. */ import MakeSorterModulePromise from "./WasmSorter/Sorter.mjs"; export type SorterWorkerIncoming = | { setBuffer: Uint32Array; setGroupIndices: Uint32Array; } | { setTz_camera_groups: Float32Array; } | { close: true }; { let sorter: any = null; let Tz_camera_groups: Float32Array | null = null; let sortRunning = false; const throttledSort = () => { if (sorter === null || Tz_camera_groups === null) { setTimeout(throttledSort, 1); return; } if (sortRunning) return; sortRunning = true; const lastView = Tz_camera_groups; // Important: we clone the output so we can transfer the buffer to the main // thread. Compared to relying on postMessage for copying, this reduces // backlog artifacts. const sortedIndices = ( sorter.sort(Tz_camera_groups) as Uint32Array ).slice(); // @ts-ignore self.postMessage({ sortedIndices: sortedIndices }, [sortedIndices.buffer]); setTimeout(() => { sortRunning = false; if (Tz_camera_groups === null) return; if ( !lastView.every( // Cast is needed because of closure... (val, i) => val === (Tz_camera_groups as Float32Array)[i], ) ) { throttledSort(); } }, 0); }; const SorterModulePromise = MakeSorterModulePromise(); self.onmessage = async (e) => { const data = e.data as SorterWorkerIncoming; if ("setBuffer" in data) { // Instantiate sorter with buffers populated. sorter = new (await SorterModulePromise).Sorter( data.setBuffer, data.setGroupIndices, ); } else if ("setTz_camera_groups" in data) { // Update object transforms. Tz_camera_groups = data.setTz_camera_groups; throttledSort(); } else if ("close" in data) { // Done! self.close(); } }; } ================================================ FILE: viser/src/viser/client/src/Splatting/WasmSorter/Sorter.mjs ================================================ var Module = (() => { var _scriptName = import.meta.url; return ( async function(moduleArg = {}) { var moduleRtn; var Module=moduleArg;var readyPromiseResolve,readyPromiseReject;var readyPromise=new Promise((resolve,reject)=>{readyPromiseResolve=resolve;readyPromiseReject=reject});var ENVIRONMENT_IS_WEB=typeof window=="object";var ENVIRONMENT_IS_WORKER=typeof importScripts=="function";var ENVIRONMENT_IS_NODE=typeof process=="object"&&typeof process.versions=="object"&&typeof process.versions.node=="string";if(ENVIRONMENT_IS_NODE){const{createRequire:createRequire}=await import("module");var require=createRequire(import.meta.url)}var moduleOverrides=Object.assign({},Module);var arguments_=[];var thisProgram="./this.program";var quit_=(status,toThrow)=>{throw toThrow};var scriptDirectory="";function locateFile(path){if(Module["locateFile"]){return Module["locateFile"](path,scriptDirectory)}return scriptDirectory+path}var read_,readAsync,readBinary;if(ENVIRONMENT_IS_NODE){var fs=require("fs");var nodePath=require("path");scriptDirectory=require("url").fileURLToPath(new URL("./",import.meta.url));read_=(filename,binary)=>{filename=isFileURI(filename)?new URL(filename):nodePath.normalize(filename);return fs.readFileSync(filename,binary?undefined:"utf8")};readBinary=filename=>{var ret=read_(filename,true);if(!ret.buffer){ret=new Uint8Array(ret)}return ret};readAsync=(filename,onload,onerror,binary=true)=>{filename=isFileURI(filename)?new URL(filename):nodePath.normalize(filename);fs.readFile(filename,binary?undefined:"utf8",(err,data)=>{if(err)onerror(err);else onload(binary?data.buffer:data)})};if(!Module["thisProgram"]&&process.argv.length>1){thisProgram=process.argv[1].replace(/\\/g,"/")}arguments_=process.argv.slice(2);quit_=(status,toThrow)=>{process.exitCode=status;throw toThrow}}else if(ENVIRONMENT_IS_WEB||ENVIRONMENT_IS_WORKER){if(ENVIRONMENT_IS_WORKER){scriptDirectory=self.location.href}else if(typeof document!="undefined"&&document.currentScript){scriptDirectory=document.currentScript.src}if(_scriptName){scriptDirectory=_scriptName}if(scriptDirectory.startsWith("blob:")){scriptDirectory=""}else{scriptDirectory=scriptDirectory.substr(0,scriptDirectory.replace(/[?#].*/,"").lastIndexOf("/")+1)}{read_=url=>{var xhr=new XMLHttpRequest;xhr.open("GET",url,false);xhr.send(null);return xhr.responseText};if(ENVIRONMENT_IS_WORKER){readBinary=url=>{var xhr=new XMLHttpRequest;xhr.open("GET",url,false);xhr.responseType="arraybuffer";xhr.send(null);return new Uint8Array(xhr.response)}}readAsync=(url,onload,onerror)=>{if(isFileURI(url)){var xhr=new XMLHttpRequest;xhr.open("GET",url,true);xhr.responseType="arraybuffer";xhr.onload=()=>{if(xhr.status==200||xhr.status==0&&xhr.response){onload(xhr.response);return}onerror()};xhr.onerror=onerror;xhr.send(null);return}fetch(url,{credentials:"same-origin"}).then(response=>{if(response.ok){return response.arrayBuffer()}return Promise.reject(new Error(response.status+" : "+response.url))}).then(onload,onerror)}}}else{}var out=Module["print"]||console.log.bind(console);var err=Module["printErr"]||console.error.bind(console);Object.assign(Module,moduleOverrides);moduleOverrides=null;if(Module["arguments"])arguments_=Module["arguments"];if(Module["thisProgram"])thisProgram=Module["thisProgram"];if(Module["quit"])quit_=Module["quit"];var wasmBinary;if(Module["wasmBinary"])wasmBinary=Module["wasmBinary"];var wasmMemory;var ABORT=false;var EXITSTATUS;var HEAP8,HEAPU8,HEAP16,HEAPU16,HEAP32,HEAPU32,HEAPF32,HEAPF64;function updateMemoryViews(){var b=wasmMemory.buffer;Module["HEAP8"]=HEAP8=new Int8Array(b);Module["HEAP16"]=HEAP16=new Int16Array(b);Module["HEAPU8"]=HEAPU8=new Uint8Array(b);Module["HEAPU16"]=HEAPU16=new Uint16Array(b);Module["HEAP32"]=HEAP32=new Int32Array(b);Module["HEAPU32"]=HEAPU32=new Uint32Array(b);Module["HEAPF32"]=HEAPF32=new Float32Array(b);Module["HEAPF64"]=HEAPF64=new Float64Array(b)}var __ATPRERUN__=[];var __ATINIT__=[];var __ATPOSTRUN__=[];var runtimeInitialized=false;function preRun(){if(Module["preRun"]){if(typeof Module["preRun"]=="function")Module["preRun"]=[Module["preRun"]];while(Module["preRun"].length){addOnPreRun(Module["preRun"].shift())}}callRuntimeCallbacks(__ATPRERUN__)}function initRuntime(){runtimeInitialized=true;callRuntimeCallbacks(__ATINIT__)}function postRun(){if(Module["postRun"]){if(typeof Module["postRun"]=="function")Module["postRun"]=[Module["postRun"]];while(Module["postRun"].length){addOnPostRun(Module["postRun"].shift())}}callRuntimeCallbacks(__ATPOSTRUN__)}function addOnPreRun(cb){__ATPRERUN__.unshift(cb)}function addOnInit(cb){__ATINIT__.unshift(cb)}function addOnPostRun(cb){__ATPOSTRUN__.unshift(cb)}var runDependencies=0;var runDependencyWatcher=null;var dependenciesFulfilled=null;function addRunDependency(id){runDependencies++;Module["monitorRunDependencies"]?.(runDependencies)}function removeRunDependency(id){runDependencies--;Module["monitorRunDependencies"]?.(runDependencies);if(runDependencies==0){if(runDependencyWatcher!==null){clearInterval(runDependencyWatcher);runDependencyWatcher=null}if(dependenciesFulfilled){var callback=dependenciesFulfilled;dependenciesFulfilled=null;callback()}}}function abort(what){Module["onAbort"]?.(what);what="Aborted("+what+")";err(what);ABORT=true;EXITSTATUS=1;what+=". Build with -sASSERTIONS for more info.";var e=new WebAssembly.RuntimeError(what);readyPromiseReject(e);throw e}var dataURIPrefix="data:application/octet-stream;base64,";var isDataURI=filename=>filename.startsWith(dataURIPrefix);var isFileURI=filename=>filename.startsWith("file://");function findWasmBinary(){if(Module["locateFile"]){var f="Sorter.wasm";if(!isDataURI(f)){return locateFile(f)}return f}return new URL("Sorter.wasm",import.meta.url).href}var wasmBinaryFile;function getBinarySync(file){if(file==wasmBinaryFile&&wasmBinary){return new Uint8Array(wasmBinary)}if(readBinary){return readBinary(file)}throw"both async and sync fetching of the wasm failed"}function getBinaryPromise(binaryFile){if(!wasmBinary){return new Promise((resolve,reject)=>{readAsync(binaryFile,response=>resolve(new Uint8Array(response)),error=>{try{resolve(getBinarySync(binaryFile))}catch(e){reject(e)}})})}return Promise.resolve().then(()=>getBinarySync(binaryFile))}function instantiateArrayBuffer(binaryFile,imports,receiver){return getBinaryPromise(binaryFile).then(binary=>WebAssembly.instantiate(binary,imports)).then(receiver,reason=>{err(`failed to asynchronously prepare wasm: ${reason}`);abort(reason)})}function instantiateAsync(binary,binaryFile,imports,callback){if(!binary&&typeof WebAssembly.instantiateStreaming=="function"&&!isDataURI(binaryFile)&&!isFileURI(binaryFile)&&!ENVIRONMENT_IS_NODE&&typeof fetch=="function"){return fetch(binaryFile,{credentials:"same-origin"}).then(response=>{var result=WebAssembly.instantiateStreaming(response,imports);return result.then(callback,function(reason){err(`wasm streaming compile failed: ${reason}`);err("falling back to ArrayBuffer instantiation");return instantiateArrayBuffer(binaryFile,imports,callback)})})}return instantiateArrayBuffer(binaryFile,imports,callback)}function getWasmImports(){return{a:wasmImports}}function createWasm(){var info=getWasmImports();function receiveInstance(instance,module){wasmExports=instance.exports;wasmMemory=wasmExports["z"];updateMemoryViews();wasmTable=wasmExports["C"];addOnInit(wasmExports["A"]);removeRunDependency("wasm-instantiate");return wasmExports}addRunDependency("wasm-instantiate");function receiveInstantiationResult(result){receiveInstance(result["instance"])}if(Module["instantiateWasm"]){try{return Module["instantiateWasm"](info,receiveInstance)}catch(e){err(`Module.instantiateWasm callback failed with error: ${e}`);readyPromiseReject(e)}}if(!wasmBinaryFile)wasmBinaryFile=findWasmBinary();instantiateAsync(wasmBinary,wasmBinaryFile,info,receiveInstantiationResult).catch(readyPromiseReject);return{}}var callRuntimeCallbacks=callbacks=>{while(callbacks.length>0){callbacks.shift()(Module)}};var noExitRuntime=Module["noExitRuntime"]||true;class ExceptionInfo{constructor(excPtr){this.excPtr=excPtr;this.ptr=excPtr-24}set_type(type){HEAPU32[this.ptr+4>>2]=type}get_type(){return HEAPU32[this.ptr+4>>2]}set_destructor(destructor){HEAPU32[this.ptr+8>>2]=destructor}get_destructor(){return HEAPU32[this.ptr+8>>2]}set_caught(caught){caught=caught?1:0;HEAP8[this.ptr+12]=caught}get_caught(){return HEAP8[this.ptr+12]!=0}set_rethrown(rethrown){rethrown=rethrown?1:0;HEAP8[this.ptr+13]=rethrown}get_rethrown(){return HEAP8[this.ptr+13]!=0}init(type,destructor){this.set_adjusted_ptr(0);this.set_type(type);this.set_destructor(destructor)}set_adjusted_ptr(adjustedPtr){HEAPU32[this.ptr+16>>2]=adjustedPtr}get_adjusted_ptr(){return HEAPU32[this.ptr+16>>2]}get_exception_ptr(){var isPointer=___cxa_is_pointer_type(this.get_type());if(isPointer){return HEAPU32[this.excPtr>>2]}var adjusted=this.get_adjusted_ptr();if(adjusted!==0)return adjusted;return this.excPtr}}var exceptionLast=0;var uncaughtExceptionCount=0;var ___cxa_throw=(ptr,type,destructor)=>{var info=new ExceptionInfo(ptr);info.init(type,destructor);exceptionLast=ptr;uncaughtExceptionCount++;throw exceptionLast};var __abort_js=()=>{abort("")};var __embind_register_bigint=(primitiveType,name,size,minRange,maxRange)=>{};var embind_init_charCodes=()=>{var codes=new Array(256);for(var i=0;i<256;++i){codes[i]=String.fromCharCode(i)}embind_charCodes=codes};var embind_charCodes;var readLatin1String=ptr=>{var ret="";var c=ptr;while(HEAPU8[c]){ret+=embind_charCodes[HEAPU8[c++]]}return ret};var awaitingDependencies={};var registeredTypes={};var typeDependencies={};var BindingError;var throwBindingError=message=>{throw new BindingError(message)};var InternalError;var throwInternalError=message=>{throw new InternalError(message)};var whenDependentTypesAreResolved=(myTypes,dependentTypes,getTypeConverters)=>{myTypes.forEach(function(type){typeDependencies[type]=dependentTypes});function onComplete(typeConverters){var myTypeConverters=getTypeConverters(typeConverters);if(myTypeConverters.length!==myTypes.length){throwInternalError("Mismatched type converter count")}for(var i=0;i{if(registeredTypes.hasOwnProperty(dt)){typeConverters[i]=registeredTypes[dt]}else{unregisteredTypes.push(dt);if(!awaitingDependencies.hasOwnProperty(dt)){awaitingDependencies[dt]=[]}awaitingDependencies[dt].push(()=>{typeConverters[i]=registeredTypes[dt];++registered;if(registered===unregisteredTypes.length){onComplete(typeConverters)}})}});if(0===unregisteredTypes.length){onComplete(typeConverters)}};function sharedRegisterType(rawType,registeredInstance,options={}){var name=registeredInstance.name;if(!rawType){throwBindingError(`type "${name}" must have a positive integer typeid pointer`)}if(registeredTypes.hasOwnProperty(rawType)){if(options.ignoreDuplicateRegistrations){return}else{throwBindingError(`Cannot register type '${name}' twice`)}}registeredTypes[rawType]=registeredInstance;delete typeDependencies[rawType];if(awaitingDependencies.hasOwnProperty(rawType)){var callbacks=awaitingDependencies[rawType];delete awaitingDependencies[rawType];callbacks.forEach(cb=>cb())}}function registerType(rawType,registeredInstance,options={}){if(!("argPackAdvance"in registeredInstance)){throw new TypeError("registerType registeredInstance requires argPackAdvance")}return sharedRegisterType(rawType,registeredInstance,options)}var GenericWireTypeSize=8;var __embind_register_bool=(rawType,name,trueValue,falseValue)=>{name=readLatin1String(name);registerType(rawType,{name:name,fromWireType:function(wt){return!!wt},toWireType:function(destructors,o){return o?trueValue:falseValue},argPackAdvance:GenericWireTypeSize,readValueFromPointer:function(pointer){return this["fromWireType"](HEAPU8[pointer])},destructorFunction:null})};var shallowCopyInternalPointer=o=>({count:o.count,deleteScheduled:o.deleteScheduled,preservePointerOnDelete:o.preservePointerOnDelete,ptr:o.ptr,ptrType:o.ptrType,smartPtr:o.smartPtr,smartPtrType:o.smartPtrType});var throwInstanceAlreadyDeleted=obj=>{function getInstanceTypeName(handle){return handle.$$.ptrType.registeredClass.name}throwBindingError(getInstanceTypeName(obj)+" instance already deleted")};var finalizationRegistry=false;var detachFinalizer=handle=>{};var runDestructor=$$=>{if($$.smartPtr){$$.smartPtrType.rawDestructor($$.smartPtr)}else{$$.ptrType.registeredClass.rawDestructor($$.ptr)}};var releaseClassHandle=$$=>{$$.count.value-=1;var toDelete=0===$$.count.value;if(toDelete){runDestructor($$)}};var downcastPointer=(ptr,ptrClass,desiredClass)=>{if(ptrClass===desiredClass){return ptr}if(undefined===desiredClass.baseClass){return null}var rv=downcastPointer(ptr,ptrClass,desiredClass.baseClass);if(rv===null){return null}return desiredClass.downcast(rv)};var registeredPointers={};var getInheritedInstanceCount=()=>Object.keys(registeredInstances).length;var getLiveInheritedInstances=()=>{var rv=[];for(var k in registeredInstances){if(registeredInstances.hasOwnProperty(k)){rv.push(registeredInstances[k])}}return rv};var deletionQueue=[];var flushPendingDeletes=()=>{while(deletionQueue.length){var obj=deletionQueue.pop();obj.$$.deleteScheduled=false;obj["delete"]()}};var delayFunction;var setDelayFunction=fn=>{delayFunction=fn;if(deletionQueue.length&&delayFunction){delayFunction(flushPendingDeletes)}};var init_embind=()=>{Module["getInheritedInstanceCount"]=getInheritedInstanceCount;Module["getLiveInheritedInstances"]=getLiveInheritedInstances;Module["flushPendingDeletes"]=flushPendingDeletes;Module["setDelayFunction"]=setDelayFunction};var registeredInstances={};var getBasestPointer=(class_,ptr)=>{if(ptr===undefined){throwBindingError("ptr should not be undefined")}while(class_.baseClass){ptr=class_.upcast(ptr);class_=class_.baseClass}return ptr};var getInheritedInstance=(class_,ptr)=>{ptr=getBasestPointer(class_,ptr);return registeredInstances[ptr]};var makeClassHandle=(prototype,record)=>{if(!record.ptrType||!record.ptr){throwInternalError("makeClassHandle requires ptr and ptrType")}var hasSmartPtrType=!!record.smartPtrType;var hasSmartPtr=!!record.smartPtr;if(hasSmartPtrType!==hasSmartPtr){throwInternalError("Both smartPtrType and smartPtr must be specified")}record.count={value:1};return attachFinalizer(Object.create(prototype,{$$:{value:record,writable:true}}))};function RegisteredPointer_fromWireType(ptr){var rawPointer=this.getPointee(ptr);if(!rawPointer){this.destructor(ptr);return null}var registeredInstance=getInheritedInstance(this.registeredClass,rawPointer);if(undefined!==registeredInstance){if(0===registeredInstance.$$.count.value){registeredInstance.$$.ptr=rawPointer;registeredInstance.$$.smartPtr=ptr;return registeredInstance["clone"]()}else{var rv=registeredInstance["clone"]();this.destructor(ptr);return rv}}function makeDefaultHandle(){if(this.isSmartPointer){return makeClassHandle(this.registeredClass.instancePrototype,{ptrType:this.pointeeType,ptr:rawPointer,smartPtrType:this,smartPtr:ptr})}else{return makeClassHandle(this.registeredClass.instancePrototype,{ptrType:this,ptr:ptr})}}var actualType=this.registeredClass.getActualType(rawPointer);var registeredPointerRecord=registeredPointers[actualType];if(!registeredPointerRecord){return makeDefaultHandle.call(this)}var toType;if(this.isConst){toType=registeredPointerRecord.constPointerType}else{toType=registeredPointerRecord.pointerType}var dp=downcastPointer(rawPointer,this.registeredClass,toType.registeredClass);if(dp===null){return makeDefaultHandle.call(this)}if(this.isSmartPointer){return makeClassHandle(toType.registeredClass.instancePrototype,{ptrType:toType,ptr:dp,smartPtrType:this,smartPtr:ptr})}else{return makeClassHandle(toType.registeredClass.instancePrototype,{ptrType:toType,ptr:dp})}}var attachFinalizer=handle=>{if("undefined"===typeof FinalizationRegistry){attachFinalizer=handle=>handle;return handle}finalizationRegistry=new FinalizationRegistry(info=>{releaseClassHandle(info.$$)});attachFinalizer=handle=>{var $$=handle.$$;var hasSmartPtr=!!$$.smartPtr;if(hasSmartPtr){var info={$$:$$};finalizationRegistry.register(handle,info,handle)}return handle};detachFinalizer=handle=>finalizationRegistry.unregister(handle);return attachFinalizer(handle)};var init_ClassHandle=()=>{Object.assign(ClassHandle.prototype,{isAliasOf(other){if(!(this instanceof ClassHandle)){return false}if(!(other instanceof ClassHandle)){return false}var leftClass=this.$$.ptrType.registeredClass;var left=this.$$.ptr;other.$$=other.$$;var rightClass=other.$$.ptrType.registeredClass;var right=other.$$.ptr;while(leftClass.baseClass){left=leftClass.upcast(left);leftClass=leftClass.baseClass}while(rightClass.baseClass){right=rightClass.upcast(right);rightClass=rightClass.baseClass}return leftClass===rightClass&&left===right},clone(){if(!this.$$.ptr){throwInstanceAlreadyDeleted(this)}if(this.$$.preservePointerOnDelete){this.$$.count.value+=1;return this}else{var clone=attachFinalizer(Object.create(Object.getPrototypeOf(this),{$$:{value:shallowCopyInternalPointer(this.$$)}}));clone.$$.count.value+=1;clone.$$.deleteScheduled=false;return clone}},delete(){if(!this.$$.ptr){throwInstanceAlreadyDeleted(this)}if(this.$$.deleteScheduled&&!this.$$.preservePointerOnDelete){throwBindingError("Object already scheduled for deletion")}detachFinalizer(this);releaseClassHandle(this.$$);if(!this.$$.preservePointerOnDelete){this.$$.smartPtr=undefined;this.$$.ptr=undefined}},isDeleted(){return!this.$$.ptr},deleteLater(){if(!this.$$.ptr){throwInstanceAlreadyDeleted(this)}if(this.$$.deleteScheduled&&!this.$$.preservePointerOnDelete){throwBindingError("Object already scheduled for deletion")}deletionQueue.push(this);if(deletionQueue.length===1&&delayFunction){delayFunction(flushPendingDeletes)}this.$$.deleteScheduled=true;return this}})};function ClassHandle(){}var createNamedFunction=(name,body)=>Object.defineProperty(body,"name",{value:name});var ensureOverloadTable=(proto,methodName,humanName)=>{if(undefined===proto[methodName].overloadTable){var prevFunc=proto[methodName];proto[methodName]=function(...args){if(!proto[methodName].overloadTable.hasOwnProperty(args.length)){throwBindingError(`Function '${humanName}' called with an invalid number of arguments (${args.length}) - expects one of (${proto[methodName].overloadTable})!`)}return proto[methodName].overloadTable[args.length].apply(this,args)};proto[methodName].overloadTable=[];proto[methodName].overloadTable[prevFunc.argCount]=prevFunc}};var exposePublicSymbol=(name,value,numArguments)=>{if(Module.hasOwnProperty(name)){if(undefined===numArguments||undefined!==Module[name].overloadTable&&undefined!==Module[name].overloadTable[numArguments]){throwBindingError(`Cannot register public name '${name}' twice`)}ensureOverloadTable(Module,name,name);if(Module.hasOwnProperty(numArguments)){throwBindingError(`Cannot register multiple overloads of a function with the same number of arguments (${numArguments})!`)}Module[name].overloadTable[numArguments]=value}else{Module[name]=value;if(undefined!==numArguments){Module[name].numArguments=numArguments}}};var char_0=48;var char_9=57;var makeLegalFunctionName=name=>{if(undefined===name){return"_unknown"}name=name.replace(/[^a-zA-Z0-9_]/g,"$");var f=name.charCodeAt(0);if(f>=char_0&&f<=char_9){return`_${name}`}return name};function RegisteredClass(name,constructor,instancePrototype,rawDestructor,baseClass,getActualType,upcast,downcast){this.name=name;this.constructor=constructor;this.instancePrototype=instancePrototype;this.rawDestructor=rawDestructor;this.baseClass=baseClass;this.getActualType=getActualType;this.upcast=upcast;this.downcast=downcast;this.pureVirtualFunctions=[]}var upcastPointer=(ptr,ptrClass,desiredClass)=>{while(ptrClass!==desiredClass){if(!ptrClass.upcast){throwBindingError(`Expected null or instance of ${desiredClass.name}, got an instance of ${ptrClass.name}`)}ptr=ptrClass.upcast(ptr);ptrClass=ptrClass.baseClass}return ptr};function constNoSmartPtrRawPointerToWireType(destructors,handle){if(handle===null){if(this.isReference){throwBindingError(`null is not a valid ${this.name}`)}return 0}if(!handle.$$){throwBindingError(`Cannot pass "${embindRepr(handle)}" as a ${this.name}`)}if(!handle.$$.ptr){throwBindingError(`Cannot pass deleted object as a pointer of type ${this.name}`)}var handleClass=handle.$$.ptrType.registeredClass;var ptr=upcastPointer(handle.$$.ptr,handleClass,this.registeredClass);return ptr}function genericPointerToWireType(destructors,handle){var ptr;if(handle===null){if(this.isReference){throwBindingError(`null is not a valid ${this.name}`)}if(this.isSmartPointer){ptr=this.rawConstructor();if(destructors!==null){destructors.push(this.rawDestructor,ptr)}return ptr}else{return 0}}if(!handle||!handle.$$){throwBindingError(`Cannot pass "${embindRepr(handle)}" as a ${this.name}`)}if(!handle.$$.ptr){throwBindingError(`Cannot pass deleted object as a pointer of type ${this.name}`)}if(!this.isConst&&handle.$$.ptrType.isConst){throwBindingError(`Cannot convert argument of type ${handle.$$.smartPtrType?handle.$$.smartPtrType.name:handle.$$.ptrType.name} to parameter type ${this.name}`)}var handleClass=handle.$$.ptrType.registeredClass;ptr=upcastPointer(handle.$$.ptr,handleClass,this.registeredClass);if(this.isSmartPointer){if(undefined===handle.$$.smartPtr){throwBindingError("Passing raw pointer to smart pointer is illegal")}switch(this.sharingPolicy){case 0:if(handle.$$.smartPtrType===this){ptr=handle.$$.smartPtr}else{throwBindingError(`Cannot convert argument of type ${handle.$$.smartPtrType?handle.$$.smartPtrType.name:handle.$$.ptrType.name} to parameter type ${this.name}`)}break;case 1:ptr=handle.$$.smartPtr;break;case 2:if(handle.$$.smartPtrType===this){ptr=handle.$$.smartPtr}else{var clonedHandle=handle["clone"]();ptr=this.rawShare(ptr,Emval.toHandle(()=>clonedHandle["delete"]()));if(destructors!==null){destructors.push(this.rawDestructor,ptr)}}break;default:throwBindingError("Unsupporting sharing policy")}}return ptr}function nonConstNoSmartPtrRawPointerToWireType(destructors,handle){if(handle===null){if(this.isReference){throwBindingError(`null is not a valid ${this.name}`)}return 0}if(!handle.$$){throwBindingError(`Cannot pass "${embindRepr(handle)}" as a ${this.name}`)}if(!handle.$$.ptr){throwBindingError(`Cannot pass deleted object as a pointer of type ${this.name}`)}if(handle.$$.ptrType.isConst){throwBindingError(`Cannot convert argument of type ${handle.$$.ptrType.name} to parameter type ${this.name}`)}var handleClass=handle.$$.ptrType.registeredClass;var ptr=upcastPointer(handle.$$.ptr,handleClass,this.registeredClass);return ptr}function readPointer(pointer){return this["fromWireType"](HEAPU32[pointer>>2])}var init_RegisteredPointer=()=>{Object.assign(RegisteredPointer.prototype,{getPointee(ptr){if(this.rawGetPointee){ptr=this.rawGetPointee(ptr)}return ptr},destructor(ptr){this.rawDestructor?.(ptr)},argPackAdvance:GenericWireTypeSize,readValueFromPointer:readPointer,fromWireType:RegisteredPointer_fromWireType})};function RegisteredPointer(name,registeredClass,isReference,isConst,isSmartPointer,pointeeType,sharingPolicy,rawGetPointee,rawConstructor,rawShare,rawDestructor){this.name=name;this.registeredClass=registeredClass;this.isReference=isReference;this.isConst=isConst;this.isSmartPointer=isSmartPointer;this.pointeeType=pointeeType;this.sharingPolicy=sharingPolicy;this.rawGetPointee=rawGetPointee;this.rawConstructor=rawConstructor;this.rawShare=rawShare;this.rawDestructor=rawDestructor;if(!isSmartPointer&®isteredClass.baseClass===undefined){if(isConst){this["toWireType"]=constNoSmartPtrRawPointerToWireType;this.destructorFunction=null}else{this["toWireType"]=nonConstNoSmartPtrRawPointerToWireType;this.destructorFunction=null}}else{this["toWireType"]=genericPointerToWireType}}var replacePublicSymbol=(name,value,numArguments)=>{if(!Module.hasOwnProperty(name)){throwInternalError("Replacing nonexistent public symbol")}if(undefined!==Module[name].overloadTable&&undefined!==numArguments){Module[name].overloadTable[numArguments]=value}else{Module[name]=value;Module[name].argCount=numArguments}};var dynCallLegacy=(sig,ptr,args)=>{sig=sig.replace(/p/g,"i");var f=Module["dynCall_"+sig];return f(ptr,...args)};var wasmTableMirror=[];var wasmTable;var getWasmTableEntry=funcPtr=>{var func=wasmTableMirror[funcPtr];if(!func){if(funcPtr>=wasmTableMirror.length)wasmTableMirror.length=funcPtr+1;wasmTableMirror[funcPtr]=func=wasmTable.get(funcPtr)}return func};var dynCall=(sig,ptr,args=[])=>{if(sig.includes("j")){return dynCallLegacy(sig,ptr,args)}var rtn=getWasmTableEntry(ptr)(...args);return rtn};var getDynCaller=(sig,ptr)=>(...args)=>dynCall(sig,ptr,args);var embind__requireFunction=(signature,rawFunction)=>{signature=readLatin1String(signature);function makeDynCaller(){if(signature.includes("j")){return getDynCaller(signature,rawFunction)}return getWasmTableEntry(rawFunction)}var fp=makeDynCaller();if(typeof fp!="function"){throwBindingError(`unknown function pointer with signature ${signature}: ${rawFunction}`)}return fp};var extendError=(baseErrorType,errorName)=>{var errorClass=createNamedFunction(errorName,function(message){this.name=errorName;this.message=message;var stack=new Error(message).stack;if(stack!==undefined){this.stack=this.toString()+"\n"+stack.replace(/^Error(:[^\n]*)?\n/,"")}});errorClass.prototype=Object.create(baseErrorType.prototype);errorClass.prototype.constructor=errorClass;errorClass.prototype.toString=function(){if(this.message===undefined){return this.name}else{return`${this.name}: ${this.message}`}};return errorClass};var UnboundTypeError;var getTypeName=type=>{var ptr=___getTypeName(type);var rv=readLatin1String(ptr);_free(ptr);return rv};var throwUnboundTypeError=(message,types)=>{var unboundTypes=[];var seen={};function visit(type){if(seen[type]){return}if(registeredTypes[type]){return}if(typeDependencies[type]){typeDependencies[type].forEach(visit);return}unboundTypes.push(type);seen[type]=true}types.forEach(visit);throw new UnboundTypeError(`${message}: `+unboundTypes.map(getTypeName).join([", "]))};var __embind_register_class=(rawType,rawPointerType,rawConstPointerType,baseClassRawType,getActualTypeSignature,getActualType,upcastSignature,upcast,downcastSignature,downcast,name,destructorSignature,rawDestructor)=>{name=readLatin1String(name);getActualType=embind__requireFunction(getActualTypeSignature,getActualType);upcast&&=embind__requireFunction(upcastSignature,upcast);downcast&&=embind__requireFunction(downcastSignature,downcast);rawDestructor=embind__requireFunction(destructorSignature,rawDestructor);var legalFunctionName=makeLegalFunctionName(name);exposePublicSymbol(legalFunctionName,function(){throwUnboundTypeError(`Cannot construct ${name} due to unbound types`,[baseClassRawType])});whenDependentTypesAreResolved([rawType,rawPointerType,rawConstPointerType],baseClassRawType?[baseClassRawType]:[],base=>{base=base[0];var baseClass;var basePrototype;if(baseClassRawType){baseClass=base.registeredClass;basePrototype=baseClass.instancePrototype}else{basePrototype=ClassHandle.prototype}var constructor=createNamedFunction(name,function(...args){if(Object.getPrototypeOf(this)!==instancePrototype){throw new BindingError("Use 'new' to construct "+name)}if(undefined===registeredClass.constructor_body){throw new BindingError(name+" has no accessible constructor")}var body=registeredClass.constructor_body[args.length];if(undefined===body){throw new BindingError(`Tried to invoke ctor of ${name} with invalid number of parameters (${args.length}) - expected (${Object.keys(registeredClass.constructor_body).toString()}) parameters instead!`)}return body.apply(this,args)});var instancePrototype=Object.create(basePrototype,{constructor:{value:constructor}});constructor.prototype=instancePrototype;var registeredClass=new RegisteredClass(name,constructor,instancePrototype,rawDestructor,baseClass,getActualType,upcast,downcast);if(registeredClass.baseClass){registeredClass.baseClass.__derivedClasses??=[];registeredClass.baseClass.__derivedClasses.push(registeredClass)}var referenceConverter=new RegisteredPointer(name,registeredClass,true,false,false);var pointerConverter=new RegisteredPointer(name+"*",registeredClass,false,false,false);var constPointerConverter=new RegisteredPointer(name+" const*",registeredClass,false,true,false);registeredPointers[rawType]={pointerType:pointerConverter,constPointerType:constPointerConverter};replacePublicSymbol(legalFunctionName,constructor);return[referenceConverter,pointerConverter,constPointerConverter]})};var heap32VectorToArray=(count,firstElement)=>{var array=[];for(var i=0;i>2])}return array};var runDestructors=destructors=>{while(destructors.length){var ptr=destructors.pop();var del=destructors.pop();del(ptr)}};function usesDestructorStack(argTypes){for(var i=1;i0?", ":"")+argsListWired}invokerFnBody+=(returns||isAsync?"var rv = ":"")+"invoker(fn"+(argsListWired.length>0?", ":"")+argsListWired+");\n";if(needsDestructorStack){invokerFnBody+="runDestructors(destructors);\n"}else{for(var i=isClassMethodFunc?1:2;i{var rawArgTypes=heap32VectorToArray(argCount,rawArgTypesAddr);invoker=embind__requireFunction(invokerSignature,invoker);whenDependentTypesAreResolved([],[rawClassType],classType=>{classType=classType[0];var humanName=`constructor ${classType.name}`;if(undefined===classType.registeredClass.constructor_body){classType.registeredClass.constructor_body=[]}if(undefined!==classType.registeredClass.constructor_body[argCount-1]){throw new BindingError(`Cannot register multiple constructors with identical number of parameters (${argCount-1}) for class '${classType.name}'! Overload resolution is currently only performed using the parameter count, not actual type info!`)}classType.registeredClass.constructor_body[argCount-1]=()=>{throwUnboundTypeError(`Cannot construct ${classType.name} due to unbound types`,rawArgTypes)};whenDependentTypesAreResolved([],rawArgTypes,argTypes=>{argTypes.splice(1,0,null);classType.registeredClass.constructor_body[argCount-1]=craftInvokerFunction(humanName,argTypes,null,invoker,rawConstructor);return[]});return[]})};var getFunctionName=signature=>{signature=signature.trim();const argsIndex=signature.indexOf("(");if(argsIndex!==-1){return signature.substr(0,argsIndex)}else{return signature}};var __embind_register_class_function=(rawClassType,methodName,argCount,rawArgTypesAddr,invokerSignature,rawInvoker,context,isPureVirtual,isAsync)=>{var rawArgTypes=heap32VectorToArray(argCount,rawArgTypesAddr);methodName=readLatin1String(methodName);methodName=getFunctionName(methodName);rawInvoker=embind__requireFunction(invokerSignature,rawInvoker);whenDependentTypesAreResolved([],[rawClassType],classType=>{classType=classType[0];var humanName=`${classType.name}.${methodName}`;if(methodName.startsWith("@@")){methodName=Symbol[methodName.substring(2)]}if(isPureVirtual){classType.registeredClass.pureVirtualFunctions.push(methodName)}function unboundTypesHandler(){throwUnboundTypeError(`Cannot call ${humanName} due to unbound types`,rawArgTypes)}var proto=classType.registeredClass.instancePrototype;var method=proto[methodName];if(undefined===method||undefined===method.overloadTable&&method.className!==classType.name&&method.argCount===argCount-2){unboundTypesHandler.argCount=argCount-2;unboundTypesHandler.className=classType.name;proto[methodName]=unboundTypesHandler}else{ensureOverloadTable(proto,methodName,humanName);proto[methodName].overloadTable[argCount-2]=unboundTypesHandler}whenDependentTypesAreResolved([],rawArgTypes,argTypes=>{var memberFunction=craftInvokerFunction(humanName,argTypes,classType,rawInvoker,context,isAsync);if(undefined===proto[methodName].overloadTable){memberFunction.argCount=argCount-2;proto[methodName]=memberFunction}else{proto[methodName].overloadTable[argCount-2]=memberFunction}return[]});return[]})};var emval_freelist=[];var emval_handles=[];var __emval_decref=handle=>{if(handle>9&&0===--emval_handles[handle+1]){emval_handles[handle]=undefined;emval_freelist.push(handle)}};var count_emval_handles=()=>emval_handles.length/2-5-emval_freelist.length;var init_emval=()=>{emval_handles.push(0,1,undefined,1,null,1,true,1,false,1);Module["count_emval_handles"]=count_emval_handles};var Emval={toValue:handle=>{if(!handle){throwBindingError("Cannot use deleted val. handle = "+handle)}return emval_handles[handle]},toHandle:value=>{switch(value){case undefined:return 2;case null:return 4;case true:return 6;case false:return 8;default:{const handle=emval_freelist.pop()||emval_handles.length;emval_handles[handle]=value;emval_handles[handle+1]=1;return handle}}}};var EmValType={name:"emscripten::val",fromWireType:handle=>{var rv=Emval.toValue(handle);__emval_decref(handle);return rv},toWireType:(destructors,value)=>Emval.toHandle(value),argPackAdvance:GenericWireTypeSize,readValueFromPointer:readPointer,destructorFunction:null};var __embind_register_emval=rawType=>registerType(rawType,EmValType);var embindRepr=v=>{if(v===null){return"null"}var t=typeof v;if(t==="object"||t==="array"||t==="function"){return v.toString()}else{return""+v}};var floatReadValueFromPointer=(name,width)=>{switch(width){case 4:return function(pointer){return this["fromWireType"](HEAPF32[pointer>>2])};case 8:return function(pointer){return this["fromWireType"](HEAPF64[pointer>>3])};default:throw new TypeError(`invalid float width (${width}): ${name}`)}};var __embind_register_float=(rawType,name,size)=>{name=readLatin1String(name);registerType(rawType,{name:name,fromWireType:value=>value,toWireType:(destructors,value)=>value,argPackAdvance:GenericWireTypeSize,readValueFromPointer:floatReadValueFromPointer(name,size),destructorFunction:null})};var integerReadValueFromPointer=(name,width,signed)=>{switch(width){case 1:return signed?pointer=>HEAP8[pointer]:pointer=>HEAPU8[pointer];case 2:return signed?pointer=>HEAP16[pointer>>1]:pointer=>HEAPU16[pointer>>1];case 4:return signed?pointer=>HEAP32[pointer>>2]:pointer=>HEAPU32[pointer>>2];default:throw new TypeError(`invalid integer width (${width}): ${name}`)}};var __embind_register_integer=(primitiveType,name,size,minRange,maxRange)=>{name=readLatin1String(name);if(maxRange===-1){maxRange=4294967295}var fromWireType=value=>value;if(minRange===0){var bitshift=32-8*size;fromWireType=value=>value<>>bitshift}var isUnsignedType=name.includes("unsigned");var checkAssertions=(value,toTypeName)=>{};var toWireType;if(isUnsignedType){toWireType=function(destructors,value){checkAssertions(value,this.name);return value>>>0}}else{toWireType=function(destructors,value){checkAssertions(value,this.name);return value}}registerType(primitiveType,{name:name,fromWireType:fromWireType,toWireType:toWireType,argPackAdvance:GenericWireTypeSize,readValueFromPointer:integerReadValueFromPointer(name,size,minRange!==0),destructorFunction:null})};var __embind_register_memory_view=(rawType,dataTypeIndex,name)=>{var typeMapping=[Int8Array,Uint8Array,Int16Array,Uint16Array,Int32Array,Uint32Array,Float32Array,Float64Array];var TA=typeMapping[dataTypeIndex];function decodeMemoryView(handle){var size=HEAPU32[handle>>2];var data=HEAPU32[handle+4>>2];return new TA(HEAP8.buffer,data,size)}name=readLatin1String(name);registerType(rawType,{name:name,fromWireType:decodeMemoryView,argPackAdvance:GenericWireTypeSize,readValueFromPointer:decodeMemoryView},{ignoreDuplicateRegistrations:true})};var stringToUTF8Array=(str,heap,outIdx,maxBytesToWrite)=>{if(!(maxBytesToWrite>0))return 0;var startIdx=outIdx;var endIdx=outIdx+maxBytesToWrite-1;for(var i=0;i=55296&&u<=57343){var u1=str.charCodeAt(++i);u=65536+((u&1023)<<10)|u1&1023}if(u<=127){if(outIdx>=endIdx)break;heap[outIdx++]=u}else if(u<=2047){if(outIdx+1>=endIdx)break;heap[outIdx++]=192|u>>6;heap[outIdx++]=128|u&63}else if(u<=65535){if(outIdx+2>=endIdx)break;heap[outIdx++]=224|u>>12;heap[outIdx++]=128|u>>6&63;heap[outIdx++]=128|u&63}else{if(outIdx+3>=endIdx)break;heap[outIdx++]=240|u>>18;heap[outIdx++]=128|u>>12&63;heap[outIdx++]=128|u>>6&63;heap[outIdx++]=128|u&63}}heap[outIdx]=0;return outIdx-startIdx};var stringToUTF8=(str,outPtr,maxBytesToWrite)=>stringToUTF8Array(str,HEAPU8,outPtr,maxBytesToWrite);var lengthBytesUTF8=str=>{var len=0;for(var i=0;i=55296&&c<=57343){len+=4;++i}else{len+=3}}return len};var UTF8Decoder=typeof TextDecoder!="undefined"?new TextDecoder("utf8"):undefined;var UTF8ArrayToString=(heapOrArray,idx,maxBytesToRead)=>{var endIdx=idx+maxBytesToRead;var endPtr=idx;while(heapOrArray[endPtr]&&!(endPtr>=endIdx))++endPtr;if(endPtr-idx>16&&heapOrArray.buffer&&UTF8Decoder){return UTF8Decoder.decode(heapOrArray.subarray(idx,endPtr))}var str="";while(idx>10,56320|ch&1023)}}return str};var UTF8ToString=(ptr,maxBytesToRead)=>ptr?UTF8ArrayToString(HEAPU8,ptr,maxBytesToRead):"";var __embind_register_std_string=(rawType,name)=>{name=readLatin1String(name);var stdStringIsUTF8=name==="std::string";registerType(rawType,{name:name,fromWireType(value){var length=HEAPU32[value>>2];var payload=value+4;var str;if(stdStringIsUTF8){var decodeStartPtr=payload;for(var i=0;i<=length;++i){var currentBytePtr=payload+i;if(i==length||HEAPU8[currentBytePtr]==0){var maxRead=currentBytePtr-decodeStartPtr;var stringSegment=UTF8ToString(decodeStartPtr,maxRead);if(str===undefined){str=stringSegment}else{str+=String.fromCharCode(0);str+=stringSegment}decodeStartPtr=currentBytePtr+1}}}else{var a=new Array(length);for(var i=0;i>2]=length;if(stdStringIsUTF8&&valueIsOfTypeString){stringToUTF8(value,ptr,length+1)}else{if(valueIsOfTypeString){for(var i=0;i255){_free(ptr);throwBindingError("String has UTF-16 code units that do not fit in 8 bits")}HEAPU8[ptr+i]=charCode}}else{for(var i=0;i{var endPtr=ptr;var idx=endPtr>>1;var maxIdx=idx+maxBytesToRead/2;while(!(idx>=maxIdx)&&HEAPU16[idx])++idx;endPtr=idx<<1;if(endPtr-ptr>32&&UTF16Decoder)return UTF16Decoder.decode(HEAPU8.subarray(ptr,endPtr));var str="";for(var i=0;!(i>=maxBytesToRead/2);++i){var codeUnit=HEAP16[ptr+i*2>>1];if(codeUnit==0)break;str+=String.fromCharCode(codeUnit)}return str};var stringToUTF16=(str,outPtr,maxBytesToWrite)=>{maxBytesToWrite??=2147483647;if(maxBytesToWrite<2)return 0;maxBytesToWrite-=2;var startPtr=outPtr;var numCharsToWrite=maxBytesToWrite>1]=codeUnit;outPtr+=2}HEAP16[outPtr>>1]=0;return outPtr-startPtr};var lengthBytesUTF16=str=>str.length*2;var UTF32ToString=(ptr,maxBytesToRead)=>{var i=0;var str="";while(!(i>=maxBytesToRead/4)){var utf32=HEAP32[ptr+i*4>>2];if(utf32==0)break;++i;if(utf32>=65536){var ch=utf32-65536;str+=String.fromCharCode(55296|ch>>10,56320|ch&1023)}else{str+=String.fromCharCode(utf32)}}return str};var stringToUTF32=(str,outPtr,maxBytesToWrite)=>{maxBytesToWrite??=2147483647;if(maxBytesToWrite<4)return 0;var startPtr=outPtr;var endPtr=startPtr+maxBytesToWrite-4;for(var i=0;i=55296&&codeUnit<=57343){var trailSurrogate=str.charCodeAt(++i);codeUnit=65536+((codeUnit&1023)<<10)|trailSurrogate&1023}HEAP32[outPtr>>2]=codeUnit;outPtr+=4;if(outPtr+4>endPtr)break}HEAP32[outPtr>>2]=0;return outPtr-startPtr};var lengthBytesUTF32=str=>{var len=0;for(var i=0;i=55296&&codeUnit<=57343)++i;len+=4}return len};var __embind_register_std_wstring=(rawType,charSize,name)=>{name=readLatin1String(name);var decodeString,encodeString,readCharAt,lengthBytesUTF;if(charSize===2){decodeString=UTF16ToString;encodeString=stringToUTF16;lengthBytesUTF=lengthBytesUTF16;readCharAt=pointer=>HEAPU16[pointer>>1]}else if(charSize===4){decodeString=UTF32ToString;encodeString=stringToUTF32;lengthBytesUTF=lengthBytesUTF32;readCharAt=pointer=>HEAPU32[pointer>>2]}registerType(rawType,{name:name,fromWireType:value=>{var length=HEAPU32[value>>2];var str;var decodeStartPtr=value+4;for(var i=0;i<=length;++i){var currentBytePtr=value+4+i*charSize;if(i==length||readCharAt(currentBytePtr)==0){var maxReadBytes=currentBytePtr-decodeStartPtr;var stringSegment=decodeString(decodeStartPtr,maxReadBytes);if(str===undefined){str=stringSegment}else{str+=String.fromCharCode(0);str+=stringSegment}decodeStartPtr=currentBytePtr+charSize}}_free(value);return str},toWireType:(destructors,value)=>{if(!(typeof value=="string")){throwBindingError(`Cannot pass non-string to C++ string type ${name}`)}var length=lengthBytesUTF(value);var ptr=_malloc(4+length+charSize);HEAPU32[ptr>>2]=length/charSize;encodeString(value,ptr+4,length+charSize);if(destructors!==null){destructors.push(_free,ptr)}return ptr},argPackAdvance:GenericWireTypeSize,readValueFromPointer:readPointer,destructorFunction(ptr){_free(ptr)}})};var __embind_register_void=(rawType,name)=>{name=readLatin1String(name);registerType(rawType,{isVoid:true,name:name,argPackAdvance:0,fromWireType:()=>undefined,toWireType:(destructors,o)=>undefined})};var __emscripten_memcpy_js=(dest,src,num)=>HEAPU8.copyWithin(dest,src,src+num);var requireRegisteredType=(rawType,humanName)=>{var impl=registeredTypes[rawType];if(undefined===impl){throwBindingError(`${humanName} has unknown type ${getTypeName(rawType)}`)}return impl};var emval_returnValue=(returnType,destructorsRef,handle)=>{var destructors=[];var result=returnType["toWireType"](destructors,handle);if(destructors.length){HEAPU32[destructorsRef>>2]=Emval.toHandle(destructors)}return result};var __emval_as=(handle,returnType,destructorsRef)=>{handle=Emval.toValue(handle);returnType=requireRegisteredType(returnType,"emval::as");return emval_returnValue(returnType,destructorsRef,handle)};var emval_symbols={};var getStringOrSymbol=address=>{var symbol=emval_symbols[address];if(symbol===undefined){return readLatin1String(address)}return symbol};var emval_methodCallers=[];var __emval_call_method=(caller,objHandle,methodName,destructorsRef,args)=>{caller=emval_methodCallers[caller];objHandle=Emval.toValue(objHandle);methodName=getStringOrSymbol(methodName);return caller(objHandle,objHandle[methodName],destructorsRef,args)};var emval_addMethodCaller=caller=>{var id=emval_methodCallers.length;emval_methodCallers.push(caller);return id};var emval_lookupTypes=(argCount,argTypes)=>{var a=new Array(argCount);for(var i=0;i>2],"parameter "+i)}return a};var reflectConstruct=Reflect.construct;var __emval_get_method_caller=(argCount,argTypes,kind)=>{var types=emval_lookupTypes(argCount,argTypes);var retType=types.shift();argCount--;var functionBody=`return function (obj, func, destructorsRef, args) {\n`;var offset=0;var argsList=[];if(kind===0){argsList.push("obj")}var params=["retType"];var args=[retType];for(var i=0;it.name).join(", ")}) => ${retType.name}>`;return emval_addMethodCaller(createNamedFunction(functionName,invokerFunction))};var __emval_get_property=(handle,key)=>{handle=Emval.toValue(handle);key=Emval.toValue(key);return Emval.toHandle(handle[key])};var __emval_incref=handle=>{if(handle>9){emval_handles[handle+1]+=1}};var __emval_new_cstring=v=>Emval.toHandle(getStringOrSymbol(v));var __emval_run_destructors=handle=>{var destructors=Emval.toValue(handle);runDestructors(destructors);__emval_decref(handle)};var __emval_take_value=(type,arg)=>{type=requireRegisteredType(type,"_emval_take_value");var v=type["readValueFromPointer"](arg);return Emval.toHandle(v)};var getHeapMax=()=>1073741824;var growMemory=size=>{var b=wasmMemory.buffer;var pages=(size-b.byteLength+65535)/65536;try{wasmMemory.grow(pages);updateMemoryViews();return 1}catch(e){}};var _emscripten_resize_heap=requestedSize=>{var oldSize=HEAPU8.length;requestedSize>>>=0;var maxHeapSize=getHeapMax();if(requestedSize>maxHeapSize){return false}var alignUp=(x,multiple)=>x+(multiple-x%multiple)%multiple;for(var cutDown=1;cutDown<=4;cutDown*=2){var overGrownHeapSize=oldSize*(1+.2/cutDown);overGrownHeapSize=Math.min(overGrownHeapSize,requestedSize+100663296);var newSize=Math.min(maxHeapSize,alignUp(Math.max(requestedSize,overGrownHeapSize),65536));var replacement=growMemory(newSize);if(replacement){return true}}return false};embind_init_charCodes();BindingError=Module["BindingError"]=class BindingError extends Error{constructor(message){super(message);this.name="BindingError"}};InternalError=Module["InternalError"]=class InternalError extends Error{constructor(message){super(message);this.name="InternalError"}};init_ClassHandle();init_embind();init_RegisteredPointer();UnboundTypeError=Module["UnboundTypeError"]=extendError(Error,"UnboundTypeError");init_emval();var wasmImports={n:___cxa_throw,t:__abort_js,s:__embind_register_bigint,x:__embind_register_bool,r:__embind_register_class,q:__embind_register_class_constructor,k:__embind_register_class_function,w:__embind_register_emval,m:__embind_register_float,c:__embind_register_integer,a:__embind_register_memory_view,l:__embind_register_std_string,f:__embind_register_std_wstring,y:__embind_register_void,v:__emscripten_memcpy_js,h:__emval_as,o:__emval_call_method,b:__emval_decref,p:__emval_get_method_caller,i:__emval_get_property,g:__emval_incref,j:__emval_new_cstring,d:__emval_run_destructors,e:__emval_take_value,u:_emscripten_resize_heap};var wasmExports=createWasm();var ___wasm_call_ctors=()=>(___wasm_call_ctors=wasmExports["A"])();var ___getTypeName=a0=>(___getTypeName=wasmExports["B"])(a0);var _malloc=a0=>(_malloc=wasmExports["D"])(a0);var _free=a0=>(_free=wasmExports["E"])(a0);var ___cxa_is_pointer_type=a0=>(___cxa_is_pointer_type=wasmExports["F"])(a0);Module["addOnPostRun"]=addOnPostRun;var calledRun;dependenciesFulfilled=function runCaller(){if(!calledRun)run();if(!calledRun)dependenciesFulfilled=runCaller};function run(){if(runDependencies>0){return}preRun();if(runDependencies>0){return}function doRun(){if(calledRun)return;calledRun=true;Module["calledRun"]=true;if(ABORT)return;initRuntime();readyPromiseResolve(Module);if(Module["onRuntimeInitialized"])Module["onRuntimeInitialized"]();postRun()}if(Module["setStatus"]){Module["setStatus"]("Running...");setTimeout(function(){setTimeout(function(){Module["setStatus"]("")},1);doRun()},1)}else{doRun()}}if(Module["preInit"]){if(typeof Module["preInit"]=="function")Module["preInit"]=[Module["preInit"]];while(Module["preInit"].length>0){Module["preInit"].pop()()}}run();moduleRtn=readyPromise; return moduleRtn; } ); })(); export default Module; ================================================ FILE: viser/src/viser/client/src/Splatting/WasmSorter/build.sh ================================================ #!/usr/bin/env bash emcc --bind -O3 sorter.cpp -o Sorter.mjs -s WASM=1 -s NO_EXIT_RUNTIME=1 -s "EXPORTED_RUNTIME_METHODS=['addOnPostRun']" -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=1GB -s STACK_SIZE=2097152 -msimd128; ================================================ FILE: viser/src/viser/client/src/Splatting/WasmSorter/sorter.cpp ================================================ #include #include #include #include #include #include #include /** SIMD dot product between two 4D vectors. */ __attribute__((always_inline)) inline float dot_f32x4(const v128_t &a, const v128_t &b) { v128_t product = wasm_f32x4_mul(a, b); v128_t temp = wasm_f32x4_add( product, wasm_i32x4_shuffle(product, product, 1, 0, 3, 2) ); v128_t tmp = wasm_f32x4_add(temp, wasm_i32x4_shuffle(temp, temp, 2, 3, 0, 1)); return wasm_f32x4_extract_lane(tmp, 0); } // Function to find the minimum value across a v128_t i32x4 vector. __attribute__((always_inline)) inline int32_t min_i32x4(v128_t vector) { int32_t elem0 = wasm_i32x4_extract_lane(vector, 0); int32_t elem1 = wasm_i32x4_extract_lane(vector, 1); int32_t elem2 = wasm_i32x4_extract_lane(vector, 2); int32_t elem3 = wasm_i32x4_extract_lane(vector, 3); return std::min({elem0, elem1, elem2, elem3}); } // Function to find the maximum value across a v128_t i32x4 vector and return it // as a float. __attribute__((always_inline)) inline int32_t max_i32x4(v128_t vector) { int32_t elem0 = wasm_i32x4_extract_lane(vector, 0); int32_t elem1 = wasm_i32x4_extract_lane(vector, 1); int32_t elem2 = wasm_i32x4_extract_lane(vector, 2); int32_t elem3 = wasm_i32x4_extract_lane(vector, 3); return std::max({elem0, elem1, elem2, elem3}); } class Sorter { std::vector centers_homog; // Centers as homogeneous coordinates. std::vector group_indices; std::vector sorted_indices; public: Sorter( const emscripten::val &buffer, const emscripten::val &group_indices_val ) { const std::vector bufferVec = emscripten::convertJSArrayToNumberVector(buffer); const float *floatBuffer = reinterpret_cast(bufferVec.data()); const int32_t num_gaussians = bufferVec.size() / 8; sorted_indices.resize(num_gaussians); centers_homog.resize(num_gaussians); for (int32_t i = 0; i < num_gaussians; i++) { centers_homog[i] = wasm_f32x4_make( floatBuffer[i * 8 + 0], floatBuffer[i * 8 + 1], floatBuffer[i * 8 + 2], 1.0 ); } group_indices = emscripten::convertJSArrayToNumberVector(group_indices_val ); }; // Run sorting using the newest view projection matrix. Mutates internal // buffers. emscripten::val sort(const emscripten::val &Tz_cam_groups_val) { const auto Tz_cam_groups_buffer = emscripten::convertJSArrayToNumberVector(Tz_cam_groups_val); const int32_t num_gaussians = centers_homog.size(); // We do a 16-bit counting sort. This is mostly translated from Kevin // Kwok's Javascript implementation: // https://github.com/antimatter15/splat/blob/main/main.js // // Note: we want to sort from minimum Z (high depth) to maximum Z (low // depth). const int32_t padded_length = std::ceil(num_gaussians / 4.0); std::vector gaussian_zs(padded_length); std::array counts0({0}); std::array starts0({0}); const int32_t num_groups = Tz_cam_groups_buffer.size() / 4; std::vector Tz_cam_groups(num_groups); const v128_t row3 = wasm_f32x4_make(0.0, 0.0, 0.0, 1.0); for (int32_t i = 0; i < num_groups; i++) { Tz_cam_groups[i] = wasm_v128_load(&Tz_cam_groups_buffer[i * 4]); } v128_t min_z_i32x4; v128_t max_z_i32x4; const v128_t splat4096 = wasm_f32x4_splat(4096.0); for (int32_t i = 0; i < padded_length; i++) { // This should get inlined. int32_t gaussianIndex = i * 4; const float z0 = dot_f32x4( Tz_cam_groups[group_indices[gaussianIndex]], centers_homog[gaussianIndex] ); gaussianIndex++; const float z1 = dot_f32x4( Tz_cam_groups[group_indices[gaussianIndex]], centers_homog[gaussianIndex] ); gaussianIndex++; const float z2 = dot_f32x4( Tz_cam_groups[group_indices[gaussianIndex]], centers_homog[gaussianIndex] ); gaussianIndex++; const float z3 = dot_f32x4( Tz_cam_groups[group_indices[gaussianIndex]], centers_homog[gaussianIndex] ); const v128_t cam_z = wasm_f32x4_make(z0, z1, z2, z3); // OpenGL camera convention: -Z is forward. const v128_t depth = wasm_f32x4_neg(cam_z); const v128_t z_int = wasm_i32x4_trunc_sat_f32x4(wasm_f32x4_mul(cam_z, splat4096)); gaussian_zs[i] = z_int; if (i == 0) { min_z_i32x4 = z_int; max_z_i32x4 = z_int; } else { // Currently, we incorrectly include padding elements in the // min/max. min_z_i32x4 = wasm_i32x4_min(min_z_i32x4, z_int); max_z_i32x4 = wasm_i32x4_max(max_z_i32x4, z_int); } } min_z_i32x4 = wasm_i32x4_splat(min_i32x4(min_z_i32x4)); max_z_i32x4 = wasm_i32x4_splat(max_i32x4(max_z_i32x4)); const v128_t z_inv = wasm_f32x4_div( wasm_f32x4_splat(256 * 256 - 1), wasm_f32x4_add( wasm_f32x4_convert_i32x4( wasm_i32x4_sub(max_z_i32x4, min_z_i32x4) ), wasm_f32x4_splat(1e-5f) ) ); for (int32_t i = 0; i < padded_length; i++) { const v128_t z_bin = wasm_i32x4_trunc_sat_f32x4(wasm_f32x4_mul( wasm_f32x4_convert_i32x4( wasm_i32x4_sub(gaussian_zs[i], min_z_i32x4) ), z_inv )); gaussian_zs[i] = z_bin; counts0[wasm_i32x4_extract_lane(z_bin, 0)]++; if (i == padded_length - 1) { if (i * 4 + 1 < num_gaussians) counts0[wasm_i32x4_extract_lane(z_bin, 1)]++; if (i * 4 + 2 < num_gaussians) counts0[wasm_i32x4_extract_lane(z_bin, 2)]++; if (i * 4 + 3 < num_gaussians) counts0[wasm_i32x4_extract_lane(z_bin, 3)]++; } else { counts0[wasm_i32x4_extract_lane(z_bin, 1)]++; counts0[wasm_i32x4_extract_lane(z_bin, 2)]++; counts0[wasm_i32x4_extract_lane(z_bin, 3)]++; } } for (int32_t i = 1; i < 256 * 256; i++) { starts0[i] = starts0[i - 1] + counts0[i - 1]; } // Update and return sorted indices. for (int32_t i = 0; i < num_gaussians; i++) sorted_indices[starts0[((int32_t *)&gaussian_zs[0])[i]]++] = i; return emscripten::val(emscripten::typed_memory_view( sorted_indices.size(), &(sorted_indices[0]) )); } }; EMSCRIPTEN_BINDINGS(c) { emscripten::class_("Sorter") .constructor() .function("sort", &Sorter::sort, emscripten::allow_raw_pointers()); }; ================================================ FILE: viser/src/viser/client/src/ThreeAssets.tsx ================================================ import { Instance, Instances, shaderMaterial } from "@react-three/drei"; import { createPortal, useFrame, useThree } from "@react-three/fiber"; import { Outlines } from "./Outlines"; import React from "react"; import * as THREE from "three"; import { GLTF, GLTFLoader } from "three/examples/jsm/loaders/GLTFLoader"; import { MeshBasicMaterial, MeshDepthMaterial, MeshDistanceMaterial, MeshLambertMaterial, MeshMatcapMaterial, MeshNormalMaterial, MeshPhongMaterial, MeshPhysicalMaterial, MeshStandardMaterial, MeshToonMaterial, ShadowMaterial, SpriteMaterial, RawShaderMaterial, ShaderMaterial, PointsMaterial, LineBasicMaterial, LineDashedMaterial, } from "three"; import { DRACOLoader } from "three/examples/jsm/loaders/DRACOLoader"; type AllPossibleThreeJSMaterials = | MeshBasicMaterial | MeshDepthMaterial | MeshDistanceMaterial | MeshLambertMaterial | MeshMatcapMaterial | MeshNormalMaterial | MeshPhongMaterial | MeshPhysicalMaterial | MeshStandardMaterial | MeshToonMaterial | ShadowMaterial | SpriteMaterial | RawShaderMaterial | ShaderMaterial | PointsMaterial | LineBasicMaterial | LineDashedMaterial; const originGeom = new THREE.SphereGeometry(1.0); const originMaterial = new THREE.MeshBasicMaterial({ color: 0xecec00 }); const PointCloudMaterial = /* @__PURE__ */ shaderMaterial( { scale: 1.0, point_ball_norm: 0.0 }, ` varying vec3 vPosition; varying vec3 vColor; // in the vertex shader uniform float scale; void main() { vPosition = position; vColor = color; vec4 world_pos = modelViewMatrix * vec4(position, 1.0); gl_Position = projectionMatrix * world_pos; gl_PointSize = (scale / -world_pos.z); } `, `varying vec3 vPosition; varying vec3 vColor; uniform float point_ball_norm; void main() { if (point_ball_norm < 1000.0) { float r = pow( pow(abs(gl_PointCoord.x - 0.5), point_ball_norm) + pow(abs(gl_PointCoord.y - 0.5), point_ball_norm), 1.0 / point_ball_norm); if (r > 0.5) discard; } gl_FragColor = vec4(vColor, 1.0); } `, ); export const PointCloud = React.forwardRef< THREE.Points, { pointSize: number; /** We visualize each point as a 2D ball, which is defined by some norm. */ pointBallNorm: number; points: Float32Array; colors: Float32Array; } >(function PointCloud(props, ref) { const getThreeState = useThree((state) => state.get); const geometry = new THREE.BufferGeometry(); geometry.setAttribute( "position", new THREE.Float32BufferAttribute(props.points, 3), ); geometry.computeBoundingSphere(); geometry.setAttribute( "color", new THREE.Float32BufferAttribute(props.colors, 3), ); const [material] = React.useState( () => new PointCloudMaterial({ vertexColors: true }), ); material.uniforms.scale.value = 10.0; material.uniforms.point_ball_norm.value = props.pointBallNorm; React.useEffect(() => { return () => { material.dispose(); geometry.dispose(); }; }); const rendererSize = new THREE.Vector2(); useFrame(() => { // Match point scale to behavior of THREE.PointsMaterial(). if (material === undefined) return; // point px height / actual height = point meters height / frustum meters height // frustum meters height = math.tan(fov / 2.0) * z // point px height = (point meters height / math.tan(fov / 2.0) * actual height) / z material.uniforms.scale.value = (props.pointSize / Math.tan( (((getThreeState().camera as THREE.PerspectiveCamera).fov / 180.0) * Math.PI) / 2.0, )) * getThreeState().gl.getSize(rendererSize).height * getThreeState().gl.getPixelRatio(); }); return ; }); /** Component for rendering the contents of GLB files. */ export const GlbAsset = React.forwardRef< THREE.Group, { glb_data: Uint8Array; scale: number } >(function GlbAsset({ glb_data, scale }, ref) { // We track both the GLTF asset itself and all meshes within it. Meshes are // used for hover effects. const [gltf, setGltf] = React.useState(); const [meshes, setMeshes] = React.useState([]); // glTF/GLB files support animations. const mixerRef = React.useRef(null); React.useEffect(() => { const loader = new GLTFLoader(); // We use a CDN for Draco. We could move this locally if we want to use Viser offline. const dracoLoader = new DRACOLoader(); dracoLoader.setDecoderPath("https://www.gstatic.com/draco/v1/decoders/"); loader.setDRACOLoader(dracoLoader); loader.parse( glb_data.buffer, "", (gltf) => { if (gltf.animations && gltf.animations.length) { mixerRef.current = new THREE.AnimationMixer(gltf.scene); gltf.animations.forEach((clip) => { mixerRef.current!.clipAction(clip).play(); }); } const meshes: THREE.Mesh[] = []; gltf?.scene.traverse((obj) => { if (obj instanceof THREE.Mesh) meshes.push(obj); }); setMeshes(meshes); setGltf(gltf); }, (error) => { console.log("Error loading GLB!"); console.log(error); }, ); return () => { if (mixerRef.current) mixerRef.current.stopAllAction(); function disposeNode(node: any) { if (node instanceof THREE.Mesh) { if (node.geometry) { node.geometry.dispose(); } if (node.material) { if (Array.isArray(node.material)) { node.material.forEach((material) => { disposeMaterial(material); }); } else { disposeMaterial(node.material); } } } } function disposeMaterial(material: AllPossibleThreeJSMaterials) { if ("map" in material) material.map?.dispose(); if ("lightMap" in material) material.lightMap?.dispose(); if ("bumpMap" in material) material.bumpMap?.dispose(); if ("normalMap" in material) material.normalMap?.dispose(); if ("specularMap" in material) material.specularMap?.dispose(); if ("envMap" in material) material.envMap?.dispose(); if ("alphaMap" in material) material.alphaMap?.dispose(); if ("aoMap" in material) material.aoMap?.dispose(); if ("displacementMap" in material) material.displacementMap?.dispose(); if ("emissiveMap" in material) material.emissiveMap?.dispose(); if ("gradientMap" in material) material.gradientMap?.dispose(); if ("metalnessMap" in material) material.metalnessMap?.dispose(); if ("roughnessMap" in material) material.roughnessMap?.dispose(); material.dispose(); // disposes any programs associated with the material } // Attempt to free resources. gltf?.scene.traverse(disposeNode); }; }, [glb_data]); useFrame((_, delta) => { if (mixerRef.current) { mixerRef.current.update(delta); } }); return ( {gltf === undefined ? null : ( <> {meshes.map((mesh) => createPortal(, mesh), )} )} ); }); /** Helper for adding coordinate frames as scene nodes. */ export const CoordinateFrame = React.forwardRef< THREE.Group, { showAxes?: boolean; axesLength?: number; axesRadius?: number; originRadius?: number; } >(function CoordinateFrame( { showAxes = true, axesLength = 0.5, axesRadius = 0.0125, originRadius = undefined, }, ref, ) { originRadius = originRadius ?? axesRadius * 2; return ( {showAxes && ( <> )} ); }); /** Helper for adding batched/instanced coordinate frames as scene nodes. */ export const InstancedAxes = React.forwardRef< THREE.Group, { wxyzsBatched: Float32Array; positionsBatched: Float32Array; axes_length?: number; axes_radius?: number; } >(function InstancedAxes( { wxyzsBatched: instance_wxyzs, positionsBatched: instance_positions, axes_length = 0.5, axes_radius = 0.0125, }, ref, ) { const axesRef = React.useRef(null); const cylinderGeom = new THREE.CylinderGeometry( axes_radius, axes_radius, axes_length, 16, ); const material = new MeshBasicMaterial(); // Dispose when done. React.useEffect(() => { return () => { cylinderGeom.dispose(); material.dispose(); }; }); // Update instance matrices and colors. React.useEffect(() => { // Pre-allocate to avoid garbage collector from running during loop. const T_world_frame = new THREE.Matrix4(); const T_world_framex = new THREE.Matrix4(); const T_world_framey = new THREE.Matrix4(); const T_world_framez = new THREE.Matrix4(); const T_frame_framex = new THREE.Matrix4() .makeRotationFromEuler(new THREE.Euler(0.0, 0.0, (3.0 * Math.PI) / 2.0)) .setPosition(0.5 * axes_length, 0.0, 0.0); const T_frame_framey = new THREE.Matrix4() .makeRotationFromEuler(new THREE.Euler(0.0, 0.0, 0.0)) .setPosition(0.0, 0.5 * axes_length, 0.0); const T_frame_framez = new THREE.Matrix4() .makeRotationFromEuler(new THREE.Euler(Math.PI / 2.0, 0.0, 0.0)) .setPosition(0.0, 0.0, 0.5 * axes_length); const tmpQuat = new THREE.Quaternion(); const red = new THREE.Color(0xcc0000); const green = new THREE.Color(0x00cc00); const blue = new THREE.Color(0x0000cc); for (let i = 0; i < instance_wxyzs.length / 4; i++) { T_world_frame.makeRotationFromQuaternion( tmpQuat.set( instance_wxyzs[i * 4 + 1], instance_wxyzs[i * 4 + 2], instance_wxyzs[i * 4 + 3], instance_wxyzs[i * 4 + 0], ), ).setPosition( instance_positions[i * 3 + 0], instance_positions[i * 3 + 1], instance_positions[i * 3 + 2], ); T_world_framex.copy(T_world_frame).multiply(T_frame_framex); T_world_framey.copy(T_world_frame).multiply(T_frame_framey); T_world_framez.copy(T_world_frame).multiply(T_frame_framez); axesRef.current!.setMatrixAt(i * 3 + 0, T_world_framex); axesRef.current!.setMatrixAt(i * 3 + 1, T_world_framey); axesRef.current!.setMatrixAt(i * 3 + 2, T_world_framez); axesRef.current!.setColorAt(i * 3 + 0, red); axesRef.current!.setColorAt(i * 3 + 1, green); axesRef.current!.setColorAt(i * 3 + 2, blue); } axesRef.current!.instanceMatrix.needsUpdate = true; axesRef.current!.instanceColor!.needsUpdate = true; }, [instance_wxyzs, instance_positions]); return ( ); }); /** Helper for visualizing camera frustums. */ export const CameraFrustum = React.forwardRef< THREE.Group, { fov: number; aspect: number; scale: number; color: number; thickness?: number; // Added thickness property image?: THREE.Texture; } >(function CameraFrustum(props, ref) { let y = Math.tan(props.fov / 2.0); let x = y * props.aspect; let z = 1.0; const volumeScale = Math.cbrt((x * y * z) / 3.0); x /= volumeScale; y /= volumeScale; z /= volumeScale; function scaledLineSegments(points: [number, number, number][], thickness = 1.0) { points = points.map((xyz) => [xyz[0] * x, xyz[1] * y, xyz[2] * z]); return [...Array(points.length - 1).keys()].map((i) => ( )); } const lineThickness = props.thickness || 1.0; // Default to 1.0 if not provided return ( {scaledLineSegments( [ // Rectangle. [-1, -1, 1], [1, -1, 1], [1, 1, 1], [-1, 1, 1], [-1, -1, 1], ], lineThickness // Pass thickness to scaledLineSegments )} {scaledLineSegments( [ // Lines to origin. [-1, -1, 1], [0, 0, 0], [1, -1, 1], ], lineThickness // Pass thickness to scaledLineSegments )} {scaledLineSegments( [ // Lines to origin. [-1, 1, 1], [0, 0, 0], [1, 1, 1], ], lineThickness // Pass thickness to scaledLineSegments )} {scaledLineSegments( [ // Up direction. [0.0, -1.2, 1.0], [0.0, -0.9, 1.0], ], lineThickness // Pass thickness to scaledLineSegments )} {props.image && ( )} ); }); function LineSegmentInstance(props: { start: THREE.Vector3; end: THREE.Vector3; color: number; thickness?: number; // Optional thickness property }) { const desiredDirection = new THREE.Vector3() .subVectors(props.end, props.start) .normalize(); const canonicalDirection = new THREE.Vector3(0.0, 1.0, 0.0); const orientation = new THREE.Quaternion().setFromUnitVectors( canonicalDirection, desiredDirection, ); const length = props.start.distanceTo(props.end); const midpoint = new THREE.Vector3() .addVectors(props.start, props.end) .divideScalar(2.0); const thickness = props.thickness || 1.0; // Default to 1.0 if not provided return ( ); } export const HoverableContext = React.createContext | null>(null); /** Outlines object, which should be placed as a child of all meshes that might * be clickable. */ export function OutlinesIfHovered( props: { alwaysMounted?: boolean; creaseAngle?: number } = { // Can be set to true for objects like meshes which may be slow to mount. // It seems better to set to False for instanced meshes, there may be some // drei or fiber-related race conditions... alwaysMounted: false, // Some thing just look better with no creasing, like camera frustum objects. creaseAngle: Math.PI, }, ) { const groupRef = React.useRef(null); const hoveredRef = React.useContext(HoverableContext); const [mounted, setMounted] = React.useState(true); useFrame(() => { if (hoveredRef === null) return; if (props.alwaysMounted) { if (groupRef.current === null) return; groupRef.current.visible = hoveredRef.current; } else if (hoveredRef.current != mounted) { setMounted(hoveredRef.current); } }); return hoveredRef === null || !mounted ? null : ( ); } ================================================ FILE: viser/src/viser/client/src/Titlebar.tsx ================================================ import { ViewerContext } from "./App"; import { ThemeConfigurationMessage } from "./WebsocketMessages"; import { Burger, Button, Container, Group, Paper, Box, useMantineColorScheme, Portal, } from "@mantine/core"; import { IconBrandGithub, IconFileDescription, IconKeyboard, } from "@tabler/icons-react"; import { useDisclosure } from "@mantine/hooks"; import { useContext } from "react"; // Type helpers. type ArrayElement = ArrayType extends readonly (infer ElementType)[] ? ElementType : never; type TitlebarContent = NonNullable< ThemeConfigurationMessage["titlebar_content"] >; function assertUnreachable(x: never): never { throw new Error("Didn't expect to get here", x); } function getIcon( icon: ArrayElement>["icon"], ) { let Icon = null; switch (icon) { case null: break; case "GitHub": Icon = IconBrandGithub; break; case "Description": Icon = IconFileDescription; break; case "Keyboard": Icon = IconKeyboard; break; default: assertUnreachable(icon); } return Icon; } // We inherit props directly from message contents. export function TitlebarButton( props: ArrayElement>, ) { const Icon = getIcon(props.icon); return ( ); } export function MobileTitlebarButton( props: ArrayElement>, ) { const Icon = getIcon(props.icon); return ( ); } export function TitlebarImage( props: NonNullable, colorScheme: string, ) { let imageSource: string; if (props.image_url_dark == null || colorScheme === "light") { imageSource = props.image_url_light; } else { imageSource = props.image_url_dark; } const image = ( {props.image_alt} ); if (props.href == null) { return image; } return ( {image} ); } export function Titlebar() { const viewer = useContext(ViewerContext)!; const content = viewer.useGui((state) => state.theme.titlebar_content); const colorScheme = useMantineColorScheme().colorScheme; const [burgerOpen, burgerHandlers] = useDisclosure(false); if (content == null) { return null; } const buttons = content.buttons; const imageData = content.image; return ( ({ height: "100%", display: "flex", alignItems: "center", })} > ({ marginRight: "auto" })}> {imageData !== null ? TitlebarImage(imageData, colorScheme) : null} ({ flexWrap: "nowrap", overflowX: "scroll", msOverflowStyle: "none", scrollbarWidth: "none", })} > {buttons?.map((btn, index) => ( ))} {buttons?.map((btn, index) => ( ))} ); } ================================================ FILE: viser/src/viser/client/src/Utils.ts ================================================ // Drag Utils export interface DragEvents { move: "touchmove" | "mousemove"; end: "touchend" | "mouseup"; } export const touchEvents: DragEvents = { move: "touchmove", end: "touchend" }; export const mouseEvents: DragEvents = { move: "mousemove", end: "mouseup" }; export function isTouchEvent( event: TouchEvent | MouseEvent, ): event is TouchEvent { return event.type === "touchmove"; } export function isMouseEvent( event: TouchEvent | MouseEvent, ): event is MouseEvent { return event.type === "mousemove"; } ================================================ FILE: viser/src/viser/client/src/WebsocketFunctions.tsx ================================================ import React from "react"; import * as THREE from "three"; import { Message } from "./WebsocketMessages"; import { ViewerContext, ViewerContextContents } from "./App"; /** Easier, hook version of makeThrottledMessageSender. */ export function useThrottledMessageSender(throttleMilliseconds: number) { const viewer = React.useContext(ViewerContext)!; return makeThrottledMessageSender(viewer, throttleMilliseconds); } /** Returns a function for sending messages, with automatic throttling. */ export function makeThrottledMessageSender( viewer: ViewerContextContents, throttleMilliseconds: number, ) { let readyToSend = true; let stale = false; let latestMessage: Message | null = null; function send(message: Message) { if (viewer.sendMessageRef.current === null) return; latestMessage = message; if (readyToSend) { viewer.sendMessageRef.current(message); stale = false; readyToSend = false; setTimeout(() => { readyToSend = true; if (!stale) return; latestMessage && send(latestMessage); }, throttleMilliseconds); } else { stale = true; } } return send; } /** Type guard for threejs textures. Meant to be used with `scene.background`. */ export function isTexture( background: | THREE.Color | THREE.Texture | THREE.CubeTexture | null | undefined, ): background is THREE.Texture { return ( background !== null && background !== undefined && (background as THREE.Texture).isTexture !== undefined ); } ================================================ FILE: viser/src/viser/client/src/WebsocketInterface.tsx ================================================ import WebsocketServerWorker from "./WebsocketServerWorker?worker"; import React, { useContext } from "react"; import { ViewerContext } from "./App"; import { syncSearchParamServer } from "./SearchParamsUtils"; import { WsWorkerIncoming, WsWorkerOutgoing } from "./WebsocketServerWorker"; /** Component for handling websocket connections. */ export function WebsocketMessageProducer() { const messageQueueRef = useContext(ViewerContext)!.messageQueueRef; const viewer = useContext(ViewerContext)!; const server = viewer.useGui((state) => state.server); const resetGui = viewer.useGui((state) => state.resetGui); syncSearchParamServer(server); React.useEffect(() => { const worker = new WebsocketServerWorker(); worker.onmessage = (event) => { const data: WsWorkerOutgoing = event.data; if (data.type === "connected") { resetGui(); viewer.useGui.setState({ websocketConnected: true }); viewer.sendMessageRef.current = (message) => { postToWorker({ type: "send", message: message }); }; } else if (data.type === "closed") { resetGui(); viewer.useGui.setState({ websocketConnected: false }); viewer.sendMessageRef.current = (message) => { console.log( `Tried to send ${message.type} but websocket is not connected!`, ); }; } else if (data.type === "message_batch") { messageQueueRef.current.push(...data.messages); } }; function postToWorker(data: WsWorkerIncoming) { worker.postMessage(data); } postToWorker({ type: "set_server", server: server }); return () => { postToWorker({ type: "close" }); viewer.sendMessageRef.current = (message) => console.log( `Tried to send ${message.type} but websocket is not connected!`, ); viewer.useGui.setState({ websocketConnected: false }); }; }, [server, resetGui]); return <>; } ================================================ FILE: viser/src/viser/client/src/WebsocketMessages.tsx ================================================ // AUTOMATICALLY GENERATED message interfaces, from Python dataclass definitions. // This file should not be manually modified. /** Message for running some arbitrary Javascript on the client. * We use this to set up the Plotly.js package, via the plotly.min.js source * code. * * (automatically generated) */ export interface RunJavascriptMessage { type: "RunJavascriptMessage"; source: string; } /** Notification message. * * (automatically generated) */ export interface NotificationMessage { type: "NotificationMessage"; mode: "show" | "update"; id: string; title: string; body: string; loading: boolean; with_close_button: boolean; auto_close: number | false; color: | "dark" | "gray" | "red" | "pink" | "grape" | "violet" | "indigo" | "blue" | "cyan" | "green" | "lime" | "yellow" | "orange" | "teal" | null; } /** Remove a specific notification. * * (automatically generated) */ export interface RemoveNotificationMessage { type: "RemoveNotificationMessage"; id: string; } /** Message for a posed viewer camera. * Pose is in the form T_world_camera, OpenCV convention, +Z forward. * * (automatically generated) */ export interface ViewerCameraMessage { type: "ViewerCameraMessage"; wxyz: [number, number, number, number]; position: [number, number, number]; fov: number; aspect: number; look_at: [number, number, number]; up_direction: [number, number, number]; } /** Message for a raycast-like pointer in the scene. * origin is the viewing camera position, in world coordinates. * direction is the vector if a ray is projected from the camera through the clicked pixel, * * * (automatically generated) */ export interface ScenePointerMessage { type: "ScenePointerMessage"; event_type: "click" | "rect-select"; ray_origin: [number, number, number] | null; ray_direction: [number, number, number] | null; screen_pos: [number, number][]; } /** Message to enable/disable scene click events. * * (automatically generated) */ export interface ScenePointerEnableMessage { type: "ScenePointerEnableMessage"; enable: boolean; event_type: "click" | "rect-select"; } /** Variant of CameraMessage used for visualizing camera frustums. * * OpenCV convention, +Z forward. * * (automatically generated) */ export interface CameraFrustumMessage { type: "CameraFrustumMessage"; name: string; fov: number; aspect: number; scale: number; color: number; thickness: number; image_media_type: "image/jpeg" | "image/png" | null; image_binary: Uint8Array | null; } /** GlTF message. * * (automatically generated) */ export interface GlbMessage { type: "GlbMessage"; name: string; glb_data: Uint8Array; scale: number; } /** Coordinate frame message. * * (automatically generated) */ export interface FrameMessage { type: "FrameMessage"; name: string; show_axes: boolean; axes_length: number; axes_radius: number; origin_radius: number; } /** Batched axes message. * * Positions and orientations should follow a `T_parent_local` convention, which * corresponds to the R matrix and t vector in `p_parent = [R | t] p_local`. * * (automatically generated) */ export interface BatchedAxesMessage { type: "BatchedAxesMessage"; name: string; wxyzs_batched: Uint8Array; positions_batched: Uint8Array; axes_length: number; axes_radius: number; } /** Grid message. Helpful for visualizing things like ground planes. * * (automatically generated) */ export interface GridMessage { type: "GridMessage"; name: string; width: number; height: number; width_segments: number; height_segments: number; plane: "xz" | "xy" | "yx" | "yz" | "zx" | "zy"; cell_color: number; cell_thickness: number; cell_size: number; section_color: number; section_thickness: number; section_size: number; } /** Add a 2D label to the scene. * * (automatically generated) */ export interface LabelMessage { type: "LabelMessage"; name: string; text: string; } /** Add a 3D gui element to the scene. * * (automatically generated) */ export interface Gui3DMessage { type: "Gui3DMessage"; order: number; name: string; container_id: string; } /** Point cloud message. * * Positions are internally canonicalized to float32, colors to uint8. * * Float color inputs should be in the range [0,1], int color inputs should be in the * range [0,255]. * * (automatically generated) */ export interface PointCloudMessage { type: "PointCloudMessage"; name: string; points: Uint8Array; colors: Uint8Array; point_size: number; point_ball_norm: number; } /** Message for a bone of a skinned mesh. * * (automatically generated) */ export interface MeshBoneMessage { type: "MeshBoneMessage"; name: string; } /** Mesh message. * * Vertices are internally canonicalized to float32, faces to uint32. * * (automatically generated) */ export interface MeshMessage { type: "MeshMessage"; name: string; vertices: Uint8Array; faces: Uint8Array; color: number | null; vertex_colors: Uint8Array | null; wireframe: boolean; opacity: number | null; flat_shading: boolean; side: "front" | "back" | "double"; material: "standard" | "toon3" | "toon5"; } /** Mesh message. * * Vertices are internally canonicalized to float32, faces to uint32. * * (automatically generated) */ export interface SkinnedMeshMessage { type: "SkinnedMeshMessage"; name: string; vertices: Uint8Array; faces: Uint8Array; color: number | null; vertex_colors: Uint8Array | null; wireframe: boolean; opacity: number | null; flat_shading: boolean; side: "front" | "back" | "double"; material: "standard" | "toon3" | "toon5"; bone_wxyzs: [number, number, number, number][]; bone_positions: [number, number, number][]; skin_indices: Uint8Array; skin_weights: Uint8Array; } /** Server -> client message to set a skinned mesh bone's orientation. * * As with all other messages, transforms take the `T_parent_local` convention. * * (automatically generated) */ export interface SetBoneOrientationMessage { type: "SetBoneOrientationMessage"; name: string; bone_index: number; wxyz: [number, number, number, number]; } /** Server -> client message to set a skinned mesh bone's position. * * As with all other messages, transforms take the `T_parent_local` convention. * * (automatically generated) */ export interface SetBonePositionMessage { type: "SetBonePositionMessage"; name: string; bone_index: number; position: [number, number, number]; } /** Message for transform gizmos. * * (automatically generated) */ export interface TransformControlsMessage { type: "TransformControlsMessage"; name: string; scale: number; line_width: number; fixed: boolean; auto_transform: boolean; active_axes: [boolean, boolean, boolean]; disable_axes: boolean; disable_sliders: boolean; disable_rotations: boolean; translation_limits: [[number, number], [number, number], [number, number]]; rotation_limits: [[number, number], [number, number], [number, number]]; depth_test: boolean; opacity: number; } /** Server -> client message to set the camera's position. * * (automatically generated) */ export interface SetCameraPositionMessage { type: "SetCameraPositionMessage"; position: [number, number, number]; } /** Server -> client message to set the camera's up direction. * * (automatically generated) */ export interface SetCameraUpDirectionMessage { type: "SetCameraUpDirectionMessage"; position: [number, number, number]; } /** Server -> client message to set the camera's look-at point. * * (automatically generated) */ export interface SetCameraLookAtMessage { type: "SetCameraLookAtMessage"; look_at: [number, number, number]; } /** Server -> client message to set the camera's field of view. * * (automatically generated) */ export interface SetCameraFovMessage { type: "SetCameraFovMessage"; fov: number; } /** Server -> client message to set a scene node's orientation. * * As with all other messages, transforms take the `T_parent_local` convention. * * (automatically generated) */ export interface SetOrientationMessage { type: "SetOrientationMessage"; name: string; wxyz: [number, number, number, number]; } /** Server -> client message to set a scene node's position. * * As with all other messages, transforms take the `T_parent_local` convention. * * (automatically generated) */ export interface SetPositionMessage { type: "SetPositionMessage"; name: string; position: [number, number, number]; } /** Client -> server message when a transform control is updated. * * As with all other messages, transforms take the `T_parent_local` convention. * * (automatically generated) */ export interface TransformControlsUpdateMessage { type: "TransformControlsUpdateMessage"; name: string; wxyz: [number, number, number, number]; position: [number, number, number]; } /** Message for rendering a background image. * * (automatically generated) */ export interface BackgroundImageMessage { type: "BackgroundImageMessage"; media_type: "image/jpeg" | "image/png"; rgb_bytes: Uint8Array; depth_bytes: Uint8Array | null; } /** Message for rendering 2D images. * * (automatically generated) */ export interface ImageMessage { type: "ImageMessage"; name: string; media_type: "image/jpeg" | "image/png"; data: Uint8Array; render_width: number; render_height: number; } /** Remove a particular node from the scene. * * (automatically generated) */ export interface RemoveSceneNodeMessage { type: "RemoveSceneNodeMessage"; name: string; } /** Set the visibility of a particular node in the scene. * * (automatically generated) */ export interface SetSceneNodeVisibilityMessage { type: "SetSceneNodeVisibilityMessage"; name: string; visible: boolean; } /** Set the clickability of a particular node in the scene. * * (automatically generated) */ export interface SetSceneNodeClickableMessage { type: "SetSceneNodeClickableMessage"; name: string; clickable: boolean; } /** Message for clicked objects. * * (automatically generated) */ export interface SceneNodeClickMessage { type: "SceneNodeClickMessage"; name: string; instance_index: number | null; ray_origin: [number, number, number]; ray_direction: [number, number, number]; screen_pos: [number, number]; } /** Reset scene. * * (automatically generated) */ export interface ResetSceneMessage { type: "ResetSceneMessage"; } /** Reset GUI. * * (automatically generated) */ export interface ResetGuiMessage { type: "ResetGuiMessage"; } /** GuiAddFolderMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', expand_by_default: 'bool', visible: 'bool') * * (automatically generated) */ export interface GuiAddFolderMessage { type: "GuiAddFolderMessage"; order: number; id: string; label: string; container_id: string; expand_by_default: boolean; visible: boolean; } /** GuiAddMarkdownMessage(order: 'float', id: 'str', markdown: 'str', container_id: 'str', visible: 'bool') * * (automatically generated) */ export interface GuiAddMarkdownMessage { type: "GuiAddMarkdownMessage"; order: number; id: string; markdown: string; container_id: string; visible: boolean; } /** GuiAddProgressBarMessage(order: 'float', id: 'str', value: 'float', animated: 'bool', color: 'Optional[Color]', container_id: 'str', visible: 'bool') * * (automatically generated) */ export interface GuiAddProgressBarMessage { type: "GuiAddProgressBarMessage"; order: number; id: string; value: number; animated: boolean; color: | "dark" | "gray" | "red" | "pink" | "grape" | "violet" | "indigo" | "blue" | "cyan" | "green" | "lime" | "yellow" | "orange" | "teal" | null; container_id: string; visible: boolean; } /** GuiAddPlotlyMessage(order: 'float', id: 'str', plotly_json_str: 'str', aspect: 'float', container_id: 'str', visible: 'bool') * * (automatically generated) */ export interface GuiAddPlotlyMessage { type: "GuiAddPlotlyMessage"; order: number; id: string; plotly_json_str: string; aspect: number; container_id: string; visible: boolean; } /** GuiAddTabGroupMessage(order: 'float', id: 'str', container_id: 'str', tab_labels: 'Tuple[str, ...]', tab_icons_html: 'Tuple[Union[str, None], ...]', tab_container_ids: 'Tuple[str, ...]', visible: 'bool') * * (automatically generated) */ export interface GuiAddTabGroupMessage { type: "GuiAddTabGroupMessage"; order: number; id: string; container_id: string; tab_labels: string[]; tab_icons_html: (string | null)[]; tab_container_ids: string[]; visible: boolean; } /** Base message type containing fields commonly used by GUI inputs. * * (automatically generated) */ export interface _GuiAddInputBase { type: "_GuiAddInputBase"; order: number; id: string; label: string; container_id: string; hint: string | null; value: any; visible: boolean; disabled: boolean; } /** GuiAddButtonMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'bool', visible: 'bool', disabled: 'bool', color: 'Optional[Color]', icon_html: 'Optional[str]') * * (automatically generated) */ export interface GuiAddButtonMessage { type: "GuiAddButtonMessage"; order: number; id: string; label: string; container_id: string; hint: string | null; value: boolean; visible: boolean; disabled: boolean; color: | "dark" | "gray" | "red" | "pink" | "grape" | "violet" | "indigo" | "blue" | "cyan" | "green" | "lime" | "yellow" | "orange" | "teal" | null; icon_html: string | null; } /** GuiAddUploadButtonMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'Any', visible: 'bool', disabled: 'bool', color: 'Optional[Color]', icon_html: 'Optional[str]', mime_type: 'str') * * (automatically generated) */ export interface GuiAddUploadButtonMessage { type: "GuiAddUploadButtonMessage"; order: number; id: string; label: string; container_id: string; hint: string | null; value: any; visible: boolean; disabled: boolean; color: | "dark" | "gray" | "red" | "pink" | "grape" | "violet" | "indigo" | "blue" | "cyan" | "green" | "lime" | "yellow" | "orange" | "teal" | null; icon_html: string | null; mime_type: string; } /** GuiAddSliderMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'float', visible: 'bool', disabled: 'bool', min: 'float', max: 'float', step: 'Optional[float]', precision: 'int', marks: 'Optional[Tuple[GuiSliderMark, ...]]' = None) * * (automatically generated) */ export interface GuiAddSliderMessage { type: "GuiAddSliderMessage"; order: number; id: string; label: string; container_id: string; hint: string | null; value: number; visible: boolean; disabled: boolean; min: number; max: number; step: number | null; precision: number; marks: { value: number; label?: string }[] | null; } /** GuiAddMultiSliderMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'Any', visible: 'bool', disabled: 'bool', min: 'float', max: 'float', step: 'Optional[float]', min_range: 'Optional[float]', precision: 'int', fixed_endpoints: 'bool' = False, marks: 'Optional[Tuple[GuiSliderMark, ...]]' = None) * * (automatically generated) */ export interface GuiAddMultiSliderMessage { type: "GuiAddMultiSliderMessage"; order: number; id: string; label: string; container_id: string; hint: string | null; value: any; visible: boolean; disabled: boolean; min: number; max: number; step: number | null; min_range: number | null; precision: number; fixed_endpoints: boolean; marks: { value: number; label?: string }[] | null; } /** GuiAddNumberMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'float', visible: 'bool', disabled: 'bool', precision: 'int', step: 'float', min: 'Optional[float]', max: 'Optional[float]') * * (automatically generated) */ export interface GuiAddNumberMessage { type: "GuiAddNumberMessage"; order: number; id: string; label: string; container_id: string; hint: string | null; value: number; visible: boolean; disabled: boolean; precision: number; step: number; min: number | null; max: number | null; } /** GuiAddRgbMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'Tuple[int, int, int]', visible: 'bool', disabled: 'bool') * * (automatically generated) */ export interface GuiAddRgbMessage { type: "GuiAddRgbMessage"; order: number; id: string; label: string; container_id: string; hint: string | null; value: [number, number, number]; visible: boolean; disabled: boolean; } /** GuiAddRgbaMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'Tuple[int, int, int, int]', visible: 'bool', disabled: 'bool') * * (automatically generated) */ export interface GuiAddRgbaMessage { type: "GuiAddRgbaMessage"; order: number; id: string; label: string; container_id: string; hint: string | null; value: [number, number, number, number]; visible: boolean; disabled: boolean; } /** GuiAddCheckboxMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'bool', visible: 'bool', disabled: 'bool') * * (automatically generated) */ export interface GuiAddCheckboxMessage { type: "GuiAddCheckboxMessage"; order: number; id: string; label: string; container_id: string; hint: string | null; value: boolean; visible: boolean; disabled: boolean; } /** GuiAddVector2Message(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'Tuple[float, float]', visible: 'bool', disabled: 'bool', min: 'Optional[Tuple[float, float]]', max: 'Optional[Tuple[float, float]]', step: 'float', precision: 'int') * * (automatically generated) */ export interface GuiAddVector2Message { type: "GuiAddVector2Message"; order: number; id: string; label: string; container_id: string; hint: string | null; value: [number, number]; visible: boolean; disabled: boolean; min: [number, number] | null; max: [number, number] | null; step: number; precision: number; } /** GuiAddVector3Message(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'Tuple[float, float, float]', visible: 'bool', disabled: 'bool', min: 'Optional[Tuple[float, float, float]]', max: 'Optional[Tuple[float, float, float]]', step: 'float', precision: 'int') * * (automatically generated) */ export interface GuiAddVector3Message { type: "GuiAddVector3Message"; order: number; id: string; label: string; container_id: string; hint: string | null; value: [number, number, number]; visible: boolean; disabled: boolean; min: [number, number, number] | null; max: [number, number, number] | null; step: number; precision: number; } /** GuiAddTextMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'str', visible: 'bool', disabled: 'bool') * * (automatically generated) */ export interface GuiAddTextMessage { type: "GuiAddTextMessage"; order: number; id: string; label: string; container_id: string; hint: string | null; value: string; visible: boolean; disabled: boolean; } /** GuiAddDropdownMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'str', visible: 'bool', disabled: 'bool', options: 'Tuple[str, ...]') * * (automatically generated) */ export interface GuiAddDropdownMessage { type: "GuiAddDropdownMessage"; order: number; id: string; label: string; container_id: string; hint: string | null; value: string; visible: boolean; disabled: boolean; options: string[]; } /** GuiAddButtonGroupMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'str', visible: 'bool', disabled: 'bool', options: 'Tuple[str, ...]') * * (automatically generated) */ export interface GuiAddButtonGroupMessage { type: "GuiAddButtonGroupMessage"; order: number; id: string; label: string; container_id: string; hint: string | null; value: string; visible: boolean; disabled: boolean; options: string[]; } /** GuiModalMessage(order: 'float', id: 'str', title: 'str') * * (automatically generated) */ export interface GuiModalMessage { type: "GuiModalMessage"; order: number; id: string; title: string; } /** GuiCloseModalMessage(id: 'str') * * (automatically generated) */ export interface GuiCloseModalMessage { type: "GuiCloseModalMessage"; id: string; } /** Sent server->client to remove a GUI element. * * (automatically generated) */ export interface GuiRemoveMessage { type: "GuiRemoveMessage"; id: string; } /** Sent client<->server when any property of a GUI component is changed. * * (automatically generated) */ export interface GuiUpdateMessage { type: "GuiUpdateMessage"; id: string; updates: Partial; } /** Message from server->client to configure parts of the GUI. * * (automatically generated) */ export interface ThemeConfigurationMessage { type: "ThemeConfigurationMessage"; titlebar_content: { buttons: | { text: string | null; icon: "GitHub" | "Description" | "Keyboard" | null; href: string | null; }[] | null; image: { image_url_light: string; image_url_dark: string | null; image_alt: string; href: string | null; } | null; } | null; control_layout: "floating" | "collapsible" | "fixed"; control_width: "small" | "medium" | "large"; show_logo: boolean; show_share_button: boolean; dark_mode: boolean; colors: | [ string, string, string, string, string, string, string, string, string, string, ] | null; } /** Message from server->client carrying Catmull-Rom spline information. * * (automatically generated) */ export interface CatmullRomSplineMessage { type: "CatmullRomSplineMessage"; name: string; positions: [number, number, number][]; curve_type: "centripetal" | "chordal" | "catmullrom"; tension: number; closed: boolean; line_width: number; color: number; segments: number | null; } /** Message from server->client carrying Cubic Bezier spline information. * * (automatically generated) */ export interface CubicBezierSplineMessage { type: "CubicBezierSplineMessage"; name: string; positions: [number, number, number][]; control_points: [number, number, number][]; line_width: number; color: number; segments: number | null; } /** Message from server->client carrying splattable Gaussians. * * (automatically generated) */ export interface GaussianSplatsMessage { type: "GaussianSplatsMessage"; name: string; buffer: Uint8Array; } /** Message from server->client requesting a render of the current viewport. * * (automatically generated) */ export interface GetRenderRequestMessage { type: "GetRenderRequestMessage"; format: "image/jpeg" | "image/png"; height: number; width: number; quality: number; } /** Message from client->server carrying a render. * * (automatically generated) */ export interface GetRenderResponseMessage { type: "GetRenderResponseMessage"; payload: Uint8Array; } /** Signal that a file is about to be sent. * * (automatically generated) */ export interface FileTransferStart { type: "FileTransferStart"; source_component_id: string | null; transfer_uuid: string; filename: string; mime_type: string; part_count: number; size_bytes: number; } /** Send a file for clients to download or upload files from client. * * (automatically generated) */ export interface FileTransferPart { type: "FileTransferPart"; source_component_id: string | null; transfer_uuid: string; part: number; content: Uint8Array; } /** Send a file for clients to download or upload files from client. * * (automatically generated) */ export interface FileTransferPartAck { type: "FileTransferPartAck"; source_component_id: string | null; transfer_uuid: string; transferred_bytes: number; total_bytes: number; } /** Message from client->server to connect to the share URL server. * * (automatically generated) */ export interface ShareUrlRequest { type: "ShareUrlRequest"; } /** Message from server->client to indicate that the share URL has been updated. * * (automatically generated) */ export interface ShareUrlUpdated { type: "ShareUrlUpdated"; share_url: string | null; } /** Message from client->server to disconnect from the share URL server. * * (automatically generated) */ export interface ShareUrlDisconnect { type: "ShareUrlDisconnect"; } /** Message from server->client to set the label of the GUI panel. * * (automatically generated) */ export interface SetGuiPanelLabelMessage { type: "SetGuiPanelLabelMessage"; label: string | null; } export type Message = | RunJavascriptMessage | NotificationMessage | RemoveNotificationMessage | ViewerCameraMessage | ScenePointerMessage | ScenePointerEnableMessage | CameraFrustumMessage | GlbMessage | FrameMessage | BatchedAxesMessage | GridMessage | LabelMessage | Gui3DMessage | PointCloudMessage | MeshBoneMessage | MeshMessage | SkinnedMeshMessage | SetBoneOrientationMessage | SetBonePositionMessage | TransformControlsMessage | SetCameraPositionMessage | SetCameraUpDirectionMessage | SetCameraLookAtMessage | SetCameraFovMessage | SetOrientationMessage | SetPositionMessage | TransformControlsUpdateMessage | BackgroundImageMessage | ImageMessage | RemoveSceneNodeMessage | SetSceneNodeVisibilityMessage | SetSceneNodeClickableMessage | SceneNodeClickMessage | ResetSceneMessage | ResetGuiMessage | GuiAddFolderMessage | GuiAddMarkdownMessage | GuiAddProgressBarMessage | GuiAddPlotlyMessage | GuiAddTabGroupMessage | _GuiAddInputBase | GuiAddButtonMessage | GuiAddUploadButtonMessage | GuiAddSliderMessage | GuiAddMultiSliderMessage | GuiAddNumberMessage | GuiAddRgbMessage | GuiAddRgbaMessage | GuiAddCheckboxMessage | GuiAddVector2Message | GuiAddVector3Message | GuiAddTextMessage | GuiAddDropdownMessage | GuiAddButtonGroupMessage | GuiModalMessage | GuiCloseModalMessage | GuiRemoveMessage | GuiUpdateMessage | ThemeConfigurationMessage | CatmullRomSplineMessage | CubicBezierSplineMessage | GaussianSplatsMessage | GetRenderRequestMessage | GetRenderResponseMessage | FileTransferStart | FileTransferPart | FileTransferPartAck | ShareUrlRequest | ShareUrlUpdated | ShareUrlDisconnect | SetGuiPanelLabelMessage; export type GuiAddComponentMessage = | GuiAddFolderMessage | GuiAddMarkdownMessage | GuiAddProgressBarMessage | GuiAddPlotlyMessage | GuiAddTabGroupMessage | GuiAddButtonMessage | GuiAddUploadButtonMessage | GuiAddSliderMessage | GuiAddMultiSliderMessage | GuiAddNumberMessage | GuiAddRgbMessage | GuiAddRgbaMessage | GuiAddCheckboxMessage | GuiAddVector2Message | GuiAddVector3Message | GuiAddTextMessage | GuiAddDropdownMessage | GuiAddButtonGroupMessage; ================================================ FILE: viser/src/viser/client/src/WebsocketServerWorker.ts ================================================ import { encode, decode } from "@msgpack/msgpack"; import { Message } from "./WebsocketMessages"; import AwaitLock from "await-lock"; export type WsWorkerIncoming = | { type: "send"; message: Message } | { type: "set_server"; server: string } | { type: "close" }; export type WsWorkerOutgoing = | { type: "connected" } | { type: "closed" } | { type: "message_batch"; messages: Message[] }; // Helper function to collect all ArrayBuffer objects. This is used for postMessage() move semantics. function collectArrayBuffers(obj: any, buffers: Set) { if (obj instanceof ArrayBuffer) { buffers.add(obj); } else if (obj instanceof Uint8Array) { buffers.add(obj.buffer); } else if (obj && typeof obj === "object") { for (const key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) { collectArrayBuffers(obj[key], buffers); } } } return buffers; } { let server: string | null = null; let ws: WebSocket | null = null; const orderLock = new AwaitLock(); const postOutgoing = ( data: WsWorkerOutgoing, transferable?: Transferable[], ) => { // @ts-ignore self.postMessage(data, transferable); }; const tryConnect = () => { if (ws !== null) ws.close(); ws = new WebSocket(server!); // Timeout is necessary when we're connecting to an SSH/tunneled port. const retryTimeout = setTimeout(() => { ws?.close(); }, 5000); ws.onopen = () => { postOutgoing({ type: "connected" }); clearTimeout(retryTimeout); console.log(`Connected! ${server}`); }; ws.onclose = (event) => { postOutgoing({ type: "closed" }); console.log(`Disconnected! ${server} code=${event.code}`); clearTimeout(retryTimeout); // Try to reconnect. if (server !== null) setTimeout(tryConnect, 1000); }; ws.onmessage = async (event) => { // Reduce websocket backpressure. const messagePromise = new Promise((resolve) => { (event.data.arrayBuffer() as Promise).then((buffer) => { resolve(decode(new Uint8Array(buffer)) as Message[]); }); }); // Try our best to handle messages in order. If this takes more than 1 second, we give up. :) await orderLock.acquireAsync({ timeout: 1000 }).catch(() => { console.log("Order lock timed out."); orderLock.release(); }); try { const messages = await messagePromise; const arrayBuffers = collectArrayBuffers(messages, new Set()); postOutgoing( { type: "message_batch", messages: messages }, Array.from(arrayBuffers), ); } finally { orderLock.acquired && orderLock.release(); } }; }; self.onmessage = (e) => { const data: WsWorkerIncoming = e.data; if (data.type === "send") { ws!.send(encode(data.message)); } else if (data.type === "set_server") { server = data.server; tryConnect(); } else if (data.type == "close") { server = null; ws !== null && ws.close(); self.close(); } else { console.log( `WebSocket worker: got ${data}, not sure what to do with it!`, ); } }; } ================================================ FILE: viser/src/viser/client/src/WorldTransformUtils.ts ================================================ import { ViewerContextContents } from "./App"; import * as THREE from "three"; /** Helper for computing the transformation between the three.js world and the * Python-exposed world frames. This is useful for things like switching * between +Y and +Z up directions for the world frame. */ export function computeT_threeworld_world(viewer: ViewerContextContents) { const wxyz = viewer.nodeAttributesFromName.current[""]!.wxyz!; const position = viewer.nodeAttributesFromName.current[""]!.position ?? [ 0, 0, 0, ]; return new THREE.Matrix4() .makeRotationFromQuaternion( new THREE.Quaternion(wxyz[1], wxyz[2], wxyz[3], wxyz[0]), ) .setPosition(position[0], position[1], position[2]); } /** Helper for converting a ray from the three.js world frame to the Python * world frame. Applies the transformation from computeT_threeworld_world. */ export function rayToViserCoords( viewer: ViewerContextContents, ray: THREE.Ray, ): THREE.Ray { const T_world_threeworld = computeT_threeworld_world(viewer).invert(); const origin = ray.origin.clone().applyMatrix4(T_world_threeworld); // Compute just the rotation term without new memory allocation; this // will mutate T_world_threeworld! const R_world_threeworld = T_world_threeworld.setPosition(0.0, 0.0, 0); const direction = ray.direction.clone().applyMatrix4(R_world_threeworld); return new THREE.Ray(origin, direction); } ================================================ FILE: viser/src/viser/client/src/components/Button.tsx ================================================ import { GuiAddButtonMessage } from "../WebsocketMessages"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { Box } from "@mantine/core"; import { Button } from "@mantine/core"; import React from "react"; import { htmlIconWrapper } from "./ComponentStyles.css"; export default function ButtonComponent({ id, visible, disabled, label, ...otherProps }: GuiAddButtonMessage) { const { messageSender } = React.useContext(GuiComponentContext)!; const { color, icon_html } = otherProps; if (!(visible ?? true)) return <>; return ( ); } ================================================ FILE: viser/src/viser/client/src/components/ButtonGroup.tsx ================================================ import * as React from "react"; import { Button, Flex } from "@mantine/core"; import { ViserInputComponent } from "./common"; import { GuiAddButtonGroupMessage } from "../WebsocketMessages"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; export default function ButtonGroupComponent({ id, hint, label, visible, disabled, options, }: GuiAddButtonGroupMessage) { const { messageSender } = React.useContext(GuiComponentContext)!; if (!visible) return <>; return ( {options.map((option, index) => ( ))} ); } ================================================ FILE: viser/src/viser/client/src/components/Checkbox.tsx ================================================ import * as React from "react"; import { ViserInputComponent } from "./common"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { GuiAddCheckboxMessage } from "../WebsocketMessages"; import { Box, Checkbox, Tooltip } from "@mantine/core"; export default function CheckboxComponent({ id, disabled, visible, hint, label, value, }: GuiAddCheckboxMessage) { const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>; let input = ( { setValue(id, value.target.checked); }} disabled={disabled} /> ); if (hint !== null && hint !== undefined) { // For checkboxes, we want to make sure that the wrapper // doesn't expand to the full width of the parent. This will // de-center the tooltip. input = ( {input} ); } return {input}; } ================================================ FILE: viser/src/viser/client/src/components/ComponentStyles.css.ts ================================================ import { globalStyle, style } from "@vanilla-extract/css"; export const htmlIconWrapper = style({ height: "1em", width: "1em", position: "relative", }); globalStyle(`${htmlIconWrapper} svg`, { height: "auto", width: "1em", position: "absolute", top: "50%", transform: "translateY(-50%)", }); // Class for sliders with default min/max marks. We use this for aestheticn // its; global styles are used to shift the min/max mark labels to stay closer // within the bounds of the slider. export const sliderDefaultMarks = style({}); globalStyle( `${sliderDefaultMarks} .mantine-Slider-markWrapper:first-of-type div:nth-of-type(2)`, { transform: "translate(-0.1rem, 0.03rem) !important", }, ); globalStyle( `${sliderDefaultMarks} .mantine-Slider-markWrapper:last-of-type div:nth-of-type(2)`, { transform: "translate(-85%, 0.03rem) !important", }, ); ================================================ FILE: viser/src/viser/client/src/components/Dropdown.tsx ================================================ import * as React from "react"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { ViserInputComponent } from "./common"; import { GuiAddDropdownMessage } from "../WebsocketMessages"; import { Select } from "@mantine/core"; export default function DropdownComponent({ id, hint, label, value, disabled, visible, options, }: GuiAddDropdownMessage) { const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>; return ( ))} ); }); MultiSlider.classes = classes; MultiSlider.displayName = "MultiSlider"; ================================================ FILE: viser/src/viser/client/src/components/MultiSliderPrimitive/Slider.context.ts ================================================ import { createSafeContext, GetStylesApi } from "@mantine/core"; export type SliderStylesNames = | "root" | "label" | "thumb" | "trackContainer" | "track" | "bar" | "markWrapper" | "mark" | "markLabel"; export type SliderCssVariables = { root: | "--slider-size" | "--slider-color" | "--slider-thumb-size" | "--slider-radius"; }; interface SliderContextValue { getStyles: GetStylesApi<{ stylesNames: SliderStylesNames; props: any; ref: any; vars: any; variant: any; }>; } export const [SliderProvider, useSliderContext] = createSafeContext("SliderProvider was not found in tree"); ================================================ FILE: viser/src/viser/client/src/components/MultiSliderPrimitive/Slider.module.css ================================================ .root { --slider-size-xs: rem(4px); --slider-size-sm: rem(6px); --slider-size-md: rem(8px); --slider-size-lg: rem(10px); --slider-size-xl: rem(12px); --slider-size: var(--slider-size-md); --slider-radius: rem(1000px); --slider-color: var(--mantine-primary-color-filled); -webkit-tap-highlight-color: transparent; outline: none; height: calc(var(--slider-size) * 2); padding-inline: var(--slider-size); display: flex; flex-direction: column; align-items: center; touch-action: none; position: relative; @mixin light { --slider-track-bg: var(--mantine-color-gray-2); --slider-track-disabled-bg: var(--mantine-color-gray-4); } @mixin dark { --slider-track-bg: var(--mantine-color-dark-4); --slider-track-disabled-bg: var(--mantine-color-dark-3); } } .label { position: absolute; top: rem(-36px); font-size: var(--mantine-font-size-xs); color: var(--mantine-color-white); padding: calc(var(--mantine-spacing-xs) / 2); border-radius: var(--mantine-radius-sm); white-space: nowrap; pointer-events: none; user-select: none; touch-action: none; @mixin where-light { background-color: var(--mantine-color-gray-9); } @mixin where-dark { background-color: var(--mantine-color-dark-4); } } .thumb { position: absolute; display: flex; height: var(--slider-thumb-size); width: var(--slider-thumb-size); border: rem(4px) solid; transform: translate(-50%, -50%); color: var(--slider-color); top: 50%; cursor: pointer; border-radius: var(--slider-radius); align-items: center; justify-content: center; transition: box-shadow 100ms ease, transform 100ms ease; z-index: 3; user-select: none; touch-action: none; outline-offset: rem(2px); left: var(--slider-thumb-offset); @mixin where-rtl { left: auto; right: calc(var(--slider-thumb-offset) - var(--slider-thumb-size)); } fieldset:disabled &, &:where([data-disabled]) { display: none; } &:where([data-dragging]) { transform: translate(-50%, -50%) scale(1.05); box-shadow: var(--mantine-shadow-sm); } @mixin where-light { border-color: var(--slider-color); background-color: var(--mantine-color-white); } @mixin where-dark { border-color: var(--mantine-color-white); background-color: var(--slider-color); } } .trackContainer { display: flex; align-items: center; width: 100%; height: calc(var(--slider-size) * 2); cursor: pointer; fieldset:disabled &, &:where([data-disabled]) { cursor: not-allowed; } } .track { position: relative; width: 100%; height: var(--slider-size); &:where([data-inverted]:not([data-disabled])) { --track-bg: var(--slider-color); } fieldset:disabled &:where([data-inverted]), &:where([data-inverted][data-disabled]) { --track-bg: var(--slider-track-disabled-bg); } &::before { content: ""; position: absolute; top: 0; bottom: 0; border-radius: var(--slider-radius); inset-inline: calc(var(--slider-size) * -1); background-color: var(--track-bg, var(--slider-track-bg)); z-index: 0; } } .bar { position: absolute; z-index: 1; top: 0; bottom: 0; background-color: var(--slider-color); border-radius: var(--slider-radius); width: var(--slider-bar-width); inset-inline-start: var(--slider-bar-offset); &:where([data-inverted]) { background-color: var(--slider-track-bg); } fieldset:disabled &:where(:not([data-inverted])), &:where([data-disabled]:not([data-inverted])) { @mixin where-light { background-color: var(--mantine-color-gray-4); } @mixin where-dark { background-color: var(--mantine-color-dark-3); } } } .markWrapper { position: absolute; inset-inline-start: calc(var(--mark-offset) - var(--slider-size) / 2); top: 0; z-index: 2; height: 0; pointer-events: none; } .mark { border: rem(2px) solid; height: var(--slider-size); width: var(--slider-size); border-radius: rem(1000px); transform: translateX((calc(var(--slider-size) / -2))); background-color: var(--mantine-color-white); pointer-events: none; @mixin where-light { border-color: var(--mantine-color-gray-2); } @mixin where-dark { border-color: var(--mantine-color-dark-4); } &:where([data-filled]) { border-color: var(--slider-color); &:where([data-disabled]) { @mixin where-light { border-color: var(--mantine-color-gray-4); } @mixin where-dark { border-color: var(--mantine-color-dark-3); } } } } .markLabel { transform: translate( calc(-50% + var(--slider-size) / 2), calc(var(--mantine-spacing-xs) / 2) ); font-size: var(--mantine-font-size-sm); white-space: nowrap; cursor: pointer; user-select: none; @mixin where-light { color: var(--mantine-color-gray-6); } @mixin where-dark { color: var(--mantine-color-dark-2); } } ================================================ FILE: viser/src/viser/client/src/components/MultiSliderPrimitive/SliderRoot/SliderRoot.tsx ================================================ import React, { forwardRef } from "react"; import { Box, BoxProps, ElementProps, MantineColor, MantineRadius, MantineSize, } from "@mantine/core"; import { useSliderContext } from "../Slider.context"; export interface SliderRootProps extends BoxProps, ElementProps<"div"> { size: MantineSize | (string & NonNullable) | number; children: React.ReactNode; color: MantineColor | undefined; disabled: boolean | undefined; variant?: string; thumbSize: string | number | undefined; radius: MantineRadius | undefined; } export const SliderRoot = forwardRef( ({ size, variant, ...others }: SliderRootProps, ref) => { const { getStyles } = useSliderContext(); return ( ); }, ); SliderRoot.displayName = "@mantine/core/SliderRoot"; ================================================ FILE: viser/src/viser/client/src/components/MultiSliderPrimitive/Thumb/Thumb.tsx ================================================ import React, { forwardRef, useState } from "react"; import { Box } from "@mantine/core"; import { Transition, TransitionOverride } from "@mantine/core"; import { useSliderContext } from "../Slider.context"; export interface ThumbProps { max: number; min: number; value: number; position: number; dragging: boolean; draggingThisThumb: boolean; label: React.ReactNode; onKeyDownCapture?: (event: React.KeyboardEvent) => void; onMouseDown?: ( event: React.MouseEvent | React.TouchEvent, ) => void; labelTransitionProps: TransitionOverride | undefined; labelAlwaysOn: boolean | undefined; thumbLabel: string | undefined; onFocus?: () => void; onBlur?: () => void; showLabelOnHover: boolean | undefined; isHovered?: boolean; children?: React.ReactNode; disabled: boolean | undefined; className?: string; style?: React.CSSProperties; } export const Thumb = forwardRef( ( { max, min, value, position, label, dragging, draggingThisThumb, onMouseDown, onKeyDownCapture, labelTransitionProps, labelAlwaysOn, thumbLabel, onFocus, onBlur, showLabelOnHover, isHovered, children = null, disabled, }: ThumbProps, ref, ) => { const { getStyles } = useSliderContext(); const [focused, setFocused] = useState(false); const isVisible = labelAlwaysOn || dragging || focused || (showLabelOnHover && isHovered); return ( tabIndex={0} role="slider" aria-label={thumbLabel} aria-valuemax={max} aria-valuemin={min} aria-valuenow={value} ref={ref} __vars={{ "--slider-thumb-offset": `${position}%` }} {...getStyles("thumb", { focusable: true, style: { /* Put active thumb + its label in front of others. */ ...(draggingThisThumb ? { zIndex: 1000 } : {}), }, })} mod={{ dragging, disabled }} onFocus={() => { setFocused(true); typeof onFocus === "function" && onFocus(); }} onBlur={() => { setFocused(false); typeof onBlur === "function" && onBlur(); }} onTouchStart={onMouseDown} onMouseDown={onMouseDown} onKeyDownCapture={onKeyDownCapture} onClick={(event) => event.stopPropagation()} > {children} {(transitionStyles) => (
{label}
)}
); }, ); Thumb.displayName = "@mantine/core/SliderThumb"; ================================================ FILE: viser/src/viser/client/src/components/MultiSliderPrimitive/Track/Track.tsx ================================================ import React from "react"; import { Box } from "@mantine/core"; import { Marks } from "../Marks/Marks"; import { useSliderContext } from "../Slider.context"; export interface TrackProps { filled: number; offset?: number; marksOffset?: number; marks: { value: number; label?: React.ReactNode }[] | undefined; min: number; max: number; value: number; children: React.ReactNode; disabled: boolean | undefined; inverted: boolean | undefined; containerProps?: React.PropsWithRef>; } export function Track({ children, disabled, marksOffset, inverted, containerProps, ...others }: TrackProps) { const { getStyles } = useSliderContext(); return ( <> {children} ); } Track.displayName = "@mantine/core/SliderTrack"; ================================================ FILE: viser/src/viser/client/src/components/MultiSliderPrimitive/index.ts ================================================ export { MultiSlider } from "./MultiSlider/MultiSlider"; ================================================ FILE: viser/src/viser/client/src/components/MultiSliderPrimitive/utils/get-change-value/get-change-value.ts ================================================ interface GetChangeValue { value: number; containerWidth?: number; min: number; max: number; step: number; precision?: number; } export function getChangeValue({ value, containerWidth, min, max, step, precision, }: GetChangeValue) { const left = !containerWidth ? value : Math.min(Math.max(value, 0), containerWidth) / containerWidth; const dx = left * (max - min); const nextValue = (dx !== 0 ? Math.round(dx / step) * step : 0) + min; const nextValueWithinStep = Math.max(nextValue, min); if (precision !== undefined) { return Number(nextValueWithinStep.toFixed(precision)); } return nextValueWithinStep; } ================================================ FILE: viser/src/viser/client/src/components/MultiSliderPrimitive/utils/get-client-position/get-client-position.ts ================================================ export function getClientPosition(event: any) { if ("TouchEvent" in window && event instanceof window.TouchEvent) { const touch = event.touches[0]; return touch.clientX; } return event.clientX; } ================================================ FILE: viser/src/viser/client/src/components/MultiSliderPrimitive/utils/get-floating-value/get-gloating-value.ts ================================================ export function getFloatingValue(value: number, precision: number) { return parseFloat(value.toFixed(precision)); } ================================================ FILE: viser/src/viser/client/src/components/MultiSliderPrimitive/utils/get-position/get-position.ts ================================================ interface GetPosition { value: number; min: number; max: number; } export function getPosition({ value, min, max }: GetPosition) { const position = ((value - min) / (max - min)) * 100; return Math.min(Math.max(position, 0), 100); } ================================================ FILE: viser/src/viser/client/src/components/MultiSliderPrimitive/utils/get-precision/get-precision.ts ================================================ export function getPrecision(step: number) { if (!step) return 0; const split = step.toString().split("."); return split.length > 1 ? split[1].length : 0; } ================================================ FILE: viser/src/viser/client/src/components/NumberInput.tsx ================================================ import * as React from "react"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { GuiAddNumberMessage } from "../WebsocketMessages"; import { ViserInputComponent } from "./common"; import { NumberInput } from "@mantine/core"; export default function NumberInputComponent({ visible, id, label, hint, value, disabled, ...otherProps }: GuiAddNumberMessage) { const { setValue } = React.useContext(GuiComponentContext)!; const { precision, min, max, step } = otherProps; if (!visible) return <>; return ( { // Ignore empty values. newValue !== "" && setValue(id, newValue); }} styles={{ input: { minHeight: "1.625rem", height: "1.625rem", }, controls: { height: "1.625em", width: "0.825em", }, }} disabled={disabled} stepHoldDelay={500} stepHoldInterval={(t) => Math.max(1000 / t ** 2, 25)} /> ); } ================================================ FILE: viser/src/viser/client/src/components/PlotlyComponent.tsx ================================================ import React from "react"; import { GuiAddPlotlyMessage } from "../WebsocketMessages"; import { useDisclosure } from "@mantine/hooks"; import { Modal, Box, Paper, Tooltip } from "@mantine/core"; import { useElementSize } from "@mantine/hooks"; // When drawing border around the plot, it should be aligned with the folder's. import { folderWrapper } from "./Folder.css"; const PlotWithAspect = React.memo(function PlotWithAspect({ jsonStr, aspectRatio, staticPlot, }: { jsonStr: string; aspectRatio: number; staticPlot: boolean; }) { // Catch if the jsonStr is empty; if so, render an empty div. if (jsonStr === "") return
; // Parse json string, to construct plotly object. // Note that only the JSON string is kept as state, not the json object. const plotJson = JSON.parse(jsonStr); // This keeps the zoom-in state, etc, see https://plotly.com/javascript/uirevision/. plotJson.layout.uirevision = "true"; // Box size change -> width value change -> plot rerender trigger. const { ref, width } = useElementSize(); plotJson.layout.width = width; plotJson.layout.height = width * aspectRatio; // Make the plot non-interactable, if specified. // Ideally, we would use `staticplot`, but this has a known bug with 3D plots: // - https://github.com/plotly/plotly.js/issues/457 // In the meantime, we choose to disable all interactions. if (staticPlot) { if (plotJson.config === undefined) plotJson.config = {}; plotJson.config.displayModeBar = false; plotJson.layout.dragmode = false; plotJson.layout.hovermode = false; plotJson.layout.clickmode = "none"; } // Use React hooks to update the plotly object, when the plot data changes. // based on https://github.com/plotly/react-plotly.js/issues/242. const plotRef = React.useRef(null); React.useEffect(() => { // @ts-ignore - Plotly.js is dynamically imported with an eval() call. Plotly.react( plotRef.current!, plotJson.data, plotJson.layout, plotJson.config, ); }, [plotJson]); return (
{/* Add a div on top of the plot, to prevent interaction + cursor changes. */} {staticPlot ? (
) : null} ); }); export default function PlotlyComponent({ visible, plotly_json_str, aspect, }: GuiAddPlotlyMessage) { if (!visible) return <>; // Create a modal with the plot, and a button to open it. const [opened, { open, close }] = useDisclosure(false); return ( {/* Draw static plot in the controlpanel, which can be clicked. */} {/* Modal contents. keepMounted makes state changes (eg zoom) to the plot persistent. */} ); } ================================================ FILE: viser/src/viser/client/src/components/ProgressBar.tsx ================================================ import { Box, Progress } from "@mantine/core"; import { GuiAddProgressBarMessage } from "../WebsocketMessages"; export default function ProgressBarComponent({ visible, color, value, animated, }: GuiAddProgressBarMessage) { if (!visible) return <>; return ( ); } ================================================ FILE: viser/src/viser/client/src/components/Rgb.tsx ================================================ import * as React from "react"; import { ColorInput } from "@mantine/core"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { rgbToHex, hexToRgb } from "./utils"; import { ViserInputComponent } from "./common"; import { GuiAddRgbMessage } from "../WebsocketMessages"; export default function RgbComponent({ id, label, hint, value, disabled, visible, }: GuiAddRgbMessage) { const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>; return ( setValue(id, hexToRgb(v))} format="hex" // zIndex of dropdown should be >modal zIndex. // On edge cases: it seems like existing dropdowns are always closed when a new modal is opened. popoverProps={{ zIndex: 1000 }} styles={{ input: { height: "1.625rem", minHeight: "1.625rem" }, // icon: { transform: "scale(0.8)" }, }} /> ); } ================================================ FILE: viser/src/viser/client/src/components/Rgba.tsx ================================================ import * as React from "react"; import { ColorInput } from "@mantine/core"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { rgbaToHex, hexToRgba } from "./utils"; import { ViserInputComponent } from "./common"; import { GuiAddRgbaMessage } from "../WebsocketMessages"; export default function RgbaComponent({ id, label, hint, value, disabled, visible, }: GuiAddRgbaMessage) { const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>; return ( setValue(id, hexToRgba(v))} format="hexa" // zIndex of dropdown should be >modal zIndex. // On edge cases: it seems like existing dropdowns are always closed when a new modal is opened. popoverProps={{ zIndex: 1000 }} styles={{ input: { height: "1.625rem", minHeight: "1.625rem" }, }} /> ); } ================================================ FILE: viser/src/viser/client/src/components/Slider.tsx ================================================ import React from "react"; import { GuiAddSliderMessage } from "../WebsocketMessages"; import { Slider, Flex, NumberInput, useMantineColorScheme, } from "@mantine/core"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { ViserInputComponent } from "./common"; import { sliderDefaultMarks } from "./ComponentStyles.css"; export default function SliderComponent({ id, label, hint, visible, disabled, value, ...otherProps }: GuiAddSliderMessage) { const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>; const updateValue = (value: number) => setValue(id, value); const { min, max, precision, step, marks } = otherProps; const colorScheme = useMantineColorScheme().colorScheme; const input = ( ({ thumb: { height: "0.75rem", width: "0.5rem", }, trackContainer: { zIndex: 3, position: "relative", }, markLabel: { transform: "translate(-50%, 0.03rem)", fontSize: "0.6rem", textAlign: "center", }, mark: { transform: "scale(1.95)", }, markFilled: { background: disabled ? colorScheme === "dark" ? theme.colors.dark[3] : theme.colors.gray[4] : theme.primaryColor, }, })} pt="0.3em" pb="0.2em" showLabelOnHover={false} min={min} max={max} step={step ?? undefined} precision={precision} value={value} onChange={updateValue} marks={ marks === null ? [ { value: min, label: `${parseInt(min.toFixed(6))}`, }, { value: max, label: `${parseInt(max.toFixed(6))}`, }, ] : marks } disabled={disabled} /> { // Ignore empty values. newValue !== "" && updateValue(Number(newValue)); }} size="xs" min={min} max={max} hideControls step={step ?? undefined} // precision={precision} style={{ width: "3rem" }} styles={{ input: { padding: "0.375em", letterSpacing: "-0.5px", minHeight: "1.875em", height: "1.875em", }, }} ml="xs" /> ); return ( {input} ); } ================================================ FILE: viser/src/viser/client/src/components/TabGroup.tsx ================================================ import * as React from "react"; import { GuiAddTabGroupMessage } from "../WebsocketMessages"; import { Tabs } from "@mantine/core"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { htmlIconWrapper } from "./ComponentStyles.css"; export default function TabGroupComponent({ tab_labels, tab_icons_html, tab_container_ids, visible, }: GuiAddTabGroupMessage) { const { GuiContainer } = React.useContext(GuiComponentContext)!; if (!visible) return <>; return ( {tab_labels.map((label, index) => ( ) } > {label} ))} {tab_container_ids.map((containerId, index) => ( ))} ); } ================================================ FILE: viser/src/viser/client/src/components/TextInput.tsx ================================================ import * as React from "react"; import { TextInput } from "@mantine/core"; import { ViserInputComponent } from "./common"; import { GuiAddTextMessage } from "../WebsocketMessages"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; export default function TextInputComponent(props: GuiAddTextMessage) { const { id, hint, label, value, disabled, visible } = props; const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>; return ( { setValue(id, value.target.value); }} styles={{ input: { minHeight: "1.625rem", height: "1.625rem", padding: "0 0.5em", }, }} disabled={disabled} /> ); } ================================================ FILE: viser/src/viser/client/src/components/UploadButton.tsx ================================================ import { GuiAddUploadButtonMessage } from "../WebsocketMessages"; import { v4 as uuid } from "uuid"; import { Box, Progress } from "@mantine/core"; import { Button } from "@mantine/core"; import React, { useContext } from "react"; import { ViewerContext, ViewerContextContents } from "../App"; import { IconCheck } from "@tabler/icons-react"; import { notifications } from "@mantine/notifications"; import { htmlIconWrapper } from "./ComponentStyles.css"; export default function UploadButtonComponent(conf: GuiAddUploadButtonMessage) { // Handle GUI input types. const viewer = useContext(ViewerContext)!; const fileUploadRef = React.useRef(null); const { isUploading, upload } = useFileUpload({ viewer, componentId: conf.id, }); const disabled = conf.disabled || isUploading; return ( { const input = e.target as HTMLInputElement; if (!input.files) return; upload(input.files[0]); }} /> ); } function useFileUpload({ viewer, componentId, }: { componentId: string; viewer: ViewerContextContents; }) { const updateUploadState = viewer.useGui((state) => state.updateUploadState); const uploadState = viewer.useGui( (state) => state.uploadsInProgress[componentId], ); const totalBytes = uploadState?.totalBytes; // Cache total bytes string const totalBytesString = React.useMemo(() => { if (totalBytes === undefined) return ""; let displaySize = totalBytes; const displayUnits = ["B", "K", "M", "G", "T", "P"]; let displayUnitIndex = 0; while (displaySize >= 100 && displayUnitIndex < displayUnits.length - 1) { displaySize /= 1024; displayUnitIndex += 1; } return `${displaySize.toFixed(1)}${displayUnits[displayUnitIndex]}`; }, [totalBytes]); // Update notification status React.useEffect(() => { if (uploadState === undefined) return; const { notificationId, filename } = uploadState; if (uploadState.uploadedBytes === 0) { // Show notification. notifications.show({ id: notificationId, title: "Uploading " + `${filename} (${totalBytesString})`, message: , autoClose: false, withCloseButton: false, loading: true, }); } else { // Update progress. const progressValue = uploadState.uploadedBytes / uploadState.totalBytes; const isDone = progressValue === 1.0; notifications.update({ id: notificationId, title: "Uploading " + `${filename} (${totalBytesString})`, message: !isDone ? ( ) : ( "File uploaded successfully." ), autoClose: isDone, withCloseButton: isDone, loading: !isDone, icon: isDone ? : undefined, }); } }, [uploadState, totalBytesString]); const isUploading = uploadState !== undefined && uploadState.uploadedBytes < uploadState.totalBytes; async function upload(file: File) { const chunkSize = 512 * 1024; // bytes const numChunks = Math.ceil(file.size / chunkSize); const transferUuid = uuid(); const notificationId = "upload-" + transferUuid; // Begin upload by setting initial state updateUploadState({ componentId: componentId, uploadedBytes: 0, totalBytes: file.size, filename: file.name, notificationId, }); viewer.sendMessageRef.current({ type: "FileTransferStart", source_component_id: componentId, transfer_uuid: transferUuid, filename: file.name, mime_type: file.type, size_bytes: file.size, part_count: numChunks, }); for (let i = 0; i < numChunks; i++) { const start = i * chunkSize; const end = (i + 1) * chunkSize; const chunk = file.slice(start, end); const buffer = await chunk.arrayBuffer(); viewer.sendMessageRef.current({ type: "FileTransferPart", source_component_id: componentId, transfer_uuid: transferUuid, part: i, content: new Uint8Array(buffer), }); } } return { isUploading, upload, }; } ================================================ FILE: viser/src/viser/client/src/components/Vector2.tsx ================================================ import * as React from "react"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { GuiAddVector2Message } from "../WebsocketMessages"; import { VectorInput, ViserInputComponent } from "./common"; export default function Vector2Component({ id, hint, label, visible, disabled, value, ...otherProps }: GuiAddVector2Message) { const { min, max, step, precision } = otherProps; const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>; return ( setValue(id, value)} min={min} max={max} step={step} precision={precision} disabled={disabled} /> ); } ================================================ FILE: viser/src/viser/client/src/components/Vector3.tsx ================================================ import * as React from "react"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { GuiAddVector3Message } from "../WebsocketMessages"; import { VectorInput, ViserInputComponent } from "./common"; export default function Vector3Component({ id, hint, label, visible, disabled, value, ...otherProps }: GuiAddVector3Message) { const { min, max, step, precision } = otherProps; const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>; return ( setValue(id, value)} min={min} max={max} step={step} precision={precision} disabled={disabled} /> ); } ================================================ FILE: viser/src/viser/client/src/components/common.tsx ================================================ import * as React from "react"; import { Box, Flex, Text, NumberInput, Tooltip } from "@mantine/core"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; export function ViserInputComponent({ id, label, hint, children, }: { id: string; children: React.ReactNode; label?: string; hint?: string | null; }) { const { folderDepth } = React.useContext(GuiComponentContext)!; if (hint !== undefined && hint !== null) { children = // We need to add for inputs that we can't assign refs to. ( {children} ); } if (label !== undefined) children = ( ); return ( {children} ); } /** GUI input with a label horizontally placed to the left of it. */ function LabeledInput(props: { id: string; label: string; input: React.ReactNode; folderDepth: number; }) { return ( {props.input} ); } export function VectorInput( props: | { id: string; n: 2; value: [number, number]; min: [number, number] | null; max: [number, number] | null; step: number; precision: number; onChange: (value: number[]) => void; disabled: boolean; } | { id: string; n: 3; value: [number, number, number]; min: [number, number, number] | null; max: [number, number, number] | null; step: number; precision: number; onChange: (value: number[]) => void; disabled: boolean; }, ) { return ( {[...Array(props.n).keys()].map((i) => ( { const updated = [...props.value]; updated[i] = v === "" ? 0.0 : Number(v); props.onChange(updated); }} size="xs" styles={{ root: { flexGrow: 1, width: 0 }, input: { paddingLeft: "0.5em", paddingRight: "1.75em", textAlign: "right", height: "1.875em", minHeight: "1.875em", }, controls: { height: "1.25em", width: "0.825em", }, }} rightSectionWidth="1em" decimalScale={props.precision} step={props.step} min={props.min === null ? undefined : props.min[i]} max={props.max === null ? undefined : props.max[i]} stepHoldDelay={500} stepHoldInterval={(t) => Math.max(1000 / t ** 2, 25)} disabled={props.disabled} /> ))} ); } ================================================ FILE: viser/src/viser/client/src/components/utils.tsx ================================================ // Color conversion helpers. export function rgbToHex([r, g, b]: [number, number, number]): string { const hexR = r.toString(16).padStart(2, "0"); const hexG = g.toString(16).padStart(2, "0"); const hexB = b.toString(16).padStart(2, "0"); return `#${hexR}${hexG}${hexB}`; } export function hexToRgb(hexColor: string): [number, number, number] { const hex = hexColor.slice(1); // Remove the # in #ffffff. const r = parseInt(hex.substring(0, 2), 16); const g = parseInt(hex.substring(2, 4), 16); const b = parseInt(hex.substring(4, 6), 16); return [r, g, b]; } export function rgbaToHex([r, g, b, a]: [ number, number, number, number, ]): string { const hexR = r.toString(16).padStart(2, "0"); const hexG = g.toString(16).padStart(2, "0"); const hexB = b.toString(16).padStart(2, "0"); const hexA = a.toString(16).padStart(2, "0"); return `#${hexR}${hexG}${hexB}${hexA}`; } export function hexToRgba(hexColor: string): [number, number, number, number] { const hex = hexColor.slice(1); // Remove the # in #ffffff. const r = parseInt(hex.substring(0, 2), 16); const g = parseInt(hex.substring(2, 4), 16); const b = parseInt(hex.substring(4, 6), 16); const a = parseInt(hex.substring(6, 8), 16); return [r, g, b, a]; } ================================================ FILE: viser/src/viser/client/src/index.css ================================================ @font-face { font-family: "Inter"; src: url("/Inter-VariableFont_slnt,wght.ttf") format("truetype"); font-weight: 1 100 200 300 400 500 600 700 800 900 1000; font-style: normal italic; } body, html { width: 100%; height: 100%; margin: 0; padding: 0; overflow: hidden; } body { font-family: "Inter", -apple-system, BlinkMacSystemFont, "Segoe UI", "Roboto", "Oxygen", "Ubuntu", "Cantarell", "Fira Sans", "Droid Sans", "Helvetica Neue", sans-serif; -webkit-font-smoothing: antialiased; -moz-osx-font-smoothing: grayscale; } code { font-family: source-code-pro, Menlo, Monaco, Consolas, "Courier New", monospace; } #root { width: 100%; height: 100%; overflow: hidden; } /* Styling for threejs stats panel. Switches position: fixed to position: * absolute, to respect parent + for better multi-pane layout. */ .stats-panel { position: absolute !important; } /* Styling for color chips in markdown */ .gfm-color-chip { margin-left: 0.125rem; display: inline-block; height: 0.625rem; width: 0.625rem; border-radius: 9999px; border: 1px solid gray; transform: TranslateY(0.07em); } ================================================ FILE: viser/src/viser/client/src/index.tsx ================================================ import ReactDOM from "react-dom/client"; import { Root } from "./App"; import { enableMapSet } from "immer"; enableMapSet(); ReactDOM.createRoot(document.getElementById("root") as HTMLElement).render( , ); ================================================ FILE: viser/src/viser/client/src/react-app-env.d.ts ================================================ /// ================================================ FILE: viser/src/viser/client/tsconfig.json ================================================ { "compilerOptions": { "target": "ESNext", "lib": ["dom", "dom.iterable", "esnext"], "types": [ "vite/client", "vite-plugin-svgr/client", "node", "@types/wicg-file-system-access" ], "allowJs": true, "skipLibCheck": true, "esModuleInterop": true, "allowSyntheticDefaultImports": true, "strict": true, "forceConsistentCasingInFileNames": true, "noFallthroughCasesInSwitch": true, "module": "esnext", "moduleResolution": "node", "resolveJsonModule": true, "isolatedModules": true, "noEmit": true, "jsx": "react-jsx" }, "include": ["src"] } ================================================ FILE: viser/src/viser/client/vite-env.d.ts ================================================ /// ================================================ FILE: viser/src/viser/client/vite.config.mts ================================================ import { defineConfig } from "vite"; import react from "@vitejs/plugin-react"; import { vanillaExtractPlugin } from "@vanilla-extract/vite-plugin"; import viteTsconfigPaths from "vite-tsconfig-paths"; import svgrPlugin from "vite-plugin-svgr"; import eslint from "vite-plugin-eslint"; import browserslistToEsbuild from "browserslist-to-esbuild"; // https://vitejs.dev/config/ export default defineConfig({ plugins: [ react(), eslint({ failOnError: false, failOnWarning: false }), viteTsconfigPaths(), svgrPlugin(), vanillaExtractPlugin(), ], server: { port: 3000, hmr: { port: 1025 }, }, worker: { format: "es", }, build: { outDir: "build", target: browserslistToEsbuild(), }, }); ================================================ FILE: viser/src/viser/extras/__init__.py ================================================ """Extra utilities. Used for example scripts.""" from ._record3d import Record3dFrame as Record3dFrame from ._record3d import Record3dLoader as Record3dLoader from ._record3d_customized import Record3dLoader_Customized as Record3dLoader_Customized from ._record3d_customized_megasam import Record3dLoader_Customized_Megasam as Record3dLoader_Customized_Megasam from ._urdf import ViserUrdf as ViserUrdf ================================================ FILE: viser/src/viser/extras/_record3d.py ================================================ from __future__ import annotations import dataclasses import json from pathlib import Path from typing import Tuple, cast import imageio.v3 as iio import liblzfse import numpy as np import numpy as onp import numpy.typing as onpt import skimage.transform from scipy.spatial.transform import Rotation class Record3dLoader: """Helper for loading frames for Record3D captures.""" # NOTE(hangg): Consider moving this module into # `examples/7_record3d_visualizer.py` since it is usecase-specific. def __init__(self, data_dir: Path): metadata_path = data_dir / "metadata" # Read metadata. metadata = json.loads(metadata_path.read_text()) K: onp.ndarray = np.array(metadata["K"], np.float32).reshape(3, 3).T fps = metadata["fps"] T_world_cameras: onp.ndarray = np.array(metadata["poses"], np.float32) T_world_cameras = np.concatenate( [ Rotation.from_quat(T_world_cameras[:, :4]).as_matrix(), T_world_cameras[:, 4:, None], ], -1, ) T_world_cameras = (T_world_cameras @ np.diag([1, -1, -1, 1])).astype(np.float32) self.K = K self.fps = fps self.T_world_cameras = T_world_cameras rgbd_dir = data_dir / "rgbd" self.rgb_paths = sorted(rgbd_dir.glob("*.jpg"), key=lambda p: int(p.stem)) self.depth_paths = [ rgb_path.with_suffix(".depth") for rgb_path in self.rgb_paths ] self.conf_paths = [rgb_path.with_suffix(".conf") for rgb_path in self.rgb_paths] def num_frames(self) -> int: return len(self.rgb_paths) def get_frame(self, index: int) -> Record3dFrame: # Read conf. conf: onp.ndarray = np.frombuffer( liblzfse.decompress(self.conf_paths[index].read_bytes()), dtype=np.uint8 ) if conf.shape[0] == 640 * 480: conf = conf.reshape((640, 480)) # For a FaceID camera 3D Video elif conf.shape[0] == 256 * 192: conf = conf.reshape((256, 192)) # For a LiDAR 3D Video else: assert False, f"Unexpected conf shape {conf.shape}" # Read depth. depth: onp.ndarray = np.frombuffer( liblzfse.decompress(self.depth_paths[index].read_bytes()), dtype=np.float32 ).copy() if depth.shape[0] == 640 * 480: depth = depth.reshape((640, 480)) # For a FaceID camera 3D Video elif depth.shape[0] == 256 * 192: depth = depth.reshape((256, 192)) # For a LiDAR 3D Video else: assert False, f"Unexpected depth shape {depth.shape}" # Read RGB. rgb = iio.imread(self.rgb_paths[index]) return Record3dFrame( K=self.K, rgb=rgb, depth=depth, mask=conf == 2, T_world_camera=self.T_world_cameras[index], ) @dataclasses.dataclass class Record3dFrame: """A single frame from a Record3D capture.""" K: onpt.NDArray[onp.float32] rgb: onpt.NDArray[onp.uint8] depth: onpt.NDArray[onp.float32] mask: onpt.NDArray[onp.bool_] T_world_camera: onpt.NDArray[onp.float32] def get_point_cloud( self, downsample_factor: int = 1 ) -> Tuple[onpt.NDArray[onp.float32], onpt.NDArray[onp.uint8]]: rgb = self.rgb[::downsample_factor, ::downsample_factor] depth = skimage.transform.resize(self.depth, rgb.shape[:2], order=0) mask = cast( onpt.NDArray[onp.bool_], skimage.transform.resize(self.mask, rgb.shape[:2], order=0), ) assert depth.shape == rgb.shape[:2] K = self.K T_world_camera = self.T_world_camera img_wh = rgb.shape[:2][::-1] grid = ( np.stack(np.meshgrid(np.arange(img_wh[0]), np.arange(img_wh[1])), 2) + 0.5 ) grid = grid * downsample_factor homo_grid = np.pad(grid[mask], np.array([[0, 0], [0, 1]]), constant_values=1) local_dirs = np.einsum("ij,bj->bi", np.linalg.inv(K), homo_grid) dirs = np.einsum("ij,bj->bi", T_world_camera[:3, :3], local_dirs) points = (T_world_camera[:, -1] + dirs * depth[mask, None]).astype(np.float32) point_colors = rgb[mask] return points, point_colors ================================================ FILE: viser/src/viser/extras/_record3d_customized.py ================================================ from __future__ import annotations import dataclasses import os import json from pathlib import Path from typing import Tuple, cast import imageio.v3 as iio import liblzfse import numpy as np import numpy as onp import numpy.typing as onpt import skimage.transform from scipy.spatial.transform import Rotation from scipy.spatial import cKDTree class Record3dLoader_Customized: """Helper for loading frames for Record3D captures.""" def __init__(self, data_dir: Path, conf_threshold: float = 1.0, foreground_conf_threshold: float = 0.1, no_mask: bool = False, xyzw=True, init_conf=False): # Read metadata. intrinsics_path = data_dir / "pred_intrinsics.txt" intrinsics = np.loadtxt(intrinsics_path) self.K: onp.ndarray = np.array(intrinsics, np.float32).reshape(-1, 3, 3) fps = 30 self.init_conf = init_conf poses_path = data_dir / "pred_traj.txt" poses = np.loadtxt(poses_path) self.T_world_cameras: onp.ndarray = np.array(poses, np.float32) self.T_world_cameras = np.concatenate( [ # Convert TUM pose to SE3 pose Rotation.from_quat(self.T_world_cameras[:, 4:]).as_matrix() if not xyzw else Rotation.from_quat(np.concatenate([self.T_world_cameras[:, 5:], self.T_world_cameras[:, 4:5]], -1)).as_matrix(), self.T_world_cameras[:, 1:4, None], ], -1, ) self.T_world_cameras = self.T_world_cameras.astype(np.float32) # Convert to homogeneous transformation matrices (ensure shape is (N, 4, 4)) num_frames = self.T_world_cameras.shape[0] ones = np.tile(np.array([0, 0, 0, 1], dtype=np.float32), (num_frames, 1, 1)) self.T_world_cameras = np.concatenate([self.T_world_cameras, ones], axis=1) self.fps = fps self.conf_threshold = conf_threshold self.foreground_conf_threshold = foreground_conf_threshold self.no_mask = no_mask # Read frames. self.rgb_paths = sorted(data_dir.glob("frame_*.png"), key=lambda p: int(p.stem.split("_")[-1])) self.depth_paths = sorted(data_dir.glob("frame_*.npy"), key=lambda p: int(p.stem.split("_")[-1])) if init_conf: self.init_conf_paths = sorted(data_dir.glob("init_conf_*.npy"), key=lambda p: int(p.stem.split("_")[-1])) else: self.init_conf_paths = [] self.conf_paths = sorted(data_dir.glob("conf_*.npy"), key=lambda p: int(p.stem.split("_")[-1])) self.mask_paths = sorted(data_dir.glob("enlarged_dynamic_mask_*.png"), key=lambda p: int(p.stem.split("_")[-1])) # Remove the last frame since it does not have a ground truth dynamic mask self.rgb_paths = self.rgb_paths[:-1] # Align all camera poses by the first frame T0 = self.T_world_cameras[len(self.T_world_cameras) // 2] # First camera pose (4x4 matrix) T0_inv = np.linalg.inv(T0) # Inverse of the first camera pose # Apply T0_inv to all camera poses self.T_world_cameras = np.matmul(T0_inv[np.newaxis, :, :], self.T_world_cameras) def num_frames(self) -> int: return len(self.rgb_paths) def get_frame(self, index: int) -> Record3dFrame: # Read depth. depth = np.load(self.depth_paths[index]) depth: onp.NDArray[onp.float32] = depth # Check if conf file exists, otherwise initialize with ones if len(self.conf_paths) == 0: conf = np.ones_like(depth, dtype=onp.float32) else: conf_path = self.conf_paths[index] if os.path.exists(conf_path): conf = np.load(conf_path) conf: onpt.NDArray[onp.float32] = conf # Clip confidence to avoid negative values conf = np.clip(conf, 0.0001, 99999) else: conf = np.ones_like(depth, dtype=onp.float32) # Check if init conf file exists, otherwise initialize with ones if len(self.init_conf_paths) == 0: # If init conf is not available, use conf init_conf = conf else: init_conf_path = self.init_conf_paths[index] if os.path.exists(init_conf_path): init_conf = np.load(init_conf_path) init_conf: onpt.NDArray[onp.float32] = init_conf # Clip confidence to avoid negative values init_conf = np.clip(init_conf, 0.0001, 99999) else: init_conf = np.ones_like(depth, dtype=onp.float32) # Check if mask file exists, otherwise initialize with zeros if len(self.mask_paths) == 0: mask = np.ones_like(depth, dtype=onp.bool_) else: mask_path = self.mask_paths[index] if os.path.exists(mask_path): mask = iio.imread(mask_path) > 0 mask: onpt.NDArray[onp.bool_] = mask else: mask = np.ones_like(depth, dtype=onp.bool_) if self.no_mask: mask = np.ones_like(mask).astype(np.bool_) # Read RGB. rgb = iio.imread(self.rgb_paths[index]) # if 4 channels, remove the alpha channel if rgb.shape[-1] == 4: rgb = rgb[..., :3] return Record3dFrame( K=self.K[index], rgb=rgb, depth=depth, mask=mask, conf=conf, init_conf=init_conf, T_world_camera=self.T_world_cameras[index], conf_threshold=self.conf_threshold, foreground_conf_threshold=self.foreground_conf_threshold, ) @dataclasses.dataclass class Record3dFrame: """A single frame from a Record3D capture.""" K: onpt.NDArray[onp.float32] rgb: onpt.NDArray[onp.uint8] depth: onpt.NDArray[onp.float32] mask: onpt.NDArray[onp.bool_] conf: onpt.NDArray[onp.float32] init_conf: onpt.NDArray[onp.float32] T_world_camera: onpt.NDArray[onp.float32] conf_threshold: float = 1.0 foreground_conf_threshold: float = 0.1 def get_point_cloud( self, downsample_factor: int = 1, bg_downsample_factor: int = 1, ) -> Tuple[onpt.NDArray[onp.float32], onpt.NDArray[onp.uint8], onpt.NDArray[onp.float32], onpt.NDArray[onp.uint8]]: rgb = self.rgb[::downsample_factor, ::downsample_factor] depth = skimage.transform.resize(self.depth, rgb.shape[:2], order=0) mask = cast( onpt.NDArray[onp.bool_], skimage.transform.resize(self.mask, rgb.shape[:2], order=0), ) assert depth.shape == rgb.shape[:2] K = self.K T_world_camera = self.T_world_camera img_wh = rgb.shape[:2][::-1] grid = ( np.stack(np.meshgrid(np.arange(img_wh[0]), np.arange(img_wh[1])), 2) + 0.5 ) grid = grid * downsample_factor conf_mask = self.conf > self.conf_threshold if self.init_conf is not None: fg_conf_mask = self.init_conf > self.foreground_conf_threshold else: fg_conf_mask = self.conf > self.foreground_conf_threshold # reshape the conf mask to the shape of the depth conf_mask = skimage.transform.resize(conf_mask, depth.shape, order=0) fg_conf_mask = skimage.transform.resize(fg_conf_mask, depth.shape, order=0) # Foreground points homo_grid = np.pad(grid[fg_conf_mask & mask], ((0, 0), (0, 1)), constant_values=1) local_dirs = np.einsum("ij,bj->bi", np.linalg.inv(K), homo_grid) dirs = np.einsum("ij,bj->bi", T_world_camera[:3, :3], local_dirs) points = (T_world_camera[:3, 3] + dirs * depth[fg_conf_mask & mask, None]).astype(np.float32) point_colors = rgb[fg_conf_mask & mask] # Background points bg_homo_grid = np.pad(grid[conf_mask & ~mask], ((0, 0), (0, 1)), constant_values=1) bg_local_dirs = np.einsum("ij,bj->bi", np.linalg.inv(K), bg_homo_grid) bg_dirs = np.einsum("ij,bj->bi", T_world_camera[:3, :3], bg_local_dirs) bg_points = (T_world_camera[:3, 3] + bg_dirs * depth[conf_mask & ~mask, None]).astype(np.float32) bg_point_colors = rgb[conf_mask & ~mask] if bg_downsample_factor > 1 and bg_points.shape[0] > 0: indices = np.random.choice( bg_points.shape[0], size=bg_points.shape[0] // bg_downsample_factor, replace=False ) bg_points = bg_points[indices] bg_point_colors = bg_point_colors[indices] return points, point_colors, bg_points, bg_point_colors ================================================ FILE: viser/src/viser/extras/_record3d_customized_megasam.py ================================================ from __future__ import annotations import dataclasses import os import json from pathlib import Path from typing import Tuple, cast import imageio.v3 as iio import liblzfse import numpy as np import numpy as onp import numpy.typing as onpt import skimage.transform from scipy.spatial.transform import Rotation from scipy.spatial import cKDTree class Record3dLoader_Customized_Megasam: """Helper for loading frames for Record3D captures directly from a NPZ file.""" def __init__(self, npz_data: dict, conf_threshold: float = 1.0, foreground_conf_threshold: float = 0.1, no_mask: bool = False, xyzw=True, init_conf=False): # Assuming npz_data is a dictionary containing all the necessary arrays from the NPZ file self.K = np.expand_dims(npz_data['intrinsic'], 0) # (3,3) -> (1,3,3) Intrinsic matrix self.K = np.repeat(self.K, npz_data['images'].shape[0], axis=0) # (1,3,3) -> (N,3,3) self.T_world_cameras = npz_data['cam_c2w'] # (N,4,4) Camera poses (extrinsics) self.fps = 30 # Assuming a frame rate of 30 self.conf_threshold = conf_threshold self.foreground_conf_threshold = foreground_conf_threshold self.no_mask = no_mask # Initialize the other parameters self.init_conf = init_conf # Read frames from the NPZ file self.images = npz_data['images'] # (N,H,W,3) RGB images self.depths = npz_data['depths'] # (N,H,W) Depth maps self.confidences = npz_data.get('conf', []) self.init_conf_data = npz_data.get('init_conf', []) self.masks = npz_data.get('enlarged_dynamic_mask', []) # Align all camera poses by the first frame T0 = self.T_world_cameras[len(self.T_world_cameras) // 2] # First camera pose (4x4 matrix) T0_inv = np.linalg.inv(T0) # Inverse of the first camera pose # Apply T0_inv to all camera poses self.T_world_cameras = np.matmul(T0_inv[np.newaxis, :, :], self.T_world_cameras) def num_frames(self) -> int: return len(self.images) def get_frame(self, index: int) -> Record3dFrame: # Read the depth for the given frame depth = self.depths[index] depth = depth.astype(np.float32) # Check if conf file exists, otherwise initialize with ones if len(self.confidences) == 0: conf = np.ones_like(depth, dtype=np.float32) else: conf = self.confidences[index] conf = np.clip(conf, 0.0001, 99999) # Check if init conf file exists, otherwise initialize with ones if len(self.init_conf_data) == 0: init_conf = conf else: init_conf = self.init_conf_data[index] init_conf = np.clip(init_conf, 0.0001, 99999) # Check if mask exists, otherwise initialize with zeros if len(self.masks) == 0: mask = np.ones_like(depth, dtype=bool) else: mask = self.masks[index] > 0 # Assuming mask is a binary image if self.no_mask: mask = np.ones_like(mask).astype(np.bool_) # Read RGB image rgb = self.images[index] if rgb.shape[-1] == 4: rgb = rgb[..., :3] return Record3dFrame( K=self.K[index], rgb=rgb, depth=depth, mask=mask, conf=conf, init_conf=init_conf, T_world_camera=self.T_world_cameras[index], conf_threshold=self.conf_threshold, foreground_conf_threshold=self.foreground_conf_threshold, ) @dataclasses.dataclass class Record3dFrame: """A single frame from a Record3D capture.""" K: onpt.NDArray[onp.float32] rgb: onpt.NDArray[onp.uint8] depth: onpt.NDArray[onp.float32] mask: onpt.NDArray[onp.bool_] conf: onpt.NDArray[onp.float32] init_conf: onpt.NDArray[onp.float32] T_world_camera: onpt.NDArray[onp.float32] conf_threshold: float = 1.0 foreground_conf_threshold: float = 0.1 def get_point_cloud( self, downsample_factor: int = 1, bg_downsample_factor: int = 1, ) -> Tuple[onpt.NDArray[onp.float32], onpt.NDArray[onp.uint8], onpt.NDArray[onp.float32], onpt.NDArray[onp.uint8]]: rgb = self.rgb[::downsample_factor, ::downsample_factor] depth = skimage.transform.resize(self.depth, rgb.shape[:2], order=0) mask = cast( onpt.NDArray[onp.bool_], skimage.transform.resize(self.mask, rgb.shape[:2], order=0), ) assert depth.shape == rgb.shape[:2] K = self.K T_world_camera = self.T_world_camera img_wh = rgb.shape[:2][::-1] grid = ( np.stack(np.meshgrid(np.arange(img_wh[0]), np.arange(img_wh[1])), 2) + 0.5 ) grid = grid * downsample_factor conf_mask = self.conf > self.conf_threshold if self.init_conf is not None: fg_conf_mask = self.init_conf > self.foreground_conf_threshold else: fg_conf_mask = self.conf > self.foreground_conf_threshold # reshape the conf mask to the shape of the depth conf_mask = skimage.transform.resize(conf_mask, depth.shape, order=0) fg_conf_mask = skimage.transform.resize(fg_conf_mask, depth.shape, order=0) # Foreground points homo_grid = np.pad(grid[fg_conf_mask & mask], ((0, 0), (0, 1)), constant_values=1) local_dirs = np.einsum("ij,bj->bi", np.linalg.inv(K), homo_grid) dirs = np.einsum("ij,bj->bi", T_world_camera[:3, :3], local_dirs) points = (T_world_camera[:3, 3] + dirs * depth[fg_conf_mask & mask, None]).astype(np.float32) point_colors = rgb[fg_conf_mask & mask] # Background points bg_homo_grid = np.pad(grid[conf_mask & ~mask], ((0, 0), (0, 1)), constant_values=1) bg_local_dirs = np.einsum("ij,bj->bi", np.linalg.inv(K), bg_homo_grid) bg_dirs = np.einsum("ij,bj->bi", T_world_camera[:3, :3], bg_local_dirs) bg_points = (T_world_camera[:3, 3] + bg_dirs * depth[conf_mask & ~mask, None]).astype(np.float32) bg_point_colors = rgb[conf_mask & ~mask] if bg_downsample_factor > 1 and bg_points.shape[0] > 0: indices = np.random.choice( bg_points.shape[0], size=bg_points.shape[0] // bg_downsample_factor, replace=False ) bg_points = bg_points[indices] bg_point_colors = bg_point_colors[indices] return points, point_colors, bg_points, bg_point_colors ================================================ FILE: viser/src/viser/extras/_urdf.py ================================================ from __future__ import annotations from functools import partial from pathlib import Path from typing import List, Tuple import numpy as onp import trimesh import yourdfpy import viser from .. import transforms as tf class ViserUrdf: """Helper for rendering URDFs in Viser. Args: target: ViserServer or ClientHandle object to add URDF to. urdf_or_path: Either a path to a URDF file or a yourdfpy URDF object. scale: Scale factor to apply to resize the URDF. root_node_name: Viser scene tree name for the root of the URDF geometry. mesh_color_override: Optional color to override the URDF's mesh colors. """ def __init__( self, target: viser.ViserServer | viser.ClientHandle, urdf_or_path: yourdfpy.URDF | Path, scale: float = 1.0, root_node_name: str = "/", mesh_color_override: tuple[float, float, float] | None = None, ) -> None: assert root_node_name.startswith("/") assert len(root_node_name) == 1 or not root_node_name.endswith("/") if isinstance(urdf_or_path, Path): urdf = yourdfpy.URDF.load( urdf_or_path, filename_handler=partial( yourdfpy.filename_handler_magic, dir=urdf_or_path.parent ), ) else: urdf = urdf_or_path assert isinstance(urdf, yourdfpy.URDF) self._target = target self._urdf = urdf self._scale = scale self._root_node_name = root_node_name # Add coordinate frame for each joint. self._joint_frames: List[viser.SceneNodeHandle] = [] for joint in self._urdf.joint_map.values(): assert isinstance(joint, yourdfpy.Joint) self._joint_frames.append( self._target.scene.add_frame( _viser_name_from_frame( self._urdf, joint.child, self._root_node_name ), show_axes=False, ) ) # Add the URDF's meshes/geometry to viser. self._meshes: List[viser.SceneNodeHandle] = [] for link_name, mesh in urdf.scene.geometry.items(): assert isinstance(mesh, trimesh.Trimesh) T_parent_child = urdf.get_transform( link_name, urdf.scene.graph.transforms.parents[link_name] ) name = _viser_name_from_frame(urdf, link_name, root_node_name) # Scale + transform the mesh. (these will mutate it!) # # It's important that we use apply_transform() instead of unpacking # the rotation/translation terms, since the scene graph transform # can also contain scale and reflection terms. mesh = mesh.copy() mesh.apply_scale(self._scale) mesh.apply_transform(T_parent_child) if mesh_color_override is None: self._meshes.append(target.scene.add_mesh_trimesh(name, mesh)) else: self._meshes.append( target.scene.add_mesh_simple( name, mesh.vertices, mesh.faces, color=mesh_color_override, ) ) def remove(self) -> None: """Remove URDF from scene.""" # Some of this will be redundant, since children are removed when # parents are removed. for frame in self._joint_frames: frame.remove() for mesh in self._meshes: mesh.remove() def update_cfg(self, configuration: onp.ndarray) -> None: """Update the joint angles of the visualized URDF.""" self._urdf.update_cfg(configuration) with self._target.atomic(): for joint, frame_handle in zip( self._urdf.joint_map.values(), self._joint_frames ): assert isinstance(joint, yourdfpy.Joint) T_parent_child = self._urdf.get_transform(joint.child, joint.parent) frame_handle.wxyz = tf.SO3.from_matrix(T_parent_child[:3, :3]).wxyz frame_handle.position = T_parent_child[:3, 3] * self._scale def get_actuated_joint_limits( self, ) -> dict[str, tuple[float | None, float | None]]: """Returns an ordered mapping from actuated joint names to position limits.""" out: dict[str, tuple[float | None, float | None]] = {} for joint_name, joint in zip( self._urdf.actuated_joint_names, self._urdf.actuated_joints ): assert isinstance(joint_name, str) assert isinstance(joint, yourdfpy.Joint) if joint.limit is None: out[joint_name] = (-onp.pi, onp.pi) else: out[joint_name] = (joint.limit.lower, joint.limit.upper) return out def get_actuated_joint_names(self) -> Tuple[str, ...]: """Returns a tuple of actuated joint names, in order.""" return tuple(self._urdf.actuated_joint_names) def _viser_name_from_frame( urdf: yourdfpy.URDF, frame_name: str, root_node_name: str = "/", ) -> str: """Given the (unique) name of a frame in our URDF's kinematic tree, return a scene node name for viser. For a robot manipulator with four frames, that looks like: ((shoulder)) == ((elbow)) / / |X| / / ((wrist)) ____/ /____ |X| [ ] [=======] [ base_link ] [] [] [___________] this would map a name like "elbow" to "base_link/shoulder/elbow". """ assert root_node_name.startswith("/") assert len(root_node_name) == 1 or not root_node_name.endswith("/") frames = [] while frame_name != urdf.scene.graph.base_frame: frames.append(frame_name) frame_name = urdf.scene.graph.transforms.parents[frame_name] if root_node_name != "/": frames.append(root_node_name) return "/".join(frames[::-1]) ================================================ FILE: viser/src/viser/extras/colmap/__init__.py ================================================ """Colmap utilities.""" from ._colmap_utils import read_cameras_binary as read_cameras_binary from ._colmap_utils import read_cameras_text as read_cameras_text from ._colmap_utils import read_images_binary as read_images_binary from ._colmap_utils import read_images_text as read_images_text from ._colmap_utils import read_points3d_binary as read_points3d_binary from ._colmap_utils import read_points3D_text as read_points3D_text ================================================ FILE: viser/src/viser/extras/colmap/_colmap_utils.py ================================================ # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of # its contributors may be used to endorse or promote products derived # from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. # # Author: Johannes L. Schoenberger (jsch at inf.ethz.ch) # # Modified 5/3/2023 to add type annotations. import struct from dataclasses import dataclass from pathlib import Path from typing import Dict, Union import numpy as np @dataclass(frozen=True) class CameraModel: model_id: int model_name: str num_params: int @dataclass(frozen=True) class Camera: id: int model: str width: int height: int params: np.ndarray @dataclass(frozen=True) class BaseImage: id: int qvec: np.ndarray tvec: np.ndarray camera_id: int name: str xys: np.ndarray point3D_ids: np.ndarray @dataclass(frozen=True) class Point3D: id: int xyz: np.ndarray rgb: np.ndarray error: Union[float, np.ndarray] image_ids: np.ndarray point2D_idxs: np.ndarray class Image(BaseImage): def qvec2rotmat(self): return qvec2rotmat(self.qvec) CAMERA_MODELS = { CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), CameraModel(model_id=1, model_name="PINHOLE", num_params=4), CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), CameraModel(model_id=3, model_name="RADIAL", num_params=5), CameraModel(model_id=4, model_name="OPENCV", num_params=8), CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), CameraModel(model_id=7, model_name="FOV", num_params=5), CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), } CAMERA_MODEL_IDS = dict( [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS] ) def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): """Read and unpack the next bytes from a binary file. :param fid: :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. :param endian_character: Any of {@, =, <, >, !} :return: Tuple of read and unpacked values. """ data = fid.read(num_bytes) return struct.unpack(endian_character + format_char_sequence, data) def read_cameras_text(path: Union[str, Path]) -> Dict[int, Camera]: """ see: src/base/reconstruction.cc void Reconstruction::WriteCamerasText(const std::string& path) void Reconstruction::ReadCamerasText(const std::string& path) """ cameras = {} with open(path, "r") as fid: while True: line = fid.readline() if not line: break line = line.strip() if len(line) > 0 and line[0] != "#": elems = line.split() camera_id = int(elems[0]) model = elems[1] width = int(elems[2]) height = int(elems[3]) params = np.array(tuple(map(float, elems[4:]))) cameras[camera_id] = Camera( id=camera_id, model=model, width=width, height=height, params=params ) return cameras def read_cameras_binary(path_to_model_file: Union[str, Path]) -> Dict[int, Camera]: """ see: src/base/reconstruction.cc void Reconstruction::WriteCamerasBinary(const std::string& path) void Reconstruction::ReadCamerasBinary(const std::string& path) """ cameras = {} with open(path_to_model_file, "rb") as fid: num_cameras = read_next_bytes(fid, 8, "Q")[0] for camera_line_index in range(num_cameras): camera_properties = read_next_bytes( fid, num_bytes=24, format_char_sequence="iiQQ" ) camera_id = camera_properties[0] model_id = camera_properties[1] model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name width = camera_properties[2] height = camera_properties[3] num_params = CAMERA_MODEL_IDS[model_id].num_params params = read_next_bytes( fid, num_bytes=8 * num_params, format_char_sequence="d" * num_params ) cameras[camera_id] = Camera( id=camera_id, model=model_name, width=width, height=height, params=np.array(params), ) assert len(cameras) == num_cameras return cameras def read_images_text(path: Union[str, Path]) -> Dict[int, Image]: """ see: src/base/reconstruction.cc void Reconstruction::ReadImagesText(const std::string& path) void Reconstruction::WriteImagesText(const std::string& path) """ images = {} with open(path, "r") as fid: while True: line = fid.readline() if not line: break line = line.strip() if len(line) > 0 and line[0] != "#": elems = line.split() image_id = int(elems[0]) qvec = np.array(tuple(map(float, elems[1:5]))) tvec = np.array(tuple(map(float, elems[5:8]))) camera_id = int(elems[8]) image_name = elems[9] elems = fid.readline().split() xys = np.column_stack( [tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))] ) point3D_ids = np.array(tuple(map(int, elems[2::3]))) images[image_id] = Image( id=image_id, qvec=qvec, tvec=tvec, camera_id=camera_id, name=image_name, xys=xys, point3D_ids=point3D_ids, ) return images def read_images_binary(path_to_model_file: Union[str, Path]) -> Dict[int, Image]: """ see: src/base/reconstruction.cc void Reconstruction::ReadImagesBinary(const std::string& path) void Reconstruction::WriteImagesBinary(const std::string& path) """ images = {} with open(path_to_model_file, "rb") as fid: num_reg_images = read_next_bytes(fid, 8, "Q")[0] for image_index in range(num_reg_images): binary_image_properties = read_next_bytes( fid, num_bytes=64, format_char_sequence="idddddddi" ) image_id = binary_image_properties[0] qvec = np.array(binary_image_properties[1:5]) tvec = np.array(binary_image_properties[5:8]) camera_id = binary_image_properties[8] image_name = "" current_char = read_next_bytes(fid, 1, "c")[0] while current_char != b"\x00": # look for the ASCII 0 entry image_name += current_char.decode("utf-8") current_char = read_next_bytes(fid, 1, "c")[0] num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ 0 ] x_y_id_s = read_next_bytes( fid, num_bytes=24 * num_points2D, format_char_sequence="ddq" * num_points2D, ) xys = np.column_stack( [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))] ) point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) images[image_id] = Image( id=image_id, qvec=qvec, tvec=tvec, camera_id=camera_id, name=image_name, xys=xys, point3D_ids=point3D_ids, ) return images def read_points3D_text(path: Union[str, Path]): """ see: src/base/reconstruction.cc void Reconstruction::ReadPoints3DText(const std::string& path) void Reconstruction::WritePoints3DText(const std::string& path) """ points3D = {} with open(path, "r") as fid: while True: line = fid.readline() if not line: break line = line.strip() if len(line) > 0 and line[0] != "#": elems = line.split() point3D_id = int(elems[0]) xyz = np.array(tuple(map(float, elems[1:4]))) rgb = np.array(tuple(map(int, elems[4:7]))) error = float(elems[7]) image_ids = np.array(tuple(map(int, elems[8::2]))) point2D_idxs = np.array(tuple(map(int, elems[9::2]))) points3D[point3D_id] = Point3D( id=point3D_id, xyz=xyz, rgb=rgb, error=error, image_ids=image_ids, point2D_idxs=point2D_idxs, ) return points3D def read_points3d_binary(path_to_model_file: Union[str, Path]) -> Dict[int, Point3D]: """ see: src/base/reconstruction.cc void Reconstruction::ReadPoints3DBinary(const std::string& path) void Reconstruction::WritePoints3DBinary(const std::string& path) """ points3D = {} with open(path_to_model_file, "rb") as fid: num_points = read_next_bytes(fid, 8, "Q")[0] for point_line_index in range(num_points): binary_point_line_properties = read_next_bytes( fid, num_bytes=43, format_char_sequence="QdddBBBd" ) point3D_id = binary_point_line_properties[0] xyz = np.array(binary_point_line_properties[1:4]) rgb = np.array(binary_point_line_properties[4:7]) error = np.array(binary_point_line_properties[7]) track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ 0 ] track_elems = read_next_bytes( fid, num_bytes=8 * track_length, format_char_sequence="ii" * track_length, ) image_ids = np.array(tuple(map(int, track_elems[0::2]))) point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) points3D[point3D_id] = Point3D( id=point3D_id, xyz=xyz, rgb=rgb, error=error, image_ids=image_ids, point2D_idxs=point2D_idxs, ) return points3D def qvec2rotmat(qvec): return np.array( [ [ 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], ], [ 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], ], [ 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, ], ] ) ================================================ FILE: viser/src/viser/infra/__init__.py ================================================ """:mod:`viser.infra` provides WebSocket-based communication infrastructure. We implement abstractions for: - Launching a WebSocket+HTTP server on a shared port. - Registering callbacks for connection events and incoming messages. - Asynchronous message sending, both broadcasted and to individual clients. - Defining dataclass-based message types. - Translating Python message types to TypeScript interfaces. These are what `viser` runs on under-the-hood, and generally won't be useful unless you're building a web-based application from scratch. """ from ._infra import ClientId as ClientId from ._infra import WebsockClientConnection as WebsockClientConnection from ._infra import WebsockMessageHandler as WebsockMessageHandler from ._infra import WebsockServer as WebsockServer from ._messages import Message as Message from ._typescript_interface_gen import ( TypeScriptAnnotationOverride as TypeScriptAnnotationOverride, ) from ._typescript_interface_gen import ( generate_typescript_interfaces as generate_typescript_interfaces, ) ================================================ FILE: viser/src/viser/infra/_async_message_buffer.py ================================================ from __future__ import annotations import asyncio import dataclasses import threading from asyncio.events import AbstractEventLoop from typing import AsyncGenerator, Dict, List, Sequence from ._messages import Message @dataclasses.dataclass class AsyncMessageBuffer: """Async iterable for keeping a persistent buffer of messages. Uses heuristics on message names to automatically cull out redundant messages.""" event_loop: AbstractEventLoop persistent_messages: bool message_event: asyncio.Event = dataclasses.field(default_factory=asyncio.Event) flush_event: asyncio.Event = dataclasses.field(default_factory=asyncio.Event) message_counter: int = 0 message_from_id: Dict[int, Message] = dataclasses.field(default_factory=dict) id_from_redundancy_key: Dict[str, int] = dataclasses.field(default_factory=dict) buffer_lock: threading.Lock = dataclasses.field(default_factory=threading.Lock) """Lock to prevent race conditions when pushing messages from different threads.""" max_window_size: int = 128 window_duration_sec: float = 1.0 / 60.0 done: bool = False def push(self, message: Message) -> None: """Push a new message to our buffer, and remove old redundant ones.""" assert isinstance(message, Message) # Add message to buffer. redundancy_key = message.redundancy_key() with self.buffer_lock: new_message_id = self.message_counter self.message_from_id[new_message_id] = message self.message_counter += 1 # If an existing message with the same key already exists in our buffer, we # don't need the old one anymore. :-) if ( redundancy_key is not None and redundancy_key in self.id_from_redundancy_key ): old_message_id = self.id_from_redundancy_key.pop(redundancy_key) self.message_from_id.pop(old_message_id) self.id_from_redundancy_key[redundancy_key] = new_message_id # Pulse message event to notify consumers that a new message is available. self.event_loop.call_soon_threadsafe(self.message_event.set) def flush(self) -> None: """Flush the message buffer; signals to yield a message window immediately.""" self.event_loop.call_soon_threadsafe(self.flush_event.set) def set_done(self) -> None: """Set the done flag. Kills the generator.""" self.done = True # Pulse message event to make sure we aren't waiting for a new message. self.event_loop.call_soon_threadsafe(self.message_event.set) # Pulse flush event to skip any windowing delay. self.event_loop.call_soon_threadsafe(self.flush_event.set) async def window_generator( self, client_id: int ) -> AsyncGenerator[Sequence[Message], None]: """Async iterator over messages. Loops infinitely, and waits when no messages are available.""" last_sent_id = -1 flush_wait = asyncio.create_task(self.flush_event.wait()) while not self.done: window: List[Message] = [] most_recent_message_id = self.message_counter - 1 while ( last_sent_id < most_recent_message_id and len(window) < self.max_window_size ): last_sent_id += 1 if self.persistent_messages: message = self.message_from_id.get(last_sent_id, None) else: # If we're not persisting messages, remove them from the buffer. with self.buffer_lock: message = self.message_from_id.pop(last_sent_id, None) if message is not None: redundancy_key = message.redundancy_key() self.id_from_redundancy_key.pop(redundancy_key, None) if message is not None and message.excluded_self_client != client_id: window.append(message) if len(window) > 0: # Yield a window! yield window else: # Wait for a new message to come in. await self.message_event.wait() self.message_event.clear() # Add a delay if either (a) we failed to yield or (b) there's currently no messages to send. most_recent_message_id = self.message_counter - 1 if len(window) == 0 or most_recent_message_id == last_sent_id: done, pending = await asyncio.wait( [flush_wait], timeout=self.window_duration_sec ) del pending if flush_wait in done and not self.done: self.flush_event.clear() flush_wait = asyncio.create_task(self.flush_event.wait()) ================================================ FILE: viser/src/viser/infra/_infra.py ================================================ from __future__ import annotations import abc import asyncio import contextlib import dataclasses import gzip import http import mimetypes import queue import threading from asyncio.events import AbstractEventLoop from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Any, Callable, Generator, NewType, TypeVar import msgspec import rich import websockets.connection import websockets.datastructures import websockets.exceptions import websockets.server from typing_extensions import Literal, assert_never, override from websockets.legacy.server import WebSocketServerProtocol from ._async_message_buffer import AsyncMessageBuffer from ._messages import Message @dataclasses.dataclass class _ClientHandleState: # Internal state for ClientConnection objects. # message_buffer: asyncio.Queue message_buffer: AsyncMessageBuffer event_loop: AbstractEventLoop ClientId = NewType("ClientId", int) TMessage = TypeVar("TMessage", bound=Message) class RecordHandle: """**Experimental.** Handle for recording outgoing messages. Useful for logging + debugging.""" def __init__( self, handler: WebsockMessageHandler, filter: Callable[[Message], bool] ): self._handler = handler self._filter = filter self._loop_start_index: int | None = None self._time: float = 0.0 self._messages: list[tuple[float, dict[str, Any]]] = [] def _insert_message(self, message: Message) -> None: """Insert a message into the recorded file.""" # Exclude GUI messages. This is hacky. if not self._filter(message): return self._messages.append((self._time, message.as_serializable_dict())) def insert_sleep(self, duration: float) -> None: """Insert a sleep into the recorded file.""" self._time += duration def set_loop_start(self) -> None: """Mark the start of the loop. Messages sent after this point will be looped. Should only be called once.""" assert self._loop_start_index is None, "Loop start already set." self._loop_start_index = len(self._messages) def end_and_serialize(self) -> bytes: """End the recording and serialize contents. Returns the recording as bytes, which should generally be written to a file.""" packed_bytes = msgspec.msgpack.encode( { "loopStartIndex": self._loop_start_index, "durationSeconds": self._time, "messages": self._messages, } ) assert isinstance(packed_bytes, bytes) self._handler._record_handle = None return gzip.compress(packed_bytes, compresslevel=9) class WebsockMessageHandler: """Mix-in for adding message handling to a class.""" def __init__(self, thread_executor: ThreadPoolExecutor) -> None: self._thread_executor = thread_executor self._incoming_handlers: dict[ type[Message], list[Callable[[ClientId, Message], None]] ] = {} self._atomic_lock = threading.Lock() self._queued_messages: queue.Queue = queue.Queue() self._locked_thread_id = -1 # Set to None if not recording. self._record_handle: RecordHandle | None = None def start_recording(self, filter: Callable[[Message], bool]) -> RecordHandle: """Start recording messages that are sent. Sent messages will be serialized and can be used for playback.""" assert self._record_handle is None, "Already recording." self._record_handle = RecordHandle(self, filter) return self._record_handle def register_handler( self, message_cls: type[TMessage], callback: Callable[[ClientId, TMessage], Any], ) -> None: """Register a handler for a particular message type.""" if message_cls not in self._incoming_handlers: self._incoming_handlers[message_cls] = [] self._incoming_handlers[message_cls].append(callback) # type: ignore def unregister_handler( self, message_cls: type[TMessage], callback: Callable[[ClientId, TMessage], Any] | None = None, ): """Unregister a handler for a particular message type.""" assert ( message_cls in self._incoming_handlers ), "Tried to unregister a handler that hasn't been registered." if callback is None: self._incoming_handlers.pop(message_cls) else: self._incoming_handlers[message_cls].remove(callback) # type: ignore def _handle_incoming_message(self, client_id: ClientId, message: Message) -> None: """Handle incoming messages.""" if type(message) in self._incoming_handlers: for cb in self._incoming_handlers[type(message)]: cb(client_id, message) @abc.abstractmethod def unsafe_send_message(self, message: Message) -> None: ... def queue_message(self, message: Message) -> None: """Wrapped method for sending messages safely.""" if self._record_handle is not None: self._record_handle._insert_message(message) got_lock = self._atomic_lock.acquire(blocking=False) if got_lock: self.unsafe_send_message(message) self._atomic_lock.release() else: # Send when lock is acquirable, while retaining message order. # This could be optimized! self._queued_messages.put(message) def try_again() -> None: with self._atomic_lock: self.unsafe_send_message(self._queued_messages.get()) self._thread_executor.submit(try_again) @contextlib.contextmanager def atomic(self) -> Generator[None, None, None]: """Returns a context where: all outgoing messages are grouped and applied by clients atomically. This should be treated as a soft constraint that's helpful for things like animations, or when we want position and orientation updates to happen synchronously. Returns: Context manager. """ # If called multiple times in the same thread, we ignore inner calls. thread_id = threading.get_ident() if thread_id == self._locked_thread_id: got_lock = False else: self._atomic_lock.acquire() self._locked_thread_id = thread_id got_lock = True yield if got_lock: self._atomic_lock.release() self._locked_thread_id = -1 class WebsockClientConnection(WebsockMessageHandler): """Handle for sending messages to and listening to messages from a single connected client.""" def __init__( self, client_id: int, thread_executor: ThreadPoolExecutor, client_state: _ClientHandleState, ) -> None: self.client_id = client_id self._state = client_state super().__init__(thread_executor) @override def unsafe_send_message(self, message: Message) -> None: """Send a message to a specific client.""" self._state.message_buffer.push(message) class WebsockServer(WebsockMessageHandler): """Websocket server abstraction. Communicates asynchronously with client applications. By default, all messages are broadcasted to all connected clients. To send messages to an individual client, we can use `on_client_connect()` to retrieve client handles. Args: host: Host to bind server to. port: Port to bind server to. message_class: Base class for message types. Subclasses of the message type should have unique names. This argument is optional currently, but will be required in the future. http_server_root: Path to root for HTTP server. verbose: Toggle for print messages. client_api_version: Flag for backwards compatibility. 0 sends individual messages. 1 sends windowed messages. """ def __init__( self, host: str, port: int, message_class: type[Message] = Message, http_server_root: Path | None = None, verbose: bool = True, client_api_version: Literal[0, 1] = 0, ): super().__init__(thread_executor=ThreadPoolExecutor(max_workers=32)) # Track connected clients. self._client_connect_cb: list[Callable[[WebsockClientConnection], None]] = [] self._client_disconnect_cb: list[Callable[[WebsockClientConnection], None]] = [] self._host = host self._port = port self._message_class = message_class self._http_server_root = http_server_root self._verbose = verbose self._client_api_version: Literal[0, 1] = client_api_version self._shutdown_event = threading.Event() self._ws_server: websockets.WebSocketServer | None = None self._client_state_from_id: dict[int, _ClientHandleState] = {} def start(self) -> None: """Start the server.""" # Start server thread. ready_sem = threading.Semaphore(value=1) ready_sem.acquire() threading.Thread( target=lambda: self._background_worker(ready_sem), daemon=True, ).start() # Wait for ready signal from the background thread. ready_sem.acquire() # Broadcast buffer should be populated by the background worker. assert isinstance(self._broadcast_buffer, AsyncMessageBuffer) def stop(self) -> None: """Stop the server.""" assert self._ws_server is not None self._ws_server.close() self._ws_server = None self._thread_executor.shutdown(wait=True) def on_client_connect(self, cb: Callable[[WebsockClientConnection], Any]) -> None: """Attach a callback to run for newly connected clients.""" self._client_connect_cb.append(cb) def on_client_disconnect( self, cb: Callable[[WebsockClientConnection], Any] ) -> None: """Attach a callback to run when clients disconnect.""" self._client_disconnect_cb.append(cb) @override def unsafe_send_message(self, message: Message) -> None: """Pushes a message onto the broadcast queue. Message will be sent to all clients. Broadcasted messages are persistent: if a new client connects to the server, they will receive a buffered set of previously broadcasted messages. The buffer is culled using the value of `message.redundancy_key()`.""" self._broadcast_buffer.push(message) def flush(self) -> None: """Flush the outgoing message buffer for broadcasted messages. Any buffered messages will immediately be sent. (by default they are windowed)""" # TODO: we should add a flush event. self._broadcast_buffer.flush() def flush_client(self, client_id: int) -> None: """Flush the outgoing message buffer for a particular client. Any buffered messages will immediately be sent. (by default they are windowed)""" self._client_state_from_id[client_id].message_buffer.flush() def _background_worker(self, ready_sem: threading.Semaphore) -> None: host = self._host port = self._port message_class = self._message_class http_server_root = self._http_server_root # Need to make a new event loop for notebook compatbility. event_loop = asyncio.new_event_loop() asyncio.set_event_loop(event_loop) self._broadcast_buffer = AsyncMessageBuffer( event_loop, persistent_messages=True ) count_lock = asyncio.Lock() connection_count = 0 total_connections = 0 async def serve(websocket: WebSocketServerProtocol) -> None: """Server loop, run once per connection.""" async with count_lock: nonlocal connection_count client_id = ClientId(connection_count) connection_count += 1 nonlocal total_connections total_connections += 1 if self._verbose: rich.print( f"[bold](viser)[/bold] Connection opened ({client_id}," f" {total_connections} total)," f" {len(self._broadcast_buffer.message_from_id)} persistent" " messages" ) client_state = _ClientHandleState( AsyncMessageBuffer(event_loop, persistent_messages=False), event_loop, ) client_connection = WebsockClientConnection( client_id, self._thread_executor, client_state ) self._client_state_from_id[client_id] = client_state def handle_incoming(message: Message) -> None: self._thread_executor.submit( error_print_wrapper( lambda: self._handle_incoming_message(client_id, message) ) ) self._thread_executor.submit( error_print_wrapper( lambda: client_connection._handle_incoming_message( client_id, message ) ) ) # New connection callbacks. for cb in self._client_connect_cb: cb(client_connection) try: # For each client: infinite loop over producers (which send messages) # and consumers (which receive messages). await asyncio.gather( _message_producer( websocket, client_state.message_buffer, client_id, self._client_api_version, ), _message_producer( websocket, self._broadcast_buffer, client_id, self._client_api_version, ), _message_consumer(websocket, handle_incoming, message_class), ) except ( websockets.exceptions.ConnectionClosedOK, websockets.exceptions.ConnectionClosedError, ): # We use a sentinel value to signal that the client producer thread # should exit. # # This is partially cosmetic: it allows us to safely finish pending # queue get() tasks, which suppresses a "Task was destroyed but it is # pending" error. client_state.message_buffer.set_done() # Disconnection callbacks. for cb in self._client_disconnect_cb: cb(client_connection) # Cleanup. self._client_state_from_id.pop(client_id) total_connections -= 1 if self._verbose: rich.print( f"[bold](viser)[/bold] Connection closed ({client_id}," f" {total_connections} total)" ) # Host client on the same port as the websocket. file_cache: dict[Path, bytes] = {} file_cache_gzipped: dict[Path, bytes] = {} async def viser_http_server( path: str, request_headers: websockets.datastructures.Headers ) -> ( tuple[http.HTTPStatus, websockets.datastructures.HeadersLike, bytes] | None ): # Ignore websocket packets. if request_headers.get("Upgrade") == "websocket": return None # Strip out search params, get relative path. path = path.partition("?")[0] relpath = str(Path(path).relative_to("/")) if relpath == ".": relpath = "index.html" assert http_server_root is not None source_path = http_server_root / relpath if not source_path.exists(): return (http.HTTPStatus.NOT_FOUND, {}, b"404") # type: ignore use_gzip = "gzip" in request_headers.get("Accept-Encoding", "") mime_type = mimetypes.guess_type(relpath)[0] if mime_type is None: mime_type = "application/octet-stream" response_headers = { "Content-Type": mime_type, } if source_path not in file_cache: file_cache[source_path] = source_path.read_bytes() if use_gzip: response_headers["Content-Encoding"] = "gzip" if source_path not in file_cache_gzipped: file_cache_gzipped[source_path] = gzip.compress( file_cache[source_path] ) response_payload = file_cache_gzipped[source_path] else: response_payload = file_cache[source_path] # Try to read + send over file. return (http.HTTPStatus.OK, response_headers, response_payload) for _ in range(1000): try: serve_future = websockets.server.serve( serve, host, port, # Compression can be turned off to reduce client-side CPU usage. # compression=None, process_request=( viser_http_server if http_server_root is not None else None ), ) self._ws_server = serve_future.ws_server event_loop.run_until_complete(serve_future) break except OSError: # Port not available. port += 1 continue if self._ws_server is None: raise RuntimeError("Failed to bind to port!") self._port = port ready_sem.release() event_loop.run_forever() # This will run only when the event loop ends, which happens when the # websocket server is closed. rich.print("[bold](viser)[/bold] Server stopped") async def _message_producer( websocket: WebSocketServerProtocol, buffer: AsyncMessageBuffer, client_id: int, client_api_version: Literal[0, 1], ) -> None: """Infinite loop to broadcast windows of messages from a buffer.""" window_generator = buffer.window_generator(client_id) while not buffer.done: outgoing = await window_generator.__anext__() if client_api_version == 1: serialized = msgspec.msgpack.encode( tuple(message.as_serializable_dict() for message in outgoing) ) assert isinstance(serialized, bytes) await websocket.send(serialized) elif client_api_version == 0: for msg in outgoing: serialized = msgspec.msgpack.encode(msg.as_serializable_dict()) assert isinstance(serialized, bytes) await websocket.send(serialized) else: assert_never(client_api_version) async def _message_consumer( websocket: WebSocketServerProtocol, handle_message: Callable[[Message], None], message_class: type[Message], ) -> None: """Infinite loop waiting for and then handling incoming messages.""" while True: raw = await websocket.recv() assert isinstance(raw, bytes) message = message_class.deserialize(raw) handle_message(message) def error_print_wrapper(inner: Callable[[], Any]) -> Callable[[], None]: """Wrap a Callable to print error messages when they happen. This can be helpful for jobs submitted to ThreadPoolExecutor instances, which, by default, will suppress error messages until returned futures are awaited. """ def wrapped() -> None: try: inner() except Exception as e: import traceback as tb tb.print_exception(type(e), e, e.__traceback__, limit=100) return wrapped ================================================ FILE: viser/src/viser/infra/_messages.py ================================================ """Message type definitions. For synchronization with the TypeScript definitions, see `_typescript_interface_gen.py.`""" from __future__ import annotations import abc import functools import warnings from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, cast import msgspec import numpy as onp from typing_extensions import get_args, get_origin, get_type_hints if TYPE_CHECKING: from ._infra import ClientId else: ClientId = Any def _prepare_for_deserialization(value: Any, annotation: Type) -> Any: # If annotated as a float but we got an integer, cast to float. These # are both `number` in Javascript. if annotation is float: return float(value) elif annotation is int: return int(value) elif get_origin(annotation) is tuple: out = [] args = get_args(annotation) if len(args) >= 2 and args[1] == ...: args = (args[0],) * len(value) elif len(value) != len(args): warnings.warn(f"[viser] {value} does not match annotation {annotation}") return value for i, v in enumerate(value): out.append( # Hack to be OK with wrong type annotations. # https://github.com/nerfstudio-project/nerfstudio/pull/1805 _prepare_for_deserialization(v, args[i]) if i < len(args) else v ) return tuple(out) return value def _prepare_for_serialization(value: Any, annotation: object) -> Any: """Prepare any special types for serialization.""" if annotation is Any: annotation = type(value) # Coerce some scalar types: if we've annotated as float / int but we get an # onp.float32 / onp.int64, for example, we should cast automatically. if annotation is float or isinstance(value, onp.floating): return float(value) if annotation is int or isinstance(value, onp.integer): return int(value) # Recursively handle tuples. if isinstance(value, tuple): if isinstance(value, onp.ndarray): assert False, ( "Expected a tuple, but got an array... missing a cast somewhere?" f" {value}" ) out = [] if get_origin(annotation) is tuple: args = get_args(annotation) if len(args) >= 2 and args[1] == ...: args = (args[0],) * len(value) elif len(value) != len(args): warnings.warn(f"[viser] {value} does not match annotation {annotation}") return value else: args = [Any] * len(value) for i, v in enumerate(value): out.append( # Hack to be OK with wrong type annotations. # https://github.com/nerfstudio-project/nerfstudio/pull/1805 _prepare_for_serialization(v, args[i]) if i < len(args) else v ) return tuple(out) # For arrays, we serialize underlying data directly. The client is responsible for # reading using the correct dtype. if isinstance(value, onp.ndarray): return value.data if value.data.c_contiguous else value.copy().data if isinstance(value, dict): return {k: _prepare_for_serialization(v, Any) for k, v in value.items()} # type: ignore return value T = TypeVar("T", bound="Message") @functools.lru_cache(maxsize=None) def get_type_hints_cached(cls: Type[Any]) -> Dict[str, Any]: return get_type_hints(cls) # type: ignore class Message(abc.ABC): """Base message type for server/client communication.""" excluded_self_client: Optional[ClientId] = None """Don't send this message to a particular client. Useful when a client wants to send synchronization information to other clients.""" def as_serializable_dict(self) -> Dict[str, Any]: """Convert a Python Message object into bytes.""" message_type = type(self) hints = get_type_hints_cached(message_type) out = { k: _prepare_for_serialization(v, hints[k]) for k, v in vars(self).items() } out["type"] = message_type.__name__ return out @classmethod def _from_serializable_dict(cls, mapping: Dict[str, Any]) -> Dict[str, Any]: """Convert a dict message back into a Python Message object.""" hints = get_type_hints_cached(cls) mapping = { k: _prepare_for_deserialization(v, hints[k]) for k, v in mapping.items() } return mapping @classmethod def deserialize(cls, message: bytes) -> Message: """Convert bytes into a Python Message object.""" mapping = msgspec.msgpack.decode(message) # msgpack deserializes to lists by default, but all of our annotations use # tuples. def lists_to_tuple(obj: Any) -> Any: if isinstance(obj, list): return tuple(lists_to_tuple(x) for x in obj) elif isinstance(obj, dict): return {k: lists_to_tuple(v) for k, v in obj.items()} else: return obj mapping = lists_to_tuple(mapping) message_type = cls._subclass_from_type_string()[cast(str, mapping.pop("type"))] message_kwargs = message_type._from_serializable_dict(mapping) return message_type(**message_kwargs) @classmethod @functools.lru_cache(maxsize=100) def _subclass_from_type_string(cls: Type[T]) -> Dict[str, Type[T]]: subclasses = cls.get_subclasses() return {s.__name__: s for s in subclasses} @classmethod def get_subclasses(cls: Type[T]) -> List[Type[T]]: """Recursively get message subclasses.""" def _get_subclasses(typ: Type[T]) -> List[Type[T]]: out = [] for sub in typ.__subclasses__(): out.append(sub) out.extend(_get_subclasses(sub)) return out return _get_subclasses(cls) @abc.abstractmethod def redundancy_key(self) -> str: """Returns a unique key for this message, used for detecting redundant messages. For example: if we send 1000 "set value" messages for the same GUI element, we should only keep the latest message. """ ================================================ FILE: viser/src/viser/infra/_typescript_interface_gen.py ================================================ import dataclasses from collections import defaultdict from typing import Any, Type, Union, cast import numpy as onp from typing_extensions import ( Annotated, Literal, NotRequired, get_args, get_origin, get_type_hints, is_typeddict, ) try: from typing import Literal as LiteralAlt except ImportError: LiteralAlt = Literal # type: ignore from ._messages import Message _raw_type_mapping = { bool: "boolean", float: "number", int: "number", str: "string", # For numpy arrays, we directly serialize the underlying data buffer. onp.ndarray: "Uint8Array", bytes: "Uint8Array", Any: "any", None: "null", type(None): "null", } def _get_ts_type(typ: Type[Any]) -> str: origin_typ = get_origin(typ) # Look for TypeScriptAnnotationOverride in the annotations. if origin_typ is Annotated: args = get_args(typ) for arg in args[1:]: if isinstance(arg, TypeScriptAnnotationOverride): return arg.annotation # If no override is found, just use the unwrapped type. origin_typ = args[0] # Automatic Python => TypeScript conversion. if origin_typ is tuple: args = get_args(typ) if len(args) == 2 and args[1] == ...: return _get_ts_type(args[0]) + "[]" else: return "[" + ", ".join(map(_get_ts_type, args)) + "]" elif origin_typ is list: args = get_args(typ) assert len(args) == 1 return _get_ts_type(args[0]) + "[]" elif origin_typ in (Literal, LiteralAlt): return " | ".join( map( lambda lit: repr(lit).lower() if type(lit) is bool else repr(lit), get_args(typ), ) ) elif origin_typ is Union: return ( "(" + " | ".join( map( _get_ts_type, get_args(typ), ) ) + ")" ) elif origin_typ is list: args = get_args(typ) return _get_ts_type(args[0]) + "[]" elif origin_typ is dict: args = get_args(typ) assert len(args) == 2 return "{ [key: " + _get_ts_type(args[0]) + "]: " + _get_ts_type(args[1]) + " }" elif is_typeddict(typ): hints = get_type_hints(typ) optional_keys = getattr(typ, "__optional_keys__", []) def fmt(key): val = hints[key] optional = key in optional_keys if get_origin(val) is NotRequired: val = get_args(val)[0] ret = f"'{key}'{'?' if optional else ''}" + ": " + _get_ts_type(val) return ret ret = "{" + ", ".join(map(fmt, hints)) + "}" return ret else: # Like get_origin(), but also supports numpy.typing.NDArray[dtype]. typ = cast(Any, getattr(typ, "__origin__", typ)) assert typ in _raw_type_mapping, f"Unsupported type {typ}" return _raw_type_mapping[typ] @dataclasses.dataclass(frozen=True) class TypeScriptAnnotationOverride: """Use with `typing.Annotated[]` to override the automatically-generated TypeScript annotation corresponding to a dataclass field.""" annotation: str def generate_typescript_interfaces(message_cls: Type[Message]) -> str: """Generate TypeScript definitions for all subclasses of a base message class.""" out_lines = [] message_types = message_cls.get_subclasses() tag_map = defaultdict(list) # Generate interfaces for each specific message. for cls in message_types: if cls.__doc__ is not None: docstring = "\n * ".join( map(lambda line: line.strip(), cls.__doc__.split("\n")) ) out_lines.append(f"/** {docstring}") out_lines.append(" *") out_lines.append(" * (automatically generated)") out_lines.append(" */") for tag in getattr(cls, "_tags", []): tag_map[tag].append(cls.__name__) out_lines.append(f"export interface {cls.__name__} " + "{") out_lines.append(f' type: "{cls.__name__}";') field_names = set([f.name for f in dataclasses.fields(cls)]) # type: ignore for name, typ in get_type_hints(cls, include_extras=True).items(): if name in field_names: typ = _get_ts_type(typ) else: continue out_lines.append(f" {name}: {typ};") out_lines.append("}") out_lines.append("") # Generate union type over all messages. out_lines.append("export type Message = ") for cls in message_types: out_lines.append(f" | {cls.__name__}") out_lines[-1] = out_lines[-1] + ";" # Generate union type over all tags. for tag, cls_names in tag_map.items(): out_lines.append(f"export type {tag} = ") for cls_name in cls_names: out_lines.append(f" | {cls_name}") out_lines[-1] = out_lines[-1] + ";" interfaces = "\n".join(out_lines) + "\n" # Add header and return. return ( "\n".join( [ ( "// AUTOMATICALLY GENERATED message interfaces, from Python" " dataclass definitions." ), "// This file should not be manually modified.", "", ] ) + interfaces ) ================================================ FILE: viser/src/viser/py.typed ================================================ ================================================ FILE: viser/src/viser/scripts/__init__.py ================================================ ================================================ FILE: viser/src/viser/scripts/dev_checks.py ================================================ #!/usr/bin/env python """Runs formatting, linting, and type checking tests.""" import subprocess import sys import tyro from rich import console from rich.style import Style CONSOLE = console.Console() TYPE_TESTS = ["pyright .", "mypy ."] FORMAT_TESTS = ["ruff check --fix .", "ruff format ."] def run_command(command: str, continue_on_fail: bool = False) -> bool: """Run a command kill actions if it fails Args: command: Command to run. continue_on_fail: Whether to continue running commands if the current one fails.. """ ret_code = subprocess.call(command, shell=True) if ret_code != 0: CONSOLE.print(f"[bold red]Error: `{command}` failed.") if not continue_on_fail: sys.exit(1) return ret_code == 0 def run_code_checks( continue_on_fail: bool = False, skip_format_checks: bool = False, skip_type_checks: bool = False, ): """Runs formatting, linting, and type checking tests. Args: continue_on_fail: Whether or not to continue running actions commands if the current one fails. skip_format_checks: Whether or not to skip format tests. skip_type_checks: Whether or not to skip type tests. """ success = True assert ( not skip_format_checks or not skip_type_checks ), "Cannot skip format and type tests at the same time." tests = [] if not skip_format_checks: tests += FORMAT_TESTS if not skip_type_checks: tests += TYPE_TESTS for test in tests: CONSOLE.line() CONSOLE.rule(f"[bold green]Running: {test}") success = success and run_command(test, continue_on_fail=continue_on_fail) if success: CONSOLE.line() CONSOLE.rule(characters="=") CONSOLE.print( "[bold green]:TADA: :TADA: :TADA: ALL CHECKS PASSED :TADA: :TADA: :TADA:", justify="center", ) CONSOLE.rule(characters="=") else: CONSOLE.line() CONSOLE.rule(characters="=", style=Style(color="red")) CONSOLE.print( "[bold red]:skull: :skull: :skull: ERRORS FOUND :skull: :skull: :skull:", justify="center", ) CONSOLE.rule(characters="=", style=Style(color="red")) def entrypoint(): """Entrypoint for use with pyproject scripts.""" tyro.cli(run_code_checks) if __name__ == "__main__": entrypoint() ================================================ FILE: viser/src/viser/theme/__init__.py ================================================ """:mod:`viser.theme` provides interfaces for themeing the viser frontend from within Python. """ from ._titlebar import TitlebarButton as TitlebarButton from ._titlebar import TitlebarConfig as TitlebarConfig from ._titlebar import TitlebarImage as TitlebarImage ================================================ FILE: viser/src/viser/theme/_titlebar.py ================================================ from typing import Literal, Optional, Tuple, TypedDict class TitlebarButton(TypedDict): """A link-only button that appears in the Titlebar.""" text: Optional[str] icon: Optional[Literal["GitHub", "Description", "Keyboard"]] href: Optional[str] class TitlebarImage(TypedDict): """An image that appears on the titlebar.""" image_url_light: str image_url_dark: Optional[str] image_alt: str href: Optional[str] class TitlebarConfig(TypedDict): """Configure the content that appears in the titlebar.""" buttons: Optional[Tuple[TitlebarButton, ...]] image: Optional[TitlebarImage] ================================================ FILE: viser/src/viser/transforms/__init__.py ================================================ """Lie group interface for rigid transforms, ported from [jaxlie](https://github.com/brentyi/jaxlie). Used by `viser` internally and in examples. Implements SO(2), SO(3), SE(2), and SE(3) Lie groups. Rotations are parameterized via S^1 and S^3. """ from ._base import MatrixLieGroup as MatrixLieGroup from ._base import SEBase as SEBase from ._base import SOBase as SOBase from ._se2 import SE2 as SE2 from ._se3 import SE3 as SE3 from ._so2 import SO2 as SO2 from ._so3 import SO3 as SO3 ================================================ FILE: viser/src/viser/transforms/_base.py ================================================ import abc from typing import ClassVar, Generic, Tuple, TypeVar, Union, overload import numpy as onp import numpy.typing as onpt from typing_extensions import Self, final, get_args, override class MatrixLieGroup(abc.ABC): """Interface definition for matrix Lie groups.""" # Class properties. # > These will be set in `_utils.register_lie_group()`. matrix_dim: ClassVar[int] """Dimension of square matrix output from `.as_matrix()`.""" parameters_dim: ClassVar[int] """Dimension of underlying parameters, `.parameters()`.""" tangent_dim: ClassVar[int] """Dimension of tangent space.""" space_dim: ClassVar[int] """Dimension of coordinates that can be transformed.""" def __init__( # Notes: # - For the constructor signature to be consistent with subclasses, `parameters` # should be marked as positional-only. But this isn't possible in Python 3.7. # - This method is implicitly overriden by the dataclass decorator and # should _not_ be marked abstract. self, parameters: onp.ndarray, ): """Construct a group object from its underlying parameters.""" raise NotImplementedError() # Shared implementations. @overload def __matmul__(self, other: Self) -> Self: ... @overload def __matmul__( self, other: onpt.NDArray[onp.floating] ) -> onpt.NDArray[onp.floating]: ... def __matmul__( self, other: Union[Self, onpt.NDArray[onp.floating]] ) -> Union[Self, onpt.NDArray[onp.floating]]: """Overload for the `@` operator. Switches between the group action (`.apply()`) and multiplication (`.multiply()`) based on the type of `other`. """ if isinstance(other, onp.ndarray): return self.apply(target=other) elif isinstance(other, MatrixLieGroup): assert self.space_dim == other.space_dim return self.multiply(other=other) else: assert False, f"Invalid argument type for `@` operator: {type(other)}" # Factory. @classmethod @abc.abstractmethod def identity(cls, batch_axes: Tuple[int, ...] = ()) -> Self: """Returns identity element. Args: batch_axes: Any leading batch axes for the output transform. Returns: Identity element. """ @classmethod @abc.abstractmethod def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> Self: """Get group member from matrix representation. Args: matrix: Matrix representaiton. Returns: Group member. """ # Accessors. @abc.abstractmethod def as_matrix(self) -> onpt.NDArray[onp.floating]: """Get transformation as a matrix. Homogeneous for SE groups.""" @abc.abstractmethod def parameters(self) -> onpt.NDArray[onp.floating]: """Get underlying representation.""" # Operations. @abc.abstractmethod def apply(self, target: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]: """Applies group action to a point. Args: target: Point to transform. Returns: Transformed point. """ @abc.abstractmethod def multiply(self, other: Self) -> Self: """Composes this transformation with another. Returns: self @ other """ @classmethod @abc.abstractmethod def exp(cls, tangent: onpt.NDArray[onp.floating]) -> Self: """Computes `expm(wedge(tangent))`. Args: tangent: Tangent vector to take the exponential of. Returns: Output. """ @abc.abstractmethod def log(self) -> onpt.NDArray[onp.floating]: """Computes `vee(logm(transformation matrix))`. Returns: Output. Shape should be `(tangent_dim,)`. """ @abc.abstractmethod def adjoint(self) -> onpt.NDArray[onp.floating]: """Computes the adjoint, which transforms tangent vectors between tangent spaces. More precisely, for a transform `GroupType`: ``` GroupType @ exp(omega) = exp(Adj_T @ omega) @ GroupType ``` In robotics, typically used for transforming twists, wrenches, and Jacobians across different reference frames. Returns: Output. Shape should be `(tangent_dim, tangent_dim)`. """ @abc.abstractmethod def inverse(self) -> Self: """Computes the inverse of our transform. Returns: Output. """ @abc.abstractmethod def normalize(self) -> Self: """Normalize/projects values and returns. Returns: Normalized group member. """ # @classmethod # @abc.abstractmethod # def sample_uniform(cls, key: onp.ndarray, batch_axes: Tuple[int, ...] = ()) -> Self: # """Draw a uniform sample from the group. Translations (if applicable) are in the # range [-1, 1]. # # Args: # key: PRNG key, as returned by `jax.random.PRNGKey()`. # batch_axes: Any leading batch axes for the output transforms. Each # sampled transform will be different. # # Returns: # Sampled group member. # """ @final def get_batch_axes(self) -> Tuple[int, ...]: """Return any leading batch axes in contained parameters. If an array of shape `(100, 4)` is placed in the wxyz field of an SO3 object, for example, this will return `(100,)`.""" return self.parameters().shape[:-1] class SOBase(MatrixLieGroup): """Base class for special orthogonal groups.""" ContainedSOType = TypeVar("ContainedSOType", bound=SOBase) class SEBase(Generic[ContainedSOType], MatrixLieGroup): """Base class for special Euclidean groups. Each SE(N) group member contains an SO(N) rotation, as well as an N-dimensional translation vector. """ # SE-specific interface. @classmethod @abc.abstractmethod def from_rotation_and_translation( cls, rotation: ContainedSOType, translation: onpt.NDArray[onp.floating], ) -> Self: """Construct a rigid transform from a rotation and a translation. Args: rotation: Rotation term. translation: translation term. Returns: Constructed transformation. """ @final @classmethod def from_rotation(cls, rotation: ContainedSOType) -> Self: return cls.from_rotation_and_translation( rotation=rotation, translation=onp.zeros( (*rotation.get_batch_axes(), cls.space_dim), dtype=rotation.parameters().dtype, ), ) @final @classmethod def from_translation(cls, translation: onpt.NDArray[onp.floating]) -> Self: # Extract rotation class from type parameter. assert len(cls.__orig_bases__) == 1 # type: ignore return cls.from_rotation_and_translation( rotation=get_args(cls.__orig_bases__[0])[0].identity(), # type: ignore translation=translation, ) @abc.abstractmethod def rotation(self) -> ContainedSOType: """Returns a transform's rotation term.""" @abc.abstractmethod def translation(self) -> onpt.NDArray[onp.floating]: """Returns a transform's translation term.""" # Overrides. @final @override def apply(self, target: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]: return self.rotation() @ target + self.translation() # type: ignore @final @override def multiply(self, other: Self) -> Self: return type(self).from_rotation_and_translation( rotation=self.rotation() @ other.rotation(), translation=(self.rotation() @ other.translation()) + self.translation(), ) @final @override def inverse(self) -> Self: R_inv = self.rotation().inverse() return type(self).from_rotation_and_translation( rotation=R_inv, translation=-(R_inv @ self.translation()), ) @final @override def normalize(self) -> Self: return type(self).from_rotation_and_translation( rotation=self.rotation().normalize(), translation=self.translation(), ) ================================================ FILE: viser/src/viser/transforms/_se2.py ================================================ import dataclasses from typing import Tuple, cast import numpy as onp import numpy.typing as onpt from typing_extensions import override from . import _base, hints from ._so2 import SO2 from .utils import broadcast_leading_axes, get_epsilon, register_lie_group @register_lie_group( matrix_dim=3, parameters_dim=4, tangent_dim=3, space_dim=2, ) @dataclasses.dataclass(frozen=True) class SE2(_base.SEBase[SO2]): """Special Euclidean group for proper rigid transforms in 2D. Broadcasting rules are the same as for numpy. Ported to numpy from `jaxlie.SE2`. Internal parameterization is `(cos, sin, x, y)`. Tangent parameterization is `(vx, vy, omega)`. """ # SE2-specific. unit_complex_xy: onpt.NDArray[onp.floating] """Internal parameters. `(cos, sin, x, y)`. Shape should be `(*, 4)`.""" @override def __repr__(self) -> str: unit_complex = onp.round(self.unit_complex_xy[..., :2], 5) xy = onp.round(self.unit_complex_xy[..., 2:], 5) return f"{self.__class__.__name__}(unit_complex={unit_complex}, xy={xy})" @staticmethod def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> "SE2": """Construct a transformation from standard 2D pose parameters. This is not the same as integrating over a length-3 twist. """ cos = onp.cos(theta) sin = onp.sin(theta) return SE2(unit_complex_xy=onp.stack([cos, sin, x, y], axis=-1)) # SE-specific. @classmethod @override def from_rotation_and_translation( cls, rotation: SO2, translation: onpt.NDArray[onp.floating], ) -> "SE2": assert translation.shape[-1:] == (2,) rotation, translation = broadcast_leading_axes((rotation, translation)) return SE2( unit_complex_xy=onp.concatenate( [rotation.unit_complex, translation], axis=-1 ) ) @override def rotation(self) -> SO2: return SO2(unit_complex=self.unit_complex_xy[..., :2]) @override def translation(self) -> onpt.NDArray[onp.floating]: return self.unit_complex_xy[..., 2:] # Factory. @classmethod @override def identity(cls, batch_axes: Tuple[int, ...] = ()) -> "SE2": return SE2( unit_complex_xy=onp.broadcast_to( onp.array([1.0, 0.0, 0.0, 0.0]), (*batch_axes, 4) ) ) @classmethod @override def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> "SE2": assert matrix.shape[-2:] == (3, 3) or matrix.shape[-2:] == (2, 3) # Currently assumes bottom row is [0, 0, 1]. return SE2.from_rotation_and_translation( rotation=SO2.from_matrix(matrix[..., :2, :2]), translation=matrix[..., :2, 2], ) # Accessors. @override def parameters(self) -> onpt.NDArray[onp.floating]: return self.unit_complex_xy @override def as_matrix(self) -> onpt.NDArray[onp.floating]: cos, sin, x, y = onp.moveaxis(self.unit_complex_xy, -1, 0) out = onp.stack( [ cos, -sin, x, sin, cos, y, onp.zeros_like(x), onp.zeros_like(x), onp.ones_like(x), ], axis=-1, ).reshape((*self.get_batch_axes(), 3, 3)) return out # Operations. @classmethod @override def exp(cls, tangent: onpt.NDArray[onp.floating]) -> "SE2": # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L558 # Also see: # > http://ethaneade.com/lie.pdf assert tangent.shape[-1:] == (3,) theta = tangent[..., 2] use_taylor = onp.abs(theta) < get_epsilon(tangent.dtype) # Shim to avoid NaNs in onp.where branches, which cause failures for # reverse-mode AD in JAX. This isn't needed for vanilla numpy. safe_theta = cast( onp.ndarray, onp.where( use_taylor, onp.ones_like(theta), # Any non-zero value should do here. theta, ), ) theta_sq = theta**2 sin_over_theta = cast( onp.ndarray, onp.where( use_taylor, 1.0 - theta_sq / 6.0, onp.sin(safe_theta) / safe_theta, ), ) one_minus_cos_over_theta = cast( onp.ndarray, onp.where( use_taylor, 0.5 * theta - theta * theta_sq / 24.0, (1.0 - onp.cos(safe_theta)) / safe_theta, ), ) V = onp.stack( [ sin_over_theta, -one_minus_cos_over_theta, one_minus_cos_over_theta, sin_over_theta, ], axis=-1, ).reshape((*tangent.shape[:-1], 2, 2)) return SE2.from_rotation_and_translation( rotation=SO2.from_radians(theta), translation=onp.einsum("...ij,...j->...i", V, tangent[..., :2]), ) @override def log(self) -> onpt.NDArray[onp.floating]: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L160 # Also see: # > http://ethaneade.com/lie.pdf theta = self.rotation().log()[..., 0] cos = onp.cos(theta) cos_minus_one = cos - 1.0 half_theta = theta / 2.0 use_taylor = onp.abs(cos_minus_one) < get_epsilon(theta.dtype) # Shim to avoid NaNs in onp.where branches, which cause failures for # reverse-mode AD in JAX. This isn't needed for vanilla numpy. safe_cos_minus_one = onp.where( use_taylor, onp.ones_like(cos_minus_one), # Any non-zero value should do here. cos_minus_one, ) half_theta_over_tan_half_theta = onp.where( use_taylor, # Taylor approximation. 1.0 - theta**2 / 12.0, # Default. -(half_theta * onp.sin(theta)) / safe_cos_minus_one, ) V_inv = onp.stack( [ half_theta_over_tan_half_theta, half_theta, -half_theta, half_theta_over_tan_half_theta, ], axis=-1, ).reshape((*theta.shape, 2, 2)) tangent = onp.concatenate( [ onp.einsum("...ij,...j->...i", V_inv, self.translation()), theta[..., None], ], axis=-1, ) return tangent @override def adjoint(self: "SE2") -> onpt.NDArray[onp.floating]: cos, sin, x, y = onp.moveaxis(self.unit_complex_xy, -1, 0) return onp.stack( [ cos, -sin, y, sin, cos, -x, onp.zeros_like(x), onp.zeros_like(x), onp.ones_like(x), ], axis=-1, ).reshape((*self.get_batch_axes(), 3, 3)) # @classmethod # @override # def sample_uniform( # cls, key: onp.ndarray, batch_axes: jdc.Static[Tuple[int, ...]] = () # ) -> "SE2": # key0, key1 = jax.random.split(key) # return SE2.from_rotation_and_translation( # rotation=SO2.sample_uniform(key0, batch_axes=batch_axes), # translation=jax.random.uniform( # key=key1, # shape=( # *batch_axes, # 2, # ), # minval=-1.0, # maxval=1.0, # ), # ) ================================================ FILE: viser/src/viser/transforms/_se3.py ================================================ from __future__ import annotations import dataclasses from typing import Tuple, cast import numpy as onp import numpy.typing as onpt from typing_extensions import override from . import _base from ._so3 import SO3 from .utils import broadcast_leading_axes, get_epsilon, register_lie_group def _skew(omega: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]: """Returns the skew-symmetric form of a length-3 vector.""" wx, wy, wz = onp.moveaxis(omega, -1, 0) zeros = onp.zeros_like(wx) return onp.stack( [zeros, -wz, wy, wz, zeros, -wx, -wy, wx, zeros], axis=-1, ).reshape((*omega.shape[:-1], 3, 3)) @register_lie_group( matrix_dim=4, parameters_dim=7, tangent_dim=6, space_dim=3, ) @dataclasses.dataclass(frozen=True) class SE3(_base.SEBase[SO3]): """Special Euclidean group for proper rigid transforms in 3D. Broadcasting rules are the same as for numpy. Ported to numpy from `jaxlie.SE3`. Internal parameterization is `(qw, qx, qy, qz, x, y, z)`. Tangent parameterization is `(vx, vy, vz, omega_x, omega_y, omega_z)`. """ # SE3-specific. wxyz_xyz: onpt.NDArray[onp.floating] """Internal parameters. wxyz quaternion followed by xyz translation. Shape should be `(*, 7)`.""" @override def __repr__(self) -> str: quat = onp.round(self.wxyz_xyz[..., :4], 5) trans = onp.round(self.wxyz_xyz[..., 4:], 5) return f"{self.__class__.__name__}(wxyz={quat}, xyz={trans})" # SE-specific. @classmethod @override def from_rotation_and_translation( cls, rotation: SO3, translation: onpt.NDArray[onp.floating], ) -> SE3: assert translation.shape[-1:] == (3,) rotation, translation = broadcast_leading_axes((rotation, translation)) return SE3(wxyz_xyz=onp.concatenate([rotation.wxyz, translation], axis=-1)) @override def rotation(self) -> SO3: return SO3(wxyz=self.wxyz_xyz[..., :4]) @override def translation(self) -> onpt.NDArray[onp.floating]: return self.wxyz_xyz[..., 4:] # Factory. @classmethod @override def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SE3: return SE3( wxyz_xyz=onp.broadcast_to( onp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), (*batch_axes, 7) ) ) @classmethod @override def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> SE3: assert matrix.shape[-2:] == (4, 4) or matrix.shape[-2:] == (3, 4) # Currently assumes bottom row is [0, 0, 0, 1]. return SE3.from_rotation_and_translation( rotation=SO3.from_matrix(matrix[..., :3, :3]), translation=matrix[..., :3, 3], ) # Accessors. @override def as_matrix(self) -> onpt.NDArray[onp.floating]: out = onp.zeros((*self.get_batch_axes(), 4, 4)) out[..., :3, :3] = self.rotation().as_matrix() out[..., :3, 3] = self.translation() out[..., 3, 3] = 1.0 return out @override def parameters(self) -> onpt.NDArray[onp.floating]: return self.wxyz_xyz # Operations. @classmethod @override def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SE3: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L761 # (x, y, z, omega_x, omega_y, omega_z) assert tangent.shape[-1:] == (6,) rotation = SO3.exp(tangent[..., 3:]) theta_squared = onp.sum(onp.square(tangent[..., 3:]), axis=-1) use_taylor = theta_squared < get_epsilon(theta_squared.dtype) # Shim to avoid NaNs in onp.where branches, which cause failures for # reverse-mode AD in JAX. This isn't needed for vanilla numpy. theta_squared_safe = cast( onp.ndarray, onp.where( use_taylor, onp.ones_like(theta_squared), # Any non-zero value should do here. theta_squared, ), ) del theta_squared theta_safe = onp.sqrt(theta_squared_safe) skew_omega = _skew(tangent[..., 3:]) V = onp.where( use_taylor[..., None, None], rotation.as_matrix(), ( onp.eye(3) + ((1.0 - onp.cos(theta_safe)) / (theta_squared_safe))[..., None, None] * skew_omega + ( (theta_safe - onp.sin(theta_safe)) / (theta_squared_safe * theta_safe) )[..., None, None] * onp.einsum("...ij,...jk->...ik", skew_omega, skew_omega) ), ) return SE3.from_rotation_and_translation( rotation=rotation, translation=onp.einsum("...ij,...j->...i", V, tangent[..., :3]), ) @override def log(self) -> onpt.NDArray[onp.floating]: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L223 omega = self.rotation().log() theta_squared = onp.sum(onp.square(omega), axis=-1) use_taylor = theta_squared < get_epsilon(theta_squared.dtype) skew_omega = _skew(omega) # Shim to avoid NaNs in onp.where branches, which cause failures for # reverse-mode AD in JAX. This isn't needed for vanilla numpy. theta_squared_safe = onp.where( use_taylor, onp.ones_like(theta_squared), # Any non-zero value should do here. theta_squared, ) del theta_squared theta_safe = onp.sqrt(theta_squared_safe) half_theta_safe = theta_safe / 2.0 V_inv = onp.where( use_taylor[..., None, None], onp.eye(3) - 0.5 * skew_omega + onp.einsum("...ij,...jk->...ik", skew_omega, skew_omega) / 12.0, ( onp.eye(3) - 0.5 * skew_omega + ( ( 1.0 - theta_safe * onp.cos(half_theta_safe) / (2.0 * onp.sin(half_theta_safe)) ) / theta_squared_safe )[..., None, None] * onp.einsum("...ij,...jk->...ik", skew_omega, skew_omega) ), ) return onp.concatenate( [onp.einsum("...ij,...j->...i", V_inv, self.translation()), omega], axis=-1 ) @override def adjoint(self) -> onpt.NDArray[onp.floating]: R = self.rotation().as_matrix() return onp.concatenate( [ onp.concatenate( [R, onp.einsum("...ij,...jk->...ik", _skew(self.translation()), R)], axis=-1, ), onp.concatenate( [onp.zeros((*self.get_batch_axes(), 3, 3)), R], axis=-1 ), ], axis=-2, ) # @classmethod # @override # def sample_uniform( # cls, key: onp.ndarray, batch_axes: jdc.Static[Tuple[int, ...]] = () # ) -> SE3: # key0, key1 = jax.random.split(key) # return SE3.from_rotation_and_translation( # rotation=SO3.sample_uniform(key0, batch_axes=batch_axes), # translation=jax.random.uniform( # key=key1, shape=(*batch_axes, 3), minval=-1.0, maxval=1.0 # ), # ) ================================================ FILE: viser/src/viser/transforms/_so2.py ================================================ from __future__ import annotations import dataclasses from typing import Tuple import numpy as onp import numpy.typing as onpt from typing_extensions import override from . import _base, hints from .utils import broadcast_leading_axes, register_lie_group @register_lie_group( matrix_dim=2, parameters_dim=2, tangent_dim=1, space_dim=2, ) @dataclasses.dataclass(frozen=True) class SO2(_base.SOBase): """Special orthogonal group for 2D rotations. Broadcasting rules are the same as for numpy. Ported to numpy from `jaxlie.SO2`. Internal parameterization is `(cos, sin)`. Tangent parameterization is `(omega,)`. """ # SO2-specific. unit_complex: onpt.NDArray[onp.floating] """Internal parameters. `(cos, sin)`. Shape should be `(*, 2)`.""" @override def __repr__(self) -> str: unit_complex = onp.round(self.unit_complex, 5) return f"{self.__class__.__name__}(unit_complex={unit_complex})" @staticmethod def from_radians(theta: hints.Scalar) -> SO2: """Construct a rotation object from a scalar angle.""" cos = onp.cos(theta) sin = onp.sin(theta) return SO2(unit_complex=onp.stack([cos, sin], axis=-1)) def as_radians(self) -> onpt.NDArray[onp.floating]: """Compute a scalar angle from a rotation object.""" radians = self.log()[..., 0] return radians # Factory. @classmethod @override def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SO2: return SO2( unit_complex=onp.stack( [onp.ones(batch_axes), onp.zeros(batch_axes)], axis=-1 ) ) @classmethod @override def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> SO2: assert matrix.shape[-2:] == (2, 2) return SO2(unit_complex=onp.asarray(matrix[..., :, 0])) # Accessors. @override def as_matrix(self) -> onpt.NDArray[onp.floating]: cos_sin = self.unit_complex out = onp.stack( [ # [cos, -sin], cos_sin * onp.array([1, -1]), # [sin, cos], cos_sin[..., ::-1], ], axis=-2, ) assert out.shape == (*self.get_batch_axes(), 2, 2) return out # type: ignore @override def parameters(self) -> onpt.NDArray[onp.floating]: return self.unit_complex # Operations. @override def apply(self, target: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]: assert target.shape[-1:] == (2,) self, target = broadcast_leading_axes((self, target)) return onp.einsum("...ij,...j->...i", self.as_matrix(), target) @override def multiply(self, other: SO2) -> SO2: return SO2( unit_complex=onp.einsum( "...ij,...j->...i", self.as_matrix(), other.unit_complex ) ) @classmethod @override def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SO2: assert tangent.shape[-1] == 1 cos = onp.cos(tangent) sin = onp.sin(tangent) return SO2(unit_complex=onp.concatenate([cos, sin], axis=-1)) @override def log(self) -> onpt.NDArray[onp.floating]: return onp.arctan2( self.unit_complex[..., 1, None], self.unit_complex[..., 0, None] ) @override def adjoint(self) -> onpt.NDArray[onp.floating]: return onp.ones((*self.get_batch_axes(), 1, 1)) @override def inverse(self) -> SO2: return SO2(unit_complex=self.unit_complex * onp.array([1, -1])) @override def normalize(self) -> SO2: return SO2( unit_complex=self.unit_complex / onp.linalg.norm(self.unit_complex, axis=-1, keepdims=True) ) # @classmethod # @override # def sample_uniform( # cls, key: onp.ndarray, batch_axes: jdc.Static[Tuple[int, ...]] = () # ) -> SO2: # out = SO2.from_radians( # jax.random.uniform( # key=key, shape=batch_axes, minval=0.0, maxval=2.0 * onp.pi) # ) # assert out.get_batch_axes() == batch_axes # return out ================================================ FILE: viser/src/viser/transforms/_so3.py ================================================ from __future__ import annotations import dataclasses from typing import Tuple import numpy as onp import numpy.typing as onpt from typing_extensions import override from . import _base, hints from .utils import broadcast_leading_axes, get_epsilon, register_lie_group @dataclasses.dataclass(frozen=True) class RollPitchYaw: """Struct containing roll, pitch, and yaw Euler angles.""" roll: onpt.NDArray[onp.floating] pitch: onpt.NDArray[onp.floating] yaw: onpt.NDArray[onp.floating] @register_lie_group( matrix_dim=3, parameters_dim=4, tangent_dim=3, space_dim=3, ) @dataclasses.dataclass(frozen=True) class SO3(_base.SOBase): """Special orthogonal group for 3D rotations. Broadcasting rules are the same as for numpy. Ported to numpy from `jaxlie.SO3`. Internal parameterization is `(qw, qx, qy, qz)`. Tangent parameterization is `(omega_x, omega_y, omega_z)`. """ wxyz: onpt.NDArray[onp.floating] """Internal parameters. `(w, x, y, z)` quaternion. Shape should be `(*, 4)`.""" @override def __repr__(self) -> str: wxyz = onp.round(self.wxyz, 5) return f"{self.__class__.__name__}(wxyz={wxyz})" @staticmethod def from_x_radians(theta: hints.Scalar) -> SO3: """Generates a x-axis rotation. Args: angle: X rotation, in radians. Returns: Output. """ zeros = onp.zeros_like(theta) return SO3.exp(onp.stack([theta, zeros, zeros], axis=-1)) @staticmethod def from_y_radians(theta: hints.Scalar) -> SO3: """Generates a y-axis rotation. Args: angle: Y rotation, in radians. Returns: Output. """ zeros = onp.zeros_like(theta) return SO3.exp(onp.stack([zeros, theta, zeros], axis=-1)) @staticmethod def from_z_radians(theta: hints.Scalar) -> SO3: """Generates a z-axis rotation. Args: angle: Z rotation, in radians. Returns: Output. """ zeros = onp.zeros_like(theta) return SO3.exp(onp.stack([zeros, zeros, theta], axis=-1)) @staticmethod def from_rpy_radians( roll: hints.Scalar, pitch: hints.Scalar, yaw: hints.Scalar, ) -> SO3: """Generates a transform from a set of Euler angles. Uses the ZYX mobile robot convention. Args: roll: X rotation, in radians. Applied first. pitch: Y rotation, in radians. Applied second. yaw: Z rotation, in radians. Applied last. Returns: Output. """ return ( SO3.from_z_radians(yaw) @ SO3.from_y_radians(pitch) @ SO3.from_x_radians(roll) ) @staticmethod def from_quaternion_xyzw(xyzw: onpt.NDArray[onp.floating]) -> SO3: """Construct a rotation from an `xyzw` quaternion. Note that `wxyz` quaternions can be constructed using the default dataclass constructor. Args: xyzw: xyzw quaternion. Shape should be (*, 4). Returns: Output. """ assert xyzw.shape[-1:] == (4,) return SO3(onp.roll(xyzw, axis=-1, shift=1)) def as_quaternion_xyzw(self) -> onpt.NDArray[onp.floating]: """Grab parameters as xyzw quaternion.""" return onp.roll(self.wxyz, axis=-1, shift=-1) def as_rpy_radians(self) -> RollPitchYaw: """Computes roll, pitch, and yaw angles. Uses the ZYX mobile robot convention. Returns: Named tuple containing Euler angles in radians. """ return RollPitchYaw( roll=self.compute_roll_radians(), pitch=self.compute_pitch_radians(), yaw=self.compute_yaw_radians(), ) def compute_roll_radians(self) -> onpt.NDArray[onp.floating]: """Compute roll angle. Uses the ZYX mobile robot convention. Returns: Euler angle in radians. """ # https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion q0, q1, q2, q3 = onp.moveaxis(self.wxyz, -1, 0) return onp.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1**2 + q2**2)) def compute_pitch_radians(self) -> onpt.NDArray[onp.floating]: """Compute pitch angle. Uses the ZYX mobile robot convention. Returns: Euler angle in radians. """ # https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion q0, q1, q2, q3 = onp.moveaxis(self.wxyz, -1, 0) return onp.arcsin(2 * (q0 * q2 - q3 * q1)) def compute_yaw_radians(self) -> onpt.NDArray[onp.floating]: """Compute yaw angle. Uses the ZYX mobile robot convention. Returns: Euler angle in radians. """ # https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion q0, q1, q2, q3 = onp.moveaxis(self.wxyz, -1, 0) return onp.arctan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2**2 + q3**2)) # Factory. @classmethod @override def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SO3: return SO3( wxyz=onp.broadcast_to(onp.array([1.0, 0.0, 0.0, 0.0]), (*batch_axes, 4)) ) @classmethod @override def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> SO3: assert matrix.shape[-2:] == (3, 3) # Modified from: # > "Converting a Rotation Matrix to a Quaternion" from Mike Day # > https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2015/01/matrix-to-quat.pdf def case0(m): t = 1 + m[..., 0, 0] - m[..., 1, 1] - m[..., 2, 2] q = onp.stack( [ m[..., 2, 1] - m[..., 1, 2], t, m[..., 1, 0] + m[..., 0, 1], m[..., 0, 2] + m[..., 2, 0], ], axis=-1, ) return t, q def case1(m): t = 1 - m[..., 0, 0] + m[..., 1, 1] - m[..., 2, 2] q = onp.stack( [ m[..., 0, 2] - m[..., 2, 0], m[..., 1, 0] + m[..., 0, 1], t, m[..., 2, 1] + m[..., 1, 2], ], axis=-1, ) return t, q def case2(m): t = 1 - m[..., 0, 0] - m[..., 1, 1] + m[..., 2, 2] q = onp.stack( [ m[..., 1, 0] - m[..., 0, 1], m[..., 0, 2] + m[..., 2, 0], m[..., 2, 1] + m[..., 1, 2], t, ], axis=-1, ) return t, q def case3(m): t = 1 + m[..., 0, 0] + m[..., 1, 1] + m[..., 2, 2] q = onp.stack( [ t, m[..., 2, 1] - m[..., 1, 2], m[..., 0, 2] - m[..., 2, 0], m[..., 1, 0] - m[..., 0, 1], ], axis=-1, ) return t, q # Compute four cases, then pick the most precise one. # Probably worth revisiting this! case0_t, case0_q = case0(matrix) case1_t, case1_q = case1(matrix) case2_t, case2_q = case2(matrix) case3_t, case3_q = case3(matrix) cond0 = matrix[..., 2, 2] < 0 cond1 = matrix[..., 0, 0] > matrix[..., 1, 1] cond2 = matrix[..., 0, 0] < -matrix[..., 1, 1] t = onp.where( cond0, onp.where(cond1, case0_t, case1_t), onp.where(cond2, case2_t, case3_t), ) q = onp.where( cond0[..., None], onp.where(cond1[..., None], case0_q, case1_q), onp.where(cond2[..., None], case2_q, case3_q), ) # We can also choose to branch, but this is slower. # t, q = jax.lax.cond( # matrix[2, 2] < 0, # true_fun=lambda matrix: jax.lax.cond( # matrix[0, 0] > matrix[1, 1], # true_fun=case0, # false_fun=case1, # operand=matrix, # ), # false_fun=lambda matrix: jax.lax.cond( # matrix[0, 0] < -matrix[1, 1], # true_fun=case2, # false_fun=case3, # operand=matrix, # ), # operand=matrix, # ) return SO3(wxyz=q * 0.5 / onp.sqrt(t[..., None])) # Accessors. @override def as_matrix(self) -> onpt.NDArray[onp.floating]: norm_sq = onp.sum(onp.square(self.wxyz), axis=-1, keepdims=True) q = self.wxyz * onp.sqrt(2.0 / norm_sq) # (*, 4) q_outer = onp.einsum("...i,...j->...ij", q, q) # (*, 4, 4) return onp.stack( [ 1.0 - q_outer[..., 2, 2] - q_outer[..., 3, 3], q_outer[..., 1, 2] - q_outer[..., 3, 0], q_outer[..., 1, 3] + q_outer[..., 2, 0], q_outer[..., 1, 2] + q_outer[..., 3, 0], 1.0 - q_outer[..., 1, 1] - q_outer[..., 3, 3], q_outer[..., 2, 3] - q_outer[..., 1, 0], q_outer[..., 1, 3] - q_outer[..., 2, 0], q_outer[..., 2, 3] + q_outer[..., 1, 0], 1.0 - q_outer[..., 1, 1] - q_outer[..., 2, 2], ], axis=-1, ).reshape(*q.shape[:-1], 3, 3) @override def parameters(self) -> onpt.NDArray[onp.floating]: return self.wxyz # Operations. @override def apply(self, target: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]: assert target.shape[-1:] == (3,) self, target = broadcast_leading_axes((self, target)) # Compute using quaternion multiplys. padded_target = onp.concatenate( [onp.zeros((*self.get_batch_axes(), 1)), target], axis=-1 ) return (self @ SO3(wxyz=padded_target) @ self.inverse()).wxyz[..., 1:] @override def multiply(self, other: SO3) -> SO3: w0, x0, y0, z0 = onp.moveaxis(self.wxyz, -1, 0) w1, x1, y1, z1 = onp.moveaxis(other.wxyz, -1, 0) return SO3( wxyz=onp.stack( [ -x0 * x1 - y0 * y1 - z0 * z1 + w0 * w1, x0 * w1 + y0 * z1 - z0 * y1 + w0 * x1, -x0 * z1 + y0 * w1 + z0 * x1 + w0 * y1, x0 * y1 - y0 * x1 + z0 * w1 + w0 * z1, ], axis=-1, ) ) @classmethod @override def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SO3: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L583 assert tangent.shape[-1:] == (3,) theta_squared = onp.sum(onp.square(tangent), axis=-1) theta_pow_4 = theta_squared * theta_squared use_taylor = theta_squared < get_epsilon(tangent.dtype) # Shim to avoid NaNs in onp.where branches, which cause failures for # reverse-mode AD in JAX. This isn't needed for vanilla numpy. safe_theta = onp.sqrt( onp.where( use_taylor, onp.ones_like(theta_squared), # Any constant value should do here. theta_squared, ) ) safe_half_theta = 0.5 * safe_theta real_factor = onp.where( use_taylor, 1.0 - theta_squared / 8.0 + theta_pow_4 / 384.0, onp.cos(safe_half_theta), ) imaginary_factor = onp.where( use_taylor, 0.5 - theta_squared / 48.0 + theta_pow_4 / 3840.0, onp.sin(safe_half_theta) / safe_theta, ) return SO3( wxyz=onp.concatenate( [ real_factor[..., None], imaginary_factor[..., None] * tangent, ], axis=-1, ) ) @override def log(self) -> onpt.NDArray[onp.floating]: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L247 w = self.wxyz[..., 0] norm_sq = onp.sum(onp.square(self.wxyz[..., 1:]), axis=-1) use_taylor = norm_sq < get_epsilon(norm_sq.dtype) # Shim to avoid NaNs in onp.where branches, which cause failures for # reverse-mode AD in JAX. This isn't needed for vanilla numpy. norm_safe = onp.sqrt( onp.where( use_taylor, 1.0, # Any non-zero value should do here. norm_sq, ) ) w_safe = onp.where(use_taylor, w, 1.0) atan_n_over_w = onp.arctan2( onp.where(w < 0, -norm_safe, norm_safe), onp.abs(w), ) atan_factor = onp.where( use_taylor, 2.0 / w_safe - 2.0 / 3.0 * norm_sq / w_safe**3, onp.where( onp.abs(w) < get_epsilon(w.dtype), onp.where(w > 0, 1.0, -1.0) * onp.pi / norm_safe, 2.0 * atan_n_over_w / norm_safe, ), ) return atan_factor[..., None] * self.wxyz[..., 1:] # type: ignore @override def adjoint(self) -> onpt.NDArray[onp.floating]: return self.as_matrix() @override def inverse(self) -> SO3: # Negate complex terms. return SO3(wxyz=self.wxyz * onp.array([1, -1, -1, -1])) @override def normalize(self) -> SO3: return SO3(wxyz=self.wxyz / onp.linalg.norm(self.wxyz, axis=-1, keepdims=True)) # @classmethod # @override # def sample_uniform( # cls, key: onp.ndarray, batch_axes: jdc.Static[Tuple[int, ...]] = () # ) -> SO3: # # Uniformly sample over S^3. # # > Reference: http://planning.cs.uiuc.edu/node198.html # u1, u2, u3 = onp.moveaxis( # jax.random.uniform( # key=key, # shape=(*batch_axes, 3), # minval=onp.zeros(3), # maxval=onp.array([1.0, 2.0 * onp.pi, 2.0 * onp.pi]), # ), # -1, # 0, # ) # a = onp.sqrt(1.0 - u1) # b = onp.sqrt(u1) # # return SO3( # wxyz=onp.stack( # [ # a * onp.sin(u2), # a * onp.cos(u2), # b * onp.sin(u3), # b * onp.cos(u3), # ], # axis=-1, # ) # ) ================================================ FILE: viser/src/viser/transforms/hints/__init__.py ================================================ from typing import Union import numpy as onp import numpy.typing as onpt # Type aliases Numpy arrays; primarily for function inputs. Scalar = Union[float, onpt.NDArray[onp.floating]] """Type alias for `Union[float, Array]`.""" __all__ = [ "Scalar", ] ================================================ FILE: viser/src/viser/transforms/utils/__init__.py ================================================ from ._utils import broadcast_leading_axes, get_epsilon, register_lie_group __all__ = ["get_epsilon", "register_lie_group", "broadcast_leading_axes"] ================================================ FILE: viser/src/viser/transforms/utils/_utils.py ================================================ from typing import TYPE_CHECKING, Callable, Tuple, Type, TypeVar, Union, cast import numpy as onp if TYPE_CHECKING: from .._base import MatrixLieGroup T = TypeVar("T", bound="MatrixLieGroup") def get_epsilon(dtype: onp.dtype) -> float: """Helper for grabbing type-specific precision constants. Args: dtype: Datatype. Returns: Output float. """ if dtype == onp.float32: return 1e-5 elif dtype == onp.float64: return 1e-10 else: assert False def register_lie_group( *, matrix_dim: int, parameters_dim: int, tangent_dim: int, space_dim: int, ) -> Callable[[Type[T]], Type[T]]: """Decorator for registering Lie group dataclasses. Sets dimensionality class variables, and marks all methods for JIT compilation. """ def _wrap(cls: Type[T]) -> Type[T]: # Register dimensions as class attributes. cls.matrix_dim = matrix_dim cls.parameters_dim = parameters_dim cls.tangent_dim = tangent_dim cls.space_dim = space_dim return cls return _wrap TupleOfBroadcastable = TypeVar( "TupleOfBroadcastable", bound="Tuple[Union[MatrixLieGroup, onp.ndarray], ...]", ) def broadcast_leading_axes(inputs: TupleOfBroadcastable) -> TupleOfBroadcastable: """Broadcast leading axes of arrays. Takes tuples of either: - an array, which we assume has shape (*, D). - a Lie group object.""" from .._base import MatrixLieGroup array_inputs = [ ( (x.parameters(), (x.parameters_dim,)) if isinstance(x, MatrixLieGroup) else (x, x.shape[-1:]) ) for x in inputs ] for array, shape_suffix in array_inputs: assert array.shape[-len(shape_suffix) :] == shape_suffix batch_axes = onp.broadcast_shapes( *[array.shape[: -len(suffix)] for array, suffix in array_inputs] ) broadcasted_arrays = tuple( onp.broadcast_to(array, batch_axes + shape_suffix) for (array, shape_suffix) in array_inputs ) return cast( TupleOfBroadcastable, tuple( array if not isinstance(inp, MatrixLieGroup) else type(inp)(array) for array, inp in zip(broadcasted_arrays, inputs) ), ) ================================================ FILE: viser/sync_message_defs.py ================================================ """Generate typescript message definitions from Python dataclasses.""" import pathlib import subprocess import viser.infra from viser._messages import Message if __name__ == "__main__": # Generate typescript source. defs = viser.infra.generate_typescript_interfaces(Message) # Write to file. target_path = pathlib.Path(__file__).parent / pathlib.Path( "src/viser/client/src/WebsocketMessages.tsx" ) assert target_path.exists() target_path.write_text(defs) print(f"Wrote to {target_path}") # Run prettier. subprocess.run(args=["npx", "prettier", "-w", str(target_path)], check=False) ================================================ FILE: viser/visualize_megasam.py ================================================ import time import sys import argparse from pathlib import Path import numpy as onp import tyro from tqdm.auto import tqdm import viser import viser.extras import viser.transforms as tf import matplotlib.cm as cm # For colormap import os import cv2 import numpy as np import argparse def main( data: Path = "./demo_tmp/NULL.npz", downsample_factor: int = 1, max_frames: int = 300, share: bool = False, conf_threshold: float = 0., foreground_conf_threshold: float = 0., point_size: float = 0.001, camera_frustum_scale: float = 0.02, no_mask: bool = True, xyzw: bool = True, axes_scale: float = 0.25, bg_downsample_factor: int = 1, init_conf: bool = False, cam_thickness: float = 1.5, ) -> None: from pathlib import Path # <-- Import Path here if not already imported data = np.load(data) server = viser.ViserServer() if share: server.request_share_url() server.scene.set_up_direction('-z') if no_mask: # not using dynamic / static mask init_conf = True # must use init_conf map, to avoid depth cleaning fg_conf_thre = conf_threshold # now fg_conf_thre is the same as conf_thre print("Loading frames!") loader = viser.extras.Record3dLoader_Customized_Megasam( data, conf_threshold=conf_threshold, foreground_conf_threshold=foreground_conf_threshold, no_mask=no_mask, xyzw=xyzw, init_conf=init_conf, ) num_frames = min(max_frames, loader.num_frames()) # Add playback UI. with server.gui.add_folder("Playback"): gui_timestep = server.gui.add_slider( "Timestep", min=0, max=num_frames - 1, step=1, initial_value=0, disabled=True, ) gui_next_frame = server.gui.add_button("Next Frame", disabled=True) gui_prev_frame = server.gui.add_button("Prev Frame", disabled=True) gui_playing = server.gui.add_checkbox("Playing", True) gui_framerate = server.gui.add_slider( "FPS", min=1, max=60, step=0.1, initial_value=loader.fps ) gui_framerate_options = server.gui.add_button_group( "FPS options", ("10", "20", "30", "60") ) gui_show_all_frames = server.gui.add_checkbox("Show all frames", False) gui_stride = server.gui.add_slider( "Stride", min=1, max=num_frames, step=1, initial_value=1, disabled=True, # Initially disabled ) # Add recording UI. with server.gui.add_folder("Recording"): gui_record_scene = server.gui.add_button("Record Scene") # Frame step buttons. @gui_next_frame.on_click def _(_) -> None: gui_timestep.value = (gui_timestep.value + 1) % num_frames @gui_prev_frame.on_click def _(_) -> None: gui_timestep.value = (gui_timestep.value - 1) % num_frames # Disable frame controls when we're playing. @gui_playing.on_update def _(_) -> None: gui_timestep.disabled = gui_playing.value or gui_show_all_frames.value gui_next_frame.disabled = gui_playing.value or gui_show_all_frames.value gui_prev_frame.disabled = gui_playing.value or gui_show_all_frames.value # Toggle frame visibility when the timestep slider changes. @gui_timestep.on_update def _(_) -> None: nonlocal prev_timestep current_timestep = gui_timestep.value if not gui_show_all_frames.value: with server.atomic(): frame_nodes[current_timestep].visible = True frame_nodes[prev_timestep].visible = False prev_timestep = current_timestep server.flush() # Optional! # Show or hide all frames based on the checkbox. @gui_show_all_frames.on_update def _(_) -> None: gui_stride.disabled = not gui_show_all_frames.value # Enable/disable stride slider if gui_show_all_frames.value: # Show frames with stride stride = gui_stride.value with server.atomic(): for i, frame_node in enumerate(frame_nodes): frame_node.visible = (i % stride == 0) # Disable playback controls gui_playing.disabled = True gui_timestep.disabled = True gui_next_frame.disabled = True gui_prev_frame.disabled = True else: # Show only the current frame current_timestep = gui_timestep.value with server.atomic(): for i, frame_node in enumerate(frame_nodes): frame_node.visible = i == current_timestep # Re-enable playback controls gui_playing.disabled = False gui_timestep.disabled = gui_playing.value gui_next_frame.disabled = gui_playing.value gui_prev_frame.disabled = gui_playing.value # Update frame visibility when the stride changes. @gui_stride.on_update def _(_) -> None: if gui_show_all_frames.value: # Update frame visibility based on new stride stride = gui_stride.value with server.atomic(): for i, frame_node in enumerate(frame_nodes): frame_node.visible = (i % stride == 0) # Recording handler @gui_record_scene.on_click def _(_): gui_record_scene.disabled = True # Save the original frame visibility state original_visibility = [frame_node.visible for frame_node in frame_nodes] rec = server._start_scene_recording() rec.set_loop_start() # Determine sleep duration based on current FPS sleep_duration = 1.0 / gui_framerate.value if gui_framerate.value > 0 else 0.033 # Default to ~30 FPS if gui_show_all_frames.value: # Record all frames according to the stride stride = gui_stride.value frames_to_record = [i for i in range(num_frames) if i % stride == 0] else: # Record the frames in sequence frames_to_record = range(num_frames) for t in frames_to_record: # Update the scene to show frame t with server.atomic(): for i, frame_node in enumerate(frame_nodes): frame_node.visible = (i == t) if not gui_show_all_frames.value else (i % gui_stride.value == 0) server.flush() rec.insert_sleep(sleep_duration) # set all invisible with server.atomic(): for frame_node in frame_nodes: frame_node.visible = False # Finish recording bs = rec.end_and_serialize() # Save the recording to a file output_path = Path(f"./viser_result/recording_{str(data).split('/')[-1]}.viser") # make sure the output directory exists output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_bytes(bs) print(f"Recording saved to {output_path.resolve()}") # Restore the original frame visibility state with server.atomic(): for frame_node, visibility in zip(frame_nodes, original_visibility): frame_node.visible = visibility server.flush() gui_record_scene.disabled = False # Load in frames. server.scene.add_frame( "/frames", wxyz=tf.SO3.exp(onp.array([onp.pi / 2.0, 0.0, 0.0])).wxyz, position=(0, 0, 0), show_axes=False, ) frame_nodes: list[viser.FrameHandle] = [] bg_positions = [] bg_colors = [] for i in tqdm(range(num_frames)): frame = loader.get_frame(i) position, color, bg_position, bg_color = frame.get_point_cloud(downsample_factor, bg_downsample_factor) bg_positions.append(bg_position) bg_colors.append(bg_color) # Add base frame. frame_nodes.append(server.scene.add_frame(f"/frames/t{i}", show_axes=False)) # Place the point cloud in the frame. server.scene.add_point_cloud( name=f"/frames/t{i}/point_cloud", points=position, colors=color, point_size=point_size, point_shape="rounded", ) # Compute color for frustum based on frame index. norm_i = i / (num_frames - 1) if num_frames > 1 else 0 # Normalize index to [0, 1] color_rgba = cm.viridis(norm_i) # Get RGBA color from colormap color_rgb = color_rgba[:3] # Use RGB components # Place the frustum with the computed color. fov = 2 * onp.arctan2(frame.rgb.shape[0] / 2, frame.K[0, 0]) aspect = frame.rgb.shape[1] / frame.rgb.shape[0] server.scene.add_camera_frustum( f"/frames/t{i}/frustum", fov=fov, aspect=aspect, scale=camera_frustum_scale, image=frame.rgb[::downsample_factor, ::downsample_factor], wxyz=tf.SO3.from_matrix(frame.T_world_camera[:3, :3]).wxyz, position=frame.T_world_camera[:3, 3], color=color_rgb, # Set the color for the frustum thickness=cam_thickness, ) # Add some axes. (Commented out to hide coordinate axes) # server.scene.add_frame( # f"/frames/t{i}/frustum/axes", # axes_length=camera_frustum_scale * axes_scale * 10, # axes_radius=camera_frustum_scale * axes_scale, # ) # Initialize frame visibility. for i, frame_node in enumerate(frame_nodes): if gui_show_all_frames.value: frame_node.visible = (i % gui_stride.value == 0) else: frame_node.visible = i == gui_timestep.value # Add background frame. bg_positions = onp.concatenate(bg_positions, axis=0) bg_colors = onp.concatenate(bg_colors, axis=0) server.scene.add_point_cloud( name=f"/frames/background", points=bg_positions, colors=bg_colors, point_size=point_size, point_shape="rounded", ) # Playback update loop. prev_timestep = gui_timestep.value while True: if gui_playing.value and not gui_show_all_frames.value: gui_timestep.value = (gui_timestep.value + 1) % num_frames time.sleep(1.0 / gui_framerate.value) if __name__ == "__main__": tyro.cli(main) ================================================ FILE: viser/visualize_pose.py ================================================ #!/usr/bin/env python3 """ Camera Pose Visualization Module This module provides comprehensive tools for visualizing camera poses and trajectories in 3D space using Plotly. It supports both static and animated visualizations with automatic camera view optimization. Adapted from: https://huggingface.co/datasets/nvidia/dynpose-100k/blob/main/scripts/visualize_pose.py """ import argparse import matplotlib import matplotlib.pyplot as plt import numpy as np import os import plotly.graph_objs as go import plotly.io as pio from tqdm import tqdm import einops import torch # Use non-interactive backend for matplotlib to avoid display issues matplotlib.use("agg") class Pose: """ A class of operations on camera poses (numpy arrays with shape [...,3,4]). Each [3,4] camera pose takes the form of [R|t]. """ def __call__(self, R=None, t=None): """ Construct a camera pose from the given rotation matrix R and/or translation vector t. """ assert R is not None or t is not None if R is None: if not isinstance(t, np.ndarray): t = np.array(t) R = np.eye(3).repeat(*t.shape[:-1], 1, 1) elif t is None: if not isinstance(R, np.ndarray): R = np.array(R) t = np.zeros(R.shape[:-1]) else: if not isinstance(R, np.ndarray): R = np.array(R) if not isinstance(t, np.ndarray): t = np.array(t) assert R.shape[:-1] == t.shape and R.shape[-2:] == (3, 3) R = R.astype(np.float32) t = t.astype(np.float32) pose = np.concatenate([R, t[..., None]], axis=-1) # [...,3,4] assert pose.shape[-2:] == (3, 4) return pose def invert(self, pose, use_inverse=False): """ Invert a camera pose. """ R, t = pose[..., :3], pose[..., 3:] R_inv = np.linalg.inv(R) if use_inverse else R.transpose(0, 2, 1) t_inv = (-R_inv @ t)[..., 0] pose_inv = self(R=R_inv, t=t_inv) return pose_inv def compose(self, pose_list): """ Compose a sequence of poses together. pose_new(x) = poseN o ... o pose2 o pose1(x) """ pose_new = pose_list[0] for pose in pose_list[1:]: pose_new = self.compose_pair(pose_new, pose) return pose_new def compose_pair(self, pose_a, pose_b): """ Compose two poses together. """ R_a, t_a = pose_a[..., :3], pose_a[..., 3:] R_b, t_b = pose_b[..., :3], pose_b[..., 3:] R_new = R_b @ R_a t_new = (R_b @ t_a + t_b)[..., 0] pose_new = self(R=R_new, t=t_new) return pose_new def scale_center(self, pose, scale): """ Scale the camera center from the origin. 0 = R@c+t --> c = -R^T@t (camera center in world coordinates) 0 = R@(sc)+t' --> t' = -R@(sc) = -R@(-R^T@st) = st """ R, t = pose[..., :3], pose[..., 3:] pose_new = np.concatenate([R, t * scale], axis=-1) return pose_new def to_hom(X): """ Convert points to homogeneous coordinates by appending ones. """ X_hom = np.concatenate([X, np.ones_like(X[..., :1])], axis=-1) return X_hom def cam2world(X, pose): """ Transform points from camera coordinates to world coordinates. """ X_hom = to_hom(X) pose_inv = Pose().invert(pose) return X_hom @ pose_inv.transpose(0, 2, 1) def get_camera_mesh(pose, depth=1): """ Create a 3D mesh representation of camera frustums for visualization. """ # Define camera frustum geometry: 4 corners of image plane + camera center vertices = ( np.array( [[-0.5, -0.5, 1], [0.5, -0.5, 1], [0.5, 0.5, 1], [-0.5, 0.5, 1], [0, 0, 0]] ) * depth ) # Shape: [5, 3] - 4 image plane corners + camera center # Define triangular faces for the camera frustum mesh faces = np.array( [[0, 1, 2], [0, 2, 3], [0, 1, 4], [1, 2, 4], [2, 3, 4], [3, 0, 4]] ) # Shape: [6, 3] - 6 triangular faces forming the pyramid # Transform vertices from camera space to world space vertices = cam2world(vertices[None], pose) # Shape: [N, 5, 3] # Create wireframe lines connecting: corners -> center -> next corner wireframe = vertices[:, [0, 1, 2, 3, 0, 4, 1, 2, 4, 3]] # Shape: [N, 10, 3] return vertices, faces, wireframe # def merge_xyz_indicators_plotly(xyz): # """Merge xyz coordinate indicators for plotly visualization.""" # xyz = xyz[:, [[-1, 0], [-1, 1], [-1, 2]]] # [N,3,2,3] # xyz_0, xyz_1 = unbind_np(xyz, axis=2) # [N,3,3] # xyz_dummy = xyz_0 * np.nan # xyz_merged = np.stack([xyz_0, xyz_1, xyz_dummy], axis=2) # [N,3,3,3] # xyz_merged = xyz_merged.reshape(-1, 3) # return xyz_merged # def get_xyz_indicators(pose, length=0.1): # """Get xyz coordinate axis indicators for a camera pose.""" # xyz = np.eye(4, 3)[None] * length # xyz = cam2world(xyz, pose) # return xyz def merge_wireframes_plotly(wireframe): """ Merge camera wireframes for efficient Plotly visualization. """ wf_dummy = wireframe[:, :1] * np.nan # Create NaN separators wireframe_merged = np.concatenate([wireframe, wf_dummy], axis=1).reshape(-1, 3) return wireframe_merged def merge_meshes(vertices, faces): """ Merge multiple camera meshes into a single mesh for efficient rendering. """ mesh_N, vertex_N = vertices.shape[:2] # Adjust face indices for each mesh by adding vertex offset faces_merged = np.concatenate([faces + i * vertex_N for i in range(mesh_N)], axis=0) # Flatten all vertices into single array vertices_merged = vertices.reshape(-1, vertices.shape[-1]) return vertices_merged, faces_merged def unbind_np(array, axis=0): """ Split numpy array along specified axis into a list of arrays. """ if axis == 0: return [array[i, :] for i in range(array.shape[0])] elif axis == 1 or (len(array.shape) == 2 and axis == -1): return [array[:, j] for j in range(array.shape[1])] elif axis == 2 or (len(array.shape) == 3 and axis == -1): return [array[:, :, j] for j in range(array.shape[2])] else: raise ValueError("Invalid axis. Use 0 for rows, 1 for columns, or 2 for depth.") def plotly_visualize_pose( poses, vis_depth=0.5, xyz_length=0.5, center_size=2, xyz_width=5, mesh_opacity=0.05 ): """ Create comprehensive Plotly visualization traces for camera poses. """ N = len(poses) # Calculate camera centers in world coordinates centers_cam = np.zeros([N, 1, 3]) # Camera centers in camera space (origin) centers_world = cam2world(centers_cam, poses) # Transform to world space centers_world = centers_world[:, 0] # Remove extra dimension [N, 3] # Generate camera frustum geometry vertices, faces, wireframe = get_camera_mesh(poses, depth=vis_depth) # Merge all camera meshes into single arrays for efficient rendering vertices_merged, faces_merged = merge_meshes(vertices, faces) wireframe_merged = merge_wireframes_plotly(wireframe) # Extract x, y, z coordinates for Plotly wireframe_x, wireframe_y, wireframe_z = unbind_np(wireframe_merged, axis=-1) centers_x, centers_y, centers_z = unbind_np(centers_world, axis=-1) vertices_x, vertices_y, vertices_z = unbind_np(vertices_merged, axis=-1) # Set up rainbow color mapping for trajectory progression color_map = plt.get_cmap("gist_rainbow") # red -> yellow -> green -> blue -> purple center_color = [] faces_merged_color = [] wireframe_color = [] # Determine quarter positions for emphasis (start, 1/3, 2/3, end) quarter_indices = set([0]) # Always include start if N >= 3: quarter_indices.add(N // 3) quarter_indices.add(2 * N // 3) quarter_indices.add(N - 1) # Always include end # Apply colors with emphasis on key trajectory points for i in range(N): # Emphasize quarter positions with higher opacity and brightness is_quarter = i in quarter_indices alpha = 6.0 if is_quarter else 0.4 # Higher opacity for key points # Generate color from rainbow colormap r, g, b, _ = color_map(i / (N - 1)) rgb = np.array([r, g, b]) * (1.2 if is_quarter else 0.8) # Brighten key points rgba = np.concatenate([rgb, [alpha]]) # Apply colors to all visualization elements wireframe_color += [rgba] * 11 # 11 line segments per camera wireframe center_color += [rgba] faces_merged_color += [rgba] * 6 # 6 triangular faces per camera frustum # Create Plotly trace objects plotly_traces = [ # Camera wireframe outlines go.Scatter3d( x=wireframe_x, y=wireframe_y, z=wireframe_z, mode="lines", line=dict(color=wireframe_color, width=1), name="Camera Wireframes", ), # Camera center points go.Scatter3d( x=centers_x, y=centers_y, z=centers_z, mode="markers", marker=dict(color=center_color, size=center_size, opacity=1), name="Camera Centers", ), # Camera frustum mesh faces go.Mesh3d( x=vertices_x, y=vertices_y, z=vertices_z, i=[f[0] for f in faces_merged], j=[f[1] for f in faces_merged], k=[f[2] for f in faces_merged], facecolor=faces_merged_color, opacity=mesh_opacity, name="Camera Frustums", ), ] return plotly_traces def compute_optimal_camera_view(poses): """ Compute optimal camera view parameters to ensure the entire trajectory is visible and aesthetically pleasing. """ # Calculate all camera positions in world coordinates centers_cam = np.zeros([len(poses), 1, 3]) centers_world = cam2world(centers_cam, poses)[:, 0] # Compute bounding box of the trajectory min_coords = np.min(centers_world, axis=0) max_coords = np.max(centers_world, axis=0) ranges = max_coords - min_coords # Calculate trajectory center point trajectory_center = (min_coords + max_coords) / 2 # Calculate maximum range for adaptive scaling max_range = np.max(ranges) # Set minimum range to avoid division by zero for very small trajectories if max_range < 1e-6: max_range = 1.0 ranges = np.ones(3) # Calculate principal direction of trajectory using PCA (Principal Component Analysis) if len(centers_world) > 1: # Center the points by subtracting the mean centered_points = centers_world - trajectory_center # Compute covariance matrix for PCA cov_matrix = np.cov(centered_points.T) # Calculate eigenvalues and eigenvectors eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix) # Sort by eigenvalues in descending order idx = np.argsort(eigenvalues)[::-1] eigenvalues = eigenvalues[idx] eigenvectors = eigenvectors[:, idx] # Main direction is the first eigenvector (highest variance) main_direction = eigenvectors[:, 0] # Ensure main direction points towards trajectory's positive direction start_to_end = centers_world[-1] - centers_world[0] if np.dot(main_direction, start_to_end) < 0: main_direction = -main_direction else: # Default direction for single pose or insufficient data main_direction = np.array([1, 0, 0]) # Calculate optimal camera distance # Based on trajectory range and field of view, using smaller factor for better screen filling fov_factor = ( 0.8 # Reduced field of view factor to make trajectory occupy more screen space ) base_distance = max_range * fov_factor # Consider trajectory aspect ratio and adjust distance accordingly aspect_ratios = ranges / max_range distance_scale = 1.0 + 0.1 * np.std( aspect_ratios ) # Reduced distance adjustment magnitude camera_distance = base_distance * distance_scale # Calculate optimal camera position # Method 1: Diagonal viewing angle based on main direction up_vector = np.array([0, 0, 1]) # World up direction (Z-axis) # Adjust strategy if main direction is nearly vertical if abs(np.dot(main_direction, up_vector)) > 0.9: # Main direction is nearly vertical, use side view view_direction = np.cross(main_direction, np.array([1, 0, 0])) if np.linalg.norm(view_direction) < 0.1: view_direction = np.cross(main_direction, np.array([0, 1, 0])) view_direction = view_direction / np.linalg.norm(view_direction) else: # Calculate diagonal view direction perpendicular to main direction # Combine horizontal component of main direction with tilt angle horizontal_component = ( main_direction - np.dot(main_direction, up_vector) * up_vector ) horizontal_component = horizontal_component / ( np.linalg.norm(horizontal_component) + 1e-8 ) # Add some tilt angles for better 3D perspective elevation_angle = np.pi / 6 # 30 degrees elevation angle azimuth_offset = np.pi / 4 # 45 degrees azimuth offset # Create tilted view direction for optimal 3D perspective view_direction = ( horizontal_component * np.cos(azimuth_offset) * np.cos(elevation_angle) + np.cross(horizontal_component, up_vector) * np.sin(azimuth_offset) * np.cos(elevation_angle) + up_vector * np.sin(elevation_angle) ) # Calculate camera eye position camera_eye = trajectory_center + view_direction * camera_distance # Fine-tune camera position to ensure entire trajectory is within view # Calculate vectors from camera position to all trajectory points view_vectors = centers_world - camera_eye view_distances = np.linalg.norm(view_vectors, axis=1) # Adjust camera distance moderately if some points are too close min_distance = camera_distance * 0.3 # Reduced minimum distance ratio if np.min(view_distances) < min_distance: distance_adjustment = min_distance / np.min(view_distances) # Limit adjustment magnitude to avoid excessive scaling distance_adjustment = min( distance_adjustment, 1.2 ) # Further limit adjustment range camera_eye = ( trajectory_center + view_direction * camera_distance * distance_adjustment ) # Calculate adaptive parameters with appropriate proportions auto_vis_depth = max_range * 0.08 # Moderately reduced camera frustum size auto_center_size = max_range * 1.5 # Moderately reduced center point size # Ensure parameters are within reasonable bounds auto_vis_depth = max(0.01, min(auto_vis_depth, max_range * 0.2)) auto_center_size = max(0.1, min(auto_center_size, max_range * 2.0)) return { "camera_eye": camera_eye, "trajectory_center": trajectory_center, "auto_vis_depth": auto_vis_depth, "auto_center_size": auto_center_size, "max_range": max_range, "ranges": ranges, "main_direction": main_direction, } def compute_multiple_camera_views(poses): """ Compute multiple optimized camera view angles, providing different viewing options. """ base_params = compute_optimal_camera_view(poses) trajectory_center = base_params["trajectory_center"] max_range = base_params["max_range"] main_direction = base_params["main_direction"] # Calculate multiple view options views = {} # 1. Best automatic view (original optimal view) views["optimal"] = base_params # 2. Top-down bird's eye view top_distance = max_range * 1.5 # Further reduced top-down view distance views["top"] = { **base_params, "camera_eye": trajectory_center + np.array([0, 0, top_distance]), "description": "Top-down view", } # 3. Side view perspective side_distance = max_range * 1.3 # Further reduced side view distance side_direction = np.cross(main_direction, np.array([0, 0, 1])) if np.linalg.norm(side_direction) < 0.1: side_direction = np.array([1, 0, 0]) else: side_direction = side_direction / np.linalg.norm(side_direction) views["side"] = { **base_params, "camera_eye": trajectory_center + side_direction * side_distance, "description": "Side view", } # 4. Diagonal view (45-degree elevation) diagonal_distance = max_range * 1.4 # Further reduced diagonal view distance elevation = np.pi / 4 # 45 degrees elevation azimuth = np.pi / 4 # 45 degrees azimuth angle diagonal_direction = np.array( [ np.cos(elevation) * np.cos(azimuth), np.cos(elevation) * np.sin(azimuth), np.sin(elevation), ] ) views["diagonal"] = { **base_params, "camera_eye": trajectory_center + diagonal_direction * diagonal_distance, "description": "Diagonal view (45° elevation)", } # 5. Trajectory start-oriented view if len(poses) > 1: start_to_center = trajectory_center - base_params["camera_eye"] start_distance = max_range * 1.2 # Further reduced start view distance start_direction = start_to_center / (np.linalg.norm(start_to_center) + 1e-8) views["trajectory_start"] = { **base_params, "camera_eye": trajectory_center + start_direction * start_distance, "description": "View from trajectory start direction", } # 6. Compact view - ensure entire trajectory is fully visible fit_distance = max_range * 0.6 # Very compact distance for close-up view fit_direction = np.array([0.7, 0.7, 0.5]) # Stable viewing direction fit_direction = fit_direction / np.linalg.norm(fit_direction) views["fit_all"] = { **base_params, "camera_eye": trajectory_center + fit_direction * fit_distance, "description": "Fit all trajectory in view", } return views def add_view_selector_to_html(html_str, views): """ Add interactive view selector to HTML visualization. This function injects JavaScript code into the HTML to provide an interactive interface for switching between different camera views and enabling auto-rotation. Args: html_str: Original HTML string containing the Plotly visualization views: Dictionary of view configurations Returns: str: Enhanced HTML string with view selector and controls """ # Generate JavaScript code for view selector view_selector_js = """
""" # Add view selector to the beginning of HTML return view_selector_js + html_str def write_html(poses, file, vis_depth=1, xyz_length=0.2, center_size=0.01, xyz_width=2): """ Write camera pose visualization to HTML file with optimized camera view. """ # Calculate basic optimal view parameters base_view = compute_optimal_camera_view(poses) # Extract trajectory information trajectory_center = base_view["trajectory_center"] max_range = base_view["max_range"] ranges = base_view["ranges"] auto_vis_depth = base_view["auto_vis_depth"] auto_center_size = base_view["auto_center_size"] # Calculate optimal view to see entire trajectory # Use larger distance to ensure entire trajectory is visible with better angles optimal_distance = ( max_range * 1.8 * 10 ) # Increase distance by 10x for better overall view # Choose ideal angle that can see the full trajectory # Use combination of 45-degree elevation and azimuth for good 3D perspective elevation = np.pi / 4 # 45-degree elevation angle azimuth = np.pi / 4 # 45-degree azimuth angle # Calculate optimal viewing direction optimal_direction = np.array( [ np.cos(elevation) * np.cos(azimuth), np.cos(elevation) * np.sin(azimuth), np.sin(elevation), ] ) # Calculate optimal camera position camera_eye = trajectory_center + optimal_direction * optimal_distance # Verify view coverage - ensure all trajectory points are within reasonable distance centers_cam = np.zeros([len(poses), 1, 3]) centers_world = cam2world(centers_cam, poses)[:, 0] # Calculate distances from optimal camera position to all trajectory points distances_to_points = np.linalg.norm(centers_world - camera_eye, axis=1) max_distance_to_point = np.max(distances_to_points) min_distance_to_point = np.min(distances_to_points) # If distance variation is too large, the view might not be ideal, adjust accordingly if max_distance_to_point / min_distance_to_point > 3.0: # Recalculate more balanced distance optimal_distance = max_range * 2.2 * 10 # Further increase distance (10x) camera_eye = trajectory_center + optimal_direction * optimal_distance # Create view dictionary with only optimal view for Auto Rotate views = { "fit_all": { "camera_eye": camera_eye, "trajectory_center": trajectory_center, "auto_vis_depth": auto_vis_depth, "auto_center_size": auto_center_size, "max_range": max_range, "ranges": ranges, "description": "Optimal view to see entire trajectory", } } print(f"Trajectory ranges: x={ranges[0]:.3f}, y={ranges[1]:.3f}, z={ranges[2]:.3f}") print(f"Max range: {max_range:.3f}") print(f"Auto vis_depth: {auto_vis_depth:.3f}, center_size: {auto_center_size:.3f}") print( f"Trajectory center: ({trajectory_center[0]:.3f}, {trajectory_center[1]:.3f}, {trajectory_center[2]:.3f})" ) print( f"Optimal camera position for full trajectory view: ({camera_eye[0]:.3f}, {camera_eye[1]:.3f}, {camera_eye[2]:.3f})" ) print(f"Camera distance from trajectory center: {optimal_distance:.3f}") print( f"Distance range to trajectory points: {min_distance_to_point:.3f} - {max_distance_to_point:.3f}" ) xyz_length = xyz_length / 3 xyz_width = xyz_width vis_depth = auto_vis_depth # Use automatically computed depth center_size = auto_center_size # Use automatically computed size traces_poses = plotly_visualize_pose( poses, vis_depth=vis_depth, xyz_length=xyz_length, center_size=center_size, xyz_width=xyz_width, mesh_opacity=0.05, ) traces_all2 = traces_poses layout2 = go.Layout( scene=dict( xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False), dragmode="orbit", aspectratio=dict(x=1, y=1, z=1), aspectmode="data", # Set initial camera view to fully see the trajectory with optimized positioning camera=dict( eye=dict(x=camera_eye[0], y=camera_eye[1], z=camera_eye[2]), center=dict( x=trajectory_center[0], y=trajectory_center[1], z=trajectory_center[2], ), up=dict(x=0, y=0, z=1), ), ), height=800, width=1200, showlegend=False, ) fig2 = go.Figure(data=traces_all2, layout=layout2) html_str2 = pio.to_html(fig2, full_html=False) # Add real-time camera view display functionality camera_info_html = """

Camera Info

Eye:
x: 2.000
y: 2.000
z: 1.000

Center:
x: 0.000
y: 0.000
z: 0.000

Up:
x: 0.000
y: 0.000
z: 1.000

""" # Add view selector and camera info to HTML enhanced_html = add_view_selector_to_html(camera_info_html + html_str2, views) file.write(enhanced_html) print(f"Enhanced visualized poses are saved to {file.name}") # Removed redundant view options printing def plotly_visualize_pose_animated( poses_full, vis_depth=0.5, xyz_length=0.5, center_size=2, xyz_width=5, mesh_opacity=0.05, ): """ Create plotly visualization traces for camera poses, frame by frame for animation. Now shows the full trajectory with future poses as completely transparent. """ N_total = len(poses_full) plotly_frames = [] # Pre-compute data for all poses to ensure consistent layout centers_cam = np.zeros([N_total, 1, 3]) centers_world = cam2world(centers_cam, poses_full) centers_world = centers_world[:, 0] # Get the camera wireframes for all poses vertices, faces, wireframe = get_camera_mesh(poses_full, depth=vis_depth) vertices_merged, faces_merged = merge_meshes(vertices, faces) wireframe_merged = merge_wireframes_plotly(wireframe) # Break up (x,y,z) coordinates. wireframe_x, wireframe_y, wireframe_z = unbind_np(wireframe_merged, axis=-1) centers_x, centers_y, centers_z = unbind_np(centers_world, axis=-1) vertices_x, vertices_y, vertices_z = unbind_np(vertices_merged, axis=-1) # Initial frame showing all poses with appropriate transparency initial_data = [] for i in tqdm(range(1, N_total + 1), desc="Generating animation frames"): current_frame = i - 1 # Current frame index (0-based) # Set the color map for the camera trajectory color_map = plt.get_cmap("gist_rainbow") center_color = [] faces_merged_color = [] wireframe_color = [] for k in range(N_total): # Process all poses # Set the camera pose colors (with a smooth gradient color map). r, g, b, _ = color_map(k / (N_total - 1)) rgb = np.array([r, g, b]) * 0.8 # Set transparency based on current frame if k < current_frame: # Past poses - visible with reduced opacity # Set transparency based on temporal distance, more distant = more transparent time_distance = (current_frame - k) / max(current_frame, 1) alpha = 0.15 + 0.25 * (1 - time_distance) # Transparency range 0.15-0.4 wireframe_alpha = alpha mesh_alpha = alpha * 0.4 elif k == current_frame: # Current pose - fully visible alpha = 0.8 # Fully opaque, dark display wireframe_alpha = 0.8 mesh_alpha = 0.6 else: # Future poses - completely transparent alpha = 0.0 # Completely transparent wireframe_alpha = 0.0 mesh_alpha = 0.0 # Set colors and transparency wireframe_color += [np.concatenate([rgb, [wireframe_alpha]])] * 11 center_color += [np.concatenate([rgb, [alpha]])] faces_merged_color += [np.concatenate([rgb, [mesh_alpha]])] * 6 frame_data = [ go.Scatter3d( x=wireframe_x, y=wireframe_y, z=wireframe_z, mode="lines", line=dict(color=wireframe_color, width=1), ), go.Scatter3d( x=centers_x, y=centers_y, z=centers_z, mode="markers", marker=dict(color=center_color, size=center_size), ), go.Mesh3d( x=vertices_x, y=vertices_y, z=vertices_z, i=[f[0] for f in faces_merged], j=[f[1] for f in faces_merged], k=[f[2] for f in faces_merged], facecolor=faces_merged_color, opacity=0.6, # Set base opacity for mesh ), ] if i == 1: # Set initial data for the first frame initial_data = frame_data plotly_frames.append(go.Frame(data=frame_data, name=str(i))) return initial_data, plotly_frames def write_html_animated( poses, file, vis_depth=1, xyz_length=0.2, center_size=0.01, xyz_width=2 ): """ Write camera pose visualization with animation to HTML file with optimized camera view. """ # Calculate basic optimal view parameters base_view = compute_optimal_camera_view(poses) # Extract trajectory information trajectory_center = base_view["trajectory_center"] max_range = base_view["max_range"] ranges = base_view["ranges"] auto_vis_depth = base_view["auto_vis_depth"] auto_center_size = base_view["auto_center_size"] # Calculate optimal view to see entire trajectory # Use larger distance to ensure entire trajectory is visible with better angles optimal_distance = ( max_range * 1.8 * 10 ) # Increase distance by 10x for better overall view # Choose ideal angle that can see the full trajectory # Use combination of 45-degree elevation and azimuth for good 3D perspective elevation = np.pi / 4 # 45-degree elevation angle azimuth = np.pi / 4 # 45-degree azimuth angle # Calculate optimal viewing direction optimal_direction = np.array( [ np.cos(elevation) * np.cos(azimuth), np.cos(elevation) * np.sin(azimuth), np.sin(elevation), ] ) # Calculate optimal camera position camera_eye = trajectory_center + optimal_direction * optimal_distance # Verify view coverage - ensure all trajectory points are within reasonable distance centers_cam = np.zeros([len(poses), 1, 3]) centers_world = cam2world(centers_cam, poses)[:, 0] # Calculate distances from optimal camera position to all trajectory points distances_to_points = np.linalg.norm(centers_world - camera_eye, axis=1) max_distance_to_point = np.max(distances_to_points) min_distance_to_point = np.min(distances_to_points) # If distance variation is too large, the view might not be ideal, adjust accordingly if max_distance_to_point / min_distance_to_point > 3.0: # Recalculate more balanced distance optimal_distance = max_range * 2.2 * 10 # Further increase distance (10x) camera_eye = trajectory_center + optimal_direction * optimal_distance # Adjust parameters for animation xyz_length = xyz_length / 3 xyz_width = xyz_width vis_depth = auto_vis_depth # Use automatically computed depth center_size = auto_center_size # Use automatically computed size print( f"Animation - Trajectory ranges: x={ranges[0]:.3f}, y={ranges[1]:.3f}, z={ranges[2]:.3f}" ) print(f"Animation - Max range: {max_range:.3f}") print( f"Animation - Auto vis_depth: {auto_vis_depth:.3f}, center_size: {auto_center_size:.3f}" ) print( f"Animation - Trajectory center: ({trajectory_center[0]:.3f}, {trajectory_center[1]:.3f}, {trajectory_center[2]:.3f})" ) print( f"Animation - Optimal camera position for full trajectory view: ({camera_eye[0]:.3f}, {camera_eye[1]:.3f}, {camera_eye[2]:.3f})" ) print(f"Animation - Camera distance from trajectory center: {optimal_distance:.3f}") print( f"Animation - Distance range to trajectory points: {min_distance_to_point:.3f} - {max_distance_to_point:.3f}" ) initial_data, plotly_frames = plotly_visualize_pose_animated( poses, vis_depth=vis_depth, xyz_length=xyz_length, center_size=center_size, xyz_width=xyz_width, mesh_opacity=0.05, ) layout = go.Layout( scene=dict( xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False), dragmode="orbit", aspectratio=dict(x=1, y=1, z=1), aspectmode="data", # Use optimized camera view settings (same 10x distance as write_html) camera=dict( eye=dict(x=camera_eye[0], y=camera_eye[1], z=camera_eye[2]), center=dict( x=trajectory_center[0], y=trajectory_center[1], z=trajectory_center[2], ), up=dict(x=0, y=0, z=1), ), ), height=800, # Increased height for better animation display width=1200, # Increased width for better animation display showlegend=False, updatemenus=[ dict( type="buttons", buttons=[ dict( label="Play", method="animate", args=[ None, { "frame": {"duration": 50, "redraw": True}, "fromcurrent": True, "transition": {"duration": 0}, }, ], ) ], ) ], ) fig = go.Figure(data=initial_data, layout=layout, frames=plotly_frames) html_str = pio.to_html(fig, full_html=False) file.write(html_str) print(f"Visualized poses are saved to {file}") def quaternion_to_matrix(quaternions, eps: float = 1e-8): """ Convert 4-dimensional quaternions to 3x3 rotation matrices. Reference: https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py """ # Order changed to match scipy format: (i, j, k, r) i, j, k, r = torch.unbind(quaternions, dim=-1) two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps) # Construct rotation matrix elements using quaternion algebra o = torch.stack( ( 1 - two_s * (j * j + k * k), # R[0,0] two_s * (i * j - k * r), # R[0,1] two_s * (i * k + j * r), # R[0,2] two_s * (i * j + k * r), # R[1,0] 1 - two_s * (i * i + k * k), # R[1,1] two_s * (j * k - i * r), # R[1,2] two_s * (i * k - j * r), # R[2,0] two_s * (j * k + i * r), # R[2,1] 1 - two_s * (i * i + j * j), # R[2,2] ), -1, ) return einops.rearrange(o, "... (i j) -> ... i j", i=3, j=3) def pose_from_quaternion(pose): """ Convert quaternion-based pose representation to 4x4 transformation matrices. Reference: https://github.com/pointrix-project/Geomotion/blob/6ab0c364f1b44ab4ea190085dbf068f62b42727c/geomotion/model/cameras.py#L6 """ # Convert numpy array to torch tensor if needed if type(pose) == np.ndarray: pose = torch.tensor(pose) # Add batch dimension if input is 1D if len(pose.shape) == 1: pose = pose[None] # Extract translation and quaternion components quat_t = pose[..., :3] # Translation components [tx, ty, tz] quat_r = pose[..., 3:] # Quaternion components [qi, qj, qk, qr] # Initialize world-to-camera transformation matrix w2c_matrix = torch.zeros((*list(pose.shape)[:-1], 3, 4), device=pose.device) w2c_matrix[..., :3, 3] = quat_t # Set translation part w2c_matrix[..., :3, :3] = quaternion_to_matrix(quat_r) # Set rotation part return w2c_matrix def viz_poses(i, pth, file, args): """ Visualize camera poses for a sequence and write to HTML file. """ file.write(f"{i} {pth}
") # Load pose data from file pose = np.load(pth) # Convert quaternion poses to transformation matrices poses = pose_from_quaternion(pose) # Input: (N,7), Output: (N,3,4) w2c matrices poses = poses.cpu().numpy() # Scale camera positions to reduce distance between camera frustums for better visualization scale_factor = getattr( args, "scale_factor", 0.3 ) # Default scale factor 0.3, adjustable via command line parameter # Apply scaling to translation part (camera positions) while keeping rotation unchanged # Create scaled copy of poses poses_scaled = poses.copy() poses_scaled[..., :3, 3] = poses[..., :3, 3] * scale_factor print(f"Original poses shape: {poses.shape}") print(f"Applied scale factor: {scale_factor}") # Generate visualization based on dynamic flag if args.dynamic: write_html_animated(poses_scaled, file, vis_depth=args.vis_depth) else: write_html(poses_scaled, file, vis_depth=args.vis_depth) if __name__ == "__main__": # Set up command-line argument parser parser = argparse.ArgumentParser( description="Visualize camera poses with interactive 3D plots", formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument( "--datas", type=str, nargs="+", required=True, help="List of pose file paths (.npy format) to visualize.", ) parser.add_argument( "--vis_depth", type=float, default=0.2, help="Depth of camera frustum visualization (default: 0.2).", ) parser.add_argument( "--scale_factor", type=float, default=0.3, help="Scale factor to reduce distance between cameras - smaller values bring cameras closer together (default: 0.3).", ) parser.add_argument( "--outdir", type=str, default="./visualize", help="Output directory to save HTML visualization files (default: ./visualize).", ) parser.add_argument( "--dynamic", action="store_true", help="Create animated visualization showing camera trajectory progression over time.", ) # Parse command-line arguments args = parser.parse_args() # Create output directory and process pose files os.makedirs(args.outdir, exist_ok=True) print(f"Processing {len(args.datas)} pose file(s)...") print(f"Output directory: {args.outdir}") print(f"Visualization type: {'Animated' if args.dynamic else 'Static'}") with open(f"{args.outdir}/visualize.html", "w") as file: for i, pth in enumerate(tqdm(args.datas, desc="Processing pose files")): if not os.path.exists(pth): print(f"Warning: Path {pth} does not exist, skipping.") continue print(f"Processing: {pth} (#{i+1})") viz_poses(i, pth, file, args) print( f"Visualization complete! Open {args.outdir}/visualize.html in your browser to view results." )